Topic 4: Sentinels

In previous programs, we used so-called routable colors, which are associated with a route to direct the flow of wavelets. On WSE-2, task IDs which can be associated with routable colors are in the range 0 through 23. This example demonstrates the use of a non-routable control task ID to signal the end of an input tensor, thus giving it the name sentinel.

In this example, the host sends a sentinel wavelet at the end of the wavelets for the input tensor. Since sentinel control task IDs are not routable colors, the programmer should not specify a route for them, but they do need to bind the control task ID to a control task, which will be activated upon receipt of the sentinel wavelet.

Here, the sentinel activates the send_result task, which relays the result of the sum reduction back to the host.

layout.csl

// color/ task ID map
//
//  ID var          ID var     ID var               ID var               ID var
//   0 main_color    9 STARTUP 18                   27 reserved (memcpy) 36
//   1 output_color 10         19                   28 reserved (memcpy) 37
//   2 H2D_1        11 IN_1    20                   29 reserved          38
//   3 H2D_2        12 IN_2    21 reserved (memcpy) 30 reserved (memcpy) 39
//   4 D2H          13         22 reserved (memcpy) 31 reserved          40
//   5              14         23 reserved (memcpy) 32                   41
//   6              15         24                   33                   42
//   7              16         25                   34                   43 send_result_task_id
//   8              17         26                   35                   44

//  +------+---------+------+------+
//  | west |sentinal | core | east |
//  +------+---------+------+------+

//            +-------+              +-----------+
//  H2D_1 --> | west  | --> IN_1 --> | sentinel  |
//  H2D_2 --> |       | --> IN_2 --> |           |
//            +-------+              +-----------+
//
//           +---------------+                        +-------+
//  IN_1 --> | sentinel.csl  | --> OUT_1 (main_color) | core  |
//  IN_2 --> |               |                        +-------+
//           +---------------+

// IDs for memcpy streaming colors
param MEMCPYH2D_DATA_1_ID: i16;
param MEMCPYH2D_DATA_2_ID: i16;
param MEMCPYD2H_DATA_1_ID: i16;

// number of PEs in a column
param size: i16;

// Sentinel to tell PE that it is time to send the result to the host
const end_computation: u16 = 43;

// Colors
const MEMCPYH2D_DATA_1: color = @get_color(MEMCPYH2D_DATA_1_ID);
const MEMCPYH2D_DATA_2: color = @get_color(MEMCPYH2D_DATA_2_ID);
const MEMCPYD2H_DATA_1: color = @get_color(MEMCPYD2H_DATA_1_ID);

const main_color:   color = @get_color(0);
const output_color: color = @get_color(1);

const IN_1: color = @get_color(11);
const IN_2: color = @get_color(12);

// Task IDs
const STARTUP:             local_task_id   = @get_local_task_id(9);
const main_task_id:        data_task_id    = @get_data_task_id(main_color);
const send_result_task_id: control_task_id = @get_control_task_id(end_computation);
const IN_1_task_id:        data_task_id    = @get_data_task_id(IN_1);
const IN_2_task_id:        data_task_id    = @get_data_task_id(IN_2);


const memcpy = @import_module( "<memcpy/get_params>", .{
    .width = 4,
    .height = size,
    .MEMCPYH2D_1 = MEMCPYH2D_DATA_1,
    .MEMCPYH2D_2 = MEMCPYH2D_DATA_2,
    .MEMCPYD2H_1 = MEMCPYD2H_DATA_1
    });

layout {
  @set_rectangle(4, size);

  const input_route  = .{ .rx = .{ WEST }, .tx = .{ RAMP } };
  const output_route = .{ .rx = .{ RAMP }, .tx = .{ EAST } };

  var idx :i16 = 0;
  while (idx < size) {

    // west.csl has two H2Ds
    @set_tile_code(0, idx, "memcpyEdge/west.csl", .{
      .memcpy_params = memcpy.get_params(0),
      .USER_IN_1 = IN_1,
      .USER_IN_2 = IN_2,
      .STARTUP = STARTUP,
    });

    @set_tile_code(1, idx, "sentinel.csl", .{
      .memcpy_params = memcpy.get_params(1),
      .wtt_in_1_task_id = IN_1_task_id,
      .wtt_in_2_task_id = IN_2_task_id,
      .OUT_1 = main_color,
      .SENTINEL = end_computation,
    });

    @set_color_config(1, idx, IN_1,       .{ .routes = input_route });
    @set_color_config(1, idx, IN_2,       .{ .routes = input_route });
    @set_color_config(1, idx, main_color, .{ .routes = output_route });

    @set_tile_code(2, idx, "pe_program.csl", .{
      .memcpy_params = memcpy.get_params(2),
      .output_color = output_color,
      .main_task_id = main_task_id,
      .send_result_task_id = send_result_task_id
    });

    @set_color_config(2, idx, main_color,   .{ .routes = input_route });
    @set_color_config(2, idx, output_color, .{ .routes = output_route });

    // east.csl only has a D2H
    @set_tile_code(3, idx, "memcpyEdge/east.csl", .{
      .memcpy_params = memcpy.get_params(3),
      .USER_OUT_1 = output_color,
      .STARTUP = STARTUP
    });

    idx += 1;
  }
}

pe_program.csl

param memcpy_params: comptime_struct;

// Colors
param output_color:     color;

// Task IDs
param main_task_id:        data_task_id;    // data task recieves data along main_color
param send_result_task_id: control_task_id; // sentinel tells PE to send result to host

// ----------
// Every PE needs to import memcpy module otherwise the I/O cannot
// propagate the data to the destination.

// memcpy module reserves input queue 0 and output queue 0
const sys_mod = @import_module( "<memcpy/memcpy>", memcpy_params);
// ----------

var result: f16 = 0.0;

const out_dsd = @get_dsd(fabout_dsd, .{.fabric_color = output_color, .extent = 1});

task main_task(data: f16) void {
  result = result + data;
}

task send_result() void {
  @fmovh(out_dsd, result);
}

comptime {
  @bind_data_task(main_task, main_task_id);
  @bind_control_task(send_result, send_result_task_id);
}

run.py

#!/usr/bin/env cs_python

import argparse
import json
import numpy as np

from cerebras.sdk.sdk_utils import memcpy_view, input_array_to_u32
from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime, MemcpyDataType # pylint: disable=no-name-in-module
from cerebras.sdk.runtime.sdkruntimepybind import MemcpyOrder # pylint: disable=no-name-in-module

parser = argparse.ArgumentParser()
parser.add_argument('--name', help='the test name')
parser.add_argument("--cmaddr", help="IP:port for CS system")
args = parser.parse_args()
dirname = args.name

# Parse the compile metadata
with open(f"{dirname}/out.json", encoding="utf-8") as json_file:
  compile_data = json.load(json_file)
params = compile_data["params"]
MEMCPYH2D_DATA_1 = int(params["MEMCPYH2D_DATA_1_ID"])
MEMCPYH2D_DATA_2 = int(params["MEMCPYH2D_DATA_2_ID"])
MEMCPYD2H_DATA_1 = int(params["MEMCPYD2H_DATA_1_ID"])
size = int(params["size"])
print(f"MEMCPYH2D_DATA_1 = {MEMCPYH2D_DATA_1}")
print(f"MEMCPYH2D_DATA_2 = {MEMCPYH2D_DATA_2}")
print(f"MEMCPYD2H_DATA_1 = {MEMCPYD2H_DATA_1}")
print(f"size = number of PEs in a column = {size}")

# memcpy_dtype is DON'T care under streaming mode
memcpy_dtype = MemcpyDataType.MEMCPY_16BIT
runner = SdkRuntime(dirname, cmaddr=args.cmaddr)

runner.load()
runner.run()

num_wvlts = 11
print(f"num_wvlts = number of wavelets for each PE = {num_wvlts}")

print("step 1: streaming H2D_1 sends number of input wavelets to P0.0")
h2d1_u32 = np.ones(size).astype(np.uint32) * num_wvlts
runner.memcpy_h2d(MEMCPYH2D_DATA_1, h2d1_u32.ravel(), 0, 0, 1, size, 1, \
    streaming=True, data_type=memcpy_dtype, order=MemcpyOrder.ROW_MAJOR, nonblock=True)

# Use a deterministic seed so that CI results are predictable
np.random.seed(seed=7)

# Setup a {size}x11 input tensor that is reduced along the second dimension
input_tensor = np.random.rand(size, num_wvlts).astype(np.float16)
expected = np.sum(input_tensor, axis=1)

print("step 2: streaming H2D_2 to P0.0")
# "input_tensor" is a 1d array
# The type of input_tensor is float16, we need to extend it to uint32
# There are two kind of extension when using the utility function input_array_to_u32
#    input_array_to_u32(np_arr: np.ndarray, sentinel: Optional[int], fast_dim_sz: int)
# 1) zero extension:
#    sentinel = None
# 2) upper 16-bit is the index of the array:
#    sentinel is Not None
#
# In this example, the upper 16-bit is don't care because pe_program.csl only
# reads lower 16-bit
tensors_u32 = input_array_to_u32(input_tensor.ravel(), 1, num_wvlts)
runner.memcpy_h2d(MEMCPYH2D_DATA_2, tensors_u32, 0, 0, 1, size, num_wvlts, \
    streaming=True, data_type=memcpy_dtype, order=MemcpyOrder.ROW_MAJOR, nonblock=True)

print("step 3: streaming D2H at P3.0")
# The D2H buffer must be of type u32
out_tensors_u32 = np.zeros(size, np.uint32)
runner.memcpy_d2h(out_tensors_u32, MEMCPYD2H_DATA_1, 3, 0, 1, size, 1, \
    streaming=True, data_type=memcpy_dtype, order=MemcpyOrder.COL_MAJOR, nonblock=False)
# remove upper 16-bit of each u32
result_tensor = memcpy_view(out_tensors_u32, np.dtype(np.float16))

runner.stop()

# Ensure that the result matches our expectation
np.testing.assert_allclose(result_tensor, expected, atol=0.05, rtol=0)
print("SUCCESS!")

commands.sh

#!/usr/bin/env bash

set -e

cslc ./layout.csl --fabric-dims=11,12 \
--fabric-offsets=4,1 -o out \
--params=size:10 \
--params=MEMCPYH2D_DATA_1_ID:2 \
--params=MEMCPYH2D_DATA_2_ID:3 \
--params=MEMCPYD2H_DATA_1_ID:4 \
--memcpy --channels=1 --width-west-buf=0 --width-east-buf=0
cs_python run.py --name out