Topic 12: Debug Library

This example shows a program that uses the tracing mechanism of the <debug> library to record variable values and compile time strings as well as timestamps, for inspection by the host code.

The program uses a row of four contiguous PEs. The first PE sends an array of values to three receiver PEs. Each PE program contains a global variable named global, initialized to zero. When the data task recv_task on the receiver PE is activated by an incoming wavelet in_data, global is incremented by 2 * in_data.

The programs running on each PE import two instances of the <debug> library. On the receiver PEs, each time a task activates, the instance named trace logs a compile time string noting that the task has begun execution, and the updated value of global. The instance named times logs a timestamp at the beginning of a task, and at the end of a task.

The host code uses the function read_trace from cerebras.sdk.debug.debug_util to read the logged values after execution of the device code finishes. Note that the PE coordinates passed to read_trace start from the northwest corner of the fabric, not from the northwest corner of the program rectangle.

layout.csl

// Color map
//
//  ID var   ID var  ID var                ID var
//   0 comm   9      18                    27 reserved (memcpy)
//   1       10      19                    28 reserved (memcpy)
//   2       11      20                    29 reserved
//   3       12      21 reserved (memcpy)  30 reserved (memcpy)
//   4       13      22 reserved (memcpy)  31 reserved
//   5       14      23 reserved (memcpy)  32
//   6       15      24                    33
//   7       16      25                    34
//   8       17      26                    35

// See task maps in sender.csl and receiver.csl

param width: u16;     // number of PEs in kernel
param num_elems: u16; // number of elements in each PE's buf

// Colors
const comm: color = @get_color(0);

const memcpy = @import_module("<memcpy/get_params>", .{
  .width = width,
  .height = 1,
});

layout {
  @set_rectangle(width, 1);

  // Sender
  @set_tile_code(0, 0, "sender.csl", .{
    .memcpy_params = memcpy.get_params(0),
    .comm = comm, .num_elems = num_elems
  });

  @set_color_config(0, 0, comm, .{ .routes = .{ .rx = .{ RAMP }, .tx = .{ EAST }}});

  // Receivers
  for (@range(u16, 1, width, 1)) |pe_x| {

    @set_tile_code(pe_x, 0, "receiver.csl", .{
      .memcpy_params = memcpy.get_params(pe_x),
      .comm = comm, .num_elems = num_elems
    });

    if (pe_x == width - 1) {
      @set_color_config(pe_x, 0, comm, .{ .routes = .{ .rx = .{ WEST }, .tx = .{ RAMP }}});
    } else {
      @set_color_config(pe_x, 0, comm, .{ .routes = .{ .rx = .{ WEST }, .tx = .{ RAMP, EAST }}});
    }
  }

  // export symbol name
  @export_name("buf", [*]u32, true);
  @export_name("main_fn", fn()void);
}

sender.csl

// WSE-2 task ID map
// On WSE-2, data tasks are bound to colors (IDs 0 through 24)
//
//  ID var                ID var  ID var                ID var
//   0                     9      18                    27 reserved (memcpy)
//   1                    10      19                    28 reserved (memcpy)
//   2                    11      20                    29 reserved
//   3                    12      21 reserved (memcpy)  30 reserved (memcpy)
//   4                    13      22 reserved (memcpy)  31 reserved
//   5                    14      23 reserved (memcpy)  32
//   6                    15      24                    33
//   7                    16      25                    34
//   8 exit_task_id       17      26                    35

// WSE-3 task ID map
// On WSE-3, data tasks are bound to input queues (IDs 0 through 7)
//
//  ID var                ID var  ID var                ID var
//   0 reserved (memcpy)   9      18                    27 reserved (memcpy)
//   1 reserved (memcpy)  10      19                    28 reserved (memcpy)
//   2                    11      20                    29 reserved
//   3                    12      21 reserved (memcpy)  30 reserved (memcpy)
//   4                    13      22 reserved (memcpy)  31 reserved
//   5                    14      23 reserved (memcpy)  32
//   6                    15      24                    33
//   7                    16      25                    34
//   8 exit_task_id       17      26                    35

param memcpy_params: comptime_struct;

const sys_mod = @import_module("<memcpy/memcpy>", memcpy_params);

// Number of elements to be send to receivers
param num_elems: u16;

// Colors
param comm: color;

// Queue IDs
const comm_oq: output_queue = @get_output_queue(2);

// Task IDs
const exit_task_id: local_task_id = @get_local_task_id(8);

const trace = @import_module(
  "<debug>",
  .{ .key = "trace",
     .buffer_size = 100,
   }
);
const times = @import_module(
  "<debug>",
  .{ .key = "times",
     .buffer_size = 100,
   }
);

// Host copies values to this array
// We then send the values to the receives
var buf = @zeros([num_elems]u32);
var ptr_buf: [*]u32 = &buf;

const buf_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{num_elems} -> buf[i] });

const out_dsd = @get_dsd(fabout_dsd, .{
  .extent = 1,
  .fabric_color = comm,
  .output_queue = comm_oq
});

fn main_fn() void {
  trace.trace_string("Sender beginning main_fn");
  times.trace_timestamp(); // Record timestamp for main_fn start
  @fmovs(out_dsd, buf_dsd, .{ .async = true, .activate = exit_task });
}

task exit_task() void {
  trace.trace_string("Sender exiting");
  times.trace_timestamp(); // Record timestamp for exit
  sys_mod.unblock_cmd_stream();
}

comptime {
  @bind_local_task(exit_task, exit_task_id);

  // On WSE-3, we must explicitly initialize input and output queues
  if (@is_arch("wse3")) {
    @initialize_queue(comm_oq, .{ .color = comm });
  }

  @export_symbol(ptr_buf, "buf");
  @export_symbol(main_fn);
}

receiver.csl

// WSE-2 task ID map
// On WSE-2, data tasks are bound to colors (IDs 0 through 24)
//
//  ID var                ID var  ID var                ID var
//   0 recv_task_id        9      18                    27 reserved (memcpy)
//   1                    10      19                    28 reserved (memcpy)
//   2                    11      20                    29 reserved
//   3                    12      21 reserved (memcpy)  30 reserved (memcpy)
//   4                    13      22 reserved (memcpy)  31 reserved
//   5                    14      23 reserved (memcpy)  32
//   6                    15      24                    33
//   7                    16      25                    34
//   8                    17      26                    35

// WSE-3 task ID map
// On WSE-3, data tasks are bound to input queues (IDs 0 through 7)
//
//  ID var                ID var  ID var                ID var
//   0 reserved (memcpy)   9      18                    27 reserved (memcpy)
//   1 reserved (memcpy)  10      19                    28 reserved (memcpy)
//   2 recv_task_id       11      20                    29 reserved
//   3                    12      21 reserved (memcpy)  30 reserved (memcpy)
//   4                    13      22 reserved (memcpy)  31 reserved
//   5                    14      23 reserved (memcpy)  32
//   6                    15      24                    33
//   7                    16      25                    34
//   8                    17      26                    35

param memcpy_params: comptime_struct;

const sys_mod = @import_module("<memcpy/memcpy>", memcpy_params);

// Number of elements expected from sender
param num_elems: u16;

// Colors
param comm: color;

// Queue IDs
const comm_iq: input_queue = @get_input_queue(2);
const comm_oq: output_queue = @get_output_queue(2);

// Task ID for recv_task, consumed wlts with color comm
// On WSE-2, data task IDs are created from colors; on WSE-3, from input queues
// Task ID for data task that recvs from memcpy
const recv_task_id: data_task_id =
  if      (@is_arch("wse2")) @get_data_task_id(comm)
  else if (@is_arch("wse3")) @get_data_task_id(comm_iq);

// Import two instances of <debug>:
// `trace` records comptime string and value of 'global'
// `times` records timestamps at begin and end of tasks
const trace = @import_module(
  "<debug>",
  .{ .key = "trace",
     .buffer_size = 100,
   }
);
const times = @import_module(
  "<debug>",
  .{ .key = "times",
     .buffer_size = 100,
   }
);

// Variable whose value we update in recv_task
var global : u32 = 0;

// Array to store received values
var buf = @zeros([num_elems]u32);
var ptr_buf: [*]u32 = &buf;

// main_fn does nothing on the senders
fn main_fn() void {}

// Track number of wavelets received by recv_task
var num_wlts_recvd: u16 = 0;

task recv_task(in_data : u32) void {

  times.trace_timestamp(); // Record timestamp for task start in `times`
  trace.trace_string("Start recv_task"); // Record string in `trace`

  buf[num_wlts_recvd] = in_data; // Store recvd value in buf
  global += 2*in_data; // Increment global by 2x received value

  trace.trace_u32(global); // Record updated value of global in `trace`

  num_wlts_recvd += 1; // Increment number of received wavelets
  // Once we have received all wavelets, we unblock cmd stream
  if (num_wlts_recvd == num_elems) {
    sys_mod.unblock_cmd_stream();
  }

  times.trace_timestamp(); // Record timestamp for task end in `times`
}

comptime {
  @bind_data_task(recv_task, recv_task_id);

  // On WSE-3, we must explicitly initialize input and output queues
  if (@is_arch("wse3")) {
    @initialize_queue(comm_iq, .{ .color = comm });
    @initialize_queue(comm_oq, .{ .color = comm });
  }

  @export_symbol(ptr_buf, "buf");
  @export_symbol(main_fn);
}

run.py

#!/usr/bin/env cs_python

import argparse
import json
import numpy as np

from cerebras.sdk.debug.debug_util import debug_util
from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime, MemcpyDataType # pylint: disable=no-name-in-module
from cerebras.sdk.runtime.sdkruntimepybind import MemcpyOrder # pylint: disable=no-name-in-module

parser = argparse.ArgumentParser()
parser.add_argument('--name', help='the test name')
parser.add_argument("--cmaddr", help="IP:port for CS system")
args = parser.parse_args()
dirname = args.name

# Parse the compile metadata
with open(f"{dirname}/out.json", encoding="utf-8") as json_file:
  compile_data = json.load(json_file)
params = compile_data["params"]
num_elems = int(params["num_elems"])
width = int(params["width"])
print(f"width = {width}")
print(f"num_elems = {num_elems}")

memcpy_dtype = MemcpyDataType.MEMCPY_32BIT
runner = SdkRuntime(dirname, cmaddr=args.cmaddr)

sym_buf = runner.get_id("buf")

runner.load()
runner.run()

x = np.arange(num_elems, dtype=np.uint32)

print("step 1: H2D copy buf to sender PE")
runner.memcpy_h2d(sym_buf, x, 0, 0, 1, 1, num_elems, \
    streaming=False, data_type=memcpy_dtype, order=MemcpyOrder.ROW_MAJOR, nonblock=False)

print("step 2: launch main_fn")
runner.launch('main_fn', nonblock=False)

print("step 3: D2H copy buf back from all PEs")
out_buf = np.arange(width*num_elems, dtype=np.uint32)
runner.memcpy_d2h(out_buf, sym_buf, 0, 0, width, 1, num_elems, \
    streaming=False, data_type=memcpy_dtype, order=MemcpyOrder.ROW_MAJOR, nonblock=False)

runner.stop()

debug_mod = debug_util(dirname, cmaddr=args.cmaddr)
core_offset_x = 4
core_offset_y = 1
print(f"=== dump core: core rectangle starts at {core_offset_x}, {core_offset_y}")

result = np.zeros([width, num_elems])
for idx in range(width):
  # Get traces recorded in 'trace'
  trace_output = debug_mod.read_trace(core_offset_x + idx, core_offset_y, 'trace')

  # On receiver PEs, record value of 'global'
  if idx > 0:
    result[idx, :] = trace_output[1::2]

  # Get timestamp traces recorded in 'times'
  timestamp_output = debug_mod.read_trace(core_offset_x + idx, core_offset_y, 'times')

  # Print out all traces for PE
  print("PE (", idx, ", 0): ")
  print("Trace: ", trace_output)
  print("Times: ", timestamp_output)
  print()

# Receiver PEs adds 2*value to running global sum on its PE.
# Value of global var is recorded after each update.
# Trace values of global var will be: 0, 2, 6, 12, 20
oracle = np.zeros([width, num_elems])
for i in range(1, width, 1):
  for j in range(num_elems):
    oracle[i, j] = 2 * j * (j+1) / 2

# Assert that all trace values of 'global' are as expected
np.testing.assert_equal(result, oracle)
print("SUCCESS!")

commands.sh

#!/usr/bin/env bash

set -e

cslc --arch=wse2 ./layout.csl --fabric-dims=11,3 \
--fabric-offsets=4,1 --params=width:4,num_elems:5 -o out  \
--memcpy --channels=1 --width-west-buf=0 --width-east-buf=0
cs_python run.py --name out