Topic 11: Debug Library

This example shows a program that uses the tracing mechanism of the <debug> library to record variable values and compile time strings as well as timestamps, for inspection by the host code.

This program uses a row of four contiguous PEs. Two colors, red (color 0) and blue (color 1), are used. On all PEs, the routing associated with these colors receives from the WEST and sends down the RAMP and EAST. Additionally, for both colors, swap_color_x is set to true. Because these colors differ only in their lowest bit, when a red wavelet comes into a router from WEST, it leaves the router to the EAST as a blue wavelet, and vice versa.

The host code sends four wavelets along the color MEMCPYH2D_DATA_1 into the first PE. The WTT of MEMCPYH2D_DATA_1 forwards this data to color blue. When a PE receives a red wavelet, the task red_task is activated, and when a PE receives a blue wavelet, the task blue_task is activated.

Each PE program contains a global variable named global, initialized to zero. When a red_task is activated by an incoming wavelet in_data, global is incremented by an amount in_data. When a blue_task is activated by an incoming wavelet in_data, global is incremented by an amount 2 * in_data.

The programs running on each PE import two instances of the <debug> library. Each time a task activates, the instance named trace logs a compile time string noting the color of the task, and the updated value of global. The instance named times logs a timestamp at the beginning of a task, and at the end of a task.

The host code uses the function read_trace from cerebras.sdk.debug.debug_util to read the logged values after execution of the device code finishes. Note that the PE coordinates passed to read_trace start from the northwest corner of the fabric, not from the northwest corner of the program rectangle.

layout.csl

// color/ task ID map
//
//  ID var         ID var     ID var                ID var
//   0 red          9         18                    27 reserved (memcpy)
//   1 blue        10         19                    28 reserved (memcpy)
//   2             11         20                    29 reserved
//   3             12         21 reserved (memcpy)  30 reserved (memcpy)
//   4             13         22 reserved (memcpy)  31 reserved
//   5             14         23 reserved (memcpy)  32
//   6 H2D         15         24                    33
//   7 D2H         16         25                    34
//   8             17         26                    35
//

param width : u16;

// IDs for memcpy streaming colors
param MEMCPYH2D_DATA_1_ID: i16;
param MEMCPYD2H_DATA_1_ID: i16;

// Colors
const MEMCPYH2D_DATA_1: color = @get_color(MEMCPYH2D_DATA_1_ID);
const MEMCPYD2H_DATA_1: color = @get_color(MEMCPYD2H_DATA_1_ID);

const red:              color = @get_color(0);
const blue:             color = @get_color(1);

// Task IDs
const h2d_task_id:  data_task_id = @get_data_task_id(MEMCPYH2D_DATA_1);
const red_task_id:  data_task_id = @get_data_task_id(red);
const blue_task_id: data_task_id = @get_data_task_id(blue);

const memcpy = @import_module( "<memcpy/get_params>", .{
  .width = width,
  .height = 1,
  .MEMCPYH2D_1 = MEMCPYH2D_DATA_1,
  .MEMCPYD2H_1 = MEMCPYD2H_DATA_1
});

layout {
  @set_rectangle(width, 1);

  for (@range(u16, width)) |pe_x| {

    const memcpy_params = memcpy.get_params(pe_x);

    @set_tile_code(pe_x, 0, "pe_program.csl", .{
      .memcpy_params = memcpy_params,
      .red = red,
      .blue = blue,
      .wtt_h2d_task_id = h2d_task_id,
      .red_task_id = red_task_id,
      .blue_task_id = blue_task_id,
    });

    const routes = .{ .rx = .{ WEST }, .tx = .{ RAMP, EAST }, .color_swap_x = true };
    const end = .{ .rx = .{ WEST }, .tx = .{ RAMP }, .color_swap_x = true };
    const start = .{ .rx = .{ RAMP }, .tx = .{ RAMP, EAST }, .color_swap_x = true };

    if (pe_x == 0){
      // 1st PE receives data from streaming H2D, then forwards it to color "red"
      // (WTT(H2D) forwards data to color "blue", not color "red")
      @set_color_config(pe_x, 0, blue, .{ .routes = start });
      @set_color_config(pe_x, 0, red, .{ .routes = start });
    }else if (pe_x == width - 1) {
      @set_color_config(pe_x, 0, blue, .{ .routes = end });
      @set_color_config(pe_x, 0, red, .{ .routes = end });
    } else {
      @set_color_config(pe_x, 0, blue, .{ .routes = routes });
      @set_color_config(pe_x, 0, red, .{ .routes = routes });
    }
  }

  // export symbol name
  @export_name("buf", [*]i16, true);
}

pe_program.csl

// Not a complete program; the top-level source file is layout.csl.

param memcpy_params: comptime_struct;

//Colors
param red:  color;
param blue: color;

// Task IDs
param wtt_h2d_task_id: data_task_id; // Data task wtt_h2d triggered by MEMCPYH2D_DATA_1 wlts
param red_task_id:     data_task_id; // Data task red_task triggered by red wlts
param blue_task_id:    data_task_id; // Data task blue_task triggerd by blue wlts

const sys_mod = @import_module( "<memcpy/memcpy>", memcpy_params);

// Import two instances of <debug>:
// `trace` records comptime string and value of 'global'
// `times` records timestamps at begin and end of tasks
const trace = @import_module(
  "<debug>",
  .{ .key = "trace",
     .buffer_size = 100,
   }
);
const times = @import_module(
  "<debug>",
  .{ .key = "times",
     .buffer_size = 100,
   }
);


// Variable whose value we update in our tasks
var global : i16 = 0;

// Task that will be triggered by red wavelet
task red_task(in_data : i16) void {
  // Record timestamp for beginning of task in `times`
  times.trace_timestamp();

  // Record string denoting task color in `trace`
  trace.trace_string("Start red task");

  // Update global variable
  global += in_data;

  // Record updated value of global in `trace`
  trace.trace_i16(global);

  // Record timestamp for end of task in `times`
  times.trace_timestamp();
}

// Task that will be triggered by blue wavelet
task blue_task(in_data : i16) void {
  // Record timestamp for beginning of task in `times`
  times.trace_timestamp();

  // Record string denoting task color in `trace`
  trace.trace_string("Start blue task");

  // Update global variable
  global += in_data * 2;

  // Record updated value of global in `trace`
  trace.trace_i16(global);

  // Record timestamp for end of task in `times`
  times.trace_timestamp();
}

comptime {
  // Associate the appropriate task with the wavelet's color
  @bind_data_task(red_task, red_task_id);
  @bind_data_task(blue_task, blue_task_id);
}


var buf = @zeros([1]i16);
var ptr_buf: [*]i16 = &buf;

const bufDsd = @get_dsd(mem1d_dsd, .{.tensor_access = |i|{1} -> buf[i]});

// PEs 0, 2 activate blue task; 1, 3 activate red task.
const outDsd = @get_dsd(fabout_dsd, .{
  .extent = 1,
  .fabric_color = blue,
  .output_queue = @get_output_queue(1)
});

// receive data from streaming H2D and forward it to color red
task wtt_h2d(data: i16) void {
  @block(wtt_h2d_task_id);
  buf[0] = data;
  @mov16(outDsd, bufDsd, .{.async=true, .unblock=wtt_h2d_task_id} );
}

comptime {
  @bind_data_task(wtt_h2d, wtt_h2d_task_id);

  @export_symbol(ptr_buf, "buf");
}

run.py

#!/usr/bin/env cs_python

import argparse
import json
import numpy as np

from cerebras.sdk.debug.debug_util import debug_util
from cerebras.sdk.sdk_utils import memcpy_view, input_array_to_u32
from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime, MemcpyDataType # pylint: disable=no-name-in-module
from cerebras.sdk.runtime.sdkruntimepybind import MemcpyOrder # pylint: disable=no-name-in-module

parser = argparse.ArgumentParser()
parser.add_argument('--name', help='the test name')
parser.add_argument("--cmaddr", help="IP:port for CS system")
args = parser.parse_args()
dirname = args.name

# Parse the compile metadata
with open(f"{dirname}/out.json", encoding="utf-8") as json_file:
  compile_data = json.load(json_file)
params = compile_data["params"]
MEMCPYH2D_DATA_1 = int(params["MEMCPYH2D_DATA_1_ID"])
width = int(params["width"])
print(f"MEMCPYH2D_DATA_1 = {MEMCPYH2D_DATA_1}")
print(f"width = {width}")

memcpy_dtype = MemcpyDataType.MEMCPY_16BIT
runner = SdkRuntime(dirname, cmaddr=args.cmaddr)

sym_buf = runner.get_id("buf")

runner.load()
runner.run()

num_entries = 4
x = np.arange(num_entries, dtype=np.int16)

print("step 1: streaming H2D to 1st PE")
tensors_u32 = input_array_to_u32(x, 0, num_entries)
runner.memcpy_h2d(MEMCPYH2D_DATA_1, tensors_u32, 0, 0, 1, 1, num_entries, \
    streaming=True, data_type=memcpy_dtype, order=MemcpyOrder.COL_MAJOR, nonblock=True)

print("step 2: copy mode D2H buf (need at least one D2H)")
# The D2H buffer must be of type u32
out_tensors_u32 = np.zeros(1, np.uint32)
runner.memcpy_d2h(out_tensors_u32, sym_buf, 0, 0, 1, 1, 1, \
    streaming=False, data_type=memcpy_dtype, order=MemcpyOrder.COL_MAJOR, nonblock=False)
# remove upper 16-bit of each u32
buf_result = memcpy_view(out_tensors_u32, np.dtype(np.int16))

runner.stop()

debug_mod = debug_util(dirname, cmaddr=args.cmaddr)
core_offset_x = 4
core_offset_y = 1
print(f"=== dump core: core rectangle starts at {core_offset_x}, {core_offset_y}")

result = np.zeros([width, num_entries])
for idx in range(width):
  # Get traces recorded in 'trace'
  trace_output = debug_mod.read_trace(core_offset_x + idx, core_offset_y, 'trace')

  # Copy all recorded trace values of variable 'global'
  result[idx, :] = trace_output[1::2]

  # Get timestamp traces recorded in 'times'
  timestamp_output = debug_mod.read_trace(core_offset_x + idx, core_offset_y, 'times')

  # Print out all traces for PE
  print("PE (", idx, ", 0): ")
  print("Trace: ", trace_output)
  print("Times: ", timestamp_output)
  print()

# In order, the host streams in 0, 1, 2, 3 from the West.
# Red tasks add values to running global sum on its PE.
# Blue tasks add 2*values to running global sum on its PE.
# Value of global var is recorded after each update.
# PEs 0, 2 activate blue task; 1, 3 activate red task.
# Trace values of global var on even PEs will be: 0, 2, 6, 12
# Trace values of global var on odd PEs will be: 0, 1, 3, 6
oracle = np.empty([width, num_entries])
for i in range(width):
  for j in range(num_entries):
    oracle[i, j] = ((i+1) % 2 + 1) * j * (j+1) / 2

# Assert that all trace values of 'global' are as expected
np.testing.assert_equal(result, oracle)
print("SUCCESS!")

commands.sh

#!/usr/bin/env bash

set -e

cslc ./layout.csl --fabric-dims=11,3 \
--fabric-offsets=4,1 --params=width:4 -o out  \
--params=MEMCPYH2D_DATA_1_ID:6 \
--params=MEMCPYD2H_DATA_1_ID:7 \
--memcpy --channels=1 --width-west-buf=0 --width-east-buf=0
cs_python run.py --name out