Topic 8: FIFOs

A FIFO DSD is useful to buffer input going into or out of a PE, as a way to extend the small hardware queues used for fabric communication. In particular, this may prevent stalls in the communication fabric when input or output happens in bursts. It is also possible to operate on the values while they flow through the FIFO, as this code sample demonstrates.

This example illustrates a typical pattern in the use of FIFOs, where a receiver receives wavelets from the fabric and forwards them to a task that performs some computation. Specifically, incoming data from the host is stored in the FIFO, thus relieving the sender from being blocked until the receiver has received all wavelets. While the incoming wavelets are being asynchronously received into the FIFO buffer, we also start a second asynchronous DSD operation that pulls data from the FIFO and forwards it to a wavelet-triggered task.

This example also illustrates another common pattern, where a PE starts a wavelet-triggered task using its own wavelets, by sending them to the router which immediately sends them back to the compute element. In our example, this wavelet-triggered task simply computes the cube of the wavelet’s data, before sending the result to the host.

layout.csl

// color/ task ID map
//
//  ID var           ID var      ID var                ID var
//   0 in_color       9 STARTUP  18                    27 reserved (memcpy)
//   1 out_color     10          19                    28 reserved (memcpy)
//   2               11          20                    29 reserved
//   3 result_color  12          21 reserved (memcpy)  30 reserved (memcpy)
//   4 H2D           13          22 reserved (memcpy)  31 reserved
//   5 D2H           14          23 reserved (memcpy)  32
//   6               15          24                    33
//   7               16          25                    34
//   8 main_task_id  17          26                    35
//

//  +------+------+------+
//  | west | core | east |
//  +------+------+------+

// IDs for memcpy streaming colors
param MEMCPYH2D_DATA_1_ID: i16;
param MEMCPYD2H_DATA_1_ID: i16;

param num_elements_to_process: i16;

// Colors
const MEMCPYH2D_DATA_1: color = @get_color(MEMCPYH2D_DATA_1_ID);
const MEMCPYD2H_DATA_1: color = @get_color(MEMCPYD2H_DATA_1_ID);
const in_color:         color = @get_color(0);
const out_color:        color = @get_color(1);
const result_color:     color = @get_color(3);

// Task IDs
const main_task_id:    local_task_id = @get_local_task_id(8);
const STARTUP:         local_task_id = @get_local_task_id(9);
const process_task_id: data_task_id  = @get_data_task_id(out_color);

const memcpy = @import_module( "<memcpy/get_params>", .{
    .width = 3,
    .height = 1,
    .MEMCPYH2D_1 = MEMCPYH2D_DATA_1,
    .MEMCPYD2H_1 = MEMCPYD2H_DATA_1
    });


layout {
  @set_rectangle(3,1);

  // west.csl has a H2D
  const memcpy_params_0 = memcpy.get_params(0);
  @set_tile_code(0, 0, "memcpyEdge/west.csl", .{
    .memcpy_params = memcpy_params_0,
    .USER_IN_1 = in_color,
    .STARTUP = STARTUP
  });

  const memcpy_params_1 = memcpy.get_params(1);
  @set_tile_code(1, 0, "buffer.csl", .{
    .memcpy_params = memcpy_params_1,
    .in_color = in_color,
    .out_color = out_color,
    .result_color = result_color,
    .main_task_id = main_task_id,
    .process_task_id = process_task_id,
    .num_elements_to_process = num_elements_to_process
  });

  // east.csl only has a D2H
  const memcpy_params_2 = memcpy.get_params(2);
  @set_tile_code(2, 0, "memcpyEdge/east.csl", .{
    .memcpy_params = memcpy_params_2,
    .USER_OUT_1 = result_color,
    .STARTUP = STARTUP
  });
}

buffer.csl

param memcpy_params: comptime_struct;

param num_elements_to_process: i16;

// Colors
param in_color:         color;
param out_color:        color;
param result_color:     color;

// Task IDs
param process_task_id: data_task_id;  // Data task process_task triggered by out_color wlts
param main_task_id:    local_task_id;

// ----------
// Every PE needs to import memcpy module otherwise the I/O cannot
// propagate the data to the destination.

// memcpy module reserves input queue 0 and output queue 0
const sys_mod = @import_module( "<memcpy/memcpy>", memcpy_params);
// ----------

var fifo_buffer = @zeros([1024]f16);
const fifo = @allocate_fifo(fifo_buffer);

const in_queue = @get_input_queue(0);
const in_dsd = @get_dsd(fabin_dsd, .{.extent = num_elements_to_process,
                                     .fabric_color = in_color,
                                     .input_queue = in_queue});
comptime {
  @set_local_color_config(in_color, .{.routes = .{.rx = .{WEST}, .tx = .{RAMP}}});
}

const out_queue = @get_output_queue(1);
const out_dsd = @get_dsd(fabout_dsd, .{.extent = num_elements_to_process,
                                       .fabric_color = out_color,
                                       .output_queue = out_queue});

const ten = [1]f16 {10.0};
const dsd_ten = @get_dsd(mem1d_dsd, .{.tensor_access = |i|{num_elements_to_process} -> ten[0]});

task main_task() void {
  // Move from the fabric to the FIFO
  // adding 10.0 to each element at the same time
  @faddh(fifo, in_dsd, dsd_ten, .{.async = true});

  // Move from the FIFO to a process_task
  // negating values at the same time
  @fnegh(out_dsd, fifo, .{.async = true});
}

const result_dsd = @get_dsd(fabout_dsd, .{.extent = 1, .fabric_color = result_color});

task process_task(element:f16) void {
  @fmovh(result_dsd, element * element * element);
}

comptime {
  @bind_data_task(process_task, process_task_id); // data task receives wlts along out_color
  @bind_local_task(main_task, main_task_id);
  @activate(main_task_id);

  @set_local_color_config(out_color, .{.routes = .{.rx = .{RAMP}, .tx = .{RAMP}}});
  @set_local_color_config(result_color, .{.routes = .{.rx = .{RAMP}, .tx = .{EAST}}});
}

run.py

#!/usr/bin/env cs_python

import argparse
import json
import numpy as np

from cerebras.sdk.sdk_utils import memcpy_view, input_array_to_u32
from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime, MemcpyDataType # pylint: disable=no-name-in-module
from cerebras.sdk.runtime.sdkruntimepybind import MemcpyOrder # pylint: disable=no-name-in-module

parser = argparse.ArgumentParser()
parser.add_argument('--name', help='the test name')
parser.add_argument("--cmaddr", help="IP:port for CS system")
args = parser.parse_args()
dirname = args.name

# Parse the compile metadata
with open(f"{dirname}/out.json", encoding="utf-8") as json_file:
  compile_data = json.load(json_file)
params = compile_data["params"]
MEMCPYH2D_DATA_1 = int(params["MEMCPYH2D_DATA_1_ID"])
MEMCPYD2H_DATA_1 = int(params["MEMCPYD2H_DATA_1_ID"])
size = int(params["num_elements_to_process"])
print(f"MEMCPYH2D_DATA_1 = {MEMCPYH2D_DATA_1}")
print(f"MEMCPYD2H_DATA_1 = {MEMCPYD2H_DATA_1}")
print(f"size = {size}")

# maximum length of the fifo
max_fifo_len = 256*20
print(f"maximum size of the buffer in the artificial halo is {max_fifo_len}")
assert size < max_fifo_len, "input size exceeds max. capacity, may stall"

memcpy_dtype = MemcpyDataType.MEMCPY_16BIT
runner = SdkRuntime(dirname, cmaddr=args.cmaddr)

runner.load()
runner.run()

np.random.seed(seed=7)

input_tensor = np.random.random(size).astype(np.float16)
print("step 1: streaming H2D to P0.0")
# "input_tensor" is an 1d array
# The type of input_tensor is f16, we need to extend it to uint32
# There are two kind of extension when using the utility function input_array_to_u32
#    input_array_to_u32(np_arr: np.ndarray, sentinel: Optional[int], fast_dim_sz: int)
# 1) zero extension:
#    sentinel = None
# 2) upper 16-bit is the index of the array:
#    sentinel is Not None
#
# In this example, the upper 16-bit is don't care because buffer.csl only
# reads lower 16-bit
tensors_u32 = input_array_to_u32(input_tensor, 1, size)
runner.memcpy_h2d(MEMCPYH2D_DATA_1, tensors_u32, 0, 0, 1, 1, size, \
    streaming=True, data_type=memcpy_dtype, order=MemcpyOrder.COL_MAJOR, nonblock=True)

print("step 2: streaming D2H at P2.0")
# The D2H buffer must be of type u32
out_tensors_u32 = np.zeros(size, np.uint32)
runner.memcpy_d2h(out_tensors_u32, MEMCPYD2H_DATA_1, 2, 0, 1, 1, size, \
    streaming=True, data_type=memcpy_dtype, order=MemcpyOrder.COL_MAJOR, nonblock=False)
# remove upper 16-bit of each u32
result_tensor = memcpy_view(out_tensors_u32, np.dtype(np.float16))

runner.stop()

add_ten_negate = -(input_tensor + 10.0)
expected = add_ten_negate * add_ten_negate * add_ten_negate

np.testing.assert_equal(result_tensor, expected)
print("SUCCESS!")

commands.sh

#!/usr/bin/env bash

set -e

cslc ./layout.csl \
--fabric-dims=10,3 --fabric-offsets=4,1 \
--params=num_elements_to_process:2048 \
-o out \
--params=MEMCPYH2D_DATA_1_ID:4 \
--params=MEMCPYD2H_DATA_1_ID:5 \
--memcpy --channels=1 --width-west-buf=0 --width-east-buf=0
cs_python run.py --name out