Topic 13: Simprint Library
Contents
Topic 13: Simprint Library¶
When running with the simulator, you can also print values directly to the
simulator logs (sim.log
).
This example modifies the previous example to show the use of the
<simprint>
library for printing comptime strings and values to the
simulator log.
Just like the previous example, this program uses a row of four contiguous PEs.
The first PE sends an array of values to three receiver PEs.
Each PE program contains a global variable named global
, initialized to
zero.
When the data task recv_task
on the receiver PE is activated by an incoming
wavelet in_data
, global
is incremented by an amount 2 * in_data
.
On the receiver PEs, each time a task activates, the program writes to
sim.log
a string denoting that the task has started, along with the value
of the wavelet received, and the updated value of global
.
The program also defines a helper function simprint_pe_coords
to print out
the coordinates of the PE to the simulator log.
The output is flushed to sim.log
whenever a newline is encountered, so you
must explicitly print "\n"
to flush the output.
After running this example, open up sim.log
to see the output.
The output from <simprint>
should look something like this:
@968 PE(0,0): sender beginning main_fn
@996 PE(0,0): sender exiting
@1156 PE(1,0): recv_task: in_data = 0, global = 0
@1158 PE(2,0): recv_task: in_data = 0, global = 0
@1160 PE(3,0): recv_task: in_data = 0, global = 0
@1338 PE(1,0): recv_task: in_data = 1, global = 2
@1340 PE(2,0): recv_task: in_data = 1, global = 2
@1342 PE(3,0): recv_task: in_data = 1, global = 2
@1520 PE(1,0): recv_task: in_data = 2, global = 6
@1522 PE(2,0): recv_task: in_data = 2, global = 6
@1524 PE(3,0): recv_task: in_data = 2, global = 6
@1702 PE(1,0): recv_task: in_data = 3, global = 12
@1704 PE(2,0): recv_task: in_data = 3, global = 12
@1706 PE(3,0): recv_task: in_data = 3, global = 12
@1884 PE(1,0): recv_task: in_data = 4, global = 20
@1886 PE(2,0): recv_task: in_data = 4, global = 20
@1888 PE(3,0): recv_task: in_data = 4, global = 20
Note that each line printed to sim.log
is prepended with the cycle at which
the print is encountered.
<simprint>
is particularly useful for debugging stalling programs.
The <debug>
library shown in the previous example requires a program to
complete to parse its output, but the <simprint>
library prints to
sim.log
whenever a newline character is encountered.
layout.csl¶
// Color map
//
// ID var ID var ID var ID var
// 0 comm 9 18 27 reserved (memcpy)
// 1 10 19 28 reserved (memcpy)
// 2 11 20 29 reserved
// 3 12 21 reserved (memcpy) 30 reserved (memcpy)
// 4 13 22 reserved (memcpy) 31 reserved
// 5 14 23 reserved (memcpy) 32
// 6 15 24 33
// 7 16 25 34
// 8 17 26 35
// See task maps in sender.csl and receiver.csl
param width: u16; // number of PEs in kernel
param num_elems: u16; // number of elements in each PE's buf
// Colors
const comm: color = @get_color(0);
const memcpy = @import_module("<memcpy/get_params>", .{
.width = width,
.height = 1,
});
layout {
@set_rectangle(width, 1);
// Sender
@set_tile_code(0, 0, "sender.csl", .{
.memcpy_params = memcpy.get_params(0),
.comm = comm, .num_elems = num_elems
});
@set_color_config(0, 0, comm, .{ .routes = .{ .rx = .{ RAMP }, .tx = .{ EAST }}});
// Receivers
for (@range(u16, 1, width, 1)) |pe_x| {
@set_tile_code(pe_x, 0, "receiver.csl", .{
.memcpy_params = memcpy.get_params(pe_x),
.comm = comm, .num_elems = num_elems
});
if (pe_x == width - 1) {
@set_color_config(pe_x, 0, comm, .{ .routes = .{ .rx = .{ WEST }, .tx = .{ RAMP }}});
} else {
@set_color_config(pe_x, 0, comm, .{ .routes = .{ .rx = .{ WEST }, .tx = .{ RAMP, EAST }}});
}
}
// export symbol name
@export_name("buf", [*]u32, true);
@export_name("main_fn", fn()void);
}
sender.csl¶
// WSE-2 task ID map
// On WSE-2, data tasks are bound to colors (IDs 0 through 24)
//
// ID var ID var ID var ID var
// 0 9 18 27 reserved (memcpy)
// 1 10 19 28 reserved (memcpy)
// 2 11 20 29 reserved
// 3 12 21 reserved (memcpy) 30 reserved (memcpy)
// 4 13 22 reserved (memcpy) 31 reserved
// 5 14 23 reserved (memcpy) 32
// 6 15 24 33
// 7 16 25 34
// 8 exit_task_id 17 26 35
// WSE-3 task ID map
// On WSE-3, data tasks are bound to input queues (IDs 0 through 7)
//
// ID var ID var ID var ID var
// 0 reserved (memcpy) 9 18 27 reserved (memcpy)
// 1 reserved (memcpy) 10 19 28 reserved (memcpy)
// 2 11 20 29 reserved
// 3 12 21 reserved (memcpy) 30 reserved (memcpy)
// 4 13 22 reserved (memcpy) 31 reserved
// 5 14 23 reserved (memcpy) 32
// 6 15 24 33
// 7 16 25 34
// 8 exit_task_id 17 26 35
param memcpy_params: comptime_struct;
const sys_mod = @import_module("<memcpy/memcpy>", memcpy_params);
const prt = @import_module("<simprint>");
// Number of elements to be send to receivers
param num_elems: u16;
// Colors
param comm: color;
// Queue IDs
const comm_oq: output_queue = @get_output_queue(2);
// Task IDs
const exit_task_id: local_task_id = @get_local_task_id(8);
// Host copies values to this array
// We then send the values to the receives
var buf = @zeros([num_elems]u32);
var ptr_buf: [*]u32 = &buf;
const buf_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{num_elems} -> buf[i] });
const out_dsd = @get_dsd(fabout_dsd, .{
.extent = 1,
.fabric_color = comm,
.output_queue = comm_oq
});
fn main_fn() void {
prt.print_string("PE(0,0): sender beginning main_fn\n");
@fmovs(out_dsd, buf_dsd, .{ .async = true, .activate = exit_task });
}
task exit_task() void {
prt.print_string("PE(0,0): sender exiting\n");
sys_mod.unblock_cmd_stream();
}
comptime {
@bind_local_task(exit_task, exit_task_id);
// On WSE-3, we must explicitly initialize input and output queues
if (@is_arch("wse3")) {
@initialize_queue(comm_oq, .{ .color = comm });
}
@export_symbol(ptr_buf, "buf");
@export_symbol(main_fn);
}
receiver.csl¶
// WSE-2 task ID map
// On WSE-2, data tasks are bound to colors (IDs 0 through 24)
//
// ID var ID var ID var ID var
// 0 recv_task_id 9 18 27 reserved (memcpy)
// 1 10 19 28 reserved (memcpy)
// 2 11 20 29 reserved
// 3 12 21 reserved (memcpy) 30 reserved (memcpy)
// 4 13 22 reserved (memcpy) 31 reserved
// 5 14 23 reserved (memcpy) 32
// 6 15 24 33
// 7 16 25 34
// 8 17 26 35
// WSE-3 task ID map
// On WSE-3, data tasks are bound to input queues (IDs 0 through 7)
//
// ID var ID var ID var ID var
// 0 reserved (memcpy) 9 18 27 reserved (memcpy)
// 1 reserved (memcpy) 10 19 28 reserved (memcpy)
// 2 recv_task_id 11 20 29 reserved
// 3 12 21 reserved (memcpy) 30 reserved (memcpy)
// 4 13 22 reserved (memcpy) 31 reserved
// 5 14 23 reserved (memcpy) 32
// 6 15 24 33
// 7 16 25 34
// 8 17 26 35
param memcpy_params: comptime_struct;
const sys_mod = @import_module("<memcpy/memcpy>", memcpy_params);
const layout_mod = @import_module("<layout>");
const prt = @import_module("<simprint>");
// Number of elements expected from sender
param num_elems: u16;
// Colors
param comm: color;
// Queue IDs
const comm_iq: input_queue = @get_input_queue(2);
const comm_oq: output_queue = @get_output_queue(2);
// Task ID for recv_task, consumed wlts with color comm
// On WSE-2, data task IDs are created from colors; on WSE-3, from input queues
// Task ID for data task that recvs from memcpy
const recv_task_id: data_task_id =
if (@is_arch("wse2")) @get_data_task_id(comm)
else if (@is_arch("wse3")) @get_data_task_id(comm_iq);
// Variable whose value we update in recv_task
var global : u32 = 0;
// Array to store received values
var buf = @zeros([num_elems]u32);
var ptr_buf: [*]u32 = &buf;
// main_fn does nothing on the senders
fn main_fn() void {}
// Track number of wavelets received by recv_task
var num_wlts_recvd: u16 = 0;
// No newline character, so these vals will not be
// printed to simlog until newline character is encountered
fn simprint_pe_coords() void {
prt.print_string("PE(");
prt.print_u16_decimal(layout_mod.get_x_coord());
prt.print_string(",");
prt.print_u16_decimal(layout_mod.get_y_coord());
prt.print_string("): ");
}
task recv_task(in_data : u32) void {
simprint_pe_coords();
prt.print_string("recv_task: in_data = ");
prt.print_u32_decimal(in_data);
buf[num_wlts_recvd] = in_data; // Store recvd value in buf
global += 2*in_data; // Increment global by 2x received value
prt.print_string(", global = ");
prt.print_u32_decimal(global);
prt.print_string("\n");
num_wlts_recvd += 1; // Increment number of received wavelets
// Once we have received all wavelets, we unblock cmd stream
if (num_wlts_recvd == num_elems) {
sys_mod.unblock_cmd_stream();
}
}
comptime {
@bind_data_task(recv_task, recv_task_id);
// On WSE-3, we must explicitly initialize input and output queues
if (@is_arch("wse3")) {
@initialize_queue(comm_iq, .{ .color = comm });
@initialize_queue(comm_oq, .{ .color = comm });
}
@export_symbol(ptr_buf, "buf");
@export_symbol(main_fn);
}
run.py¶
#!/usr/bin/env cs_python
import argparse
import json
import numpy as np
from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime, MemcpyDataType # pylint: disable=no-name-in-module
from cerebras.sdk.runtime.sdkruntimepybind import MemcpyOrder # pylint: disable=no-name-in-module
parser = argparse.ArgumentParser()
parser.add_argument('--name', help='the test name')
parser.add_argument("--cmaddr", help="IP:port for CS system")
args = parser.parse_args()
dirname = args.name
# Parse the compile metadata
with open(f"{dirname}/out.json", encoding="utf-8") as json_file:
compile_data = json.load(json_file)
params = compile_data["params"]
num_elems = int(params["num_elems"])
width = int(params["width"])
print(f"width = {width}")
print(f"num_elems = {num_elems}")
memcpy_dtype = MemcpyDataType.MEMCPY_32BIT
runner = SdkRuntime(dirname, cmaddr=args.cmaddr)
sym_buf = runner.get_id("buf")
runner.load()
runner.run()
x = np.arange(num_elems, dtype=np.uint32)
print("step 1: H2D copy buf to sender PE")
runner.memcpy_h2d(sym_buf, x, 0, 0, 1, 1, num_elems, \
streaming=False, data_type=memcpy_dtype, order=MemcpyOrder.ROW_MAJOR, nonblock=False)
print("step 2: launch main_fn")
runner.launch('main_fn', nonblock=False)
print("step 3: D2H copy buf back from all PEs")
out_buf = np.arange(width*num_elems, dtype=np.uint32)
runner.memcpy_d2h(out_buf, sym_buf, 0, 0, width, 1, num_elems, \
streaming=False, data_type=memcpy_dtype, order=MemcpyOrder.ROW_MAJOR, nonblock=False)
runner.stop()
# Receiver PEs write received value to out_buf,
# so out_buf of each PE should be same as x
# Assert that out_buf of each PE matches input array x
np.testing.assert_equal(np.tile(x, (width,1)), out_buf.reshape(width, num_elems))
print("SUCCESS!")
commands.sh¶
#!/usr/bin/env bash
set -e
cslc --arch=wse3 ./layout.csl --fabric-dims=11,3 \
--fabric-offsets=4,1 --params=width:4,num_elems:5 -o out \
--memcpy --channels=1 --width-west-buf=0 --width-east-buf=0
cs_python run.py --name out