stencil-3d-7pts

This example evaluates the performance of 7-point stencil. The kernel records the start and end of spmv by tsc counter. In addition the tsc counters of all PEs are not sychronized in the beginning. To avoid the timing variation among those PEs, sync() synchronizes all PEs and samples the reference clock.

The kernel kernel.csl defines a couple of host-callable functions, f_sync(), f_tic() and f_toc() in order to synchronize the PEs and record the timing of spmv.

The kernel allreduce/pe.csl performs a reduction over the whole rectangle to synchronize the PEs, then the bottom-right PE sends a signal to other PEs to sample the reference clock.

The kernel stencil_3d_7pts/pe.csl performs a matrix-vector product (spmv) where the matrix has 7 diagonals corresponding to 7 point stencil. The stencil coefficients can vary per PE, but must be the same for the local vector. The user can change the coefficients based on the boundary condition or curvilinear coordinate transformation.

The script run.py has the following parameters:

  • -k=<int> specifies the maximum size of local vector.

  • --zDim=<int> specifies how many elements per PE are computed.

  • --channels=<int> specifies the number of I/O channels, no bigger than 16.

The tic() samples “time_start” and toc() samples “time_end”. The sync() samples “time_ref” which is used to adjust “time_start” and “time_end”. The elapsed time (unit: cycles) is measured by cycles_send = max(time_end) - min(time_start)

The overall runtime (us) is computed via the following formula time_send = (cycles_send / 0.85) * 1.e-3 us

The bandwidth is calculated by bandwidth = ((6*w*h*4)/time_send)

layout.csl

// c0,c1,c2,c3,c4,c5,c6,c7 are internal colors of 7-point stencil
param C0_ID: i16;
param C1_ID: i16;
param C2_ID: i16;
param C3_ID: i16;
param C4_ID: i16;
param C5_ID: i16;
param C6_ID: i16;
param C7_ID: i16;
// c8 is an internal color of allreduce
param C8_ID: i16;

param MAX_ZDIM: i16; // maximum size of local vector x and y
param width: i16 ; // width of the core
param height: i16 ; // height of the core

param BLOCK_SIZE: i16; // size of temporary buffers for communication

const C0: color = @get_color(C0_ID);
const C1: color = @get_color(C1_ID);
const C2: color = @get_color(C2_ID);
const C3: color = @get_color(C3_ID);
const C4: color = @get_color(C4_ID);
const C5: color = @get_color(C5_ID);
const C6: color = @get_color(C6_ID);
const C7: color = @get_color(C7_ID);
const C8: color = @get_color(C8_ID);

// entrypoints of sync module
const STARTUP: local_task_id = @get_local_task_id(13);

// entrypoints of 7-point stenil
const EN_STENCIL_1: local_task_id = @get_local_task_id(14);
const EN_STENCIL_2: local_task_id = @get_local_task_id(15);
const EN_STENCIL_3: local_task_id = @get_local_task_id(16);

// entrypoints of allreduce
const EN_REDUCE_1: local_task_id = @get_local_task_id(17);
const EN_REDUCE_2: local_task_id = @get_local_task_id(18);
const EN_REDUCE_3: local_task_id = @get_local_task_id(19);
const EN_REDUCE_4: local_task_id = @get_local_task_id(20);

const stencil = @import_module( "../csl-libs/stencil_3d_7pts/layout.csl", .{
    .colors = [8]color{C0, C1, C2, C3, C4, C5, C6, C7},
    .entrypoints = [3]local_task_id{EN_STENCIL_1, EN_STENCIL_2, EN_STENCIL_3},
    .width = width,
    .height = height
    });

const reduce = @import_module( "../csl-libs/allreduce/layout.csl", .{
    .colors = [1]color{C8},
    .entrypoints = [4]local_task_id{EN_REDUCE_1, EN_REDUCE_2, EN_REDUCE_3, EN_REDUCE_4},
    .width = width,
    .height = height
    });

const memcpy = @import_module( "<memcpy/get_params>", .{
    .width = width,
    .height = height,
    });

layout{

    @comptime_assert(C0_ID < C1_ID);
    @comptime_assert(C1_ID < C2_ID);
    @comptime_assert(C2_ID < C3_ID);
    @comptime_assert(C3_ID < C4_ID);
    @comptime_assert(C4_ID < C5_ID);
    @comptime_assert(C5_ID < C6_ID);
    @comptime_assert(C6_ID < C7_ID);
    @comptime_assert(C7_ID < C8_ID);

    // step 1: configure the rectangle which does not include halo
    @set_rectangle( width, height );

    // step 2: compile csl code for a set of PEx.y and generate out_x_y.elf
    //   format: @set_tile_code(x, y, code.csl, param_binding);

    var py: i16 = 0;
    while(py < height) : (py +=1) {
        var px: i16 = 0;
        while(px < width) : (px +=1) {

            const memcpyParams = memcpy.get_params(px);
            const stencilParams = stencil.get_params(px, py);
            const reduceParams = reduce.get_params(px, py);
            var params: comptime_struct = .{
                .memcpyParams = memcpyParams,
                .reduceParams = reduceParams,
                .MAX_ZDIM = MAX_ZDIM,
                .BLOCK_SIZE = BLOCK_SIZE,
                .STARTUP = STARTUP,
                .stencilParams = stencilParams
            };

            @set_tile_code(px, py, "kernel.csl", params);
        }
    }

    @export_name("x", [*]f32, true);
    @export_name("y", [*]f32, true);
    @export_name("stencil_coeff", [*]f32, true);
    @export_name("time_buf_u16", [*]u16, true);
    @export_name("time_ref", [*]u16, true);

    @export_name("f_tic", fn()void);
    @export_name("f_toc", fn()void);
    @export_name("f_memcpy_timestamps", fn()void);
    @export_name("f_spmv", fn(i16)void);
    @export_name("f_sync", fn(i16)void);
    @export_name("f_reference_timestamps", fn()void);
} // end of layout

kernel.csl

param memcpyParams: comptime_struct;

param reduceParams: comptime_struct;

param stencilParams: comptime_struct;

param MAX_ZDIM: i16; // size of vector x

param BLOCK_SIZE: i16; // size of temporary buffers for communication

param STARTUP: local_task_id;

const timestamp = @import_module("<time>");

// input/output queue ID = 0 is reserved for memcpy module
const sys_mod = @import_module( "<memcpy/memcpy>", memcpyParams);

// allreduce uses input queue/output queue 1
const reduce_mod = @import_module( "../csl-libs/allreduce/pe.csl", @concat_structs(reduceParams, .{
     .f_callback = sys_mod.unblock_cmd_stream,
     .queues = [1]u16{1},
     .dest_dsr_ids = [1]u16{1},
     .src0_dsr_ids = [1]u16{1},
     .src1_dsr_ids = [1]u16{1}
     }));

// output queue cannot overlap input queues
const stencil_mod = @import_module( "../csl-libs/stencil_3d_7pts/pe.csl", @concat_structs(stencilParams, .{
     .f_callback = sys_mod.unblock_cmd_stream,
     .input_queues = [4]u16{3, 4, 5, 6},
     .output_queues = [1]u16{2},
     .BLOCK_SIZE = BLOCK_SIZE,
     .dest_dsr_ids = [2]u16{2,3},
     .src0_dsr_ids = [1]u16{2},
     .src1_dsr_ids = [2]u16{2,3}
     }));


// tsc_size_words = 3
// starting time of H2D/D2H
var tscStartBuffer = @zeros([timestamp.tsc_size_words]u16);
// ending time of H2D/D2H
var tscEndBuffer = @zeros([timestamp.tsc_size_words]u16);


////////////////////////////////////////////////////////////////////////////////
// Main memory (48KB)
////////////////////////////////////////////////////////////////////////////////

var x = @zeros([MAX_ZDIM]f32);
var y = @zeros([MAX_ZDIM]f32);

var dot = @zeros([1]f32);

// stencil coefficients are organized as
// {c_west, c_east, c_south, c_north, c_bottom, c_top, c_center}
//
// The formula is
//    c_west * x[i-1][j][k] + c_east * x[i+1][j][k] +
//    c_south * x[i][j-1][k] + c_north * x[i][j+1][k] +
//    c_bottom * x[i][j][k-1] + c_top * x[i][j][k+1] +
//    c_center * x[i][j][k]
var stencil_coeff = @zeros([7]f32);

// time_buf_u16[0:5] = {tscStartBuffer, tscEndBuffer}
var time_buf_u16 = @zeros([timestamp.tsc_size_words*2]u16);

// reference clock inside allreduce module
var time_ref_u16 = @zeros([timestamp.tsc_size_words]u16);

var ptr_x: [*]f32 = &x;
var ptr_y: [*]f32 = &y;
var ptr_stencil_coeff: [*]f32 = &stencil_coeff;
var ptr_time_buf_u16: [*]u16 = &time_buf_u16;
var ptr_time_ref: [*]u16 = &time_ref_u16;

////////////////////////////////////////////////////////////////////////////////
// Tasks
// syntax
//     task_begin(name, entrypoint, color)
////////////////////////////////////////////////////////////////////////////////


fn f_tic() void {
    timestamp.get_timestamp(&tscStartBuffer);

    // the user must unblock cmd color for every PE
    sys_mod.unblock_cmd_stream();
}

fn f_toc() void {
    timestamp.get_timestamp(&tscEndBuffer);

    // the user must unblock cmd color for every PE
    sys_mod.unblock_cmd_stream();
}

fn f_memcpy_timestamps() void {

    time_buf_u16[0] = tscStartBuffer[0];
    time_buf_u16[1] = tscStartBuffer[1];
    time_buf_u16[2] = tscStartBuffer[2];

    time_buf_u16[3] = tscEndBuffer[0];
    time_buf_u16[4] = tscEndBuffer[1];
    time_buf_u16[5] = tscEndBuffer[2];

    // the user must unblock cmd color for every PE
    sys_mod.unblock_cmd_stream();
}

// stencil coefficients are organized as
// {c_west, c_east, c_south, c_north, c_bottom, c_top, c_center}
fn f_spmv(n:i16) void {
    stencil_mod.spmv(n, &stencil_coeff, &x, &y);
}

fn f_sync( n: i16 ) void {
   reduce_mod.allreduce(n, &dot, reduce_mod.TYPE_BINARY_OP.ADD);
}

fn f_reference_timestamps() void {

    time_ref_u16[0] = reduce_mod.tscRefBuffer[0];
    time_ref_u16[1] = reduce_mod.tscRefBuffer[1];
    time_ref_u16[2] = reduce_mod.tscRefBuffer[2];

    // the user must unblock cmd color for every PE
    sys_mod.unblock_cmd_stream();
}

task f_startup() void {
    timestamp.enable_tsc();
}


comptime {

    @activate(STARTUP);

    @bind_local_task(f_startup, STARTUP);
}

comptime {
    @export_symbol(ptr_x, "x");
    @export_symbol(ptr_y, "y");
    @export_symbol(ptr_stencil_coeff, "stencil_coeff");
    @export_symbol(ptr_time_buf_u16, "time_buf_u16");
    @export_symbol(ptr_time_ref, "time_ref");
}

comptime{
    @export_symbol(f_tic);
    @export_symbol(f_toc);
    @export_symbol(f_memcpy_timestamps);
    @export_symbol(f_spmv);
    @export_symbol(f_sync);
    @export_symbol(f_reference_timestamps);
}

run.py

#!/usr/bin/env cs_python
# pylint: disable=too-many-function-args

""" test 7-point stencil

    The Laplacian operator L on 3-dimensional domain can be represented by 7-point
  stencil based on the standard 2nd order Finite Difference Method. The operator form
  with Dirichlet boundary conditions can be written by
         L[u](i,j,k) = u(i+1, j,  k  ) + u(i-1, j,   k  ) +
                       u(i,   j+1,k  ) + u(i,   j-1, k  ) +
                       u(i,   j,  k+1) + u(i,   j,   k-1) +
                      -6*u(i, j, k)
  In general the coefficients of those 7 points can vary. To minimize the memory
  consumption, this example assumes the coefficients are independent of index k and
  whole vector u(i,j,:) is placed in one PE (px=j, py=i).
  The above formula can be re-written by
     c_west   * x[i-1][j  ][k  ] + c_east  * x[i+1][j  ][k  ] +
     c_south  * x[i  ][j-1][k  ] + c_north * x[i  ][j+1][k  ] +
     c_bot    * x[i  ][j  ][k-1] + c_top   * x[i  ][j  ][k+1] +
     c_center * x[i][j][k]
  Each PE only holds 7 coefficients organized by c_west, c_east, c_south, c_north,
  c_bot, c_top and c_center.

  This example provides two modules, one is allreduce and the other is stencil_3d_7pts.
  "allreduce" module can synchronize all PEs to form a reference clock.
  "stencil_3d_7pts" module can compute y = A*x where A is the matrix from 7-point stencil

  The framework is
  ---
       sync()      // synchronize all PEs to sample the reference clock
       tic()       // record start time
       spmv(zdim)  // compute y = A*x
       toc()       // record end time
  ---

  The tic() samples "time_start" and toc() samples "time_end". The sync() samples
  "time_ref" which is used to shift "time_start" and "time_end".
  The elapsed time is measured by
       cycles_send = max(time_end) - min(time_start)

  The overall runtime is computed via the following formula
       time_send = (cycles_send / 0.85) *1.e-3 us
  where a PE runs with clock speed 850MHz

  Each PE needs to gather six f32 from six neighbors, the cost of the communication is
        6*h*w*zDim*4 bytes
  where w-by-h is the core rectangle and zDim is the length of local vector.

  Here is the list of parameters:
    -m=<int> is the height of the core
    -n=<int> is the width of the core
    -k=<int> is size of x and y allocated in the core
    --zDim=<int> is the number of f32 per PE, computed by y = A*x
                 zDim must be not greater than k
    --channels=<int> specifies the number of I/O channels, no bigger than 16
"""


import struct
import os
from typing import Optional
from pathlib import Path
import shutil
import subprocess
import random

import numpy as np

from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime, MemcpyDataType, MemcpyOrder # pylint: disable=no-name-in-module

from cmd_parser import parse_args


from util import (
    hwl_2_oned_colmajor,
    oned_to_hwl_colmajor,
    laplacian,
)


def float_to_hex(f):
  return hex(struct.unpack('<I', struct.pack('<f', f))[0])

def make_u48(words):
  return words[0] + (words[1] << 16) + (words[2] << 32)


def csl_compile_core(
    cslc: str,
    width: int,  # width of the core
    height: int, # height of the core
    pe_length: int,
    blockSize: int,
    file_config: str,
    elf_dir: str,
    fabric_width: int,
    fabric_height: int,
    core_fabric_offset_x: int, # fabric-offsets of the core
    core_fabric_offset_y: int,
    use_precompile: bool,
    arch: Optional[str],
    C0: int,
    C1: int,
    C2: int,
    C3: int,
    C4: int,
    C5: int,
    C6: int,
    C7: int,
    C8: int,
    channels: int,
    width_west_buf: int,
    width_east_buf: int
):
  if not use_precompile:
    args = []
    args.append(cslc) # command
    args.append(file_config)
    args.append(f"--fabric-dims={fabric_width},{fabric_height}")
    args.append(f"--fabric-offsets={core_fabric_offset_x},{core_fabric_offset_y}")
    args.append(f"--params=width:{width},height:{height},MAX_ZDIM:{pe_length}")
    args.append(f"--params=BLOCK_SIZE:{blockSize}")
    args.append(f"--params=C0_ID:{C0}")
    args.append(f"--params=C1_ID:{C1}")
    args.append(f"--params=C2_ID:{C2}")
    args.append(f"--params=C3_ID:{C3}")
    args.append(f"--params=C4_ID:{C4}")
    args.append(f"--params=C5_ID:{C5}")
    args.append(f"--params=C6_ID:{C6}")
    args.append(f"--params=C7_ID:{C7}")
    args.append(f"--params=C8_ID:{C8}")

    args.append(f"-o={elf_dir}")
    if arch is not None:
      args.append(f"--arch={arch}")
    args.append("--memcpy")
    args.append(f"--channels={channels}")
    args.append(f"--width-west-buf={width_west_buf}")
    args.append(f"--width-east-buf={width_east_buf}")

    print(f"subprocess.check_call(args = {args}")
    subprocess.check_call(args)
  else:
    print("\tuse pre-compile ELFs")



def main():
  """Main method to run the example code."""

  random.seed(127)

  args, dirname = parse_args()

  cslc = "cslc"
  if args.driver is not None:
    cslc = args.driver

  print(f"cslc = {cslc}")

  width_west_buf = args.width_west_buf
  width_east_buf = args.width_east_buf
  channels = args.channels
  assert channels <= 16, "only support up to 16 I/O channels"
  assert channels >= 1, "number of I/O channels must be at least 1"

  print(f"width_west_buf = {width_west_buf}")
  print(f"width_east_buf = {width_east_buf}")
  print(f"channels = {channels}")

  height = args.m
  width = args.n
  pe_length = args.k
  zDim = args.zDim
  blockSize = args.blockSize

  print(f"width = {width}, height = {height}, pe_length={pe_length}, zDim={zDim}, blockSize={blockSize}")
  assert pe_length >= 2, "the maximum size of z must be greater than 1"
  assert zDim <= pe_length, "[0, zDim) cannot exceed the storage"

  np.random.seed(2)
  # A is h-by-w-by-l
  x = np.arange(height*width*pe_length).reshape(height, width, pe_length).astype(np.float32) + 100

  x_1d = hwl_2_oned_colmajor(height, width, pe_length, x, np.float32)

  # stencil coefficients has the following order
  # {c_west, c_east, c_south, c_north, c_bottom, c_top, c_center}
  stencil_coeff = np.zeros((height, width, 7), dtype = np.float32)
  for i in range(height):
    for j in range(width):
      stencil_coeff[(i, j, 0)] = -1 # west
      stencil_coeff[(i, j, 1)] = -2 # east
      stencil_coeff[(i, j, 2)] = -3 # south
      stencil_coeff[(i, j, 3)] = -4 # north
      stencil_coeff[(i, j, 4)] = -5 # bottom
      stencil_coeff[(i, j, 5)] = -6 # top
      stencil_coeff[(i, j, 6)] = 6  # center

  stencil_coeff_1d = hwl_2_oned_colmajor(height, width, 7, stencil_coeff, np.float32)

  y_ref = np.zeros((height, width, pe_length), dtype=np.float32)

  laplacian(stencil_coeff, zDim, x, y_ref)

  # fabric-offsets = 1,1
  fabric_offset_x = 1
  fabric_offset_y = 1
  # starting point of the core rectangle = (core_fabric_offset_x, core_fabric_offset_y)
  # memcpy framework requires 3 columns at the west of the core rectangle
  # memcpy framework requires 2 columns at the east of the core rectangle
  core_fabric_offset_x = fabric_offset_x + 3 + width_west_buf
  core_fabric_offset_y = fabric_offset_y
  # (min_fabric_width, min_fabric_height) is the minimal dimension to run the app
  min_fabric_width = (core_fabric_offset_x + width + 2 + 1 + width_east_buf)
  min_fabric_height = (core_fabric_offset_y + height + 1)

  fabric_width = 0
  fabric_height = 0
  if args.fabric_dims:
    w_str, h_str = args.fabric_dims.split(",")
    fabric_width = int(w_str)
    fabric_height = int(h_str)

  if fabric_width == 0 or fabric_height == 0:
    fabric_width = min_fabric_width
    fabric_height = min_fabric_height

  assert fabric_width >= min_fabric_width
  assert fabric_height >= min_fabric_height

  # prepare the simulation
  print('store ELFs and log files in the folder ', dirname)

  # text file containing the simulator logs
  sim_log = os.path.join(dirname, "sim.log")

  # layout of a rectangle
  code_csl = "layout.csl"

  C0 = 0
  C1 = 1
  C2 = 2
  C3 = 3
  C4 = 4
  C5 = 5
  C6 = 6
  C7 = 7
  C8 = 8

  csl_compile_core(
      cslc,
      width,
      height,
      pe_length,
      blockSize,
      code_csl,
      dirname,
      fabric_width,
      fabric_height,
      core_fabric_offset_x,
      core_fabric_offset_y,
      args.run_only,
      args.arch,
      C0,
      C1,
      C2,
      C3,
      C4,
      C5,
      C6,
      C7,
      C8,
      channels,
      width_west_buf,
      width_east_buf
  )
  if args.compile_only:
    print("COMPILE ONLY: EXIT")
    return

  memcpy_dtype = MemcpyDataType.MEMCPY_32BIT
  simulator = SdkRuntime(dirname, cmaddr=args.cmaddr)

  symbol_x = simulator.get_id("x")
  symbol_y = simulator.get_id("y")
  symbol_stencil_coeff = simulator.get_id("stencil_coeff")
  symbol_time_buf_u16 = simulator.get_id("time_buf_u16")
  symbol_time_ref = simulator.get_id("time_ref")

  simulator.load()
  simulator.run()

  print(f"copy vector x of type f32")
  # the size of x per PE is pe_length
  simulator.memcpy_h2d(symbol_x, x_1d, 0, 0, width, height, pe_length,\
          streaming=False, data_type=memcpy_dtype, order=MemcpyOrder.COL_MAJOR, nonblock=True)

  print(f"copy coefficients of type f32")
  # each PE holds 7 coefficients
  simulator.memcpy_h2d(symbol_stencil_coeff, stencil_coeff_1d, 0, 0, width, height, 7,\
          streaming=False, data_type=memcpy_dtype, order=MemcpyOrder.COL_MAJOR, nonblock=True)

  print("step 1: sync all PEs")
  simulator.launch("f_sync", np.int16(1), nonblock=False)

  print("step 2: tic() records time_start")
  simulator.launch("f_tic", nonblock=True)

  print(f"step 3: compute y = A*x with zDim = {zDim}")
  # positive zDim can be smaller than pe_length
  simulator.launch("f_spmv", np.int16(zDim), nonblock=False)

  print("step 4: toc() records time_end")
  simulator.launch("f_toc", nonblock=False)

  print("step 5: prepare (time_start, time_end)")
  simulator.launch("f_memcpy_timestamps", nonblock=False)

  print("step 6: D2H (time_start, time_end)")
  time_memcpy_hwl_1d = np.zeros(height*width*6, np.uint32)
  simulator.memcpy_d2h(time_memcpy_hwl_1d, symbol_time_buf_u16, 0, 0, width, height, 6,\
    streaming=False, data_type=MemcpyDataType.MEMCPY_16BIT, order=MemcpyOrder.COL_MAJOR, nonblock=False)
  time_memcpy_hwl = oned_to_hwl_colmajor(height, width, 6, time_memcpy_hwl_1d, np.uint16)

  print("step 7: D2H y of type f32")
  y_1d = np.zeros(height*width*pe_length, np.float32)
  simulator.memcpy_d2h(y_1d, symbol_y, 0, 0, width, height, pe_length,\
    streaming=False, data_type=memcpy_dtype, order=MemcpyOrder.COL_MAJOR, nonblock=False)
  y_wse = np.reshape(y_1d, (height, width, pe_length), order='F')

  print("step 8: prepare reference clock")
  simulator.launch("f_reference_timestamps", nonblock=False)

  print("step 9: D2H reference clock")
  time_ref_1d = np.zeros(height*width*3, np.uint32)
  simulator.memcpy_d2h(time_ref_1d, symbol_time_ref, 0, 0, width, height, 3,\
    streaming=False, data_type=MemcpyDataType.MEMCPY_16BIT, order=MemcpyOrder.COL_MAJOR, nonblock=False)
  time_ref_hwl = oned_to_hwl_colmajor(height, width, 3, time_ref_1d, np.uint16)

  simulator.stop()

  if args.cmaddr is None:
    # move simulation log and core dump to the given folder
    dst_log = Path(f"{dirname}/sim.log")
    src_log = Path("sim.log")
    if src_log.exists():
      shutil.move(src_log, dst_log)

    dst_trace = Path(f"{dirname}/simfab_traces")
    src_trace = Path("simfab_traces")
    if dst_trace.exists():
      shutil.rmtree(dst_trace)
    if src_trace.exists():
      shutil.move(src_trace, dst_trace)

  # time_start = start time of spmv
  time_start = np.zeros((height, width)).astype(int)
  # time_end = end time of spmv
  time_end = np.zeros((height, width)).astype(int)
  word = np.zeros(3).astype(np.uint16)
  for w in range(width):
    for h in range(height):
      word[0] = time_memcpy_hwl[(h, w, 0)]
      word[1] = time_memcpy_hwl[(h, w, 1)]
      word[2] = time_memcpy_hwl[(h, w, 2)]
      time_start[(h,w)] = make_u48(word)
      word[0] = time_memcpy_hwl[(h, w, 3)]
      word[1] = time_memcpy_hwl[(h, w, 4)]
      word[2] = time_memcpy_hwl[(h, w, 5)]
      time_end[(h,w)] = make_u48(word)

  # time_ref = reference clock
  time_ref = np.zeros((height, width)).astype(int)
  word = np.zeros(3).astype(np.uint16)
  for w in range(width):
    for h in range(height):
      word[0] = time_ref_hwl[(h, w, 0)]
      word[1] = time_ref_hwl[(h, w, 1)]
      word[2] = time_ref_hwl[(h, w, 2)]
      time_ref[(h, w)] = make_u48(word)

  # adjust the reference clock by the propagation delay
  # the right-bottom PE signals other PEs, the propagation delay is
  #     (h-1) - py + (w-1) - px
  for py in range(height):
    for px in range(width):
      time_ref[(py, px)] = time_ref[(py, px)] - ((width+height-2)-(px + py))

  # shift time_start and time_end by time_ref
  time_start = time_start - time_ref
  time_end = time_end - time_ref

  # cycles_send = time_end[(h,w)] - time_start[(h,w)]
  # 850MHz --> 1 cycle = (1/0.85) ns = (1/0.85)*1.e-3 us
  # time_send = (cycles_send / 0.85) *1.e-3 us
  #
  # each PE needs to gather six f32 from six neighbors, the cost of the communication is
  #      6*h*w*zDim*4 bytes
  #
  # bandwidth = (((wvlts-1) * 4)/time_send) MBS
  wvlts = 6*height*width*zDim
  min_time_start = time_start.min()
  max_time_end = time_end.max()
  cycles_send = max_time_end - min_time_start
  time_send = (cycles_send / 0.85) *1.e-3
  bandwidth = ((wvlts * 4)/time_send)
  print(f"cycles_send = {cycles_send} cycles")
  print(f"time_send = {time_send} us")
  print(f"bandwidth = {bandwidth} MB/S ")

  z = y_ref.ravel() - y_wse.ravel()
  nrm_z = np.linalg.norm(z, np.inf)
  print(f"|y_ref - y_wes| = {nrm_z}")
  np.testing.assert_allclose(y_ref.ravel(), y_wse.ravel(), 1.e-5)
  print("\nSUCCESS!")

if __name__ == "__main__":
  main()

cmd_parser.py

# This is not a real test, but a module that gets imported in other tests.

"""command parser for bandwidthTest

   -m <int>     number of rows of the core rectangle
   -n <int>     number of columns of the core rectangle
   -k <int>     number of elements of local tensor
   --zDim <int>   number of elements to compute y=A*x
   --blockSize <int>  the size of temporary buffers for communication
   --latestlink   working directory
   --driver     path to CSL compiler
   --fabric-dims  fabric dimension of a WSE
   --cmaddr       IP address of a WSE
   --channels        number of I/O channels, 1 <= channels <= 16
   --width-west-buf  number of columns of the buffer in the west of the core rectangle
   --width-east-buf  number of columns of the buffer in the east of the core rectangle
   --compile-only    compile ELFs
   --run-only        run the test with precompiled binary
"""


import os
import argparse


def parse_args():
  parser = argparse.ArgumentParser()
  parser.add_argument(
      "-m",
      default=1, type=int,
      help="number of rows")
  parser.add_argument(
      "-n",
      default=1, type=int,
      help="number of columns")
  parser.add_argument(
      "-k",
      default=1, type=int,
      help="size of local tensor, no less than 2")
  parser.add_argument(
      "--zDim",
      default=2, type=int,
      help="[0 zDim-1) is the domain of Laplacian")
  parser.add_argument(
      "--latestlink",
      help="folder to contain the log files (default: latest)")
  parser.add_argument(
      "-d",
      "--driver",
      help="The path to the CSL compiler")
  parser.add_argument(
      "--compile-only",
      help="Compile only", action="store_true")
  parser.add_argument(
      "--fabric-dims",
      help="Fabric dimension, i.e. <W>,<H>")
  parser.add_argument(
      "--cmaddr",
      help="CM address and port, i.e. <IP>:<port>")
  parser.add_argument(
      "--run-only",
      help="Run only", action="store_true")
  # arch = wse1 or wse2
  parser.add_argument(
      "--arch",
      help="wse1 or wse2. Default is wse1 when not supplied.")
  parser.add_argument(
      "--width-west-buf",
      default=0, type=int,
      help="width of west buffer")
  parser.add_argument(
      "--width-east-buf",
      default=0, type=int,
      help="width of east buffer")
  parser.add_argument(
      "--channels",
      default=1, type=int,
      help="number of I/O channels, between 1 and 16")
  parser.add_argument(
      "--blockSize",
      default=2, type=int,
      help="the size of temporary buffers for communication")

  args = parser.parse_args()

  logs_dir = "latest"
  if args.latestlink:
    logs_dir = args.latestlink

  dir_exist = os.path.isdir(logs_dir)
  if dir_exist:
    print(f"{logs_dir} already exists")
  else:
    print(f"create {logs_dir} to store log files")
    os.mkdir(logs_dir)

  return args, logs_dir

csl-libs/stencil_3d_7pts/layout.csl

param colors:[8]color;
param entrypoints:[3]local_task_id;
param width : i16 ;   // width of the core
param height: i16 ;   // height of the core

const C0 : color = colors[0];
const C1 : color = colors[1];
const C2 : color = colors[2];
const C3 : color = colors[3];
const C4 : color = colors[4];
const C5 : color = colors[5];
const C6 : color = colors[6];
const C7 : color = colors[7];

// entrypoints of sync module
const SEND: local_task_id = entrypoints[0];
const RECV: local_task_id = entrypoints[1];
const COMM: local_task_id = entrypoints[2];

fn get_params(px:i16, py:i16) comptime_struct {

    var first_py: bool = (0 == py);
    var last_py: bool = ((height-1) == py);
    var is_py_even: bool = (0 == (py % 2));

    var first_px: bool = (0 == px);
    var last_px: bool = ((width-1) == px);
    var is_px_even: bool = (0 == (px % 2));

    // C0, C1:recv_west, send_east
    //         C0     C1     C0     C1     C0
    // West P0 --> P1 --> P2 --> P3 --> P4 --> P5 East
    //
    var c_recv_west: color = C1;
    var c_send_east: color = C0;
    if (is_px_even){
        c_recv_west = C1;
        c_send_east = C0;
    }else{
        c_recv_west = C0;
        c_send_east = C1;
    }

    // C2, C3: recv_east, send_west
    //          C2     C3     C2     C3     C2
    // West P0 <-- P1 <-- P2 <-- P3 <-- P4 <-- P5 East
    //
    var c_recv_east: color = C2;
    var c_send_west: color = C3;
    if (is_px_even){
        c_recv_east = C2;
        c_send_west = C3;
    }else{
        c_recv_east = C3;
        c_send_west = C2;
    }

    // C4, C5: recv_south, send_north
    //           C4     C5     C4     C5     C4
    // North P0 <-- P1 <-- P2 <-- P3 <-- P4 <-- P5 south
    //
    var c_recv_south: color = C4;
    var c_send_north: color = C5;
    if (is_py_even){
        c_recv_south = C4;
        c_send_north = C5;
    }else{
        c_recv_south = C5;
        c_send_north = C4;
    }

    // C6, C7: recv_north, send_south
    //           C6     C7     C6     C7     C6
    // North P0 --> P1 --> P2 --> P3 --> P4 --> P5 south
    //
    var c_recv_north: color = C7;
    var c_send_south: color = C6;
    if (is_py_even){
        c_recv_north = C7;
        c_send_south = C6;
    }else{
        c_recv_north = C6;
        c_send_south = C7;
    }

    return .{
        .c_recv_west = c_recv_west,
        .c_send_east = c_send_east,
        .c_recv_east = c_recv_east,
        .c_send_west = c_send_west,
        .c_recv_south = c_recv_south,
        .c_send_north = c_send_north,
        .c_recv_north = c_recv_north,
        .c_send_south = c_send_south,

        .SEND = SEND,
        .RECV = RECV,
        .COMM = COMM,

        .first_px = first_px,
        .last_px = last_px,
        .first_py = first_py,
        .last_py = last_py,
    };
}

csl-libs/stencil_3d_7pts/pe.csl

param c_recv_west: color;
param c_send_east: color;
param c_recv_east: color;
param c_send_west: color;

param c_recv_south: color;
param c_send_north: color;
param c_recv_north: color;
param c_send_south: color;

param COMM: local_task_id; // entrypoint f_comm
param SEND: local_task_id; // entrypoint f_send
param RECV: local_task_id; // entrypoint f_recv

param first_px: bool;
param last_px: bool;
param first_py: bool;
param last_py: bool;

// To continue next command, f_callback = sys_mod.unblock_cmd_stream
param f_callback : fn ()void;

param input_queues:[4]u16;
param output_queues:[1]u16;

param BLOCK_SIZE: i16; // size of temporary buffers for communication

// explicit DSR allocation
param dest_dsr_ids:[2]u16;
param src0_dsr_ids:[1]u16;
param src1_dsr_ids:[2]u16;

// The call-graph of the stencil kernel is
//
//  COMM ----> SEND ----> Laplacian
//       |              |
//       +---> RECV ----+
//
// We need two sets of DSRs, one for SEND and one for RECV.
// Once SEND and RECV are done, Laplacian takes several serial FMACs.
// We can reuse either set of DSR for Laplacian.
//
// For example:
//   dest_dsr_send = @get_dsr(dsr_dest, 1);
//   src1_dsr_send = @get_dsr(dsr_src1, 1);
//   dest_dsr_recv = @get_dsr(dsr_dest, 2);
//   src1_dsr_recv = @get_dsr(dsr_src1, 2);
//

////////////////////////////////////////////////////////////////////////////////
// Main memory (48KB)
////////////////////////////////////////////////////////////////////////////////

// The formula of Laplacian is
//     c_west   * x[i-1][j  ][k  ] + c_east  * x[i+1][j  ][k  ] +
//     c_south  * x[i  ][j-1][k  ] + c_north * x[i  ][j+1][k  ] +
//     c_bottom * x[i  ][j  ][k-1] + c_top  * x[i  ][j  ][k+1] +
//     c_center * x[i][j][k]
// The following stencil coefficents are passed by spmv()
var c_west: f32;
var c_east: f32;
var c_south: f32;
var c_north: f32;
var c_bottom: f32;
var c_top: f32;
var c_center: f32;

// The following buffers hold data from four neighbors
var west_buf = @zeros([BLOCK_SIZE]f32); // from west
var east_buf = @zeros([BLOCK_SIZE]f32); // from east
var south_buf = @zeros([BLOCK_SIZE]f32);// from south
var north_buf = @zeros([BLOCK_SIZE]f32);// from north

var count_send_recv: i16 = 0;

const SEND_STATE_EAST: i16 = 0;
const SEND_STATE_WEST: i16 = 1;
const SEND_STATE_NORTH: i16 = 2;
const SEND_STATE_SOUTH: i16 = 3;
const SEND_STATE_DONE: i16 = 4;

const RECV_STATE_WEST: i16 = 0;
const RECV_STATE_EAST: i16 = 1;
const RECV_STATE_SOUTH: i16 = 2;
const RECV_STATE_NORTH: i16 = 3;
const RECV_STATE_DONE: i16 = 4;

var recv_state: i16 = RECV_STATE_WEST;
var send_state: i16 = SEND_STATE_EAST;

var zDim: i16 = 0;
var cur_length: i16 = BLOCK_SIZE;
var start_x: i16 = 0;
var rem_length: i16 = 0;

const dest_dsr_send = @get_dsr(dsr_dest, dest_dsr_ids[0]);
const src1_dsr_send = @get_dsr(dsr_src1, src1_dsr_ids[0]);
const dest_dsr_recv = @get_dsr(dsr_dest, dest_dsr_ids[1]);
const src1_dsr_recv = @get_dsr(dsr_src1, src1_dsr_ids[1]);
const src0_dsr = @get_dsr(dsr_src0, src0_dsr_ids[0]);

const dummy_f32 = @zeros([1]f32);

// communication with neighbors
// mem_center_buf_dsd: send to W, E, S, N
// mem_west_buf_dsd: recv from W
// mem_east_buf_dsd: recv from E
// mem_south_buf_dsd: recv from S
// mem_north_buf_dsd: recv from N
// The portal function spmv() resets these DSDs with proper length, either zDim or BLOCK_SIZE.
// If last iteration has smaller size than BLOCK_SIZE, reset length again.
// spmv() binds x to mem_center_buf_dsd and advances it by BLOCK_SIZE when SEND and RECV are done.
var mem_center_buf_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{BLOCK_SIZE} -> dummy_f32[i] });
var mem_west_buf_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{BLOCK_SIZE} -> west_buf[i] });
var mem_east_buf_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{BLOCK_SIZE} -> east_buf[i] });
var mem_south_buf_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{BLOCK_SIZE} -> south_buf[i] });
var mem_north_buf_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{BLOCK_SIZE} -> north_buf[i] });

// mem_y_buf_dsd holds the partial sum of laplacian on x-y plane.
// spmv() binds y to mem_y_buf_dsd and advances it by BLOCK_SIZE when SEND and RECV are done.
var mem_y_buf_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{BLOCK_SIZE} -> dummy_f32[i] });

// mem_y_z_*** and mem_x_z_*** are used in laplacian on z direction.
// spmv() binds x to mem_x_z_*** and binds y to mem_y_z_***, also increases the offsets and resets
// the length according to the following formula.
// boundary condition x[-1] = x[zDim] = 0
// y[k] += x[k-1] * c_bottom for k = 1,2,...,zDim-1
// y[k] += x[k+1] * c_top for k = 0,1,2,...,zDim-2
//
// The following DSDs resets the length at runtime
// |mem_y_z_minus_buf_dsd| = zDim-1
// |mem_y_z_plus_buf_dsd| = zDim-1
// |mem_y_z_buf_dsd| = zDim
//
// |mem_x_z_minus_buf_dsd| = zDim-1
// |mem_x_z_plus_buf_dsd| = zDim-1
// |mem_x_z_buf_dsd| = zDim
var mem_y_z_minus_buf_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{1} -> dummy_f32[i+1] });
var mem_y_z_plus_buf_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{1} -> dummy_f32[i] });
var mem_y_z_buf_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{1} -> dummy_f32[i] });

var mem_x_z_minus_buf_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{1} -> dummy_f32[i] });
var mem_x_z_plus_buf_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{1} -> dummy_f32[i+1] });
var mem_x_z_buf_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{1} -> dummy_f32[i] });

var fab_recv_west_wdsd =  @get_dsd(fabin_dsd, .{
   .extent = BLOCK_SIZE,
   .fabric_color = c_recv_west,
   .input_queue = @get_input_queue(input_queues[0])
});

var fab_recv_east_wdsd =  @get_dsd(fabin_dsd, .{
   .extent = BLOCK_SIZE,
   .fabric_color = c_recv_east,
   .input_queue = @get_input_queue(input_queues[1])
});

var fab_recv_south_wdsd =  @get_dsd(fabin_dsd, .{
   .extent = BLOCK_SIZE,
   .fabric_color = c_recv_south,
   .input_queue = @get_input_queue(input_queues[2])
});

var fab_recv_north_wdsd =  @get_dsd(fabin_dsd, .{
   .extent = BLOCK_SIZE,
   .fabric_color = c_recv_north,
   .input_queue = @get_input_queue(input_queues[3])
});

var fab_trans_east_wdsd = @get_dsd(fabout_dsd, .{
    .extent = BLOCK_SIZE,
    .fabric_color = c_send_east,
    .output_queue = @get_output_queue(output_queues[0])
});

var fab_trans_west_wdsd = @get_dsd(fabout_dsd, .{
    .extent = BLOCK_SIZE,
    .fabric_color = c_send_west,
    .output_queue = @get_output_queue(output_queues[0])
});

var fab_trans_north_wdsd = @get_dsd(fabout_dsd, .{
    .extent = BLOCK_SIZE,
    .fabric_color = c_send_north,
    .output_queue = @get_output_queue(output_queues[0])
});

var fab_trans_south_wdsd = @get_dsd(fabout_dsd, .{
    .extent = BLOCK_SIZE,
    .fabric_color = c_send_south,
    .output_queue = @get_output_queue(output_queues[0])
});


// The portal function of 7-point stencil module
//   y = A*x
//
// How to use:
//  stencil_mod = = @import_module( "<stencil_3d_7_pts/pe>")
//  stencil_mod.spmv(coeff, n=zDim, x, y); // compute y = A*x
//  The user has to prepare the coefficients, the input vector x and
//  the output vector y.
//  spmv() only accepts pointers for coeff, x, and y.
//
//  The callback is triggered when spmv() finishes.
//
//  The user can adjust coefficents around the boundary to handle
//  Neumann condition.
//  For example, (-1, 2, 1) becomes (2, -2) at west boundary.
//
//  Assumption: n >= 2
//  if n = 1, we cannot set DSD length with zDim-1 = 0.
//  so we skip z-direction, just update center in laplacian_z
//  (n = 1 is a 2D problem)
//
fn spmv(n: i16, coeff: *[7]f32, x: [*]f32, y: [*]f32) void {

    @assert(2 <= n);

    zDim = n;

    c_west = (coeff.*)[0];
    c_east = (coeff.*)[1];
    c_south = (coeff.*)[2];
    c_north = (coeff.*)[3];
    c_bottom = (coeff.*)[4];
    c_top = (coeff.*)[5];
    c_center = (coeff.*)[6];

    // case 1: zDim <= BLOCK_SIZE
    //   cur_length = zDim
    //   only one iteration
    // case 2: zDim > BLOCK_SIZE
    //   cur_length = BLOCK_SIZE
    //   There are at least two iterations
    //   1st and last iterations reset the DSD's length
    //
    // Although cur_length = BLOCK_SIZE at comptime, if the user calls spmv()
    // twice and the size of 1st spmv() is not multiple of BLOCK_SIZE, then
    // 2nd spmv() has cur_length < BLOCK_SIZE when it begins.
    // So we need to reset all DSDs with cur_length = min(zDim,BLOCK_SIZE)
    cur_length = BLOCK_SIZE;
    if (zDim < cur_length){
        cur_length = zDim;
    }
    // bind x and y to mem_center_buf_dsd and mem_y_buf_dsd respectively
    // the length of both DSDs will be reset by update_dsd_length()
    mem_center_buf_dsd = @set_dsd_base_addr(mem_center_buf_dsd, x);
    mem_y_buf_dsd = @set_dsd_base_addr(mem_y_buf_dsd, y);

    // reset the length of all DSDs except laplacian of z
    update_dsd_length(cur_length);

//--- mem_y_z_*** and mem_x_z_*** are only used in laplacian_z
//  x, y and zDim are runtme variables, so must reset the DSDs.
    // mem_y_z_minus_buf_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{zDim-1} -> y[i+1] });
    mem_y_z_minus_buf_dsd = @set_dsd_base_addr(mem_y_z_minus_buf_dsd, y);
    mem_y_z_minus_buf_dsd = @increment_dsd_offset(mem_y_z_minus_buf_dsd, 1, f32);
    mem_y_z_minus_buf_dsd = @set_dsd_length(mem_y_z_minus_buf_dsd, @bitcast(u16,zDim-1));

    // mem_y_z_plus_buf_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{zDim-1} -> y[i] });
    mem_y_z_plus_buf_dsd = @set_dsd_base_addr(mem_y_z_plus_buf_dsd, y);
    mem_y_z_plus_buf_dsd = @set_dsd_length(mem_y_z_plus_buf_dsd, @bitcast(u16,zDim-1));

    // mem_y_z_buf_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{zDim} -> y[i] });
    mem_y_z_buf_dsd = @set_dsd_base_addr(mem_y_z_buf_dsd, y);
    mem_y_z_buf_dsd = @set_dsd_length(mem_y_z_buf_dsd, @bitcast(u16,zDim));

    // mem_x_z_minus_buf_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{zDim-1} -> x[i] });
    mem_x_z_minus_buf_dsd = @set_dsd_base_addr(mem_x_z_minus_buf_dsd, x);
    mem_x_z_minus_buf_dsd = @set_dsd_length(mem_x_z_minus_buf_dsd, @bitcast(u16,zDim-1));

    // mem_x_z_plus_buf_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{zDim-1} -> x[i+1] });
    mem_x_z_plus_buf_dsd = @set_dsd_base_addr(mem_x_z_plus_buf_dsd, x);
    mem_x_z_plus_buf_dsd = @increment_dsd_offset(mem_x_z_plus_buf_dsd, 1, f32);
    mem_x_z_plus_buf_dsd = @set_dsd_length(mem_x_z_plus_buf_dsd, @bitcast(u16,zDim-1));

    // mem_x_z_buf_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{zDim} -> x[i] });
    mem_x_z_buf_dsd = @set_dsd_base_addr(mem_x_z_buf_dsd, x);
    mem_x_z_buf_dsd = @set_dsd_length(mem_x_z_buf_dsd, @bitcast(u16,zDim));
//---

    // reset the starting position of z-direction
    start_x = 0;

    // reset y[k] = 0 for k = 0,1,2,..., zDim-1
    // @fmovs(mem_y_z_buf_dsd, zero);
    @load_to_dsr(dest_dsr_send, mem_y_z_buf_dsd);
    @fmovs(dest_dsr_send, @as(f32,0));

    // start first block of spmv
    // COMM is called multiple times to finish the spmv
    @activate(COMM);
}

// Laplacian on x-y plane with neighbors received from COMM
// The formula is
//    y[i][j][k] += c_west * x[i-1][j][k] + c_east * x[i+1][j][k] +
//                  c_south* x[i][j-1][k] + c_north* x[i][j+1][k]
//
// fmacs: Fp32 multiply add
//  @fmacs(dest_dsd, src_dsd1, src_dsd2, f32_value)
//  dest = src0 + src1 * scalar
//
// TODO: to reduce latency, combine laplacian_xy into RECV
fn laplacian_xy() void {
    //@fmacs(mem_y_buf_dsd, mem_y_buf_dsd, mem_west_buf_dsd, c_west);
    @load_to_dsr(dest_dsr_send, mem_y_buf_dsd);
    @load_to_dsr(src0_dsr, mem_y_buf_dsd);
    @load_to_dsr(src1_dsr_send, mem_west_buf_dsd);
    @fmacs(dest_dsr_send, src0_dsr, src1_dsr_send, c_west);

    //@fmacs(mem_y_buf_dsd, mem_y_buf_dsd, mem_east_buf_dsd, c_east);
    @load_to_dsr(dest_dsr_send, mem_y_buf_dsd);
    @load_to_dsr(src0_dsr, mem_y_buf_dsd);
    @load_to_dsr(src1_dsr_send, mem_east_buf_dsd);
    @fmacs(dest_dsr_send, src0_dsr, src1_dsr_send, c_east);

    //@fmacs(mem_y_buf_dsd, mem_y_buf_dsd, mem_south_buf_dsd, c_south);
    @load_to_dsr(dest_dsr_send, mem_y_buf_dsd);
    @load_to_dsr(src0_dsr, mem_y_buf_dsd);
    @load_to_dsr(src1_dsr_send, mem_south_buf_dsd);
    @fmacs(dest_dsr_send, src0_dsr, src1_dsr_send, c_south);

    //@fmacs(mem_y_buf_dsd, mem_y_buf_dsd, mem_north_buf_dsd, c_north);
    @load_to_dsr(dest_dsr_send, mem_y_buf_dsd);
    @load_to_dsr(src0_dsr, mem_y_buf_dsd);
    @load_to_dsr(src1_dsr_send, mem_north_buf_dsd);
    @fmacs(dest_dsr_send, src0_dsr, src1_dsr_send, c_north);
}


// Laplacian on z direction with whole local vector x
// The formula is
//   y[i][j][k] += c_bottom * x[i][j][k-1] + c_top * x[i][j][k+1] +
//                 c_center * x[i][j][k]
//
// The reason to separate z-direction from x-y plane:
// - need more logics to handle the boundary condition
//
fn laplacian_z() void {
    // y[k] += x[k-1] * c_bottom for k = 1,2,...,zDim-1
    //@fmacs(mem_y_z_minus_buf_dsd, mem_y_z_minus_buf_dsd, mem_x_z_minus_buf_dsd, c_bottom);
    @load_to_dsr(dest_dsr_send, mem_y_z_minus_buf_dsd);
    @load_to_dsr(src0_dsr, mem_y_z_minus_buf_dsd);
    @load_to_dsr(src1_dsr_send, mem_x_z_minus_buf_dsd);
    @fmacs(dest_dsr_send, src0_dsr, src1_dsr_send, c_bottom);

    // y[k] += x[k+1] * c_top for k = 0,1,2,...,zDim-2
    //@fmacs(mem_y_z_plus_buf_dsd, mem_y_z_plus_buf_dsd, mem_x_z_plus_buf_dsd, c_top);
    @load_to_dsr(dest_dsr_send, mem_y_z_plus_buf_dsd);
    @load_to_dsr(src0_dsr, mem_y_z_plus_buf_dsd);
    @load_to_dsr(src1_dsr_send, mem_x_z_plus_buf_dsd);
    @fmacs(dest_dsr_send, src0_dsr, src1_dsr_send, c_top);

    // y[k] += x[k] * c_center for k = 0,1,2,..., zDim-1
    //@fmacs(mem_y_z_buf_dsd, mem_y_z_buf_dsd, mem_x_z_buf_dsd, c_center);
    @load_to_dsr(dest_dsr_send, mem_y_z_buf_dsd);
    @load_to_dsr(src0_dsr, mem_y_z_buf_dsd);
    @load_to_dsr(src1_dsr_send, mem_x_z_buf_dsd);
    @fmacs(dest_dsr_send, src0_dsr, src1_dsr_send, c_center);
}

// If both SEND and RECV are done, perform
// - compute laplacian on x-y
// - update the DSD of pointer x and y
// - activate COMM
fn compute_and_next_send_recv() void {
    if (2 <= count_send_recv){
        // RECV/SEND of this block are done
        laplacian_xy();
        // advance start_x, mem_center_buf_dsd and mem_y_buf_dsd for next block
        // only one of SEND/RECV can do so
        start_x += BLOCK_SIZE;
        mem_center_buf_dsd = @increment_dsd_offset(mem_center_buf_dsd, BLOCK_SIZE, f32);
        mem_y_buf_dsd = @increment_dsd_offset(mem_y_buf_dsd, BLOCK_SIZE, f32);
        // continue the next block
        @activate(COMM);
    }
}

// Send data to four neighbors sequentially
// The corresponding recv sequence is in f_recv()
//
// SEND_STATE_EAST -> send to east
// SEND_STATE_WEST -> send to west
// SEND_STATE_NORTH --> send to north
// SEND_STATE_SOUTH --> send to south
task f_send() void {
    if (SEND_STATE_EAST == send_state){ // send to east
        // The last PE of x-dir has no east neighbor, so it does not send data to the east
        if (last_px){
            @activate(SEND); // nothing to send
        }else{
            //@mov32(fab_trans_east_wdsd, mem_center_buf_dsd, .{.async=true, .activate = f_send });
            @load_to_dsr(dest_dsr_send, fab_trans_east_wdsd, .{.async=true, .activate = f_send });
            @load_to_dsr(src1_dsr_send, mem_center_buf_dsd);
            @mov32(dest_dsr_send, src1_dsr_send, .{.async=true} );
        }
        send_state = SEND_STATE_WEST;// goto next state
    }else if (SEND_STATE_WEST == send_state){ // send to west
        // The first PE of x-dir has no west neighbor, so it does not send data to the west
        if (first_px){
            @activate(SEND); // nothing to send
        }else{
            //@mov32(fab_trans_west_wdsd, mem_center_buf_dsd, .{.async=true, .activate = f_send });
            @load_to_dsr(dest_dsr_send, fab_trans_west_wdsd, .{.async=true, .activate = f_send });
            @load_to_dsr(src1_dsr_send, mem_center_buf_dsd);
            @mov32(dest_dsr_send, src1_dsr_send, .{.async=true} );
        }
        send_state = SEND_STATE_NORTH;
    }else if (SEND_STATE_NORTH == send_state){ // send to north
        // The first PE of y-dir has no north neighbor, so it does not send data to the north
        if (first_py){
            @activate(SEND); // nothing to send
        }else{
            //@mov32(fab_trans_north_wdsd, mem_center_buf_dsd, .{.async=true, .activate = f_send });
            @load_to_dsr(dest_dsr_send, fab_trans_north_wdsd, .{.async=true, .activate = f_send });
            @load_to_dsr(src1_dsr_send, mem_center_buf_dsd);
            @mov32(dest_dsr_send, src1_dsr_send, .{.async=true} );
        }
        send_state = SEND_STATE_SOUTH;
    }else if (SEND_STATE_SOUTH == send_state) { // send to south
        // The last PE of y-dir has no south neighbor, so it does not send data to the south
        if (last_py){
            @activate(SEND); // nothing to send
        }else{
            //@mov32(fab_trans_south_wdsd, mem_center_buf_dsd, .{.async=true, .activate = f_send });
            @load_to_dsr(dest_dsr_send, fab_trans_south_wdsd, .{.async=true, .activate = f_send });
            @load_to_dsr(src1_dsr_send, mem_center_buf_dsd);
            @mov32(dest_dsr_send, src1_dsr_send, .{.async=true} );
        }
        send_state = SEND_STATE_DONE;
    }else{
        count_send_recv += 1;
        // if both SEND and RECV are done, perform
        // - compute laplacian on x-y
        // - update the DSD of pointer x and y
        // - activate COMM
        compute_and_next_send_recv();
        // reset send_state for next block SEND
        send_state = SEND_STATE_EAST;
    }
}

// Receive data from four neighbors sequentially
// The corresponding send sequence is in f_send()
//
// RECV_STATE_WEST -> receive from west
// RECV_STATE_EAST -> receive from east
// RECV_STATE_SOUTH -> receive from south
// RECV_STATE_NORTH -> receive from north
//
task f_recv() void {
    if (RECV_STATE_WEST == recv_state){ // receive from west
        // The first PE of x-dir has no west neighbor so it does not receive data from the west
        if (first_px){
            @activate(RECV); // nothing to receive
        }else{
            //@mov32(mem_west_buf_dsd, fab_recv_west_wdsd, .{.async=true, .activate = f_recv });
            @load_to_dsr(dest_dsr_recv, mem_west_buf_dsd);
            @load_to_dsr(src1_dsr_recv, fab_recv_west_wdsd, .{.async=true, .activate = f_recv });
            @mov32(dest_dsr_recv, src1_dsr_recv, .{.async=true} );
        }
        recv_state = RECV_STATE_EAST; // goto next state
    }else if (RECV_STATE_EAST == recv_state){ // receive from east
        // The last PE of x-dir has no east neighbor, so it does not recv data from the east
        if (last_px){
            @activate(RECV); // nothing to receive
        }else{
            //@mov32(mem_east_buf_dsd, fab_recv_east_wdsd, .{.async=true, .activate = f_recv });
            @load_to_dsr(dest_dsr_recv, mem_east_buf_dsd);
            @load_to_dsr(src1_dsr_recv, fab_recv_east_wdsd, .{.async=true, .activate = f_recv });
            @mov32(dest_dsr_recv, src1_dsr_recv, .{.async=true} );
        }
        recv_state = RECV_STATE_SOUTH;
    }else if (RECV_STATE_SOUTH == recv_state){ // receive from south
        // The last PE of y-dir has no south neighbor, so it does not recv data from the south
        if (last_py){
            @activate(RECV); // nothing to receive
        }else{
            //@mov32(mem_south_buf_dsd, fab_recv_south_wdsd, .{.async=true, .activate = f_recv });
            @load_to_dsr(dest_dsr_recv, mem_south_buf_dsd);
            @load_to_dsr(src1_dsr_recv, fab_recv_south_wdsd, .{.async=true, .activate = f_recv });
            @mov32(dest_dsr_recv, src1_dsr_recv, .{.async=true} );
        }
        recv_state = RECV_STATE_NORTH;
    }else if (RECV_STATE_NORTH == recv_state){ // receive from north
        // The first PE of y-dir has no north neighbor so it does not receive data from the north
        if (first_py){
            @activate(RECV); // nothing to receive
        }else{
            //@mov32(mem_north_buf_dsd, fab_recv_north_wdsd, .{.async=true, .activate = f_recv });
            @load_to_dsr(dest_dsr_recv, mem_north_buf_dsd);
            @load_to_dsr(src1_dsr_recv, fab_recv_north_wdsd, .{.async=true, .activate = f_recv });
            @mov32(dest_dsr_recv, src1_dsr_recv, .{.async=true} );
        }
        recv_state = RECV_STATE_DONE;
    }else{
        count_send_recv += 1;
        // if both SEND and RECV are done, perform
        // - compute laplacian on x-y
        // - update the DSD of pointer x and y
        // - activate COMM
        compute_and_next_send_recv();
        // reset recv_state for next block RECV
        recv_state = RECV_STATE_WEST;
    }
}

fn update_dsd_length( cur_length: i16) void {

    var u16_cur_length: u16 = @bitcast(u16,cur_length);

    // update the length of fabin/fabout for the communication
    fab_recv_west_wdsd = @set_dsd_length(fab_recv_west_wdsd, u16_cur_length);
    fab_recv_east_wdsd = @set_dsd_length(fab_recv_east_wdsd, u16_cur_length);
    fab_recv_south_wdsd = @set_dsd_length(fab_recv_south_wdsd, u16_cur_length);
    fab_recv_north_wdsd = @set_dsd_length(fab_recv_north_wdsd, u16_cur_length);
    fab_trans_east_wdsd = @set_dsd_length(fab_trans_east_wdsd, u16_cur_length);
    fab_trans_west_wdsd = @set_dsd_length(fab_trans_west_wdsd, u16_cur_length);
    fab_trans_north_wdsd = @set_dsd_length(fab_trans_north_wdsd, u16_cur_length);
    fab_trans_south_wdsd = @set_dsd_length(fab_trans_south_wdsd, u16_cur_length);

    // update length of local x for send
    mem_center_buf_dsd = @set_dsd_length(mem_center_buf_dsd, u16_cur_length);
    // update length of local received buffers
    mem_west_buf_dsd = @set_dsd_length(mem_west_buf_dsd, u16_cur_length);
    mem_east_buf_dsd = @set_dsd_length(mem_east_buf_dsd, u16_cur_length);
    mem_south_buf_dsd = @set_dsd_length(mem_south_buf_dsd, u16_cur_length);
    mem_north_buf_dsd = @set_dsd_length(mem_north_buf_dsd, u16_cur_length);

    // update length for x-y laplacian
    mem_y_buf_dsd = @set_dsd_length(mem_y_buf_dsd, u16_cur_length);
}

// case 1: zDim <= BLOCK_SIZE
//   cur_length = zDim set by f_spmv
//   rem_length = zDim = cur_length, no update
// case 2: 2*BLOCK_SIZE > zDim > BLOCK_SIZE
//   cur_length = BLOCK_SIZE set by f_spmv
//   0 < rem_length < BLOCK_SIZE=cur_length
//   this is last iteration, so update all DSDs
// case 3: zDim >= 2*BLOCK_SIZE
//   cur_length = BLOCK_SIZE set by f_spmv
//   BLOCK_SIZE <= rem_length
//   This is NOT the last iteration
//
// start_x is the starting position of current spmv.
// start_x is updated by the end of previous COMM.
// it is possible that start_x < 0 because COMM always
// advances start_x by BLOCK_SIZE.
//
// Example 1: zDim = 3, BLOCK_SIZE = 2
// 1st iteration: start_x=0, cur_length=2, rem_length=3
// 2nd iteration: start_x=2, cur_length=2, rem_length=1 -> last iteration, update cur_length=1
// 3rd iteration: start_x=4, cur_length=1, rem_length=-1 -> spmv finishes
//
// Example 2: zDim = 3, BLOCK_SIZE = 4
// 1st iteration: start_x=0, cur_length=3, rem_length=3, last iteration
// 2nd iteration: start_x=4, cur_length=3, rem_length=-1 -> spmv finishes
//
// Example 3: zDim = 3, BLOCK_SIZE = 3
// 1st iteration: start_x=0, cur_length=3, rem_length=3, last iteration
// 2nd iteration: start_x=3, cur_length=3, rem_length=0 -> spmv finishes
//
// Example 4: zDim = 4, BLOCK_SIZE = 2
// 1st iteration: start_x=0, cur_length=2, rem_length=4
// 2nd iteration: start_x=2, cur_length=2, rem_length=2, last iteration
// 3rd iteration: start_x=4, cur_length=2, rem_length=0 -> spmv finishes
//
// Example 5: zDim = 5, BLOCK_SIZE = 2
// 1st iteration: start_x=0, cur_length=2, rem_length=5
// 2nd iteration: start_x=2, cur_length=2, rem_length=3
// 3rd iteration: start_x=4, cur_length=1, rem_length=1, last iteration
// 3rd iteration: start_x=6, cur_length=1, rem_length=-1 -> spmv finishes
//
task f_comm() void {
    rem_length = (zDim - start_x);
    if (0 < rem_length){
        // the condition of last iteration "rem_length <= cur_length"
        if (rem_length < cur_length){ // last iteration with different length
            // last block, update the length of DSD
            cur_length = rem_length;
            // update all DSDs except laplacian of z
            update_dsd_length(cur_length);
        }
        // The next spmv starts at x[start_x], y[start_x]
        // compute_and_next_send_recv() has updated
        // start_x, mem_center_buf_dsd and mem_y_buf_dsd

        // the state machine must be in the inital state
        @assert(RECV_STATE_WEST == recv_state);
        @assert(SEND_STATE_EAST == send_state);

        // restart a block SEND/RECV
        count_send_recv = 0;
        // send data to the east
        @activate(SEND);
        // recv data from the west
        @activate(RECV);
    }else{
        // laplacian has been done on x-y direction by COMM
        laplacian_z(); // laplacian on z-direction
        // spmv finishes
        //sys_mod.unblock_cmd_stream();
        f_callback();
    }
}

comptime {
    @bind_local_task(f_send, SEND);
    @bind_local_task(f_recv, RECV);
    @bind_local_task(f_comm, COMM);
}

// C0, C1:recv_west, send_east
//
//         C0     C1     C0     C1
// West P0 --> P1 --> P2 --> P3 --> P4  East
//
//         C0     C1     C0     C1     C0
// West P0 --> P1 --> P2 --> P3 --> P4 --> P5 East
//
// P0: send C0
// P_even: recv C1, send C0
// P_odd: recv C0, send C1
// P_last: recv C0 if odd; recv C1 if even
comptime {
    if (first_px){
        // px = 0: send to east east
        @set_local_color_config(c_send_east, .{ .routes = .{ .rx = .{RAMP}, .tx = .{EAST} } } );
    }else{
        if (last_px){
           // px = width-1: recv from west
           @set_local_color_config(c_recv_west, .{ .routes = .{ .rx = .{WEST}, .tx = .{RAMP} } } );
        }else{
           // 0 < px < width-1: receive from west, send to east
           @set_local_color_config(c_recv_west, .{ .routes = .{ .rx = .{WEST}, .tx = .{RAMP} } } );
           @set_local_color_config(c_send_east, .{ .routes = .{ .rx = .{RAMP}, .tx = .{EAST} } } );
        }
    }
}


// C2, C3: recv_east, send_west
//
//          C2     C3     C2     C3
// West P0 <-- P1 <-- P2 <-- P3 <-- P4  East
//
//          C2     C3     C2     C3     C2
// West P0 <-- P1 <-- P2 <-- P3 <-- P4 <-- P5 East
//
// P0: recv C2
// P_even: recv C2, send C3
// P_odd: recv C3, send C2
// P_last: send C2 if odd; send C3 if even
comptime {
    if (first_px){
        // px = 0: receive from east
        @set_local_color_config(c_recv_east, .{ .routes = .{ .rx = .{EAST}, .tx = .{RAMP} } } );
    }else{
        if (last_px){
           // px = width-1: send to west
           @set_local_color_config(c_send_west, .{ .routes = .{ .rx = .{RAMP}, .tx = .{WEST} } } );
        }else{
           // 0 < px < width-1: receive from east, send to west
           @set_local_color_config(c_recv_east, .{ .routes = .{ .rx = .{EAST}, .tx = .{RAMP} } } );
           @set_local_color_config(c_send_west, .{ .routes = .{ .rx = .{RAMP}, .tx = .{WEST} } } );
        }
    }
}

// C4, C5: recv_south, send_north
//
//           C4     C5     C4     C5
// North P0 <-- P1 <-- P2 <-- P3 <-- P4   south
//
//           C4     C5     C4     C5     C4
// North P0 <-- P1 <-- P2 <-- P3 <-- P4 <-- P5 south
//
// P0: recv C4
// P_even: recv C4, send C5
// P_odd: recv C4, send C5
// P_last: send C4 if odd; send C5 if even
comptime {
    if (first_py){
        // py = 0 (even): receive from south
        @set_local_color_config(c_recv_south, .{ .routes = .{ .rx = .{SOUTH}, .tx = .{RAMP} } } );
    }else{
        if (last_py){
           // py = height-1: send to north
           @set_local_color_config(c_send_north, .{ .routes = .{ .rx = .{RAMP}, .tx = .{NORTH} } } );
        }else{
           // 0 < py < height-1: receive from south, send to north
           @set_local_color_config(c_recv_south, .{ .routes = .{ .rx = .{SOUTH}, .tx = .{RAMP} } } );
           @set_local_color_config(c_send_north, .{ .routes = .{ .rx = .{RAMP}, .tx = .{NORTH} } } );
        }
    }
}

// C6, C7: recv_north, send_south
//
//           C6     C7     C6     C7
// North P0 --> P1 --> P2 --> P3 --> P4   south
//
//           C6     C7     C6     C7     C6
// North P0 --> P1 --> P2 --> P3 --> P4 --> P5 south
//
// P0: send C6
// P_even: recv C7, send C6
// P_odd: recv C6, send C7
// P_last: recv C6 if odd; recv C7 if even
comptime {
    if (first_py){
        // py = 0 (even): send to south
        @set_local_color_config(c_send_south, .{ .routes = .{ .rx = .{RAMP}, .tx = .{SOUTH} } } );
    }else{
        if (last_py){
           // py = height-1: recv from north
           @set_local_color_config(c_recv_north, .{ .routes = .{ .rx = .{NORTH}, .tx = .{RAMP} } } );
        }else{
           // 0 < py < height-1: receive from north, send to south
           @set_local_color_config(c_recv_north, .{ .routes = .{ .rx = .{NORTH}, .tx = .{RAMP} } } );
           @set_local_color_config(c_send_south, .{ .routes = .{ .rx = .{RAMP}, .tx = .{SOUTH} } } );
        }
    }
}


// binding a color to an input queue.
// This is necessary when an explicit DSR binds to a fabin DSD because
// the compiler no longer can generate the instruction to set up the
// config register of input queue.
comptime {
    @initialize_queue(@get_input_queue(input_queues[0]), .{.color = c_recv_west});
    @initialize_queue(@get_input_queue(input_queues[1]), .{.color = c_recv_east});
    @initialize_queue(@get_input_queue(input_queues[2]), .{.color = c_recv_south});
    @initialize_queue(@get_input_queue(input_queues[3]), .{.color = c_recv_north});
}

csl-libs/allreduce/layout.csl

param colors: [1]color;
param entrypoints: [4]local_task_id;
param width: i16 ;   // width of the core
param height: i16 ;  // height of the core


const C0: color = colors[0];

// entrypoints of allreduce module
const SEND_CTRL: local_task_id = entrypoints[0];
const SEND_DATA: local_task_id = entrypoints[1];
const STATE_ENTRY: local_task_id = entrypoints[2];
// LOCK runs only if teardown is received and the operation is done
// LOCK performs the state transition
// teardown handler activates LOCK
// the operation blocks LOCK in the beginning and unblocks it when it finishes
const C_LOCK: local_task_id = entrypoints[3];

fn get_params(px:i16, py:i16) comptime_struct {

    var first_py: bool = (0 == py);
    var last_py: bool = ((height-1) == py);

    var first_px: bool = (0 == px);
    var last_px: bool = ((width-1) == px);

    return .{
        .first_px = first_px,
        .last_px = last_px,
        .first_py = first_py,
        .last_py = last_py,
        .C_ROUTE = C0,
        .C_SEND_CTRL = SEND_CTRL,
        .C_SEND_DATA = SEND_DATA,
        .C_STATE_ENTRY = STATE_ENTRY,
        .C_LOCK = C_LOCK,
        .width = width,
        .height = height
    };
}

csl-libs/allreduce/pe.csl

// allreduce module has the following three operations
//  - row reduction
//  - column reduction
//  - broadcasting
//
// It only uses a single routable color, three entrypoints and a single
// input/output queue. Any user's kernel can combine this module with other
// modules without running out of resources.
//
// 1. row reduction
//   The reduction is from left to right. The last PE (px=width-1) receives
//   all data from the neighbors one by one. The result is stored in the last
//   PE, other PEs do not change the content.
//
// 2. column reduction
//   The reduction is from north to south. The last PE (py=height-1) receives
//   all data from the neighbors one by one. The result is stored in the last
//   PE, other PEs do not change the content.
//
// 3. broadcast
//   The right-bottom PE (px=width-1, py=height-1) broadcasts the data upwards to
//   whole column, then each PE in this column broadcasts data to its west neighbors.
//

// The portal allreduce_nrm2() computes nrm2(x) by using allreduce(MAX)
// x must be a scalar (for simplicity)
// Here is the sequence of operations
// 1. allreduce(MAX, |x|)
//    xmax = max(|x|) overwrites |x|
// 2. SCALE_AND_SQUARE
//    alpha = approx(xmax)
//    |x| = |x|/alpha
//    |x| = |x| * |x|
// 3. allreduce(ADD, |x|)
//    |x| = sum{ (xj/alpha)^2 }
// 4. NRM2
//    |x| = alpha * sqrt(|x|)
//    All PEs perform NRM2 because of broadcasting, so we don't need to broadcast
//    the final result to all PEs.
//
// The state machine has 9 states
//   # (1) allreduce(MAX)
//   wvlts_per_pe = 1
//   functorop = MAX
//   state_seq[0] = STATE_ROW_REDUCE;
//   state_seq[1] = STATE_COL_REDUCE;
//   state_seq[2] = STATE_BCAST;
//   # (2) SCALE_AND_SQUARE
//   state_seq[3] = STATE_SCALE_AND_SQUARE; // next operation is ADD
//   # (3) allreduce(ADD)
//   state_seq[4] = STATE_ROW_REDUCE;
//   state_seq[5] = STATE_COL_REDUCE;
//   state_seq[6] = STATE_BCAST;
//   # (4) NRM2
//   state_seq[7] = STATE_NRM2;
//   # (5) END
//   state_seq[8] = STATE_DONE;
//


// How to assign explicit DSRs
//
// reduction:
//  last PE: f_send_data --> @fadds(mem_x_buf_dsd, mem_x_buf_dsd, fab_recv_wdsd, .{.async=true, .activate=f_send_data} );
//                 ^                          |
//                 |--------------------------+
//  others: f_send_data --> @mov32(fab_trans_x_wdsd, mem_x_buf_dsd, .{.async=true, .activate=f_send_ctrl} );
//          --> @mov32(fab_trans_ctrl_wdsd, mem_ctrl_buf_dsd, .{.async=true, .activate=f_send_data } );
//          --> f_send_data
//          1st PE: @mov32(fab_trans_ctrl_wdsd, mem_buf_td_dsd, .{.async=true} );
//
// bcast:
//  right-bottom PE: @mov32(fab_trans_x_wdsd, mem_x_buf_dsd, .{.async=true, .activate=f_send_ctrl} );
//                   --> @mov32(fab_trans_ctrl_wdsd, mem_buf_td_dsd, .{.async=true} );
//  others: @mov32(mem_x_buf_dsd, fab_recv_wdsd, .{.async=true} );
//
// Only one dest DSR, one src0 DSR and one src1 DSR are enough because
// - the teardown separates different operations
// - when TD arrives, sender has sent out the data/ctrl
//   the receiver has received all data because there is only one color
// - all DSD operations are serialized
//
// For example:
//   dest_dsr = @get_dsr(dsr_dest, 1);
//   src0_dsr = @get_dsr(dsr_src0, 1);
//   src1_dsr = @get_dsr(dsr_src1, 1);
//

// The sequence of LOCK of { row_reduce, col_reduce, bcast}
//
//  row_reduce blocks LOCK
//  T29 activates LOCK
//  row_reduce unblocks LOCK when it finishes
//
//  LOCK goes to next state
//
//  col_reduce blocks LOCK
//  T29 activates LOCK
//  col_reduce unblocks LOCK when it finishes
//
//  LOCK goes to next state
//
//  bcast blocks LOCK
//  T29 activates LOCK
//  bcast unblocks LOCK when it finishes
//
//  LOCK goes to next state (done)
//

param C_ROUTE: color;

param C_SEND_CTRL: local_task_id;  // send switch advance
param C_SEND_DATA: local_task_id;  // send data
param C_STATE_ENTRY: local_task_id; // state machine
// LOCK runs only if teardown is received and the operation is done
// LOCK performs the state transition
// teardown handler activates LOCK
// the operation blocks LOCK in the beginning and unblocks it when it finishes
param C_LOCK: local_task_id;

param first_px: bool; // (0 == px)
param last_px: bool;  // ((width-1) == px)
param first_py: bool; // (0 == py)
param last_py: bool;  // ((height-1) == py)

// row reduction needs to receive width-1 neighbors
// column reduction needs to receive height-1 neighbors
param width: i16;
param height: i16;

// f_callback = sys_mod.unblock_cmd_stream, to continue next command
param f_callback: fn ()void;

// last PE uses this ID as the input queue
// others use this ID as the output queue
param queues:[1]u16;

// explicit DSR allocation
param dest_dsr_ids: [1]u16;
param src0_dsr_ids: [1]u16;
param src1_dsr_ids: [1]u16;

const timestamp = @import_module("<time>");

const math_lib = @import_module("<math>");

// A new type for binary operators
// compiler assigns ADD=0 and MAX=1
const TYPE_BINARY_OP = enum(u16) { ADD, MAX };

// tsc_size_words = 3
var tscRefBuffer = @zeros([timestamp.tsc_size_words]u16);

////////////////////////////////////////////////////////////////////////////////
// Main memory (48KB)
////////////////////////////////////////////////////////////////////////////////

var x: [*]f32;

var functor: TYPE_BINARY_OP = TYPE_BINARY_OP.ADD;

const STATE_ROW_REDUCE: i16 = 0;
const STATE_COL_REDUCE: i16 = 1;
const STATE_BCAST: i16 = 2;
const STATE_SCALE_AND_SQUARE: i16 = 3;
const STATE_NRM2: i16 = 4;
const STATE_DONE: i16 = 5;

// allreduce(ADD/MAX) has four states
// allreduce_nrm2 has 9 states
// "+1" is to avoid out-of-bound if
// STATE_DONE also dereference next state
var state_seq = @zeros([9+1]i16);
var state_idx: i16 = 0;
var cur_state: i16 = 0;
var next_state: i16 = 0;

// record the reduction length from the caller
var wvlts_per_pe: u16 = 0;

// number of PEs involed in the reduction: last PE needs to count number of received neighbors
// WARNING: reduce_pes only records number of received PEs
//   row reduction: width-1
//   column reduction: height-1
// If reduce_pes is wrong, simfab shows re-entry error of UT when row reduction and col reduction
// are combined because row reduction has extra UT1 waiting for wavelets
var reduce_pes: i16 = 0;
// 1st PE during the reduction: send TD to others
//   row reduction: {px = 0}
//   column reduction: {py = 0}
var reduce_first_pe: bool;
// last PE during the reduction: receive data from w-1 or h-1 neighbors
//   row reduction: {px = w-1}
//   column reduction: {py = h-1}
var reduce_last_pe: bool;

// last PE uses count_recv_or_send to receive data from w-1 neighbors
// other PEs use count_recv_or_send to send data and control
var count_recv_or_send: i16 = 0;


const dest_dsr = @get_dsr(dsr_dest, dest_dsr_ids[0]);
const src0_dsr = @get_dsr(dsr_src0, src0_dsr_ids[0]);
const src1_dsr = @get_dsr(dsr_src1, src1_dsr_ids[0]);


// The portal function of allreduce(ADD/MAX)
//
// How to use:
//  reduce_mod = = @import_module( "<allreduce/pe>");
//  reduce_mod.allreduce(n, x);
//  The user has to prepare input vector x.
//
//  When allreduce() finishes, it will call user's callback.
//
// case 1: row reduction
//   state_seq = {STATE_ROW_REDUCE, STATE_DONE}
// case 2: column reduction
//   state_seq = {STATE_COL_REDUCE, STATE_DONE}
// case 3: row + column reduction
//   state_seq = {STATE_ROW_REDUCE, STATE_COL_REDUCE, STATE_DONE}
// case 4: broadcast
//   state_seq = {STATE_BCAST, STATE_DONE}
//
fn allreduce( n: i16, in_tensor: [*]f32, op: TYPE_BINARY_OP ) void {

   x = in_tensor;
   functor = op;

   @assert(n > 0);

   wvlts_per_pe = @bitcast(u16, n);

   // setup state sequence
   state_seq[0] = STATE_ROW_REDUCE;
   state_seq[1] = STATE_COL_REDUCE;
   state_seq[2] = STATE_BCAST;
   state_seq[3] = STATE_DONE;

   state_idx = 0;
   cur_state = state_seq[0];
   @activate(C_STATE_ENTRY);
}

// nrm2_x_copy keeps a copy of x during the nrm2 because
// the x is used by the allreduce
// After allreduce(MAX,x), all PEs have the same x = max(|xj|)
// nrm2_x_copy is used in scale_and_square:
//    alpha = approx(x[0])
//    x[0] <- (nrm2_x_copy / alpha)^2
// Then allreduce(ADD, x) updates x[0] = (|x|_2/alpha)^2
var nrm2_x_copy: f32;


// The portal function rnm2
//
// It only computes nrm2(x[0]) because
// - common case is n = 1
// - no SIMD on sqrt
//
fn allreduce_nrm2(in_tensor: [*]f32) void {

    x = in_tensor;
    functor = TYPE_BINARY_OP.MAX;
    wvlts_per_pe = 1; // nrm2 of x[0]

    // x <-- |xj|
    var xreg = x[0];
    xreg = math_lib.abs(xreg);
    x[0] = xreg;
    // nrm2_x_copy can keep either xj or |xj|
    nrm2_x_copy = xreg;

    // setup state sequence
    // (1) allreduce(MAX)
    state_seq[0] = STATE_ROW_REDUCE;
    state_seq[1] = STATE_COL_REDUCE;
    state_seq[2] = STATE_BCAST;
    // x[0] = max(|xj|)
    // (2) SCALE_AND_SQUARE
    // x[0] = (|xj|/alpha)^2
    state_seq[3] = STATE_SCALE_AND_SQUARE; // next operation is ADD
    // (3) allreduce(ADD)
    state_seq[4] = STATE_ROW_REDUCE;
    state_seq[5] = STATE_COL_REDUCE;
    state_seq[6] = STATE_BCAST;
    // x[0] = sum{(|xj|/alpha)^2}
    // (4) NRM2
    // x[0] = |x|_2
    state_seq[7] = STATE_NRM2;
    // (5) END
    state_seq[8] = STATE_DONE;

    state_idx = 0;
    cur_state = state_seq[0];
    @activate(C_STATE_ENTRY);
}

//--------------------- system utility for teardown

// ref: old monolith/src/ucode/kernels/lib/pe_address_map.casm
// const TAMAP_FAB_MAP_START_ADDR       = 0x7f20
// ref: monolith/src/ucode/kernels/lib/pe_addr_map_ein.casm
// const TAMAP_FAB_TRAFFIC_MAP_START_ADDR = 0x7f20
const TAMAP_FAB_MAP_START_ADDR : u16 = 0x7f20;
const D2H_COLOR_CONFIG_ADDR : u16 = TAMAP_FAB_MAP_START_ADDR + @get_int(C_ROUTE);

// mask out bit 0:9, including input/output pos0
// keep bit 10:15, including
//  bit 15: point to point
//  bit 14: TIP
//  bit 12-13: control wavelet pop mode
//  bit 11: color swap for E,W inputs
//  bit 10: color swap for N,S inputs
//
// The teardown clears bit 14 (TIP)
// bit 12-13 is comptime decided, only last PE uses pop_always, others pop_on_advance
const MASK_INPUT_OUTPUT_POS0: u16 = 0xfc00;

// bits 0:4 define the initial output switch position
const OUTPUT_WEST: u16  = 0x1;  // bit 0: west output mask
const OUTPUT_EAST: u16  = 0x2;  // bit 1: east output mask
const OUTPUT_SOUTH: u16 = 0x4;  // bit 2: south output mask
const OUTPUT_NORTH: u16 = 0x8;  // bit 3: north output mask
const OUTPUT_RAMP: u16  = 0x10; // bit 4: offramp output mask

// bits 5:9 define the initial input switch position
const INPUT_WEST: u16  = 0x20;  // bit 5: west input mask
const INPUT_EAST: u16  = 0x40;  // bit 6: east input mask
const INPUT_SOUTH: u16 = 0x80;  // bit 7: south input mask
const INPUT_NORTH: u16 = 0x100; // bit 8: north input mask
const INPUT_RAMP: u16  = 0x200; // bit 9: onramp input mask

// Fabric switch configuration
// 0x7f40 - 0x7f57 - colors 0-23. Each address is for a single color
// Bits 14:13 Current Switch position (writes both input and output switch position; reads input position)
// Bits 12 Ring mode (1) (Switch movements Stop on last valid setting if ring mode is 0.)
// Bit 11 Switch position 3 switch select (1=input; 0 = output)
// Bits 10:8 Switch position 3 (5 = INVALID; 4 = CE; 3 = N; 2 = S; 1 = E; 0 = W)
// Bit 7 Switch position 2 switch select (1=input; 0 = output)
// Bits 6:4 Switch position 2 (5 = INVALID; 4 = CE; 3 = N; 2 = S; 1 = E; 0 = W)
// Bit 3 Switch position 1 switch select (1=input; 0 = output)
// Bits 2:0 Switch position 1 (5 = INVALID; 4 = CE; 3 = N; 2 = S; 1 = E; 0 = W)
//
// ref: monolith/src/ucode/kernels/lib/pe_addr_map_fyn.casm
// .const TAMAP_FAB_SWITCH_CFG_START_ADDR = 0x7f40
const TAMAP_FAB_SWITCH_CFG_START_ADDR: u16 = 0x7f40;
const D2H_SWITCH_CONFIG_ADDR: u16 = TAMAP_FAB_SWITCH_CFG_START_ADDR + @get_int(C_ROUTE);
// mask bits 14:13
// masking with MASK_SWITCH_RESET_POS0 is equivalent to set bits14:13 to zero (i.e. back to pos0)
const MASK_SWITCH_RESET_POS0: u16 = 0x9fff;

// To clear setting of pos1, set bits 2:0 to zero, but keep others unchanged
const MASK_SWITCH_CLEAR_POS1: u16 = 0xfff8;
// Bit 3 is always 1 because "pos1 = {.rx = RAMP}" implies position 1 switch select is "1=input"
const SWITCH_POS1_INVALID: u16 = 0x5;
const SWITCH_POS1_RAMP: u16 = 0x4;

fn translate_word_to_bytes( addr: u16 ) u16 {
    var addr_bytes = addr * 2 ;
    return addr_bytes;
}


////////////////////////////////////////////////////////////////////////////////
// DSDs
// data-structure descriptors (DSDs), loaded into data-structure registers (DSRs)
//
// Queues 0,1: input depth 6 wavelets
// Queues 2,3: input depth 4 wavelets
// Queues 4-7: input depth 2 wavelets
//
// queues 0,1: output depth 2 wavelets
// queues 2,3: output depth 6 wavelets
// queues 4,5: output depth 2 wavelets
//
// Length of an operand:
// The length of all other types of DSRs is specified by the length field of its DSD. When
// the bits encoding the length are 0x7fff, the length is infinite.
//
// Length of the vector instruction is then determined in the following order:
// 1. If src0 has a non-zero length, that length is used
// 2. If src1 has a non-zero length, that length is used
// 3. If dst has a non-zero length, that length is used
// 4. if no operands have length (all operands are GPR), length = 1
////////////////////////////////////////////////////////////////////////////////

const dummy = @zeros([1]i16);

// rowReduce() binds mem_x_buf_dsd to pointer x and resets its length to n (given by the caller)
// Last PE adds data from neighbors to mem_x_buf_dsd
// Other PEs send mem_x_buf_dsd to the east
var mem_x_buf_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{1} -> dummy[i] });

// other PE (not last PE) uses this DSD to send x
var fab_trans_x_wdsd = @get_dsd(fabout_dsd, .{
    .extent = 1,
    .fabric_color = C_ROUTE,
    .output_queue = @get_output_queue(queues[0])
});

// WARNING: control wavelet must be sent with the same microthread, via the same output buffer,
// otherwise, we may see only one data wavelet, then 2nd is the control wavelet, then
// the remaining data cannot be sent out because the routing is back to {.rx=WEST, .tx=EAST},
// there is no path from RAMP to the router.
const fab_trans_ctrl_wdsd = @get_dsd(fabout_dsd, .{
    .extent = 1,
    .control = true,
    .fabric_color = C_ROUTE,
    .output_queue = @get_output_queue(queues[0]),
});


// row reduction: the last PE receives the data from its w-1 neighbors,
// the receiving sequence is p0, p1, ..., p{w-1}.
// It uses the same queue ID because it does not send, only receives.
// It does not receive ctrl wavelets because of NOCE.
// f_send_data() receives data (w-1) times
//
var fab_recv_wdsd =  @get_dsd(fabin_dsd, .{
   .extent = 1,
   .fabric_color = C_ROUTE,
   .input_queue = @get_input_queue(queues[0])
});


////////////////////////////////////////////////////////////////////////////////
// Tasks
// syntax
//     task_begin(name, entrypoint, color)
////////////////////////////////////////////////////////////////////////////////


const switches = @import_module("<memcpy/memcpy_switches>");

// The following arrays define values for control wavelets, which update the
// switch position at the recipient PEs.
// All are comptime constants
//
// ctrl_11 is for other PEs which changes switch of two consecutive PEs
var ctrl_11 = [1]u32 { switches.ctrl(switches.switch_cmd_11()) };

var mem_ctrl_buf_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{1} -> ctrl_11[i] });

// teardown from casm
// teardown_buf[0] = (1 << 31) | (0b111 << 22)
//
// teardown from csl
// from cslang/test-e2e/dynamic_filters/sender.csl
//  31=0x1f = no entrypoint
//
// teardown wavelet = 0x1df 0000
//const teardown_buf = [1]u32{(31 << 16) | 0b111 << 22};
// teardown wavelet = 0x9df 9249
const teardown_buf = [1]u32 { switches.ctrl(switches.teardown_cmd_1()) };

const mem_buf_td_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{1} -> teardown_buf[i] });

// the color C_ROUTE is in teardown mode at comptime by specifying {.teardown = true}
// reduce() and broadcast_to_all() block LOCK in the beginning and unblock it
// when the operation finishes
//
// WARNING: we don't block LOCK in the xxx_configure() because it is sequential, and
// it is more intuitive to unblock LOCK in reduce() and broadcast_to_all()
//
task f_state_entry() void {

    if (STATE_ROW_REDUCE == cur_state){
        // rowReduce_configure() reconfigures the pos0, pos1 and clears TIP
        rowReduce_configure();
        // perform row reduction and store the result to last PE
        // 1st PE will send TD to turn C_ROUTE back to teardown mode
        reduce_pes = width-1;
        reduce_first_pe = first_px;
        reduce_last_pe = last_px;
        reduce( wvlts_per_pe );

        // prefetch next state which will be copied into cur_state in teardown handler
        state_idx += 1;
        next_state = state_seq[state_idx];

    }else if (STATE_COL_REDUCE == cur_state){
        // colReduce_configure() reconfigures the pos0, pos1 and clears TIP
        colReduce_configure();
        // perform column reduction and store the result to last PE
        // 1st PE will send TD to turn C_ROUTE back to teardown mode
        reduce_pes = height-1;
        reduce_first_pe = first_py;
        reduce_last_pe = last_py;
        reduce( wvlts_per_pe );

        // prefetch next state which will be copied into cur_state in teardown handler
        state_idx += 1;
        next_state = state_seq[state_idx];

    }else if (STATE_BCAST == cur_state){
        // bcast_configure() reconfigures pos0 and disables pos1
        bcast_configure();
        // right-bottom PE broadcasts the data to others and also sends TD
        broadcast_to_all( wvlts_per_pe );

        // prefetch next state which will be copied into cur_state in teardown handler
        state_idx += 1;
        next_state = state_seq[state_idx];

    }else if (STATE_SCALE_AND_SQUARE == cur_state){
        // Assume allreduce(MAX) is done
        // x[0] = xmax = max({|xj|})

        // Update x[0] by (xmax/alpha)^2
        scale_and_square();
        // next reduction is allreduce(ADD)
        functor = TYPE_BINARY_OP.ADD;

        // prefetch next state which will be copied into cur_state in teardown handler
        // sequential code: f_lock is not triggered to assign (cur_state = next_state)
        // update cur_state directly
        state_idx += 1;
        cur_state = state_seq[state_idx];

        @activate(C_STATE_ENTRY);

    }else if (STATE_NRM2 == cur_state){
        // Assume allreduce(ADD) is done
        // x[0] = sum({(|xj|/alpha)^2})
        // Update x[0] by |x|_2
        nrm2_postprocessing();

        // prefetch next state which will be copied into cur_state in teardown handler
        // sequential code: f_lock is not triggered to assign (cur_state = next_state)
        // update cur_state directly
        state_idx += 1;
        cur_state = state_seq[state_idx];

        @activate(C_STATE_ENTRY);

    }else if (STATE_DONE == cur_state){
        // state machine is done, return control back to the caller
        timestamp.get_timestamp(&tscRefBuffer);

        f_callback();
    }else{
        @assert(false); // Error: unknown state
        // assert() is ignored by HW, it could hang here
        // To avoid a stall, trigger callback (the caveat is the wrong result)
        f_callback();
    }
}

fn reduce( n: u16 ) void {

    // WARNING: block LOCK in the beginning and only
    // unblock LOCK when "reduce" finishes
    @block(C_LOCK);

    count_recv_or_send = 0;

    // changes switch of of itself and its neighbor
    // The last PE does not call f_send_ctrl(), so this op is DON'T care
    mem_ctrl_buf_dsd =  @set_dsd_base_addr(mem_ctrl_buf_dsd, ctrl_11);

    mem_x_buf_dsd = @set_dsd_base_addr(mem_x_buf_dsd, x);
    mem_x_buf_dsd = @set_dsd_length(mem_x_buf_dsd, n);

    fab_recv_wdsd = @set_dsd_length(fab_recv_wdsd, n);
    fab_trans_x_wdsd = @set_dsd_length(fab_trans_x_wdsd, n);

    // last PE receives data from w-1 neighbors
    // other PEs send data and control to the east/south
    @activate(C_SEND_DATA);  // triggers f_send_data
}

fn broadcast_to_all( n: u16 ) void {

    // WARNING: block LOCK in the beginning and only
    // unblock LOCK when "broadcast" finishes
    @block(C_LOCK);

    // No PE sends switch advance
    // mem_ctrl_buf_dsd =  @set_dsd_base_addr(mem_ctrl_buf_dsd, ctrl_11);

    mem_x_buf_dsd = @set_dsd_base_addr(mem_x_buf_dsd, x);
    mem_x_buf_dsd = @set_dsd_length(mem_x_buf_dsd, n);
    fab_recv_wdsd = @set_dsd_length(fab_recv_wdsd, n);
    fab_trans_x_wdsd = @set_dsd_length(fab_trans_x_wdsd, n);

    if ( last_px and last_py ){
        // Pw-1,h-1 sends data and then f_send_ctrl sends a TD
        // f_send_ctrl() will unblock LOCK
        //@mov32(fab_trans_x_wdsd, mem_x_buf_dsd, .{.async=true, .activate=f_send_ctrl} );
        @load_to_dsr(dest_dsr, fab_trans_x_wdsd, .{.async=true, .activate=f_send_ctrl} );
        @load_to_dsr(src1_dsr, mem_x_buf_dsd);
        @mov32(dest_dsr, src1_dsr, .{.async=true} );
    }else{
        // other PEs receive data and wait for TD
        // unblock LOCK after data is received, T29 will activate LOCK
        //@mov32(mem_x_buf_dsd, fab_recv_wdsd, .{.async=true, .unblock=C_LOCK} );
        @load_to_dsr(dest_dsr, mem_x_buf_dsd);
        @load_to_dsr(src1_dsr, fab_recv_wdsd, .{.async=true, .unblock=C_LOCK} );
        @mov32(dest_dsr, src1_dsr, .{.async=true} );
    }
}


var alpha: f32;
var inv_alpha: f32;


// Assume the caller finishes allreduce(MAX,|x|) so x[0] = max({|xj|})
// Update x[0] by (x[0]/alpha)^2
// where
//     alpha = 2^(E-127) approximates x[0]
//
fn scale_and_square() void {
    var xreg: f32 = x[0];
    // (1) compute alpha
    approx(xreg, &alpha, &inv_alpha);
    // (2) scale x by x/alpha
    xreg = nrm2_x_copy;
    xreg = xreg * inv_alpha;
    // (3) square x
    // xreg is O(1), SQUARE does not overflow
    x[0] = xreg * xreg;
}

// Assume the caller has computed
//   x[0] = allreduce(ADD, (xj/alpha)^2)
//
// Update x[0] by |x|_2
//
fn nrm2_postprocessing() void{
    // x[0] = sum({(xj/alpha)^2}) = |x|^2 / alpha^2
    var xreg: f32 = x[0];
    xreg = math_lib.sqrt(xreg);
    x[0] = xreg * alpha;
}

// last PE does not send data, it only receives data
// row-reduce sequence: f_send_data() --> f_send_ctrl()
//                      ^                  |
//                      |------------------+
//
// f_send_data() is the last call when the reduction finishes
// unblock LOCK here when the operation is done
task f_send_data() void {
    if (reduce_last_pe){
        // last PE receives data from reduce_pes neighbors
        if (count_recv_or_send < reduce_pes){
            //@fadds(mem_x_buf_dsd, mem_x_buf_dsd, fab_recv_wdsd, .{.async=true, .activate=f_send_data} );
            @load_to_dsr(src1_dsr, fab_recv_wdsd, .{.async=true, .activate=f_send_data} );
            @load_to_dsr(src0_dsr, mem_x_buf_dsd);
            @load_to_dsr(dest_dsr, mem_x_buf_dsd);
            if (TYPE_BINARY_OP.ADD == functor){
                @fadds(dest_dsr, src0_dsr, src1_dsr, .{.async=true} );
            }else{
                @fmaxs(dest_dsr, src0_dsr, src1_dsr, .{.async=true} );
            }
            count_recv_or_send += 1;
        }else{
            // last PE has received all data from the reduce_pes neighbors
            // wait for TD from 1st PE
            // unblock LOCK, T29 will activate LOCK
            @unblock(C_LOCK);
        }
    }else{
        // other PE (not last PE) sends data and control
        if (count_recv_or_send < 1){
            //@mov32(fab_trans_x_wdsd, mem_x_buf_dsd, .{.async=true, .activate=f_send_ctrl} );
            @load_to_dsr(dest_dsr, fab_trans_x_wdsd, .{.async=true, .activate=f_send_ctrl} );
            @load_to_dsr(src1_dsr, mem_x_buf_dsd);
            @mov32(dest_dsr, src1_dsr, .{.async=true} );
            count_recv_or_send += 1; 
        }else{
            // sending is done (including data wavelets and control wavelets)
            @unblock(C_LOCK);
            // only 1st PE sends TD to other PEs
            // T29 will activate LOCK
            if (reduce_first_pe){
                //@mov32(fab_trans_ctrl_wdsd, mem_buf_td_dsd, .{.async=true} );
                @load_to_dsr(dest_dsr, fab_trans_ctrl_wdsd, .{.async=true} );
                @load_to_dsr(src1_dsr, mem_buf_td_dsd);
                @mov32(dest_dsr, src1_dsr, .{.async=true} );
            }
        }
    }
}


task f_send_ctrl() void{
    if (STATE_BCAST == cur_state){
        //broadcast: Pw-1,h-1 only sends the TD
        // unblock LOCK after TD is sent out
        //@mov32(fab_trans_ctrl_wdsd, mem_buf_td_dsd, .{.async=true, .unblock=C_LOCK} );
        @load_to_dsr(dest_dsr, fab_trans_ctrl_wdsd, .{.async=true, .unblock=C_LOCK} );
        @load_to_dsr(src1_dsr, mem_buf_td_dsd);
        @mov32(dest_dsr, src1_dsr, .{.async=true} );
    }else{
        // reduction: other PEs (not last PE) sends switch advance
        //   last PE does not trigger f_send_ctrl because it only receives data
        // f_send_data() will unblock LOCK
        //@mov32(fab_trans_ctrl_wdsd, mem_ctrl_buf_dsd, .{.async=true, .activate=f_send_data } );
        @load_to_dsr(dest_dsr, fab_trans_ctrl_wdsd, .{.async=true, .activate=f_send_data } );
        @load_to_dsr(src1_dsr, mem_ctrl_buf_dsd);
        @mov32(dest_dsr, src1_dsr, .{.async=true} );
    }
}


// LOCK runs only if TD is received and the operation (*) finishes
//
// Here is the sequence
// - the operation blocks LOCK in the beginning
// - teardown handler activates LOCK
// - the operation unblocks LOCK when it finishes
// - LOCK is picked by the HW scheduler to perform the state transition
//
// (*) operation is {row_reduce, col_reduce, bcast}
//
task f_lock() void {
    cur_state = next_state; // go to next state
    @activate(C_STATE_ENTRY);
}


// (alpha, inv_alpha) = approx(x) approximates x by positive alpha such that
//     x = alpha * (x/alpha)
// where alpha = 2^(exp) and (x/alpha) has no precision loss.
//
// If x is a normal number, |x| = 2^(exp) * r, then alpha = 2^(exp)
//
// The purpose of this approximation is for nrm2(x).
// nrm2(x) can hit overflow if we just do square-sum.
// The simple workaround is to square-sum of x/max(x).
// However the division is very expensive, about 50 cycles.
// We just need a number alpha close to max(x) such that x/alpha = O(1).
// The cost of approx is about 11 instructions, much cheaper than div.
//
// Assume x = sign * 2^(E-127) * mantissa, "approx" handles the following
// four cases:
//
// case 1: x is a normal number
//    0 < E < 255
//    x is normal
//    x = sign * 2^(E-127) * 1.b22b21... b1b0
//    min(x) = 0x0080 0000
//           = 2^(-126) = 1.1754943508 x 10^(-38)
//    max(x) = 0x7f7f ffff
//           = 2^127 x (2 - 2^(-23)) = 3.4028234664 * 10^38
//
// case 2: x is a subnormal number
//    E = 0 and mantissa > 0
//    x = sign * 2^(-127) * b22.b21... b1b0
//      = sign * 2^(-126) * 0.b22b21... b1b0
//    min(x) = 0x000 0001
//           = 2^(-126) x 2^(-23) = 2^(-149) = 1.4*10^(-45)
//    max(x) = 007f ffff
//           = 2^(-126) x (1 - 2^(-23)) = 1.17 x 10^(-38)
//
// case 3: x = 0
//    E = 0 and mantissa = 0
//
// case 4: x = inf or nan
//    inf: E = 255 and mantissa = 0
//    nan: E = 255 and mantissa > 0
//
// Example 1: x = 14.0
//    alpha_u32 = 0x41000000
//    inv_alpha_u32 = 0x3e000000
//    alpha = 8.
//    inv_alpha = 0.125
// Example 2: x = 0.15625
//    alpha_u32 = 0x3e000000
//    inv_alpha_u32 = 0x41000000
//    alpha = 0.125
//    inv_alpha = 8.0
// Example 3: x = 1.e-43
//    alpha_u32 = 0x800000
//    inv_alpha_u32 = 0x7e800000
//    alpha = 1.1754943508222875e-38
//    inv_alpha = 8.507059173023462e+37
// Example 4: x = 1.0/0.0 (np.float32(np.inf))
//    alpha_u32 = 0x3f800000
//    inv_alpha_u32 = 0x3f800000
//    alpha = 1.0
//    inv_alpha = 1.0
// Example 5: x = 0.0/0.0 (np.float32(np.nan))
//    alpha_u32 = 0x3f800000
//    inv_alpha_u32 = 0x3f800000
//    alpha = 1.0
//    inv_alpha = 1.0
//
fn approx(x: f32, alpha: *f32, inv_alpha: *f32) void {
   const MASK_EXPONENT: u32 = 0x7F800000;
   const MASK_MANTISSA: u32 = 0x007FFFFF;
   const x_u32: u32 = @bitcast(u32, x);
   // x is presented by (sign | E | mantissa)
   // sign has 1 bit, E has 8 bits and mantissa has 23 bits
   // E = (x & MASK_EXPONEN) >> 23
   const exp: u32 = (x_u32 & MASK_EXPONENT);
   // mantissa = b22b21...b1b0 has 23-bit, need u32
   const mantissa: u32 = (x_u32) & MASK_MANTISSA;
   // E has 8-bit, use u16
   var E: u16 = @as(u16, (exp >> 23));

   // case 1: 0 < E < 255, x is normal
   // the following if-clause handles case 2, 3 and 4
   if (0 == E){
        if (0 == mantissa){
            // case 3: x = 0
            // reset alpha = 1
            E = 127;
        }else{
            // case 2: x is subnormal
            // reset alpha= 2^(-126)
            E = 1;
        }
    }
    if (255 == E){
        // case 4: x = inf or NAN
        // reset alpha = 1
        E = 127;
    }
    // alpha and inv_alpha are u32
    // alpha = 2^(E - 127)
    // inv_alpha = 1/alpha = 2^(127 - E)
    var alpha_u32: u32 = (@as(u32, E) << 23);
    var inv_alpha_u32: u32 = @as(u32, (254 - E)) << 23;

    alpha.* = @bitcast(f32, alpha_u32);
    inv_alpha.* = @bitcast(f32, inv_alpha_u32);
}


comptime {
    @bind_local_task(f_send_ctrl, C_SEND_CTRL);
    @bind_local_task(f_send_data, C_SEND_DATA);
    @bind_local_task(f_state_entry, C_STATE_ENTRY);
    @bind_local_task(f_lock, C_LOCK);
}


//----------------- the following is the routing of C_ROUTE

const tile_config = @import_module("<tile_config>");

fn rowReduce_configure() void {

    // (1) setup switch according to config parameters
    // 1. pos0 (color config reg)
    // 2. pos1 (switch config reg)
    //    pos1 = {.rx = RAMP} for all PEs except last PE

    // reset switch position to pos0
    // WARNING: if switch config register does not reset the switch position back to pos0,
    // it is possible that some PE is at pos1 after the switch is reconfigured and the sending
    // pattern is messed up, for example, the PE sends data first, then forwards the data from
    // the west.
    var r_switch_state : u16 = @bitcast(*u16, translate_word_to_bytes(D2H_SWITCH_CONFIG_ADDR) ).* ;
    // mask bits 14:13 to reset input&output position to pos0
    r_switch_state = r_switch_state & MASK_SWITCH_RESET_POS0;
    // WARNING: all PEs are configured by
    //   - ".pos1 = .{ .rx = RAMP }"  --> bit 3 is 1
    //   - ".ring_mode = true"  --> bit 12 is 1
    //   - ".pop_mode = .{ .pop_on_advance = true }" --> bits 13:12 of fabric per-color config
    // mask bits 2:0 to clear setting of pos1
    r_switch_state = r_switch_state & MASK_SWITCH_CLEAR_POS1;
    if (last_px){
        // last PE does not have pos1
        r_switch_state = r_switch_state | SWITCH_POS1_INVALID;
    }else{
        // others have ".pos1 = .{ .rx = RAMP }"
        r_switch_state = r_switch_state | SWITCH_POS1_RAMP;
    }
    @bitcast(*u16, translate_word_to_bytes(D2H_SWITCH_CONFIG_ADDR) ).* = r_switch_state;

    var r_state : u16 = @bitcast(*u16, translate_word_to_bytes(D2H_COLOR_CONFIG_ADDR) ).* ;
    // clear input/output switch pos0
    r_state = r_state & MASK_INPUT_OUTPUT_POS0 ;

    if (first_px){
        // 1st PE must has {rx = RAMP} to send out the data
        // .rx = .{ RAMP },.tx = .{ EAST },
        r_state = r_state | INPUT_RAMP | OUTPUT_EAST;
    }else if (last_px){
        // last PE only receives data
        // .rx = .{ WEST }, .tx = .{ RAMP },
        r_state = r_state | INPUT_WEST | OUTPUT_RAMP;
    }else{ 
        // 0 < px < width-1
        // .rx = .{ WEST }, .tx = .{ EAST },
        r_state = r_state | INPUT_WEST | OUTPUT_EAST;
    }
    // update the switch pos0
    @bitcast(*u16, translate_word_to_bytes(D2H_COLOR_CONFIG_ADDR) ).* = r_state;

    // (2) clear teardown-in-progress bit
    // config_reg[c] ^= mask where mask = 1 << 14
    tile_config.teardown.exit(C_ROUTE);
}


fn colReduce_configure() void {

    // (1) setup switch according to config parameters
    // 1. pos0 (color config reg)
    // 2. pos1 (switch config reg)
    //    pos1 = {.rx = RAMP} for all PEs except last PE

    // reset switch position to pos0
    // WARNING: if switch config register does not reset the switch position back to pos0,
    // it is possible that some PE is at pos1 after the switch is reconfigured and the sending
    // pattern is messed up, for example, the PE sends data first, then forwards the data from
    // the west.
    var r_switch_state : u16 = @bitcast(*u16, translate_word_to_bytes(D2H_SWITCH_CONFIG_ADDR) ).* ;
    // mask bits 14:13 to reset input&output position to pos0
    r_switch_state = r_switch_state & MASK_SWITCH_RESET_POS0;
    // WARNING: all PEs are configured by
    //   - ".pos1 = .{ .rx = RAMP }"  --> bit 3 is 1
    //   - ".ring_mode = true"  --> bit 12 is 1
    //   - ".pop_mode = .{ .pop_on_advance = true }" --> bits 13:12 of fabric per-color config
    // mask bits 2:0 to clear setting of pos1
    r_switch_state = r_switch_state & MASK_SWITCH_CLEAR_POS1;
    if (last_py){
        // last PE does not have pos1
        r_switch_state = r_switch_state | SWITCH_POS1_INVALID;
    }else{
        // others have ".pos1 = .{ .rx = RAMP }"
        r_switch_state = r_switch_state | SWITCH_POS1_RAMP;
    }
    @bitcast(*u16, translate_word_to_bytes(D2H_SWITCH_CONFIG_ADDR) ).* = r_switch_state;

    var r_state : u16 = @bitcast(*u16, translate_word_to_bytes(D2H_COLOR_CONFIG_ADDR) ).* ;
    // clear input/output switch pos0
    r_state = r_state & MASK_INPUT_OUTPUT_POS0 ;

    if (first_py){
        // 1st PE must has {rx = RAMP} to send out the data
        // .rx = .{ RAMP },.tx = .{ SOUTH },
        r_state = r_state | INPUT_RAMP | OUTPUT_SOUTH;
    }else if (last_py){
        // last PE only receives data
        // .rx = .{ NORTH }, .tx = .{ RAMP },
        r_state = r_state | INPUT_NORTH | OUTPUT_RAMP;
    }else{
        // 0 < py < width-1
        // .rx = .{ NORTH }, .tx = .{ SOUTH },
        r_state = r_state | INPUT_NORTH | OUTPUT_SOUTH;
    }
    // update the switch pos0
    @bitcast(*u16, translate_word_to_bytes(D2H_COLOR_CONFIG_ADDR) ).* = r_state;

    // (2) clear teardown-in-progress bit
    // config_reg[c] ^= mask where mask = 1 << 14
    tile_config.teardown.exit(C_ROUTE);
}


// w > 1 and h > 1
//  x <-- x <-- x
//              ^
//              |
//  x <-- x <-- x
//              ^
//              |
//  x <-- x <-- x
//
fn bcast_configure() void {

    // (1) setup switch according to config parameters
    // 1. pos0 (color config reg)
    // 2. pos1 (switch config reg)
    //    pos1 = {invalid} for all PEs

    // reset switch position to pos0
    // WARNING: if switch config register does not reset the switch position back to pos0,
    // it is possible that some PE is at pos1 after the switch is reconfigured and the sending
    // pattern is messed up, for example, the PE sends data first, then forwards the data from
    // the west.
    var r_switch_state : u16 = @bitcast(*u16, translate_word_to_bytes(D2H_SWITCH_CONFIG_ADDR) ).* ;
    // mask bits 14:13 to reset input&output position to pos0
    r_switch_state = r_switch_state & MASK_SWITCH_RESET_POS0;
    // WARNING: all PEs have pos0 only, so disable pos1
    //   no change for ring_mode and pop_mode
    //   - ".ring_mode = true"  --> bit 12 is 1
    //   - ".pop_mode = .{ .pop_on_advance = true }" --> bits 13:12 of fabric per-color config
    // mask bits 2:0 to clear setting of pos1
    r_switch_state = r_switch_state & MASK_SWITCH_CLEAR_POS1;
    r_switch_state = r_switch_state | SWITCH_POS1_INVALID;
    @bitcast(*u16, translate_word_to_bytes(D2H_SWITCH_CONFIG_ADDR) ).* = r_switch_state;

    var r_state : u16 = @bitcast(*u16, translate_word_to_bytes(D2H_COLOR_CONFIG_ADDR) ).* ;
    // clear input/output switch pos0
    r_state = r_state & MASK_INPUT_OUTPUT_POS0 ;

    if (last_px){
        // px = w-1
        if (last_py){
            // Pw-1,h-1: send to west and north, { .rx = .{RAMP}, .tx = .{WEST, NOTH} } }
            r_state = r_state | INPUT_RAMP | OUTPUT_WEST | OUTPUT_NORTH;
        }else{
            if (first_py){
                // Pw-1,0: { .rx = .{SOUTH}, .tx = .{WEST, RAMP} }
                r_state = r_state | INPUT_SOUTH | OUTPUT_WEST | OUTPUT_RAMP;
            }else{
                // Pw-1,py: 0 < py < h-1, { .rx = .{SOUTH}, .tx = .{WEST, RAMP, NORTH} }
                r_state = r_state | INPUT_SOUTH | OUTPUT_WEST | OUTPUT_RAMP | OUTPUT_NORTH;
            }
        }
    }else{
        if (first_px){
            // px = 0, {.rx = .{EAST}, .tx = .{RAMP}}
            r_state = r_state | INPUT_EAST | OUTPUT_RAMP;
        }else{
            // 0 < px < w-1, { .rx = .{EAST}, .tx = .{WEST, RAMP} }
            r_state = r_state | INPUT_EAST | OUTPUT_RAMP | OUTPUT_WEST;
        }
    }

    // update the switch pos0
    @bitcast(*u16, translate_word_to_bytes(D2H_COLOR_CONFIG_ADDR) ).* = r_state;

    // (2) clear teardown-in-progress bit
    // config_reg[c] ^= mask where mask = 1 << 14
    tile_config.teardown.exit(C_ROUTE);
}

// state 1: row-reduce
// state 2: col-reduce
// state 3: bcast
//
fn teardown_allreduce() void {
    // turn C_ROUTE back to teardown mode
    // LOCK can be picked only when the operation finishes
    @activate(C_LOCK);
}

comptime {
    @set_teardown_handler(teardown_allreduce, C_ROUTE);
}

//
// routing of C_ROUTE (send data to west, from leftmost)
//    -->   --->-->   -->-->
//    ^
//    |
//   sw_adv
//    -->      -->   -->-->
//    ^        ^
//    |        |
//   data     data
//            sw_adv
//    -->     --> -->     -->
//    ^                   ^
//    |                   |
//   sw_adv              data
//                      sw_adv
//    -->       -->    --> -->
//    ^         ^
//    |         |
//             data
//             sw_adv
//
comptime {

    // The switch must work for different operations, including
    //   - row reduction
    //   - column reduction
    //   - broadcasting
    // The initial setting must be universal so we can reconfigure the
    // switch for these three operations at runtime
    //
    // We need to set invariant parameters at comptime (runtime does not alter):
    // 1. teardown mode at comptime
    //   {.teardown = true} implies color is in teardown mode at comptime
    // 2. ring mode
    //   ".ring_mode = true"  --> fabric switch config reg sets bit 12 as 1
    // 3. pop on advance
    //   ".pop_mode = .{ .pop_on_advance = true }" --> fabric per-color config reg sets bits 13:12
    // 4. position 1
    //   ".pos1 = .{ .rx = RAMP }"  --> fabric switch config reg sets bit 3 as 1
    //
    // The following (last) PEs do not have position 1:
    //   - "px = width" for row reduction
    //   - "py = height" for column reduction
    //   - all for broadcasting
    // The runtime resets position 1 (bits 2:0 of fabric switch config) to either
    //   SWITCH_POS1_INVALID to disable position 1 or
    //   SWITCH_POS1_RAMP to reset position 1 back to ".pos1 = .{ .rx = RAMP }"
    // The bit 3 of fabric switch config is always 1 (position 1 switch select is "1=input")
    // If position 1 is disabled, bit 3 is don't care
    // If position 1 is disabled, pop mode is also don't care because of NOCE
    // If position 1 is disabled, ring mode is also don't care
    //
    // Remark: we don't use ".pop_mode = .{ .always_pop = true }" because there is no need
    // to propagate the TD to mux. All PEs have a teardown handler to deal with this TD, so
    // we do not need to pop out an instruction in TD wavelet, for example
    //     0x9df 9249 --> 0x91f 9249
    // (The instruction is NOT teardown 0b111, but 0b100 (NOCE, NOP))
    // (teardown = 0x9df,9249  (cmd1=teardown+NOCE, others=NOP+NOCE))
    //
    // The original setting of row reduction
    // 1st PE: px = 0
    //   .pos0 = .{ .rx = .{ RAMP }, .tx = .{ EAST }}
    //   .pop_mode = .{ .pop_on_advance = true }
    //   .pos1 = .{ .rx = RAMP }
    //   .ring_mode = true
    //   .teardown = true
    // middle: 1st PE < px < last PE
    //   .pos0 = .{ .rx = .{ WEST }, .tx = .{ EAST }}
    //   .pop_mode = .{ .pop_on_advance = true }
    //   .pos1 = .{ .rx = RAMP }
    //   .ring_mode = true
    //   .teardown = true
    // last PE: px = w-1
    //   .pos0 = .{ .rx = .{ WEST }, .tx = .{ RAMP }}
    //   .teardown = true
    //
    // The original setting of column reduction
    // 1st PE: py = 0
    //   .pos0 = .{ .rx = .{ RAMP }, .tx = .{ SOUTH }}
    //   .pop_mode = .{ .pop_on_advance = true }
    //   .pos1 = .{ .rx = RAMP }
    //   .ring_mode = true
    //   .teardown = true
    // middle: 1st PE < py < last PE
    //   .pos0 = .{ .rx = .{ NORTH }, .tx = .{ SOUTH }}
    //   .pop_mode = .{ .pop_on_advance = true }
    //   .pos1 = .{ .rx = RAMP }
    //   .ring_mode = true
    //   .teardown = true
    // last PE: py = h-1
    //   .pos0 = .{ .rx = .{ NORTH }, .tx = .{ RAMP }}
    //   .teardown = true
    //
    const universalConfig = .{
        .routes= .{
            .rx = .{ WEST },
            .tx = .{ EAST },
        },
        .switches=.{
            .pos1 = .{ .rx = RAMP },
            .ring_mode = true,
            .pop_mode = .{ .pop_on_advance = true },
        },
        .teardown = true
    };

    if (1 == width){
        @comptime_assert(1 < width);
    }else{
        @set_local_color_config(C_ROUTE, universalConfig);
    }
}


// binding a color to an input queue.
// This is necessary when an explicit DSR binds to a fabin DSD because
// the compiler no longer can generate the instruction to set up the
// config register of input queue.
comptime {
    @initialize_queue(@get_input_queue(queues[0]), .{.color = C_ROUTE} );
}

commands.sh

#!/usr/bin/env bash

set -e

cslc ./layout.csl --fabric-dims=12,7 --fabric-offsets=4,1 \
--params=width:5,height:5,MAX_ZDIM:5 --params=BLOCK_SIZE:2 --params=C0_ID:0 \
--params=C1_ID:1 --params=C2_ID:2 --params=C3_ID:3 --params=C4_ID:4 --params=C5_ID:5 \
--params=C6_ID:6 --params=C7_ID:7 --params=C8_ID:8 -o=out \
--memcpy --channels=1 --width-west-buf=0 --width-east-buf=0
cs_python ./run.py -m=5 -n=5 -k=5 --latestlink out --channels=1 \
--width-west-buf=0 --width-east-buf=0 --zDim=5 --run-only