Host-to-Device Broadcast Test

This example shows how to use row or column broadcast. For example if the user wants to broadcast a column of data [1.0, 2.0, 3.0, 4.0] to a region of interest starting from (1,1) with width 3 and height 4, one element per PE, the H2D API requires the user to prepare the following 3-by-4 tensor,

| 1.0  1.0  1.0 |
| 2.0  2.0  2.0 |
| 3.0  3.0  3.0 |
| 4.0  4.0  4.0 |

and use memcpy_h2d() API to stream 12 elements into the device. This operation wastes host bandwidth by 3x. Now the user can use the new API, memcpy_h2d_rowbcast(), to stream 4 elements only.

The same for column broadcasting, the user only needs to provide data of one row and uses memcpy_h2d_colbcast() API.

The new broadcasting scheme only supports H2D, not D2H.

The kernel of row-col-broadcast is the same as bandwidth-test. The run.py calculates the bandwidth as well. The formula of the bandwidth calculation is the same as bandwidth-test, so the user can see how much time this new API can save.

layout.csl

// c0,c1,c2,c3,c4 are internal colors of sync module
param C0_ID: i16;
param C1_ID: i16;
param C2_ID: i16;
param C3_ID: i16;
param C4_ID: i16;

param pe_length: i16; // number of wavelets per PE
param width : i16 ;   // width of the core
param height: i16 ;   // height of the core


const C0 : color = @get_color(C0_ID);
const C1 : color = @get_color(C1_ID);
const C2 : color = @get_color(C2_ID);
const C3 : color = @get_color(C3_ID);
const C4 : color = @get_color(C4_ID);

// entrypoints of sync module
const STARTUP: local_task_id = @get_local_task_id(15);
const SYNC_Y: local_task_id = @get_local_task_id(16);
const SYNC_BCAST: local_task_id = @get_local_task_id(17);
const EXIT: local_task_id = @get_local_task_id(18);


const memcpy = @import_module( "<memcpy/get_params>", .{
    .width = width,
    .height = height,
    });

const sync = @import_module( "sync/layout.csl", .{
    .colors = [5]color{C0, C1, C2, C3, C4},
    .entrypoints = [4]local_task_id{STARTUP, SYNC_Y, SYNC_BCAST, EXIT},
    .width = width,
    .height = height
    });

layout{

    // H2D or D2H colors must be less than 15 (smallest color of entrypoints)
    @comptime_assert( C0_ID < C1_ID);
    @comptime_assert( C1_ID < C2_ID);
    @comptime_assert( C2_ID < C3_ID);
    @comptime_assert( C3_ID < C4_ID);

    // step 1: configure the rectangle which does not include halo
    @set_rectangle( width, height );

    // step 2: compile csl code for a set of PEx.y and generate out_x_y.elf
    //   format: @set_tile_code(x, y, code.csl, param_binding);

    var py: i16 = 0;
    while(py < height) : (py +=1) {
        var px: i16 = 0;
        while( px < width) : (px +=1) {

            const memcpyParams = memcpy.get_params(px);
            const syncParams = sync.get_params(px, py);

            var params: comptime_struct = .{
                .memcpyParams = memcpyParams,
                .pe_length = pe_length,

                .syncParams = syncParams,
            };

            @set_tile_code(px, py, "kernel.csl", params);
        }
    }

    @export_name("A", [*]f32, true);
    @export_name("time_memcpy", [*]f32, true);
    @export_name("time_ref", [*]f32, true);

    @export_name("f_tic", fn()void);
    @export_name("f_toc", fn()void);
    @export_name("f_memcpy_timestamps", fn()void);
    @export_name("f_sync", fn()void);
    @export_name("f_reference_timestamps", fn()void);
} // end of layout

kernel.csl

// contraints: input/output queue ID = 0 is reserved for memcpy module
// only use microthread 2,3,4,5,6,7

param memcpyParams: comptime_struct;

param syncParams: comptime_struct;

param pe_length: i16;


const timestamp = @import_module("<time>");
// starting time of H2D/D2H
var tscStartBuffer = @zeros([timestamp.tsc_size_words]u16);
// ending time of H2D/D2H
var tscEndBuffer = @zeros([timestamp.tsc_size_words]u16);


const sys_mod = @import_module( "<memcpy/memcpy>", memcpyParams);

const sync_mod = @import_module( "sync/pe.csl", @concat_structs(syncParams, .{
     .f_callback = sys_mod.unblock_cmd_stream,
     .input_queues=[3]u16{2, 3, 4},
     .output_queues=[3]u16{2, 3, 4},
     }));


////////////////////////////////////////////////////////////////////////////////
// Main memory (48KB)
////////////////////////////////////////////////////////////////////////////////

const size : i16 = 1024*4;

var A = @zeros([size]f32);
// time_buf_f32[0:2] = {tscStartBuffer, tscEndBuffer}
var time_buf_f32 = @zeros([3]f32);
// reference clock inside sync module
var time_ref_f32 = @zeros([2]f32);

var ptr_A : [*]f32 = &A;
var ptr_time_memcpy: [*]f32 = &time_buf_f32;
var ptr_time_ref: [*]f32 = &time_ref_f32;

////////////////////////////////////////////////////////////////////////////////
// Tasks
////////////////////////////////////////////////////////////////////////////////


fn f_tic() void {
    timestamp.get_timestamp(&tscStartBuffer);

    // the user must unblock cmd color for every PE
    sys_mod.unblock_cmd_stream();
}

fn f_toc() void {
    timestamp.get_timestamp(&tscEndBuffer);

    // the user must unblock cmd color for every PE
    sys_mod.unblock_cmd_stream();
}

fn f_memcpy_timestamps() void {
    // time_buf_f32[0] = {tscStartBuffer[1], tscStartBuffer[0]}
    // time_buf_f32[1] = {tscEndBuffer[0], tscStartBuffer[2]}
    // time_buf_f32[2] = {tscEndBuffer[2], tscEndBuffer[1]}
    var lo_ : u16 = 0;
    var hi_ : u16 = 0;
    var word : u32 = 0;

    lo_ = tscStartBuffer[0];
    hi_ = tscStartBuffer[1];
    time_buf_f32[0] = @bitcast(f32, (@as(u32,hi_) << @as(u16,16)) | @as(u32, lo_) );

    lo_ = tscStartBuffer[2];
    hi_ = tscEndBuffer[0];
    time_buf_f32[1] = @bitcast(f32, (@as(u32,hi_) << @as(u16,16)) | @as(u32, lo_) );

    lo_ = tscEndBuffer[1];
    hi_ = tscEndBuffer[2];
    time_buf_f32[2] = @bitcast(f32, (@as(u32,hi_) << @as(u16,16)) | @as(u32, lo_) );

    // the user must unblock cmd color for every PE
    sys_mod.unblock_cmd_stream();
}

fn f_sync() void {
    // sync all PEs and record the reference clock
    sync_mod.f_sync();
}

fn f_reference_timestamps() void {
    // time_ref_f32[0] = {tscRefBuffer[1], tscRefBuffer[0]}
    // time_ref_f32[1] = {0, tscRefBuffer[2]}
    var lo_ : u16 = 0;
    var hi_ : u16 = 0;

    lo_ = sync_mod.tscRefBuffer[0];
    hi_ = sync_mod.tscRefBuffer[1];
    time_ref_f32[0] = @bitcast(f32, (@as(u32,hi_) << @as(u16,16)) | @as(u32, lo_) );

    lo_ = sync_mod.tscRefBuffer[2];
    hi_ = 0;
    time_ref_f32[1] = @bitcast(f32, (@as(u32,hi_) << @as(u16,16)) | @as(u32, lo_) );

    // the user must unblock cmd color for every PE
    sys_mod.unblock_cmd_stream();
}


comptime {

    @comptime_assert( pe_length <= size );
}

comptime {
    @export_symbol(ptr_A, "A");
    @export_symbol(ptr_time_memcpy, "time_memcpy");
    @export_symbol(ptr_time_ref, "time_ref");
}

comptime{
    @export_symbol(f_tic);
    @export_symbol(f_toc);
    @export_symbol(f_memcpy_timestamps);
    @export_symbol(f_sync);
    @export_symbol(f_reference_timestamps);
}

run.py

#!/usr/bin/env cs_python
# pylint: disable=too-many-function-args

""" Test row or column broadcast
    The kernel is the same as bandwidthTest.
    The bandwidth calculation follows bandwidthTest.

    Here is the list of parameters:
    -m=<int> specifies the height of the core.
    -n=<int> specifies the width of the core.
    -k=<int> specifies the maximum number of elements per PE in the core.
    --roi_px=<int> specifies the starting column index of region of interest
    --roi_py=<int> specifies the starting row index of region of interest
    --roi_w=<int> specifies the width of region of interest
    --roi_h=<int> specifies the height of region of interest
    --channels specifies the number of I/O channels, no bigger than 16.
"""

import random
import struct

import numpy as np
from cmd_parser import parse_args

from cerebras.sdk.runtime.sdkruntimepybind import (  # pylint: disable=no-name-in-module
    MemcpyDataType,
    MemcpyOrder,
    SdkRuntime,
)


def float_to_hex(f):
    return hex(struct.unpack('<I', struct.pack('<f', f))[0])


def make_u48(words):
    return words[0] + (words[1] << 16) + (words[2] << 32)


def main():
    """Main method to run the example code."""

    random.seed(127)

    args, dirname = parse_args()

    height = args.m
    width = args.n
    pe_length = args.k
    use_col_major = args.use_col_major
    is_row_bcast = args.is_row_bcast
    loop_count = args.loop_count

    print(f"core: width = {width}, height = {height}, pe_length={pe_length}")

    np.random.seed(2)
    if is_row_bcast:
        print("row broadcast mode: only prepare data for 1 column")
        # A is h-by-1-by-l
        A = (
            np.arange(height * 1 * pe_length)
            .reshape(height, 1, pe_length)
            .astype(np.uint32)
        )
    else:
        print("column broadcast mode: only prepare data for 1 row")
        # A is 1-by-w-by-l
        A = (
            np.arange(1 * width * pe_length)
            .reshape(1, width, pe_length)
            .astype(np.uint32)
        )
    print(f"shape(A) = {A.shape}")
    print(f"A = {A}")

    px = args.roi_px
    py = args.roi_py
    pw = args.roi_w
    ph = args.roi_h

    print(f"ROI: px = {px}, py = {py}, pw = {pw}, ph = {ph}")

    assert 0 <= px, "px must be non-negative"
    assert 0 <= py, "px must be non-negative"
    assert width >= pw, "pw must not be greater than width"
    assert height >= ph, "ph must not be greater than height"

    # extract ROI from A
    if is_row_bcast:
        B = A[py : (py + ph), 0:, 0:]
    else:
        B = A[0:, px : (px + pw), 0:]
    print(f"shape(B) = {B.shape}")
    print(f"B = {B}")

    bx, by, bz = B.shape
    if is_row_bcast:
        assert bx == ph
        assert by == 1
        assert bz == pe_length
    else:
        assert bx == 1
        assert by == pw
        assert bz == pe_length

    print(f"use_col_major = {use_col_major}")
    if use_col_major:
        B_1d = B.T.ravel()
    else:
        B_1d = B.ravel()

    print('store ELFs and log files in the folder ', dirname)

    memcpy_dtype = MemcpyDataType.MEMCPY_32BIT

    runner = SdkRuntime(
        dirname,
        suppress_simfab_trace=True,
        # msg_level="DEBUG",
        cmaddr=args.cmaddr,
    )

    symbol_A = runner.get_id("A")
    symbol_time_memcpy = runner.get_id("time_memcpy")
    symbol_time_ref = runner.get_id("time_ref")

    runner.load()
    runner.run()

    print("step 1: sync() synchronizes all PEs and records reference clock")
    runner.call("f_sync", [], nonblock=True)

    print("step 2: tic() records time_start")
    runner.call("f_tic", [], nonblock=True)

    print(f"len(B_1d) = {len(B_1d)}")
    print(f"B_1d = {B_1d}")
    for j in range(loop_count):
        if is_row_bcast:
            print("step 1: memcpy_h2d_rowbcast(B)")
            runner.memcpy_h2d_rowbcast(
                symbol_A,
                B_1d,
                px,
                py,
                pw,
                ph,
                pe_length,
                streaming=False,
                data_type=memcpy_dtype,
                order=(
                    MemcpyOrder.COL_MAJOR
                    if use_col_major
                    else MemcpyOrder.ROW_MAJOR
                ),
                nonblock=True,
            )
        else:
            print("step 1: memcpy_h2d_colbcast(B)")
            runner.memcpy_h2d_colbcast(
                symbol_A,
                B_1d,
                px,
                py,
                pw,
                ph,
                pe_length,
                streaming=False,
                data_type=memcpy_dtype,
                order=(
                    MemcpyOrder.COL_MAJOR
                    if use_col_major
                    else MemcpyOrder.ROW_MAJOR
                ),
                nonblock=True,
            )

    print("step 4: toc() records time_end")
    runner.call("f_toc", [], nonblock=False)

    print("step 5: prepare (time_start, time_end)")
    runner.call("f_memcpy_timestamps", [], nonblock=False)

    print("step 6: D2H (time_start, time_end)")
    # time_start/time_end is of type u16[3]
    # {time_start, time_end} is packed into three f32
    time_memcpy_1d_f32 = np.zeros(height * width * 3, np.float32)
    runner.memcpy_d2h(
        time_memcpy_1d_f32,
        symbol_time_memcpy,
        0,
        0,
        width,
        height,
        3,
        streaming=False,
        data_type=memcpy_dtype,
        order=MemcpyOrder.ROW_MAJOR,
        nonblock=False,
    )
    time_memcpy_hwl = np.reshape(
        time_memcpy_1d_f32, (height, width, 3), order='C'
    )

    print("step 7: prepare reference clock")
    runner.call("f_reference_timestamps", [], nonblock=False)

    print("step 8: D2H reference clock")
    # time_ref is of type u16[3], packed into two f32
    time_ref_1d_f32 = np.zeros(height * width * 2, np.float32)
    runner.memcpy_d2h(
        time_ref_1d_f32,
        symbol_time_ref,
        0,
        0,
        width,
        height,
        2,
        streaming=False,
        data_type=memcpy_dtype,
        order=MemcpyOrder.ROW_MAJOR,
        nonblock=False,
    )
    time_ref_hwl = np.reshape(time_ref_1d_f32, (height, width, 2), order='C')

    print("step 9: D2H(A)")
    E_1d = np.zeros(height * width * pe_length, A.dtype)
    runner.memcpy_d2h(
        E_1d,
        symbol_A,
        0,
        0,
        width,
        height,
        pe_length,
        streaming=False,
        data_type=memcpy_dtype,
        order=MemcpyOrder.COL_MAJOR,
        nonblock=False,
    )

    runner.stop()

    print("DONE")

    # E is h-by-w-by-l
    E_hwl = np.reshape(E_1d, (height, width, pe_length), order='F')
    print(f"E_hwl (from device) = {E_hwl}")

    # B_ext is the expected result
    B_ext = (
        np.zeros(height * width * pe_length)
        .reshape(height, width, pe_length)
        .astype(A.dtype)
    )
    if is_row_bcast:
        # copy B to each column of ROI
        for w in range(pw):
            B_ext[py : (py + ph), (px + w) : (px + w + 1), 0:] = B
    else:
        # copy B to each row of ROI
        for h in range(ph):
            B_ext[(py + h) : (py + h + 1), px : (px + pw), 0:] = B
    print(f"B_ext = {B_ext}")

    print("check E_hwl == B_ext")
    assert np.allclose(E_hwl.ravel(), B_ext.ravel(), 0)

    # time_start = start time of H2D/D2H
    time_start = np.zeros((height, width)).astype(int)
    # time_end = end time of H2D/D2H
    time_end = np.zeros((height, width)).astype(int)
    word = np.zeros(3).astype(np.uint16)
    for w in range(width):
        for h in range(height):
            hex_t0 = int(float_to_hex(time_memcpy_hwl[(h, w, 0)]), base=16)
            hex_t1 = int(float_to_hex(time_memcpy_hwl[(h, w, 1)]), base=16)
            hex_t2 = int(float_to_hex(time_memcpy_hwl[(h, w, 2)]), base=16)
            word[0] = hex_t0 & 0x0000FFFF
            word[1] = (hex_t0 >> 16) & 0x0000FFFF
            word[2] = hex_t1 & 0x0000FFFF
            time_start[(h, w)] = make_u48(word)
            word[0] = (hex_t1 >> 16) & 0x0000FFFF
            word[1] = hex_t2 & 0x0000FFFF
            word[2] = (hex_t2 >> 16) & 0x0000FFFF
            time_end[(h, w)] = make_u48(word)

    # time_ref = reference clock
    time_ref = np.zeros((height, width)).astype(int)
    word = np.zeros(3).astype(np.uint16)
    for w in range(width):
        for h in range(height):
            hex_t0 = int(float_to_hex(time_ref_hwl[(h, w, 0)]), base=16)
            hex_t1 = int(float_to_hex(time_ref_hwl[(h, w, 1)]), base=16)
            word[0] = hex_t0 & 0x0000FFFF
            word[1] = (hex_t0 >> 16) & 0x0000FFFF
            word[2] = hex_t1 & 0x0000FFFF
            time_ref[(h, w)] = make_u48(word)
    # adjust the reference clock by the propagation delay
    for py in range(height):
        for px in range(width):
            time_ref[(py, px)] = time_ref[(py, px)] - (px + py)

    # shift time_start and time_end by time_ref
    time_start = time_start - time_ref
    time_end = time_end - time_ref

    # cycles_send = time_end[(h,w)] - time_start[(h,w)]
    # 850MHz --> 1 cycle = (1/0.85) ns = (1/0.85)*1.e-3 us
    # time_send = (cycles_send / 0.85) *1.e-3 us
    # bandwidth = (((wvlts-1) * 4)/time_send) MBS
    wvlts = pw * ph * pe_length
    min_time_start = time_start.min()
    max_time_end = time_end.max()
    cycles_send = max_time_end - min_time_start
    time_send = (cycles_send / 0.85) * 1.0e-3
    bandwidth = ((wvlts * 4) / time_send) * loop_count
    print(f"ROI: pw = {pw}, ph= {ph}, pe_length={pe_length}")
    print(f"wvlts = {wvlts}, loop_count = {loop_count}")
    print(f"cycles_send = {cycles_send} cycles")
    print(f"time_send = {time_send} us")
    print(f"bandwidth = {bandwidth} MB/S ")


if __name__ == "__main__":
    main()

_cmd_parser.py

# This is not a real test, but a module that gets imported in other tests.

"""command parser for broadcast

   -m <int>      number of rows of the core rectangle
   -n <int>      number of columns of the core rectangle
   -k <int>      number of elements of local tensor
   --latestlink  working directory
   --cmaddr      IP address of a WSE
   --roi_px      starting column index of region of interest
   --roi_py      starting row index of region of interest
   --roi_w       width of region of interest
   --roi_h       height of region of interest
"""


import argparse
import os


def parse_args():
    """command parser"""
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", default=1, type=int, help="number of rows")
    parser.add_argument("-n", default=1, type=int, help="number of columns")
    parser.add_argument("-k", default=1, type=int, help="size of local tensor")
    parser.add_argument(
        "--latestlink", help="folder to contain the log files (default: latest)"
    )
    parser.add_argument(
        "--cmaddr", help="CM address and port, i.e. <IP>:<port>"
    )
    parser.add_argument(
        "--arch", help="wse2 or wse3. Default is wse2 when not supplied."
    )
    parser.add_argument(
        "--channels", default=1, type=int, help="number of channels"
    )
    parser.add_argument(
        "--roi_px", default=1, type=int, help="starting column index of ROI"
    )
    parser.add_argument(
        "--roi_py", default=1, type=int, help="starting row index of ROI"
    )
    parser.add_argument("--roi_w", default=3, type=int, help="width of ROI")
    parser.add_argument("--roi_h", default=3, type=int, help="height of ROI")
    parser.add_argument(
        "--use_col_major",
        action="store_true",
        help="use column major to send the row or column broadcast",
    )
    parser.add_argument(
        "--is_row_bcast",
        action="store_true",
        help="row broadcast or column broadcast",
    )
    parser.add_argument("--fabric-dims", help="Fabric dimension, i.e. <W>,<H>")
    parser.add_argument(
        "--loop_count",
        default=1,
        type=int,
        help="number of back-to-back H2D/D2H",
    )

    args = parser.parse_args()

    logs_dir = "latest"
    if args.latestlink:
        logs_dir = args.latestlink

    dir_exist = os.path.isdir(logs_dir)
    if dir_exist:
        print(f"{logs_dir} already exists")
    else:
        print(f"create {logs_dir} to store log files")
        os.mkdir(logs_dir)

    return args, logs_dir

sync/layout.csl

param colors:[5]color;
param entrypoints:[4]local_task_id;
param width : i16 ;   // width of the core
param height: i16 ;   // height of the core

const C0 : color = colors[0];
const C1 : color = colors[1];
const C2 : color = colors[2];
const C3 : color = colors[3];
const C4 : color = colors[4];

const STARTUP: local_task_id = entrypoints[0];
const SYNC_Y: local_task_id = entrypoints[1];
const SYNC_BCAST: local_task_id = entrypoints[2];
const EXIT: local_task_id = entrypoints[3];

fn get_params(px:i16, py:i16) comptime_struct {

    var first_py: bool = (0 == py);
    var last_py: bool = ((height-1) == py);
    var is_py_even: bool = (0 == (py % 2));

    var first_px: bool = (0 == px);
    var last_px: bool = ((width-1) == px);
    var is_px_even: bool = (0 == (px % 2));

    var c_recv_px: color = C0;
    var c_send_px: color = C1;
    if (is_px_even){
        c_recv_px = C0;
        c_send_px = C1;
    }else{
        c_recv_px = C1;
        c_send_px = C0;
    }

    var c_recv_py: color = C2;
    var c_send_py: color = C3;
    if (is_py_even){
        c_recv_py = C2;
        c_send_py = C3;
    }else{
        c_recv_py = C3;
        c_send_py = C2;
    }

    return .{
        .c_recv_px = c_recv_px,
        .c_send_px = c_send_px,
        .c_recv_py = c_recv_py,
        .c_send_py = c_send_py,
        .c_bcast = C4,

        .STARTUP = STARTUP,
        .SYNC_Y = SYNC_Y,
        .SYNC_BCAST = SYNC_BCAST,
        .EXIT = EXIT,

        .first_px = first_px,
        .last_px = last_px,
        .first_py = first_py,
        .last_py = last_py,
    };
}

sync/pe.csl

param c_recv_px: color;
param c_send_px: color;
param c_recv_py: color;
param c_send_py: color;
param c_bcast: color;

param STARTUP: local_task_id;
param SYNC_Y: local_task_id;
param SYNC_BCAST: local_task_id;
param EXIT: local_task_id;

param first_px: bool;
param last_px: bool;
param first_py: bool;
param last_py: bool;

// f_callback = sys_mod.unblock_cmd_stream, to continue next command
param f_callback : fn ()void;

// input_queues={2,3,4}
// output_queues={2,3,4}
param input_queues:[3]u16;
param output_queues:[3]u16;

const c_recv_px_iq = @get_input_queue(input_queues[0]);
const c_send_px_oq = @get_output_queue(output_queues[0]);

const c_recv_py_iq = @get_input_queue(input_queues[1]);
const c_send_py_oq = @get_output_queue(output_queues[1]);

const c_bcast_iq = @get_input_queue(input_queues[2]);
const c_bcast_oq = @get_output_queue(input_queues[2]);

const timestamp = @import_module("<time>");

// tsc_size_words = 3
var tscRefBuffer = @zeros([timestamp.tsc_size_words]u16);

////////////////////////////////////////////////////////////////////////////////
// Main memory (48KB)
////////////////////////////////////////////////////////////////////////////////

var buf = @zeros([1]f32);

////////////////////////////////////////////////////////////////////////////////
// Tasks
// syntax
//     task_begin(name, entrypoint, color)
////////////////////////////////////////////////////////////////////////////////

const mem_buf_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{1} -> buf[i] });

var fab_recv_data_px_wdsd =  @get_dsd(fabin_dsd, .{
   .extent = 1,
   .fabric_color = c_recv_px,
   .input_queue = c_recv_px_iq
});

var fab_trans_data_px_wdsd = @get_dsd(fabout_dsd, .{
    .extent = 1,
    .fabric_color = c_send_px,
    .output_queue = c_send_px_oq
});

var fab_recv_data_py_wdsd =  @get_dsd(fabin_dsd, .{
   .extent = 1,
   .fabric_color = c_recv_py,
   .input_queue = c_recv_py_iq
});

var fab_trans_data_py_wdsd = @get_dsd(fabout_dsd, .{
    .extent = 1,
    .fabric_color = c_send_py,
    .output_queue = c_send_py_oq
});

var fab_recv_data_bcast_wdsd =  @get_dsd(fabin_dsd, .{
   .extent = 1,
   .fabric_color = c_bcast,
   .input_queue = c_bcast_iq
});

var fab_trans_data_bcast_wdsd = @get_dsd(fabout_dsd, .{
    .extent = 1,
    .fabric_color = c_bcast,
    .output_queue = c_bcast_oq
});



// Each row performs a sync from the last PE to first PE
fn f_sync() void {
    // sync a row
    if (last_px){
        // px = width-1: send sync signal
        @mov32(fab_trans_data_px_wdsd, mem_buf_dsd, .{.async=true, .activate = f_sync_y });
    }else{
        if (first_px){
            // px = 0: receive signal
            @mov32(mem_buf_dsd, fab_recv_data_px_wdsd, .{.async=true, .activate = f_sync_y });
        }else{
            // 0 < px < width-1: receive signal and forward it
            @mov32(fab_trans_data_px_wdsd, fab_recv_data_px_wdsd, .{.async=true, .activate = f_sync_y });
        }
    }
}


// prerequisite: row synchronization is done
//   the first PE is the last one to receive the signal
// The first column performs a sync from last PE to first PE
// other PEs wait for bcast signal
task f_sync_y() void {
    if (first_px){
        // 1st column performs a sync
        if (last_py){
            // py = height-1: send sync signal
            @mov32(fab_trans_data_py_wdsd, mem_buf_dsd, .{.async=true, .activate = f_sync_bcast });
        }else{
            if (first_py){
                // py = 0: receive signal
                @mov32(mem_buf_dsd, fab_recv_data_py_wdsd, .{.async=true, .activate = f_sync_bcast });
            }else{
                // 0 < py < height-1: receive signal and forward it
                @mov32(fab_trans_data_py_wdsd, fab_recv_data_py_wdsd, .{.async=true, .activate = f_sync_bcast });
            }
        }
    }else{
        // other PEs wait for bcast signal
        @activate(SYNC_BCAST); // trigger f_sync_bcast
    }
}

// prerequisite: sync is done, P0.0 is the last one to receive the sync
// P0.0 broadcasts the signal, others wait for the bcast signal from P0.0
task f_sync_bcast() void {

    if ( first_px and first_py ){
        // P0.0 sends the signal
        @mov32(fab_trans_data_bcast_wdsd, mem_buf_dsd, .{.async=true, .activate = f_exit });
    }else{
        // others wait for bcast from P0.0
        @mov32(mem_buf_dsd, fab_recv_data_bcast_wdsd, .{.async=true, .activate = f_exit });
    }
}

// record reference clock T
// T is regarded as clock 0 because all PEs sync with P0.0
task f_exit() void {

    timestamp.get_timestamp(&tscRefBuffer);

    //sys_mod.unblock_cmd_stream();
    f_callback();
}


task f_startup() void {
    timestamp.enable_tsc();
}

comptime {
    @activate(STARTUP);

    @bind_local_task(f_startup, STARTUP);
    @bind_local_task(f_sync_y, SYNC_Y);
    @bind_local_task(f_sync_bcast, SYNC_BCAST);
    @bind_local_task(f_exit, EXIT);

    // On WSE-3, we must explicitly initialize input and output queues
    if (@is_arch("wse3")) {
        @initialize_queue(c_recv_px_iq, .{ .color = c_recv_px });
        @initialize_queue(c_send_px_oq, .{ .color = c_send_px });

        @initialize_queue(c_recv_py_iq, .{ .color = c_recv_py });
        @initialize_queue(c_send_py_oq, .{ .color = c_send_py });

        @initialize_queue(c_bcast_iq, .{ .color = c_bcast });
        @initialize_queue(c_bcast_oq, .{ .color = c_bcast });
    }
}


// sync a row with C0 and C1
//
//     C0     C1     C0     C1
// P0 <-- P1 <-- P2 <-- P3 <-- P4
//
//     C0     C1     C0     C1     C0
// P0 <-- P1 <-- P2 <-- P3 <-- P4 <-- P5
//
// P0: recv C0
// P_even: recv C0, send C1
// P_odd: recv C1, send C0
// P_last: send C0 if odd; send C1 if even
comptime {
    if (first_px){
        // px = 0: receive from east
        @set_local_color_config(c_recv_px, .{ .routes = .{ .rx = .{EAST}, .tx = .{RAMP} } } );
    }else{
        if (last_px){
           // px = width-1: send to west
           @set_local_color_config(c_send_px, .{ .routes = .{ .rx = .{RAMP}, .tx = .{WEST} } } );
        }else{
           // 0 < px < width-1: receive from east, send to west
           @set_local_color_config(c_recv_px, .{ .routes = .{ .rx = .{EAST}, .tx = .{RAMP} } } );
           @set_local_color_config(c_send_px, .{ .routes = .{ .rx = .{RAMP}, .tx = .{WEST} } } );
        }
    }
}

// sync a col with C2 and C3
//     C2     C3     C2     C3
// P0 <-- P1 <-- P2 <-- P3 <-- P4
//
//     C2     C3     C2     C3     C2
// P0 <-- P1 <-- P2 <-- P3 <-- P4 <-- P5
//
// P0: recv C2
// P_even: recv C2, send C3
// P_odd: recv C3, send C2
// P_last: send C2 if odd; send C3 if even
comptime {
    if (first_py){
        // py = 0 (even): receive from south
        @set_local_color_config(c_recv_py, .{ .routes = .{ .rx = .{SOUTH}, .tx = .{RAMP} } } );
    }else{
        if (last_py){
           // py = height-1: send to north
           @set_local_color_config(c_send_py, .{ .routes = .{ .rx = .{RAMP}, .tx = .{NORTH} } } );
        }else{
           // 0 < py < height-1: receive from south, send to north
           @set_local_color_config(c_recv_py, .{ .routes = .{ .rx = .{SOUTH}, .tx = .{RAMP} } } );
           @set_local_color_config(c_send_py, .{ .routes = .{ .rx = .{RAMP}, .tx = .{NORTH} } } );
        }
    }
}


// w > 1 and h > 1
//  x --> x --> x
//  |
//  V
//  x --> x --> x
//  |
//  V
//  x --> x --> x
//
// WARNING: corner case for w=1 or h=1
comptime {
    if (first_px){
        // px = 0
        if (first_py){
            // P0,0: send to east and south
            @set_local_color_config(c_bcast, .{ .routes = .{ .rx = .{RAMP}, .tx = .{EAST, SOUTH} } } );
        }else{
            if (last_py){
                // P0,h-1
                @set_local_color_config(c_bcast, .{ .routes = .{ .rx = .{NORTH}, .tx = .{EAST, RAMP} } } );
            }else{
                // P0,py: 0 < py < height-1
                @set_local_color_config(c_bcast, .{ .routes = .{ .rx = .{NORTH}, .tx = .{EAST, RAMP, SOUTH} } } );
            }
        }
    }else{
        if (last_px){
            // px = width-1
           @set_local_color_config(c_bcast, .{ .routes = .{ .rx = .{WEST}, .tx = .{RAMP} } } );
        }else{
            // 0 < px < width-1
           @set_local_color_config(c_bcast, .{ .routes = .{ .rx = .{WEST}, .tx = .{EAST, RAMP} } } );
        }
    }
}

commands.sh

#!/usr/bin/env bash

set -e

cslc ./src/layout.csl --arch wse3 --fabric-dims=12,7 --fabric-offsets=4,1 \
--params=width:5,height:5,pe_length:5 --params=C0_ID:0 \
--params=C1_ID:1 --params=C2_ID:2 --params=C3_ID:3 --params=C4_ID:4 -o=out \
--memcpy --channels=2 --width-west-buf=0 --width-east-buf=0
cs_python ./run.py -m=5 -n=5 -k=5 --latestlink out --is_row_bcast --loop_count=1