Topic 5: Sentinels
Contents
Topic 5: Sentinels¶
In previous programs, we used so-called routable colors, which are associated with a route to direct the flow of wavelets. On WSE-2, task IDs which can receive data wavelets are in the range 0 through 23, corresponding to the IDs of the colors. On WSE-3, task IDs which can receive data wavelets are in the range 0 through 7, corresponding to input queues which are bound to a routable color. We have also used local tasks, which on WSE-2 can be associated with any task ID from 0 to 30, and on WSE-3 can be associated with any task ID from 8 to 30.
This example demonstrates the use of a non-routable control task ID to signal the end of an input tensor. We call this use for a control task ID a sentinel.
In this example, the host sends to a receiving PE (sentinel.csl
) the number
of wavelets that the receiving PE should expect to receive, followed by the
stream of data.
The receiving PE then sends the data to its neighbor (pe_program.csl
),
followed by a control wavelet which specifies the control task ID that the
neighbor will activate.
Since sentinel control task IDs are not routable colors, the programmer does
not specify a route, but does need to bind the control task ID to a control
task, which will be activated upon receipt of the sentinel wavelet.
Here, the sentinel activates the send_result
task, which relays the
result of the sum reduction back to the host.
layout.csl¶
// Color map
//
// ID var ID var ID var ID var
// 0 main_color 9 18 27 reserved (memcpy)
// 1 10 19 28 reserved (memcpy)
// 2 MEMCPYH2D_1 11 20 29 reserved
// 3 MEMCPYH2D_2 12 21 reserved (memcpy) 30 reserved (memcpy)
// 4 MEMCPYD2H_1 13 22 reserved (memcpy) 31 reserved
// 5 14 23 reserved (memcpy) 32
// 6 15 24 33
// 7 16 25 34
// 8 17 26 35
// See task maps in sentinel.csl and pe_program.csl
// +--------------+ +----------------+
// MEMCPYH2D_1 -> | sentinel.csl | -> main_color -> | pe_program.csl | -> MEMCPYD2H_1
// MEMCPYH2D_2 -> | | +----------------+
// +--------------+
// IDs for memcpy streaming colors
param MEMCPYH2D_DATA_1_ID: i16;
param MEMCPYH2D_DATA_2_ID: i16;
param MEMCPYD2H_DATA_1_ID: i16;
// number of PEs in a column
param size: i16;
// Sentinel to tell PE that it is time to send the result to the host
const end_computation: u16 = 43;
// Colors
const MEMCPYH2D_DATA_1: color = @get_color(MEMCPYH2D_DATA_1_ID);
const MEMCPYH2D_DATA_2: color = @get_color(MEMCPYH2D_DATA_2_ID);
const MEMCPYD2H_DATA_1: color = @get_color(MEMCPYD2H_DATA_1_ID);
const main_color: color = @get_color(0);
const memcpy = @import_module("<memcpy/get_params>", .{
.width = 2,
.height = size,
.MEMCPYH2D_1 = MEMCPYH2D_DATA_1,
.MEMCPYH2D_2 = MEMCPYH2D_DATA_2,
.MEMCPYD2H_1 = MEMCPYD2H_DATA_1
});
layout {
@set_rectangle(2, size);
for (@range(u16, size)) |idx| {
@set_tile_code(0, idx, "sentinel.csl", .{
.memcpy_params = memcpy.get_params(0),
.main_color = main_color,
.sentinel = end_computation,
});
@set_color_config(0, idx, main_color,.{ .routes = .{ .rx = .{ RAMP }, .tx = .{ EAST }}});
@set_tile_code(1, idx, "pe_program.csl", .{
.memcpy_params = memcpy.get_params(1),
.main_color = main_color,
.sentinel = end_computation,
});
@set_color_config(1, idx, main_color, .{ .routes = .{ .rx = .{ WEST }, .tx = .{ RAMP }}});
}
}
pe_program.csl¶
// WSE-2 task ID map
// On WSE-2, data tasks are bound to colors (IDs 0 through 24)
//
// ID var ID var ID var ID var
// 0 main_task_id 9 18 27 reserved (memcpy)
// 1 10 19 28 reserved (memcpy)
// 2 11 20 29 reserved
// 3 12 21 reserved (memcpy) 30 reserved (memcpy)
// 4 13 22 reserved (memcpy) 31 reserved
// 5 14 23 reserved (memcpy) 32
// 6 15 24 33
// 7 16 25 34
// 8 17 26 35
// ...
// 43 sentinel_task_id
// WSE-3 task ID map
// On WSE-3, data tasks are bound to input queues (IDs 0 through 7)
//
// ID var ID var ID var ID var
// 0 reserved (memcpy) 9 18 27 reserved (memcpy)
// 1 reserved (memcpy) 10 19 28 reserved (memcpy)
// 2 main_task_id 11 20 29 reserved
// 3 12 21 reserved (memcpy) 30 reserved (memcpy)
// 4 13 22 reserved (memcpy) 31 reserved
// 5 14 23 reserved (memcpy) 32
// 6 15 24 33
// 7 16 25 34
// 8 17 26 35
// ...
// 43 sentinel_task_id
param memcpy_params: comptime_struct;
const sys_mod = @import_module("<memcpy/memcpy>", memcpy_params);
// Sentinel to signal end of data
param sentinel: u16;
// Colors
param main_color: color;
// Queue IDs
const main_iq: input_queue = @get_input_queue(2);
const d2h_1_oq: output_queue = @get_output_queue(3);
// Task IDs
// On WSE-2, data task IDs are created from colors; on WSE-3, from input queues
const main_task_id: data_task_id =
if (@is_arch("wse2")) @get_data_task_id(main_color)
else if (@is_arch("wse3")) @get_data_task_id(main_iq);
const send_result_task_id: control_task_id = @get_control_task_id(sentinel);
// Accumulate all received values along main_color in result[0]
var result = @zeros([1]f32);
const result_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{1} -> result[i] });
const out_dsd = @get_dsd(fabout_dsd, .{
.fabric_color = sys_mod.MEMCPYD2H_1,
.extent = 1,
.output_queue = d2h_1_oq
});
task main_task(data: f32) void {
result[0] += data;
}
task send_result() void {
@fmovs(out_dsd, result_dsd, .{ .async = true });
}
comptime {
@bind_data_task(main_task, main_task_id);
@bind_control_task(send_result, send_result_task_id);
// On WSE-3, we must explicitly initialize input and output queues
if (@is_arch("wse3")) {
@initialize_queue(main_iq, .{ .color = main_color });
@initialize_queue(d2h_1_oq, .{ .color = sys_mod.MEMCPYD2H_1 });
}
}
run.py¶
#!/usr/bin/env cs_python
import argparse
import json
import numpy as np
from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime, MemcpyDataType # pylint: disable=no-name-in-module
from cerebras.sdk.runtime.sdkruntimepybind import MemcpyOrder # pylint: disable=no-name-in-module
parser = argparse.ArgumentParser()
parser.add_argument('--name', help='the test name')
parser.add_argument("--cmaddr", help="IP:port for CS system")
args = parser.parse_args()
dirname = args.name
# Parse the compile metadata
with open(f"{dirname}/out.json", encoding="utf-8") as json_file:
compile_data = json.load(json_file)
params = compile_data["params"]
MEMCPYH2D_DATA_1 = int(params["MEMCPYH2D_DATA_1_ID"])
MEMCPYH2D_DATA_2 = int(params["MEMCPYH2D_DATA_2_ID"])
MEMCPYD2H_DATA_1 = int(params["MEMCPYD2H_DATA_1_ID"])
size = int(params["size"])
print(f"MEMCPYH2D_DATA_1 = {MEMCPYH2D_DATA_1}")
print(f"MEMCPYH2D_DATA_2 = {MEMCPYH2D_DATA_2}")
print(f"MEMCPYD2H_DATA_1 = {MEMCPYD2H_DATA_1}")
print(f"size = {size}")
memcpy_dtype = MemcpyDataType.MEMCPY_32BIT
runner = SdkRuntime(dirname, cmaddr=args.cmaddr)
runner.load()
runner.run()
num_wvlts = 11
print(f"num_wvlts = number of wavelets for each PE = {num_wvlts}")
print("step 1: streaming H2D_1 sends number of input wavelets to P0")
h2d1_u32 = np.ones(size).astype(np.uint32) * num_wvlts
runner.memcpy_h2d(MEMCPYH2D_DATA_1, h2d1_u32.ravel(), 0, 0, 1, size, 1, \
streaming=True, data_type=memcpy_dtype, order=MemcpyOrder.ROW_MAJOR, nonblock=True)
# Use a deterministic seed so that CI results are predictable
np.random.seed(seed=7)
# Setup a {size}x11 input tensor that is reduced along the second dimension
input_tensor = np.random.rand(size, num_wvlts).astype(np.float32)
expected = np.sum(input_tensor, axis=1)
print("step 2: streaming H2D_2 to P0")
runner.memcpy_h2d(MEMCPYH2D_DATA_2, input_tensor.ravel(), 0, 0, 1, size, num_wvlts, \
streaming=True, data_type=memcpy_dtype, order=MemcpyOrder.ROW_MAJOR, nonblock=True)
print("step 3: streaming D2H at P1")
result_tensor = np.zeros(size, np.float32)
runner.memcpy_d2h(result_tensor, MEMCPYD2H_DATA_1, 1, 0, 1, size, 1, \
streaming=True, data_type=memcpy_dtype, order=MemcpyOrder.ROW_MAJOR, nonblock=False)
runner.stop()
# Ensure that the result matches our expectation
np.testing.assert_allclose(result_tensor, expected, atol=0.05, rtol=0)
print("SUCCESS!")
commands.sh¶
#!/usr/bin/env bash
set -e
cslc --arch=wse3 ./layout.csl --fabric-dims=11,12 \
--fabric-offsets=4,1 -o out \
--params=MEMCPYH2D_DATA_1_ID:2 \
--params=MEMCPYH2D_DATA_2_ID:3 \
--params=MEMCPYD2H_DATA_1_ID:4 \
--params=size:4 \
--memcpy --channels=1 --width-west-buf=0 --width-east-buf=0
cs_python run.py --name out