spmv-hypersparse

This example evaluates the performance of sparse matrix-vector multiplication. The kernel records the start and end of spmv by tsc counter. In addition the tsc counters of all PEs are not sychronized in the beginning. To avoid the timing variation among those PEs, f_sync() synchronizes all PEs and samples the reference clock.

The kernel kernel.csl defines a couple of host-callable functions, f_sync(), f_tic() and f_toc() in order to synchronize the PEs and record the timing of spmv.

The kernel allreduce2R1E/pe.csl performs a reduction over the whole rectangle to synchronize the PEs, then the bottom-right PE sends a signal to other PEs to sample the reference clock. The allreduce2R1E is a variant of allreduce in stencil-3d-7pts. The former uses 2 routable colors and 1 entrypoints, the latter uses 1 routable color and 4 entrypoints. allreduce2R1E is designed for spmv kernel which only has three unused colors.

The kernel hypersparse_spmv/pe.csl performs a matrix-vector product (spmv) where the matrix A is hypersparse, partitioned into 2D grids. The input vector x and output vector y are also distributed into 2D grids.

The user has to provide the matrix A in Matrix Market File format with 1-based index. To obtain the best performance, the user may need to reorder the matrix such that the variatoin of the nonzeros of each parition is small. One option is util/analyze.cpp which provides a load balancing algorithm.

The script run.py has the following parameters:

  • --infile_mtx=<path to mtx file> contains the sparse matrix A

  • --num_pe_rows=<int> specifies the height of the core rectangle

  • --num_pe_cols=<int> specifies the width of the core rectangle

  • --channels=<int> specifies the number of I/O channels, no bigger than 16.

The tic() samples “time_start” and toc() samples “time_end”. The sync() samples “time_ref” which is used to adjust “time_start” and “time_end”. The elapsed time (unit: cycles) is measured by cycles_send = max(time_end) - min(time_start)

The overall runtime (us) is computed via the following formula time_send = (cycles_send / 0.85) * 1.e-3 us

The bandwidth is calculated by bandwidth = ((2*nnz+m)*4)/time_send)

See the SDK examples repository or the release tarball for additional supporting data preparation scripts.

layout.csl

// color var           color  var           color  var                color  var
//   0                 10  init             20   tx_east              30 reserved (memcpy)
//   1  c0             11  compute_north    21   reserved (memcpy)    31 reserved
//   2  c1             12  compute_south    22   reserved (memcpy)    32
//   3  c2             13  tx_north         23   reserved (memcpy)    33 reserved (memcpy)
//   4  c3             14  tx_south         24   compute_local        34 reserved (memcpy)
//   5  c4             15  rx_north         25   curr_rx_north_done   35 reserved (memcpy)
//   6  c5             16  rx_south         26   curr_rx_south_done   36 reserved (memcpy)
//   7  allreduce_c0   17  rx_east          27   reserved (memcpy)    37 reserved (memcpy)
//   8  allreduce_c1   18  rx_west          28   reserved (memcpy)
//   9  allreduce_EN1  19  tx_west          29   reserved (memcpy)

// routable colors for spmv
param c0 = @get_color(1);
param c1 = @get_color(2);
param c2 = @get_color(3);
param c3 = @get_color(4);
param c4 = @get_color(5);
param c5 = @get_color(6);

// routable colors for allreduce
param allreduce_c0 = @get_color(7);
param allreduce_c1 = @get_color(8);
// entrypoint for allreduce
param allreduce_EN1: local_task_id = @get_local_task_id(9);

// entrypoints for spmv
param EN1: local_task_id = @get_local_task_id(10);
param EN2: local_task_id = @get_local_task_id(11);
param EN3: local_task_id = @get_local_task_id(12);
param EN4: local_task_id = @get_local_task_id(13);
param EN5: local_task_id = @get_local_task_id(14);
param EN6: local_task_id = @get_local_task_id(15);
param EN7: local_task_id = @get_local_task_id(16);
param EN8: local_task_id = @get_local_task_id(17);
param EN9: local_task_id = @get_local_task_id(18);
param EN10: local_task_id = @get_local_task_id(19);
param EN11: local_task_id = @get_local_task_id(20);
param EN12: local_task_id = @get_local_task_id(24);
param EN13: local_task_id = @get_local_task_id(25);
param EN14: local_task_id = @get_local_task_id(26);


// parameters of spmv layout
//          pcols
//       +----------+
// prows |  core    |
//       +----------+
//
param prows: u16;   // number of PE rows (height of the core rectangle)
param pcols: u16;   // number of PE cols (width of the core rectangle)

// structure of the matrix
param nrows: u32;   // total number of matrix rows
param ncols: u32;   // total number of matrix cols
param max_local_nnz: u16;       // max of the local number of nonzeros (among all PEs)
param max_local_nnz_cols: u16;  // max of the local nnz cols
param max_local_nnz_rows: u16;  // max of the local nnz rows
param local_vec_sz: u16;        // size of local vector
param local_out_vec_sz: u16;    // size of local vector
param y_pad_start_row_idx: u16; // local row index where padding starts


const spmv = @import_module( "hypersparse_spmv/layout.csl", .{
    .colors = [6]color{c0, c1, c2, c3, c4, c5},
    .entrypoints = [14]local_task_id{EN1, EN2, EN3, EN4, EN5, EN6, EN7, EN8, EN9, EN10, EN11, EN12, EN13, EN14},
    .width = pcols,
    .height = prows
    });

const reduce = @import_module( "allreduce2R1E/layout.csl", .{
    .colors = [2]color{allreduce_c0, allreduce_c1},
    .entrypoints = [1]local_task_id{allreduce_EN1},
    .width = pcols,
    .height = prows
    });

const memcpy = @import_module( "<memcpy/get_params>", .{
    .width = pcols,
    .height = prows,
    });

layout {
    // NOTE: This scheme assumes prows >= 4
    @comptime_assert(prows >= 4);

    //         --> px = pcol_id
    //          pcols
    //       +----------+
    // prows |  core    |  | py = prow_id
    //       |          |  V
    //       +----------+
    @set_rectangle(pcols, prows);

    var pcol_id: u16 = 0;
    while (pcol_id < pcols) : (pcol_id += 1) {

        var prow_id: u16 = 0;
        while (prow_id < prows) : (prow_id += 1) {

            const memcpyParams = memcpy.get_params(pcol_id);
            const spmvParams = spmv.get_params(pcol_id, prow_id);
            const reduceParams = reduce.get_params(pcol_id, prow_id);
            var params: comptime_struct = .{
                .memcpyParams = memcpyParams,
                .spmvParams = spmvParams,
                .reduceParams = reduceParams,
                .nrows = nrows,
                .ncols = ncols,
                .local_vec_sz = local_vec_sz,
                .max_local_nnz = max_local_nnz,
                .max_local_nnz_cols = max_local_nnz_cols,
                .max_local_nnz_rows = max_local_nnz_rows,
                .local_out_vec_sz = local_out_vec_sz,
                .y_pad_start_row_idx = y_pad_start_row_idx,
            };
            @set_tile_code(pcol_id, prow_id, "kernel.csl", params);

        } // while prow_id
    } // while pcol_id

    @export_name("mat_vals_buf", [*]f32, true);
    @export_name("x_tx_buf", [*]f32, true);
    @export_name("y_local_buf", [*]f32, true);

    @export_name("mat_rows_buf", [*]u16, true);
    @export_name("mat_col_idx_buf", [*]u16, true);
    @export_name("mat_col_loc_buf", [*]u16, true);
    @export_name("mat_col_len_buf", [*]u16, true);
    @export_name("y_rows_init_buf", [*]u16, true);

    @export_name("local_nnz", [*]u16, true);
    @export_name("local_nnz_cols", [*]u16, true);
    @export_name("local_nnz_rows", [*]u16, true);

    @export_name("time_buf_u16", [*]u16, true);

    @export_name("time_ref_u16", [*]u16, true);

    @export_name("f_enable_tsc", fn()void);
    @export_name("f_tic", fn()void);
    @export_name("f_toc", fn()void);
    @export_name("f_spmv", fn()void);
    @export_name("f_memcpy_timestamps", fn()void);
    @export_name("f_sync", fn(i16)void);
    @export_name("f_reference_timestamps", fn()void);
}

kernel.csl

param memcpyParams: comptime_struct;

param spmvParams: comptime_struct;

param reduceParams: comptime_struct;

// parameters
param nrows: u32;   // total number of matrix rows
param ncols: u32;   // total number of matrix cols (= nrows)
param max_local_nnz: u16;       // max of the local number of nonzeros (among all PEs)
param max_local_nnz_cols: u16;  // max of the local nnz cols
param max_local_nnz_rows: u16;  // max of the local nnz rows
param local_vec_sz: u16;    // size of local vector
param local_out_vec_sz: u16;    // size of local vector
param y_pad_start_row_idx: u16;   // local row index where padding starts

// data buffers
// input matrix
var mat_vals_buf = @zeros([max_local_nnz]f32);      // in matrix values (sparse): 4B
// input vector: for north-going and south-going trains
// buffer storing data for tx
var x_tx_buf = @zeros([local_vec_sz]f32);       // in vector values (dense): 4B

var mat_rows_buf = @zeros([max_local_nnz]u16);      // in matrix relative row offsets: 2B
                                                // need this in preprocessing: 2B
var mat_col_idx_buf = @zeros([max_local_nnz_cols]u16);   // column idx of nnz cols (max possible size is nnz)
var mat_col_loc_buf = @zeros([max_local_nnz_cols]u16);   // col location in mat_vals_buf and mat_rows_buf (max nnz)
var mat_col_len_buf = @zeros([max_local_nnz_cols]u16);   // col length (nnz rows in a col)
// precomputed output vector (sparse format) local rows index information
var y_rows_init_buf = @zeros([max_local_nnz_rows]u16);       // init -- this should not be modified

var local_nnz = @zeros([1]u16);         // actual local number of nonzeros
var local_nnz_cols = @zeros([1]u16);    // actual local number of nnz cols
var local_nnz_rows = @zeros([1]u16);    // actual local number of nnz rows

// final reduced local output vector (dense)
var y_local_buf = @zeros([local_out_vec_sz]f32);

// temporary buffer for allreduce
var dot = @zeros([1]f32);

const timestamp = @import_module("<time>");

const sys_mod = @import_module( "<memcpy/memcpy>", memcpyParams);

// input_queues cannot overlap with output_queues
const spmv_mod = @import_module( "hypersparse_spmv/pe.csl", @concat_structs(spmvParams, .{
     .f_callback = sys_mod.unblock_cmd_stream,

     .nrows = nrows,
     .ncols = ncols,
     .local_vec_sz = local_vec_sz,
     .max_local_nnz = max_local_nnz,
     .max_local_nnz_cols = max_local_nnz_cols,
     .max_local_nnz_rows = max_local_nnz_rows,
     .local_out_vec_sz = local_out_vec_sz,
     .y_pad_start_row_idx = y_pad_start_row_idx,

     .mat_vals_buf = &mat_vals_buf,
     .mat_rows_buf = &mat_rows_buf,
     .mat_col_idx_buf = &mat_col_idx_buf,
     .mat_col_loc_buf = &mat_col_loc_buf,
     .mat_col_len_buf = &mat_col_len_buf,
     .y_rows_init_buf = &y_rows_init_buf,
     .local_nnz = &local_nnz,
     .local_nnz_cols = &local_nnz_cols,
     .local_nnz_rows = &local_nnz_rows,

     .input_queues=[4]u16{4, 1, 6, 7},
     .output_queues=[2]u16{2,3},
     .dest_dsr_ids = [6]u16{1, 4, 5, 6, 2, 3},
     .src1_dsr_ids = [6]u16{4, 1, 6, 7, 2, 3},
     }));

// allreduce uses input queue/output queue 5
// dest_dsr and src0_dsr must be a valid pair, for example (7,1) is invalid
const reduce_mod = @import_module( "allreduce2R1E/pe.csl", @concat_structs(reduceParams, .{
     .f_callback = sys_mod.unblock_cmd_stream,
     .MAX_ZDIM = 1,
     .queues = [1]u16{5},
     .dest_dsr_ids = [1]u16{7},
     .src0_dsr_ids = [1]u16{7},
     .src1_dsr_ids = [1]u16{5}
     }));

// tsc library
var tsc_start_buffer = @zeros([timestamp.tsc_size_words]u16);
var tsc_end_buffer = @zeros([timestamp.tsc_size_words]u16);

// time_buf_u16[0:5] = {tsc_start_buffer, tsc_end_buffer}
var time_buf_u16 = @zeros([timestamp.tsc_size_words*2]u16);
var ptr_time_buf_u16: [*]u16 = &time_buf_u16;

// reference clock inside allreduce module
var time_ref_u16 = @zeros([timestamp.tsc_size_words]u16);
var ptr_time_ref_u16: [*]u16 = &time_ref_u16;

var ptr_mat_vals_buf: [*]f32 = &mat_vals_buf;
var ptr_x_tx_buf: [*]f32 = &x_tx_buf;
var ptr_y_local_buf: [*]f32 = &y_local_buf;
var ptr_mat_rows_buf: [*]u16 = &mat_rows_buf;
var ptr_mat_col_idx_buf: [*]u16 = &mat_col_idx_buf;
var ptr_mat_col_loc_buf: [*]u16 = &mat_col_loc_buf;
var ptr_mat_col_len_buf: [*]u16 = &mat_col_len_buf;
var ptr_y_rows_init_buf: [*]u16 = &y_rows_init_buf;
var ptr_local_nnz: [*]u16 = &local_nnz;
var ptr_local_nnz_cols: [*]u16 = &local_nnz_cols;
var ptr_local_nnz_rows: [*]u16 = &local_nnz_rows;


fn f_enable_tsc() void {
    timestamp.enable_tsc();

    // the user must unblock cmd color for every PE
    sys_mod.unblock_cmd_stream();
}

fn f_tic() void {
    timestamp.get_timestamp(&tsc_start_buffer);

    // the user must unblock cmd color for every PE
    sys_mod.unblock_cmd_stream();
}

fn f_toc() void {
    timestamp.get_timestamp(&tsc_end_buffer);

    // the user must unblock cmd color for every PE
    sys_mod.unblock_cmd_stream();
}

// compute y = A*x
//
// To ping-pong the spmv by
//    spmv(x, y) // y = A*x
//    spmv(y, x) // x = A*y
// we need to make sure local_vec_sz = local_out_vec_sz, otherwise compilation fails
// because of mismatch of the dimensions
//
fn f_spmv() void {
    spmv_mod.spmv(&x_tx_buf, &y_local_buf);
}

fn f_memcpy_timestamps() void {

    time_buf_u16[0] = tsc_start_buffer[0];
    time_buf_u16[1] = tsc_start_buffer[1];
    time_buf_u16[2] = tsc_start_buffer[2];
    time_buf_u16[3] = tsc_end_buffer[0];
    time_buf_u16[4] = tsc_end_buffer[1];
    time_buf_u16[5] = tsc_end_buffer[2];

    // the user must unblock cmd color for every PE
    sys_mod.unblock_cmd_stream();
}

fn f_sync( n: i16 ) void {
   reduce_mod.allreduce(n, &dot);
}

fn f_reference_timestamps() void {

    time_ref_u16[0] = reduce_mod.tscRefBuffer[0];
    time_ref_u16[1] = reduce_mod.tscRefBuffer[1];
    time_ref_u16[2] = reduce_mod.tscRefBuffer[2];

    // the user must unblock cmd color for every PE
    sys_mod.unblock_cmd_stream();
}

comptime{

    @export_symbol(ptr_mat_vals_buf, "mat_vals_buf");
    @export_symbol(ptr_x_tx_buf, "x_tx_buf");
    @export_symbol(ptr_y_local_buf, "y_local_buf");

    @export_symbol(ptr_mat_rows_buf, "mat_rows_buf");
    @export_symbol(ptr_mat_col_idx_buf, "mat_col_idx_buf");
    @export_symbol(ptr_mat_col_loc_buf, "mat_col_loc_buf");
    @export_symbol(ptr_mat_col_len_buf, "mat_col_len_buf");
    @export_symbol(ptr_y_rows_init_buf, "y_rows_init_buf");

    @export_symbol(ptr_local_nnz, "local_nnz");
    @export_symbol(ptr_local_nnz_cols, "local_nnz_cols");
    @export_symbol(ptr_local_nnz_rows, "local_nnz_rows");

    @export_symbol(ptr_time_buf_u16, "time_buf_u16");

    @export_symbol(ptr_time_ref_u16, "time_ref_u16");
}


comptime{
    @export_symbol(f_enable_tsc);
    @export_symbol(f_tic);
    @export_symbol(f_toc);
    @export_symbol(f_memcpy_timestamps);
    @export_symbol(f_spmv);
    @export_symbol(f_sync);
    @export_symbol(f_reference_timestamps);
}

run.py

#!/usr/bin/env cs_python
# pylint: disable=too-many-function-args

""" test sparse matrix-vector multiplication

  This example aims at a hypersparse matrix with almost uniform distribution.
  The algorithm partitions the sparse matrix into 2D grids. The algorithm may
  fail if there exists one parition which has too many nonzeros to fit the
  memory capacity (48KB) of the PE.

  To obtain the best performance, the user may need to reorder the matrix such
  that the variatoin of the nonzeros of each parition is small.

  To run this example, the user has to provide a file of Matrix Market File
  format with 1-based index. For example, the user can reorder the matrix A by
  the permutation matrices P and Q, and writes P*A*Q^T to a file. One option is
  "util/analyze.cpp" which provides a load balancing algorithm.

  This example reads a MTX file, generates the vector x, partitions the matrix,
  and computes y = A*x.

  The framework is
  ---
       sync()  // synchronize all PEs to sample the reference clock
       tic()   // record start time
       spmv()  // compute y = A*x
       toc()   // record end time
  ---

  The tic() samples "time_start" and toc() samples "time_end". The sync() samples
  "time_ref" which is used to shift "time_start" and "time_end".
  The elapsed time is measured by
       cycles_send = max(time_end) - min(time_start)

  The overall runtime is computed via the following formula
       time_send = (cycles_send / 0.85) *1.e-3 us
  where a PE runs with clock speed 850MHz

  The spmv kernel performs y = A * x
  where A is m-by-n with nnz nonzeros

  The standard measurement counts the number of memory access of
       y[i] = sum{ Aij * xj : Aij is nonzero }
  - read Aij: nnz
  - read xj: nnz
  - write y[i]: m
  Total number of memory access: (2*nnz + m) f32

  Here is the list of parameters:
    --infile_mtx=<path to mtx file> contains the sparse matrix A
    --num_pe_rows=<int> specifies the height of the core rectangle
    --num_pe_cols=<int> specifies the width of the core rectangle
    --channels=<int> specifies the number of I/O channels, no bigger than 16

  How to compile and run
     To build a 5-by-4 core rectangle, we need to pass --num_pe_cols=5 --num_pe_rows=4
     Use the following command to compile
        python run.py --arch=wse2 --num_pe_cols=5 --num_pe_rows=4 --channels=1
           --driver=<path to cslc> --compile-only --infile_mtx=<path to mtx file>
     Use the following command to run
        python run.py --arch=wse2 --num_pe_cols=5 --num_pe_rows=4 --channels=1
           --is_weight_one --run-only --infile_mtx=<path to mtx file>
"""

import os, sys
import subprocess
import time
import math
import numpy as np
import scipy.sparse as sparse
import shutil

from pathlib import Path
from datetime import datetime
from typing import Optional

from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime, MemcpyDataType, MemcpyOrder # pylint: disable=no-name-in-module
#from cerebras.sdk.debug.debug_util import debug_util

from cmd_parser import parse_args

from memory_usage import memory_per_pe

from scipy.io import mmread

from preprocess import preprocess


def make_u48(words):
  return words[0] + (words[1] << 16) + (words[2] << 32)


def hwl_to_oned_colmajor(
    height: int,
    width: int,
    pe_length: int,
    A_hwl: np.ndarray,
    dtype
):
    """
    Given a 3-D tensor A[height][width][pe_length], transform it to
    1D array by column-major
    """
    if A_hwl.dtype == np.float32:
        A_1d = np.zeros(height*width*pe_length, dtype)
        idx = 0
        for l in range(pe_length):
            for w in range(width):
                for h in range(height):
                    A_1d[idx] = A_hwl[(h, w, l)]
                    idx = idx + 1
    elif A_hwl.dtype == np.uint16:
        assert dtype == np.uint32, "only support dtype = u32 if A is f16"
        A_1d = np.zeros(height*width*pe_length, dtype)
        idx = 0
        for l in range(pe_length):
            for w in range(width):
                for h in range(height):
                    x = A_hwl[(h, w, l)]
                    # x can be (np.float16, np.int16, np.uint16)
                    # convert x to u16
                    z = x.view(np.uint16)
                    # zero extension of u16
                    A_1d[idx] = np.uint32(z)
                    idx = idx + 1
    else:
        raise RuntimeError(f"{type(A_hwl)} is not supported")

    return A_1d


def oned_to_hwl_colmajor(
    height: int,
    width: int,
    pe_length: int,
    A_1d: np.ndarray,
    dtype
):
    """
    Given a 1-D tensor A_1d[height*width*pe_length], transform it to
    3-D tensor A[height][width][pe_length] by column-major
    """
    if dtype == np.float32:
        # only support f32 to f32
        assert A_1d.dtype == np.float32, "only support f32 to f32"
        A_hwl = np.reshape(A_1d, (height, width, pe_length), order='F')

    elif dtype == np.uint16:
        # only support u32 to u16 by dropping upper 16-bit
        assert A_1d.dtype == np.uint32, "only support u32 to u16"
        A_hwl = np.zeros((height, width, pe_length), dtype)
        idx = 0
        for l in range(pe_length):
            for w in range(width):
                for h in range(height):
                    x = A_1d[idx]
                    x = x & 0x0000FFFF # drop upper 16-bit
                    A_hwl[(h, w, l)] = np.uint16(x)
                    idx = idx + 1
    else:
        raise RuntimeError(f"{dtype} is not supported")

    return A_hwl


def read_input_vector(IS_INVEC_1, vec_len):
    if IS_INVEC_1:
        return np.ones(vec_len).astype(np.float32)
    else:
        np.random.seed(0)
        return np.random.rand(vec_len).astype(np.float32)


# x is distributed into the core rectangle by the following steps
# step 1: distribute x into columns
#    vec_len_per_pe_col = ceil(vec_len / np_cols)
# step 2: distribute the column into PEs
#    vec_len_per_pe = ceil(vec_len_per_pe_col / np_rows)
#
# For example, if core rectangle is 2-by-2 and local_vec_sz is 13
#    Each column has vec_len_per_pe_col = ceil(13/2) = 7
#    The size of result is 7*2 = 14 which is bigger than local_vec_sz due to padding
#    Each PE has vec_len_per_pe = ceil(7/2) = 4
#
# If x is {1,2,3,4,5,6,7,8,9,10,11,12,13}, the core has
#          PE.x=0      PE.x=1
#    +-------------+-------------+
#    | {1,2,3,4}   | {8,9,10,11} | PE.y=0
#    +-------------+-------------+
#    | {5,6,7,x}   | {12,13,x,x} | PE.y=1
#    +-------------+-------------+
# column 0 has 7 elements, {1,2,3,4,5,6,7}
# column 1 has 6 elements, {8,9,10,11,12,13}
#
# The symbol x is DON'T CARE
#
def dist_x_to_hwl(ncols, x, local_vec_sz, np_cols, np_rows):
    # core rectangle is np_cols-by-np_rows
    #            np_cols
    #         +----------+
    # np_rows |  core    |
    #         +----------+
    # input vector is distributed into columns, then distributed into rows

    vec_len = ncols
    vec_len_per_pe_col = math.ceil(vec_len / np_cols)
    vec_len_per_pe = math.ceil(vec_len_per_pe_col / np_rows)
    assert(vec_len_per_pe == local_vec_sz)

    pad_len_per_pe_col = (vec_len_per_pe * np_rows) - vec_len_per_pe_col

    pad_len = (vec_len_per_pe_col * np_cols) - vec_len
    # invec = [x, ones(pad_len)]
    invec = np.copy(x)
    ## BIG NOTE: Since this is input vector, padding needs to be 1s
    if pad_len > 0:
        invec = np.append(invec, np.ones(pad_len))

    x_hwl = np.zeros( (np_rows, np_cols, vec_len_per_pe), x.dtype)
    ## now this is equally divided into np_cols
    for col in range(np_cols):
        ## get the slice for this col and append padding
        invec_col = invec[col * vec_len_per_pe_col : (col + 1) * vec_len_per_pe_col]
        if pad_len_per_pe_col > 0:
            invec_col = np.append(invec_col, np.ones(pad_len_per_pe_col)).astype(x.dtype)
        ## now this is equally divided into np_rows
        for row in range(np_rows):
            ## get the slice for this row
            data = invec_col[row * vec_len_per_pe : (row + 1) * vec_len_per_pe]
            x_hwl[(row, col)] = data

    return x_hwl

# The dimension of out_vec is h-by-w-by-l
# h = np_rows is the height of the core
# w = np_cols is the width of the core
# l = local_out_vec_sz is the size of local vector
#
# The out_vec_sz is the length of y = A*x
#
# y is distributed into the core rectangle by the following steps
# step 1: distribute y into rows
#    vec_len_per_pe_row = math.ceil(out_vec_sz / np_rows)
# step 2: distribute the row into PEs
#    vec_len_per_pe = math.ceil(vec_len_per_pe_row / np_cols)
#
# If out_vec_sz is smaller than (vec_len_per_pe_row*np_rows), padding is added
#
# The function unpad_3d_to_1d returns a result of size (vec_len_per_pe_row*np_rows)
#
# For example, if core rectangle is 2-by-2 and out_vec_sz is 13
#    Each row has vec_len_per_pe_row = ceil(13/2) = 7
#    The size of result is 7*2 = 14 which is bigger than out_vec_sz due to padding
#    Each PE has vec_len_per_pe = ceil(7/2) = 4
#
# If y is {1,2,3,4,5,6,7,8,9,10,11,12,13}, the core has
#          PE.x=0      PE.x=1
#    +-------------+-------------+
#    | {1,2,3,4}   | {5,6,7,x}   | PE.y=0
#    +-------------+-------------+
#    | {8,9,10,11} | {12,13,x,x} | PE.y=1
#    +-------------+-------------+
# row 0 has 7 elements, {1,2,3,4,5,6,7
# row 1 has 6 elements, {8,9,10,11,12,13}
#
# The symbol x is DON'T CARE
#
def unpad_3d_to_1d(out_vec_sz, out_vec):
    assert 3 == out_vec.ndim, "y must be a 3-d tensor of the form h-by-w-by-l"
    (height, width, local_out_vec_sz) = out_vec.shape
    # core rectangle is np_cols-by-np_rows
    #            np_cols
    #         +----------+
    # np_rows |  core    |
    #         +----------+
    np_rows = height
    np_cols = width

    vec_len_per_pe_row = math.ceil(out_vec_sz / np_rows)
    vec_len_per_pe = math.ceil(vec_len_per_pe_row / np_cols)
    # check if local_out_vec_sz = math.ceil(math.ceil(out_vec_sz / np_rows) / np_cols)
    assert(vec_len_per_pe == local_out_vec_sz)

    # result includes the padding
    #    y = result[0:out_vec_sz]
    # clear result to avoid bogus value outside the range [0, out_vec_sz)
    result = np.zeros(vec_len_per_pe_row * np_rows, dtype = np.float32)
    # tmp_buf contains the padding one row PEs
    # tmp_buf gathers data of a whole row PE
    tmp_buf = np.empty(vec_len_per_pe * np_cols, dtype = np.float32)
    for row in range(np_rows):
        low_idx = row * vec_len_per_pe_row
        high_idx = low_idx + vec_len_per_pe_row
        # gather data into tmp_buf
        for col in range(np_cols):
            start = col * vec_len_per_pe
            end = start + vec_len_per_pe
            tmp_buf[start:end] = out_vec[(row, col)]
        result[low_idx:high_idx] = tmp_buf[0:vec_len_per_pe_row]
    return result


def verify_result(ref, res):
    print(f'Comparing result with reference...')
    abs_diff = np.sum(abs(ref - res))
    abs_rel = abs_diff / len(ref)
    print(f'reference[{len(ref)}]: \n{ref}')
    print(f'result   [{len(res)}]: \n{res}')
    print(f'[[ Absolute diff: {abs_diff} ]]')
    print(f'[[ Average diff : {abs_rel} ]]')
    atol = 1e-8
    rtol = 1e-5
    is_correct = np.allclose(ref, res, rtol, atol)
    result = 'PASS' if is_correct else 'FAIL'
    print(f'[[ Result within tolerance {atol}: {result} ]]')
    print(f'[[ Result within tolerance {atol}: {result} ]]')
    if not is_correct:
        import pandas as pd
        unequal = ~np.isclose(ref, res)
        unequal_idx = list(np.where(unequal))
        mismatches = list(zip(ref[tuple(unequal_idx)], res[tuple(unequal_idx)]))
        df = pd.DataFrame(mismatches, columns=['reference', 'result'], index=unequal_idx)
        print(f'{df}')


# y = A*x
# where A is nrows-by-ncols, represented by a CSR triplet
def generate_reference(nrows, ncols, csrRowPtr, csrColInd, csrVal, x):
    assert ncols == len(x), "the dimension of x does not match the dimension of A"
    mat = sparse.csr_matrix((csrVal, csrColInd, csrRowPtr), shape=(nrows, ncols))
    y = mat.dot(np.array(x).transpose())
    return y


def timing_analysis(height, width, nrows, ncols, nnz, time_memcpy_hwl, time_ref_hwl):
    time_start = np.zeros((height, width)).astype(int)
    time_end = np.zeros((height, width)).astype(int)
    word = np.zeros(3).astype(np.uint16)
    for w in range(width):
        for h in range(height):
            word[0] = time_memcpy_hwl[(h, w, 0)]
            word[1] = time_memcpy_hwl[(h, w, 1)]
            word[2] = time_memcpy_hwl[(h, w, 2)]
            time_start[(h,w)] = make_u48(word)
            word[0] = time_memcpy_hwl[(h, w, 3)]
            word[1] = time_memcpy_hwl[(h, w, 4)]
            word[2] = time_memcpy_hwl[(h, w, 5)]
            time_end[(h,w)] = make_u48(word)

    # time_ref = reference clock
    time_ref = np.zeros((height, width)).astype(int)
    word = np.zeros(3).astype(np.uint16)
    for w in range(width):
        for h in range(height):
            word[0] = time_ref_hwl[(h, w, 0)]
            word[1] = time_ref_hwl[(h, w, 1)]
            word[2] = time_ref_hwl[(h, w, 2)]
            time_ref[(h, w)] = make_u48(word)

    # adjust the reference clock by the propagation delay
    # the right-bottom PE signals other PEs, the propagation delay is
    #     (h-1) - py + (w-1) - px
    for py in range(height):
        for px in range(width):
            time_ref[(py, px)] = time_ref[(py, px)] - ((width+height-2)-(px + py))

    # shift time_start and time_end by time_ref
    time_start = time_start - time_ref
    time_end = time_end - time_ref

    # cycles_send = time_end[(h,w)] - time_start[(h,w)]
    # 850MHz --> 1 cycle = (1/0.85) ns = (1/0.85)*1.e-3 us
    # time_send = (cycles_send / 0.85) *1.e-3 us
    #
    # The spmv kernel performs y = A * x
    #   y[i] = sum{ Aij * xj : Aij is nonzero }
    # where A is m-by-n with nnz nonzeros
    #
    # We use the following standard measurement
    # - read Aij: nnz
    # - read xj: nnz
    # - write y[i]: m
    # Total number of wavelets: (2*nnz + m)
    #
    wvlts = 2 * nnz + height
    min_time_start = time_start.min()
    max_time_end = time_end.max()
    cycles_send = max_time_end - min_time_start
    time_send = (cycles_send / 0.85) *1.e-3
    bandwidth = ((wvlts * 4)/time_send)
    print(f"cycles_send = {cycles_send} cycles")
    print(f"time_send = {time_send} us")
    print(f"bandwidth = {bandwidth} MB/S ")


def csl_compile_core(
    cslc: str,
    file_config : str,
    elf_dir : str,
    fabric_width : int,
    fabric_height : int,
    core_fabric_offset_x : int, # fabric-offsets of the core
    core_fabric_offset_y : int,
    use_precompile: bool,
    arch : Optional[str],
    ncols: int,
    nrows: int,
    np_cols: int,
    np_rows: int,
    max_local_nnz: int,
    max_local_nnz_cols: int,
    max_local_nnz_rows: int,
    local_vec_sz: int,
    local_out_vec_sz: int,
    out_pad_start_idx: int,
    channels:int,
    width_west_buf: int,
    width_east_buf:int
):
    comp_dir = elf_dir

    if not use_precompile:
        args = []
        args.append(cslc) # command
        args.append(file_config) # options
        args.append(f"--fabric-dims={fabric_width},{fabric_height}") # options
        args.append(f"--fabric-offsets={core_fabric_offset_x},{core_fabric_offset_y}") # options
        args.append(f"--params=ncols:{ncols}") # options
        args.append(f"--params=nrows:{nrows}") # options
        args.append(f"--params=pcols:{np_cols}") # options
        args.append(f"--params=prows:{np_rows}") # options
        args.append(f"--params=max_local_nnz:{max_local_nnz}") # options
        args.append(f"--params=max_local_nnz_cols:{max_local_nnz_cols}") # options
        args.append(f"--params=max_local_nnz_rows:{max_local_nnz_rows}") # options
        args.append(f"--params=local_vec_sz:{local_vec_sz}") # options
        args.append(f"--params=local_out_vec_sz:{local_out_vec_sz}") # options
        args.append(f"--params=y_pad_start_row_idx:{out_pad_start_idx}") # options

        args.append(f"-o={comp_dir}")
        if arch is not None:
            args.append(f"--arch={arch}")
        args.append(f"--memcpy")
        args.append(f"--channels={channels}")
        args.append(f"--width-west-buf={width_west_buf}")
        args.append(f"--width-east-buf={width_east_buf}")

        print(f"subprocess.check_call(args = {args}")
        subprocess.check_call(args)
    else:
        print(f"[csl_compile_core] use pre-compile ELFs")


def main():
    """Main method to run the example code."""

    args = parse_args()

    cslc = "cslc"
    if args.driver is not None:
        cslc = args.driver
    print(f"cslc = {cslc}")

    width_west_buf = args.width_west_buf
    width_east_buf = args.width_east_buf
    channels = args.channels
    assert channels <= 16, "only support up to 16 I/O channels"
    assert channels >= 1, "number of I/O channels must be at least 1"

    print(f"width_west_buf = {width_west_buf}")
    print(f"width_east_buf = {width_east_buf}")
    print(f"channels = {channels}")

    dirname = args.latestlink

    # core rectangle is np_cols-by-np_rows
    np_cols = args.num_pe_cols
    np_rows = args.num_pe_rows
    IS_INVEC_1 = args.is_invec_one

    width = np_cols
    height = np_rows
    print(f"width = {width}, height = {height}")

    start = time.time()
    infile_mtx = args.infile_mtx
    print(f"infile_mtx = {infile_mtx}")

    A_coo = mmread(infile_mtx)
    # the CSR format is 0-based
    A_csr = A_coo.tocsr(copy=True)
    # sort column indices
    A_csr = A_csr.sorted_indices().astype(np.float32)
    assert 1 == A_csr.has_sorted_indices, "Error: A is not sorted"

    [nrows, ncols] = A_csr.shape
    nnz = A_csr.nnz

    print(f"Load matrix A, {nrows}-by-{ncols} with {nnz} nonzeros")

    if not args.is_weight_one:
        print("WARNING: reset the matrix with random values")
        np.random.seed(123)
        (A_csr.data)[0:nnz] = np.random.rand(nnz).astype(np.float32)

    csrRowPtr = A_csr.indptr
    csrColInd = A_csr.indices
    csrVal    = A_csr.data

    A_csc = A_csr.tocsc(copy=True)
    # sort row indices
    A_csc = A_csc.sorted_indices().astype(np.float32)
    assert 1 == A_csc.has_sorted_indices, "Error: A is not sorted"

    cscColPtr = A_csc.indptr
    cscRowInd = A_csc.indices
    cscVal    = A_csc.data

    matrix_info = preprocess(
        # A is nrows-by-ncols with nnz nonzeros
        nrows,
        ncols,
        nnz,
        # core rectangle is fabx-by-faby
        np_cols,
        np_rows,
        # (csrRowPtr, csrColInd, csrVal) is the CSR representation
        csrRowPtr,
        csrColInd,
        csrVal,
        # (cscColPtr, cscRowInd, cscVal) is the CSC representation
        cscColPtr,
        cscRowInd,
        cscVal)

    end = time.time()
    print(f"prepare the structure for spmv kernel: {end-start}s", flush=True)

    max_local_nnz = matrix_info['max_local_nnz']
    max_local_nnz_cols = matrix_info['max_local_nnz_cols']
    max_local_nnz_rows = matrix_info['max_local_nnz_rows']
    mat_vals_buf = matrix_info['mat_vals_buf']
    mat_rows_buf = matrix_info['mat_rows_buf']
    mat_col_idx_buf = matrix_info['mat_col_idx_buf']
    mat_col_loc_buf = matrix_info['mat_col_loc_buf']
    mat_col_len_buf = matrix_info['mat_col_len_buf']
    y_rows_init_buf = matrix_info['y_rows_init_buf']
    local_nnz = matrix_info['local_nnz']
    local_nnz_cols = matrix_info['local_nnz_cols']
    local_nnz_rows = matrix_info['local_nnz_rows']

    x_ref = read_input_vector(IS_INVEC_1, ncols)

    # core rectangle is np_cols-by-np_rows
    #            np_cols
    #         +----------+
    # np_rows |  core    |
    #         +----------+
    # input vector is distributed into columns, then distributed into rows
    # output vector is distributed into rows, then distributed into columns
    local_vec_sz = math.ceil(math.ceil(ncols / np_cols) / np_rows)
    local_out_vec_sz = math.ceil(math.ceil(nrows / np_rows) / np_cols)

    x_tx_buf = dist_x_to_hwl(ncols, x_ref, local_vec_sz, np_cols, np_rows)

    print(f'Generating reference y = A*x ...')
    y_ref = generate_reference(nrows, ncols, csrRowPtr, csrColInd, csrVal, x_ref)

    mem_use_per_pe = memory_per_pe(max_local_nnz, max_local_nnz_cols, max_local_nnz_rows, local_vec_sz, local_out_vec_sz)
    print(f'Total memory use per PE = {mem_use_per_pe} bytes = {mem_use_per_pe / 1024} KB', flush=True)
    assert mem_use_per_pe < 46*1024, "exceed maximum memory capacity, increase the core rectangle"

    # fabric-offsets = 1,1
    fabric_offset_x = 1
    fabric_offset_y = 1
    # starting point of the core rectangle = (core_fabric_offset_x, core_fabric_offset_y)
    # memcpy framework requires 3 columns at the west of the core rectangle
    # memcpy framework requires 2 columns at the east of the core rectangle
    core_fabric_offset_x = fabric_offset_x + 3 + width_west_buf
    core_fabric_offset_y = fabric_offset_y
    # (min_fabric_width, min_fabric_height) is the minimal dimension to run the app
    min_fabric_width = (core_fabric_offset_x + width + 2 + 1 + width_east_buf)
    min_fabric_height = (core_fabric_offset_y + height + 1)

    fabric_width = 0
    fabric_height = 0
    if args.fabric_dims:
        w_str, h_str = args.fabric_dims.split(",")
        fabric_width = int(w_str)
        fabric_height = int(h_str)

    if fabric_width == 0 or fabric_height == 0:
        fabric_width = min_fabric_width
        fabric_height = min_fabric_height

    assert fabric_width >= min_fabric_width
    assert fabric_height >= min_fabric_height

    print(f"fabric_width = {fabric_width}, fabric_height = {fabric_height}")
    print(f"core_fabric_offset_x = {core_fabric_offset_x}, core_fabric_offset_y = {core_fabric_offset_y}")

    # prepare the simulation
    print('store ELFs and log files in the folder ', dirname)

    # layout of a rectangle
    code_csl = "layout.csl"

    ## calculate the output vector padding info
    out_vec_len_per_pe_row = math.ceil(nrows / np_rows)
    out_pad_start_idx = out_vec_len_per_pe_row

    start = time.time()
    csl_compile_core(
        cslc,
        code_csl,
        dirname,
        fabric_width,
        fabric_height,
        core_fabric_offset_x, # fabric-offsets of the core
        core_fabric_offset_y,
        args.run_only,
        args.arch,
        ncols, # m, number of rows of the matrix
        nrows, # n, number of columns of the matrix
        np_cols, # width
        np_rows, # height
        max_local_nnz,
        max_local_nnz_cols,
        max_local_nnz_rows,
        local_vec_sz,
        local_out_vec_sz,
        out_pad_start_idx,
        channels,
        width_west_buf,
        width_east_buf
    )
    end = time.time()
    print(f"Compilation done in {end-start}s", flush=True)

    if args.compile_only:
        print("COMPILE ONLY: EXIT")
        return

    simulator = SdkRuntime(dirname, cmaddr=args.cmaddr)

    sym_mat_vals_buf = simulator.get_id("mat_vals_buf")
    sym_x_tx_buf = simulator.get_id("x_tx_buf");
    sym_y_local_buf = simulator.get_id("y_local_buf");

    sym_mat_rows_buf = simulator.get_id("mat_rows_buf")
    sym_mat_col_idx_buf = simulator.get_id("mat_col_idx_buf")
    sym_mat_col_loc_buf = simulator.get_id("mat_col_loc_buf")
    sym_mat_col_len_buf = simulator.get_id("mat_col_len_buf")
    sym_y_rows_init_buf = simulator.get_id("y_rows_init_buf")
    sym_local_nnz = simulator.get_id("local_nnz")
    sym_local_nnz_cols = simulator.get_id("local_nnz_cols")
    sym_local_nnz_rows = simulator.get_id("local_nnz_rows")
    sym_time_buf_u16 = simulator.get_id("time_buf_u16")
    sym_time_ref_u16 = simulator.get_id("time_ref_u16")

    start = time.time()
    simulator.load()
    end = time.time()
    print(f"*** Load done in {end-start}s")

    start = time.time()
    simulator.run()

    print("step 1: enable tsc counter to sample the clock")
    simulator.launch("f_enable_tsc", nonblock=True)

    print("step 2: copy the structure of A and vector x to the device")
    # 1. mat_vals_buf[max_local_nnz], type = f32
    mat_vals_buf_1d = hwl_to_oned_colmajor(height, width, max_local_nnz, mat_vals_buf, np.float32)
    simulator.memcpy_h2d(sym_mat_vals_buf, mat_vals_buf_1d, 0, 0, width, height, max_local_nnz,\
        streaming=False, data_type=MemcpyDataType.MEMCPY_32BIT, order=MemcpyOrder.COL_MAJOR, nonblock=True)

    # 2: x_tx_buf[local_vec_sz], type = f32
    x_tx_buf_1d = hwl_to_oned_colmajor(height, width, local_vec_sz, x_tx_buf, np.float32)
    simulator.memcpy_h2d(sym_x_tx_buf, x_tx_buf_1d, 0, 0, width, height, local_vec_sz,\
        streaming=False, data_type=MemcpyDataType.MEMCPY_32BIT, order=MemcpyOrder.COL_MAJOR, nonblock=True)

    # 3: mat_rows_buf[max_local_nnz], type = u16
    mat_rows_buf_1d = hwl_to_oned_colmajor(height, width, max_local_nnz, mat_rows_buf, np.uint32)
    simulator.memcpy_h2d(sym_mat_rows_buf, mat_rows_buf_1d, 0, 0, width, height, max_local_nnz,\
        streaming=False, data_type=MemcpyDataType.MEMCPY_16BIT, order=MemcpyOrder.COL_MAJOR, nonblock=True)

    # 4: mat_col_idx_buf[max_local_nnz_cols], type = u16
    mat_col_idx_buf_1d = hwl_to_oned_colmajor(height, width, max_local_nnz_cols, mat_col_idx_buf, np.uint32)
    simulator.memcpy_h2d(sym_mat_col_idx_buf, mat_col_idx_buf_1d, 0, 0, width, height, max_local_nnz_cols,\
        streaming=False, data_type=MemcpyDataType.MEMCPY_16BIT, order=MemcpyOrder.COL_MAJOR, nonblock=True)

    # 5: mat_col_loc_buf[max_local_nnz_cols], type = u16
    mat_col_loc_buf_1d = hwl_to_oned_colmajor(height, width, max_local_nnz_cols, mat_col_loc_buf, np.uint32)
    simulator.memcpy_h2d(sym_mat_col_loc_buf, mat_col_loc_buf_1d, 0, 0, width, height, max_local_nnz_cols,\
        streaming=False, data_type=MemcpyDataType.MEMCPY_16BIT, order=MemcpyOrder.COL_MAJOR, nonblock=True)

    # 6: mat_col_len_buf[max_local_nnz_cols], type = u16
    mat_col_len_buf_1d = hwl_to_oned_colmajor(height, width, max_local_nnz_cols, mat_col_len_buf, np.uint32)
    simulator.memcpy_h2d(sym_mat_col_len_buf, mat_col_len_buf_1d, 0, 0, width, height, max_local_nnz_cols,\
        streaming=False, data_type=MemcpyDataType.MEMCPY_16BIT, order=MemcpyOrder.COL_MAJOR, nonblock=True)

    # 7: y_rows_init_buf[max_local_nnz_rows], type = u16
    y_rows_init_buf_1d = hwl_to_oned_colmajor(height, width, max_local_nnz_rows, y_rows_init_buf, np.uint32)
    simulator.memcpy_h2d(sym_y_rows_init_buf, y_rows_init_buf_1d, 0, 0, width, height, max_local_nnz_rows,\
        streaming=False, data_type=MemcpyDataType.MEMCPY_16BIT, order=MemcpyOrder.COL_MAJOR, nonblock=True)

    # 8: local_nnz, type = u16
    local_nnz_1d = hwl_to_oned_colmajor(height, width, 1, local_nnz, np.uint32)
    simulator.memcpy_h2d(sym_local_nnz, local_nnz_1d, 0, 0, width, height, 1,\
        streaming=False, data_type=MemcpyDataType.MEMCPY_16BIT, order=MemcpyOrder.COL_MAJOR, nonblock=True)

    # 9: local_nnz_cols, type = u16
    local_nnz_cols_1d = hwl_to_oned_colmajor(height, width, 1, local_nnz_cols, np.uint32)
    simulator.memcpy_h2d(sym_local_nnz_cols, local_nnz_cols_1d, 0, 0, width, height, 1,\
        streaming=False, data_type=MemcpyDataType.MEMCPY_16BIT, order=MemcpyOrder.COL_MAJOR, nonblock=True)

    # 10: local_nnz_rows, type = u16
    local_nnz_rows_1d = hwl_to_oned_colmajor(height, width, 1, local_nnz_rows, np.uint32)
    simulator.memcpy_h2d(sym_local_nnz_rows, local_nnz_rows_1d, 0, 0, width, height, 1,\
        streaming=False, data_type=MemcpyDataType.MEMCPY_16BIT, order=MemcpyOrder.COL_MAJOR, nonblock=True)

    print("step 3: sync all PEs to sample the reference clock")
    simulator.launch("f_sync", np.int16(1), nonblock=False)

    print("step 4: tic() records time_start")
    simulator.launch("f_tic", nonblock=True)

    print("step 5: spmv")
    simulator.launch("f_spmv", nonblock=False)

    print("step 5: toc() records time_end")
    simulator.launch("f_toc", nonblock=False)

    print("step 6: prepare (time_start, time_end)")
    simulator.launch("f_memcpy_timestamps", nonblock=False)

    print("step 7: fetch the timing time_buf_u16[6] = (time_start, time_end), type = u16")
    time_memcpy_hwl_1d = np.zeros(height*width*6, np.uint32)
    simulator.memcpy_d2h(time_memcpy_hwl_1d, sym_time_buf_u16, 0, 0, width, height, 6,\
        streaming=False, data_type=MemcpyDataType.MEMCPY_16BIT, order=MemcpyOrder.COL_MAJOR, nonblock=False)
    time_memcpy_hwl = oned_to_hwl_colmajor(height, width, 6, time_memcpy_hwl_1d, np.uint16)

    print("step 8: fetch the output vector y of type f32")
    y_1d = np.zeros(height*width*local_out_vec_sz, np.float32)
    simulator.memcpy_d2h(y_1d, sym_y_local_buf, 0, 0, width, height, local_out_vec_sz,\
        streaming=False, data_type=MemcpyDataType.MEMCPY_32BIT, order=MemcpyOrder.COL_MAJOR, nonblock=False)

    print("step 9: prepare reference clock")
    simulator.launch("f_reference_timestamps", nonblock=False)

    print("step 10: D2H reference clock")
    time_ref_1d = np.zeros(height*width*3, np.uint32)
    simulator.memcpy_d2h(time_ref_1d, sym_time_ref_u16, 0, 0, width, height, 3,\
        streaming=False, data_type=MemcpyDataType.MEMCPY_16BIT, order=MemcpyOrder.COL_MAJOR, nonblock=False)
    time_ref_hwl = oned_to_hwl_colmajor(height, width, 3, time_ref_1d, np.uint16)

    simulator.stop()

    end = time.time()
    print(f"*** Run done in {end-start}s")

    timing_analysis( height, width, nrows, ncols, nnz, time_memcpy_hwl, time_ref_hwl)

    # The output y_wse distributed into nrows-by-ncols PEs
    y_wse = np.reshape(y_1d, (height, width, local_out_vec_sz), order='F')
    # y_wse is packed into 1d vector with zero padding
    y_wse = unpad_3d_to_1d(nrows, y_wse)
    # remove padding of y_wse because y_ref has no padding
    verify_result(y_ref, y_wse[0:nrows])

    if args.cmaddr is None:
        # move simulation log and core dump to the given folder
        dst_log = Path(f"{dirname}/sim.log")
        src_log = Path("sim.log")
        if src_log.exists():
            shutil.move(src_log, dst_log)

        dst_trace = Path(f"{dirname}/simfab_traces")
        src_trace = Path("simfab_traces")
        if dst_trace.exists():
            shutil.rmtree(dst_trace)
        if src_trace.exists():
            shutil.move(src_trace, dst_trace)

    # dump the device memory via debug tool
    if 0:
        print(f"time_ref_hwl = \n{time_ref_hwl}")
        debug_mod = debug_util(dirname, cmaddr=args.cmaddr)
        for py in range(height):
            for px in range(width):
                t = debug_mod.get_symbol(core_fabric_offset_x+px, core_fabric_offset_y+py, 'time_ref_u16', np.uint16)
                print(f"(py, px) = {py, px}, time_ref_u16_ij = {t}")


if __name__ == "__main__":
    main()

cmd_parser.py


import os
import tempfile
import numpy as np
import argparse

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--infile_mtx',
        help='the sparse matrix in MTX format',
        required=True
    )
    parser.add_argument(
        '--num_pe_cols',
        type=int,
        help='width of the core rectangle',
        required=True
    )
    parser.add_argument(
        '--num_pe_rows',
        type=int,
        help='height of the core rectangle',
        required=True
    )
    parser.add_argument(
        "--fabric-dims",
        help="Fabric dimension, i.e. <W>,<H>"
    )
    parser.add_argument(
        "--compile-only",
        help="Compile only", action="store_true"
    )
    parser.add_argument(
        "--run-only",
        help="Run only", action="store_true"
    )
    parser.add_argument(
        "--width-west-buf",
        default=0,
        type=int,
        help="width of west buffer"
    )
    parser.add_argument(
        "--width-east-buf",
        default=0,
        type=int,
        help="width of east buffer"
    )
    parser.add_argument(
        "--channels",
        default=1,
        type=int,
        help="number of I/O channels, between 1 and 16"
    )
    parser.add_argument(
        "-d",
        "--driver",
        help="The path to the CSL compiler",
    )
    parser.add_argument(
        "--cmaddr",
        help="CM address and port, i.e. <IP>:<port>"
    )
    parser.add_argument(
        "--arch",
        help="wse1 or wse2. Default is wse1 when not supplied."
    )
    parser.add_argument(
        '--is_invec_one',
        help="input vector x is all one",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        '--is_weight_one',
        help="matrix A is from the given matrix or all one",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--latestlink",
        help="folder to contain the log files (default: latest)",
        default="latest"
    )

    args = parser.parse_args()

    return args

memory_usage.py

import numpy as np


def memory_per_pe(max_local_nnz, max_local_nnz_cols, max_local_nnz_rows, local_in_vec_sz, local_out_vec_sz):
    '''
    // input matrix
    var mat_vals_buf = @zeros([max_local_nnz]f32);      // in matrix values (sparse): 4B
    var mat_rows_buf = @zeros([max_local_nnz]u16);      // in matrix relative row offsets: 2B
                                                    // need this in preprocessing: 2B
    var mat_col_idx_buf = @zeros([max_local_nnz_cols]u16);   // column idx of nnz cols (max possible size is nnz)
    var mat_col_loc_buf = @zeros([max_local_nnz_cols]u16);   // col location in mat_vals_buf and mat_rows_buf (max nnz)
    var mat_col_len_buf = @zeros([max_local_nnz_cols]u16);   // col length (nnz rows in a col)

    // input vector: for north-going and south-going trains
    // buffer storing data for tx
    var x_tx_buf = @zeros([local_vec_sz]f32);       // in vector values (dense): 4B
    // double buffers storing rx data
    var x_north_buf0 = @zeros([local_vec_sz]f32);   // in vector values (dense): 4B
    var x_south_buf0 = @zeros([local_vec_sz]f32);   // in vector values (dense): 4B
    var x_north_buf1 = @zeros([local_vec_sz]f32);   // in vector values (dense): 4B
    var x_south_buf1 = @zeros([local_vec_sz]f32);   // in vector values (dense): 4B

    // precomputed output vector (sparse format) local rows index information
    var y_rows_init_buf = @zeros([max_local_nnz_rows]u16);       // init -- this should not be modified

    // output vector (sparse): to store partial computed output vectors for north and south trains
    var y_vals_north_buf = @zeros([max_local_nnz_rows]f32);       // 4B
    var y_rows_north_buf = @zeros([max_local_nnz_rows]u16);       // 2B
    var y_vals_south_buf = @zeros([max_local_nnz_rows]f32);       // 4B
    var y_rows_south_buf = @zeros([max_local_nnz_rows]u16);       // 2B

    // buffers for east and west trains
    var y_vals_west_buf = @zeros([max_local_nnz_rows]f32);    // rx/tx vals on west-train during reduction (sparse): 4B
    var y_rows_west_buf = @zeros([max_local_nnz_rows]u16);    // rx/tx rows on west-train during reduction (sparse): 4B
    var y_vals_east_buf = @zeros([max_local_nnz_rows]f32);    // rx/tx vals on east-train during reduction (sparse): 4B
    var y_rows_east_buf = @zeros([max_local_nnz_rows]u16);    // rx/tx rows on east-train during reduction (sparse): 4B

    // final reduced local output vector (dense)
    var y_local_buf = @zeros([local_out_vec_sz]f32);    // 4B
    '''
    
    dtsz_u16 = np.uint16().itemsize     ## 2 bytes
    dtsz_f32 = np.float32().itemsize    ## 4 bytes
    
    ## input matrix in sparse format
    in_mat_mem = (dtsz_f32 + dtsz_u16) * max_local_nnz + 3 * dtsz_u16 * max_local_nnz_cols
    ## input vector in dense format
    in_vec_mem = 5 * dtsz_f32 * local_in_vec_sz                    ## 4 buffers + 1 tx
    ## partial output vector in sparse format
    sp_vec_init_mem = dtsz_u16 * max_local_nnz_rows ## init/precomputed rows data
    sp_vec_mem = 4 * ((dtsz_f32 + dtsz_u16) * max_local_nnz_rows)  ## 4 sets of buffers
    ## output vector in dense format
    out_vec_mem = dtsz_f32 * local_out_vec_sz
    
    return in_mat_mem + in_vec_mem + sp_vec_init_mem + sp_vec_mem + out_vec_mem

preprocess.py


import os

import numpy as np

# name mapping between spmv kernel and this C code
#   C code           spmv kernel
# ----------------------------------
#  local_nzcols     local_nnzcols
#  local_nzrows     local_nnzrows
#  local_nnz        local_nnz
#  y_rows           y_rows_init_buf
#  A_colloc         mat_col_loc_buf
#  A_collen         mat_col_len_buf
#  A_colidx         mat_col_idx_buf
#  A_rows           mat_rows_buf
#  A_vals           mat_vals_buf
#
def preprocess(
    # A is nrows-by-ncols with nnz nonzeros
    nrows: int,
    ncols: int,
    nnz: int,
    # core rectangle of spmv is fabx-by-faby
    fabx: int,
    faby: int,
    # (csrRowPtr, csrColInd, csrVal) is the CSR representation
    csrRowPtr: np.ndarray,
    csrColInd: np.ndarray,
    csrVal: np.ndarray,
    # (cscColPtr, cscRowInd, cscVal) is the CSC representation
    cscColPtr: np.ndarray,
    cscRowInd: np.ndarray,
    cscVal: np.ndarray):
    """
    Given a spare matrix A of dimension nrows-by-ncols with nnz nonzeros
    and the dimension of core rectangle fabx-by-faby, parition the matrix
    A such that PE(px=j, py=i) contains the submatrix Aij with the 
    following quantities:
 
    local_nzrows: number of nonzero rows
    local_nzcols: number of nonzero columns
    local_nnz: number of nonzero elements
    y_rows[local_nzrows]: nonzero row index
    y_vals[local_nzrows]: not used
    A_colloc[local_nzcols]: prefix sum of A_collen, used to point to A_rows
    A_collen[local_nzcols]: A_collen[j] is number of nonzeros of j-th nonzero columns
    A_colidx[local_nzcols]: column index of nonzero columns
    A_rows[local_nnz]: position of row index of nonzeros in y_rows
    A_vals[local_nnz]: value of nonzeros

    """
    assert 0 == csrRowPtr[0], "CSR must be base-0"
    assert 0 == cscColPtr[0], "CSC must be base-0"
    assert nnz == csrRowPtr[nrows], "CSR has wrong nnz"
    assert nnz == cscColPtr[ncols], "CSC has wrong nnz"

    bx = int((ncols + fabx-1) / fabx) # number of columns of a block
    by = int((nrows + faby-1) / faby) # number of rows of a block
 
    local_nzrows = np.zeros((faby, fabx, 1), dtype = np.int32)
    local_nzcols = np.zeros((faby, fabx, 1), dtype = np.int32)
    local_nnz = np.zeros((faby, fabx, 1), dtype = np.int32)

    max_grid_dim = max(faby, fabx)
    counted = np.zeros(max_grid_dim, dtype = np.int32)

    # step 1: compute local_ncols and local_nnz
    counted[0:max_grid_dim] = -1 # invalid token 
    for col in range(ncols):
        check_token = col
        # col = col_b * bx + col_l
        # where col_b is the column block index
        #       col_l is local column index
        col_b = int(col / bx)
        col_l = col - col_b * bx
        start = cscColPtr[col]
        end = cscColPtr[col+1]
        for colidx in range(start, end):
            row = cscRowInd[colidx]
            # row = row_b * by + row_l
            # where row_b is the row block index
            #       row_l is local row index
            row_b = int(row / by)
            row_l = row - row_b * by
            local_nnz[(row_b, col_b)] += 1
            # Suppose Aij is block (row_b, col_b)
            # if |{Aij(i, col_l) != 0}| > 0, col_l is a nonzero column in Aij
            # we use counted[row_b] to count only once
            # if Aij(i1, col_l) and Aij(i2, col_l) are nonzero and i1 < i2,
            # only Aij(i1, col_l) adds local_nzcols[(row_b, col_b)]
            if (counted[row_b] != check_token):
                # Aij(row_l,col_l) is nonzero
                local_nzcols[(row_b, col_b)] += 1
                counted[row_b] = check_token

    # step 2: compute local_nrows
    counted[0:max_grid_dim] = -1 # invalid token
    for row in range(nrows):
        check_token = row
        # row = row_b * by + row_l
        row_b = int(row / by)
        row_l = row - row_b * by
        start = csrRowPtr[row]
        end = csrRowPtr[row+1]
        for colidx in range(start, end):
            col = csrColInd[colidx]
            # col = col_b * bx + col_l
            col_b = int(col / bx)
            col_l = col - col_b * bx
            # Suppose Aij is block (row_b, col_b)
            # if |{Aij(row_l, j) != 0}| > 0, row_l is a nonzero row in Aij
            # we use counted[col_b] to count only once
            # if Aij(row_l, j1) and Aij(row_l, j2) are nonzero and j1 < j2,
            # only Aij(row_l, j1) adds local_nzrows[(row_b, col_b)]
            if (counted[col_b] != check_token):
                # Aij(row_l,col_l) is nonzero
                local_nzrows[(row_b, col_b)] += 1
                counted[col_b] = check_token

    # step 3: compute maximum dimension of Aij
    max_local_nnz = max(local_nnz.ravel())
    max_local_nnz_cols = max(local_nzcols.ravel())
    max_local_nnz_rows = max(local_nzrows.ravel())

    assert max_local_nnz < np.iinfo(np.uint16).max,\
       "LOCAL NUMBER OF NONZEROS WILL OVERFLOW, TRY USING A LARGER FABRIC"
    assert max_local_nnz_cols < np.iinfo(np.uint16).max,\
       "LOCAL NUMBER OF NZCOLS WILL OVERFLOW, TRY USING A LARGER FABRIC"
    assert max_local_nnz_rows < np.iinfo(np.uint16).max,\
       "LOCAL NUMBER OF NZROWS WILL OVERFLOW, TRY USING A LARGER FABRIC"
    # no data overflows u16, we can convert the data to u16
    local_nnz = local_nnz.astype(np.uint16)
    local_nzrows = local_nzrows.astype(np.uint16)
    local_nzcols = local_nzcols.astype(np.uint16)

    #     spmv kernel                      actual storage in preprocess
    # ------------------------------------------------------------------
    # mat_vals_buf[max_local_nnz]           A_vals[local_nnz]
    # mat_rows_buf[max_local_nnz]           A_rows[local_nnz]
    # mat_col_loc_buf[max_local_nnz_cols]   A_colloc[local_nzcols]
    # mat_col_len_buf[max_local_nnz_cols]   A_collen[local_nzcols]
    # mat_col_idx_buf[max_local_nnz_cols]   A_colidx[local_nzcols]
    # y_rows_init_buf[max_local_nnz_rows]   y_rows[local_nzrows]
    #
    # To prepare the data for spmv, each PE allocates the maximum dimension
    # max_local_nnz, max_local_nnz_cols or max_local_nnz_rows
    A_vals = np.zeros((faby, fabx, max_local_nnz), dtype = np.float32)
    A_rows = np.zeros((faby, fabx, max_local_nnz), dtype = np.uint16)
    A_colloc = np.zeros((faby, fabx, max_local_nnz_cols), dtype = np.uint16)
    A_collen = np.zeros((faby, fabx, max_local_nnz_cols), dtype = np.uint16)
    A_colidx = np.zeros((faby, fabx, max_local_nnz_cols), dtype = np.uint16)
    y_rows = np.zeros((faby, fabx, max_local_nnz_rows), dtype = np.uint16)

    # step 4: compute y_rows
    local_pos = np.zeros((faby, fabx), dtype = np.int32)
    counted[0:max_grid_dim] = -1 # invalid token
    for row in range(nrows):
        check_token = row
        # row = row_b * by + row_l
        row_b = int(row / by)
        row_l = row - row_b * by
        start = csrRowPtr[row]
        end = csrRowPtr[row+1]
        for colidx in range(start, end):
            col = csrColInd[colidx]
            # col = col_b * bx + col_l
            col_b = int(col / bx)
            col_l = col - col_b * bx
            # Suppose Aij is block (row_b, col_b)
            # if |{Aij(row_l, j) != 0}| > 0, row_l is a nonzero row in Aij
            # we use counted[col_b] to count only once
            if (counted[col_b] != check_token):
                # Aij(row_l,col_l) is nonzero
                pos = local_pos[(row_b, col_b)]
                y_rows[(row_b, col_b, pos)] = row_l
                local_pos[(row_b, col_b)] = pos + 1 # advance to next nonzero row in Aij
                counted[col_b] = check_token

    # step 5: compute A_colloc, A_colidx, A_colen and A_rows
    #  y_rows is computed in step 4 because A_rows must be constructed by using y_rows

    # "local_pos" keeps track of the position of nonzero column in A_colidx
    local_pos = np.zeros((faby, fabx), dtype = np.int32)
    counted[0:max_grid_dim] = -1 # invalid token
    for col in range(ncols):
        check_token = col
        # col = col_b * bx + col_l
        # where col_b is the column block index
        #       col_l is local column index
        col_b = int(col / bx)
        col_l = col - col_b * bx
        start = cscColPtr[col]
        end = cscColPtr[col+1]
        for colidx in range(start, end):
            row = cscRowInd[colidx]
            val = cscVal[colidx]
            # row = row_b * by + row_l
            # where row_b is the row block index
            #       row_l is local row index
            row_b = int(row / by)
            row_l = row - row_b * by
            # Suppose Aij is block (row_b, col_b)
            # Aij(row_l,col_l) is nonzero
            if (counted[row_b] != check_token):
                # pos = position of nonzero column index in A_colidx and A_colen
                # A_collen[pos] is accumulated nonzero rows
                # A_colidx[pos] is the nonzero local column index
                pos = local_pos[(row_b, col_b)]
                # only record nonzero local column index once
                A_colidx[(row_b, col_b, pos)] = col_l
                # update A_colloc such that
                # A_colloc[0] = 0
                # A_colloc[j] = A_colloc[j-1] + A_colen[j-1]
                if (0 < pos):
                    A_colloc[(row_b, col_b, pos)] = A_colloc[(row_b, col_b, pos-1)] + A_collen[(row_b, col_b, pos-1)]
                local_pos[(row_b, col_b)] = pos + 1 # advance to next nonzero column in Aij
                counted[row_b] = check_token
            #else:
            #   "pos" is still current position of nonzero column index in A_colen

            # Remark: "pos" is well-defined because CSC is sorted in ascending order
            #   if col_l changes, then previous nonzero col_l is done
            #   When the loop enters 1st row_l of in Aij(:, col_l), it defines "pos"
            #   , the subsequent row_l in the same Aij(:, col_l) keeps the same "pos"
            #   When the loop exits Aij, A_collen and A_rows for Aij(:, col_l) is done
            #   When the loop enters Aij again, it re-starts the process for next nonzero
            #   col_l in Aij
            pos_start = A_colloc[(row_b, col_b, pos)] # position of 1st row index if Aij(:, col_l) in A_rows
            pos_rel_rowidx = A_collen[(row_b, col_b, pos)] # position of nonzero row index in A_rows
                                                           # corresponding to Aij(:, col_l)
            pos_rowidx = pos_rel_rowidx + pos_start
            # y_rows records distance(y_row.begin, find(y_rows.begin(), y_rows.end(), row_l))
            # spmv uses y_rows to store the result of outer-product of A*x
            y_rows_list = list(y_rows[(row_b, col_b)])
            A_rows[(row_b, col_b, pos_rowidx)] = y_rows_list.index(row_l)
            A_vals[(row_b, col_b, pos_rowidx)] = val
            A_collen[(row_b, col_b, pos)] = pos_rel_rowidx+1 # move to next nonzero Aij(row_l, col_l)


    matrix_info = {}
    matrix_info['nrows'] = nrows # number of rows of the matrix
    matrix_info['ncols'] = ncols # number of columns of the matrix
    matrix_info['nnz'] = nnz # number of nonzeros of the matrix
    matrix_info['max_local_nnz'] = max_local_nnz
    matrix_info['max_local_nnz_cols'] = max_local_nnz_cols
    matrix_info['max_local_nnz_rows'] = max_local_nnz_rows
    matrix_info['mat_vals_buf'] = A_vals
    matrix_info['mat_rows_buf'] = A_rows
    matrix_info['mat_col_loc_buf'] = A_colloc
    matrix_info['mat_col_len_buf'] = A_collen
    matrix_info['mat_col_idx_buf'] = A_colidx
    matrix_info['y_rows_init_buf'] = y_rows
    matrix_info['local_nnz'] = local_nnz
    matrix_info['local_nnz_cols'] = local_nzcols
    matrix_info['local_nnz_rows'] = local_nzrows

    return matrix_info

hypersparse_spmv/layout.csl

param colors:[6]color;
param entrypoints:[14]local_task_id;
param width : i16 ;   // width of the core
param height: i16 ;   // height of the core

// phase 1: north <-> south trains
const c0: color = colors[0];
const c1: color = colors[1];

// phase 2: west <-> east trains
const c4: color = colors[2];
const c5: color = colors[3];
const c6: color = colors[4];
const c7: color = colors[5];

// entrypoints
const init: local_task_id = entrypoints[0];
const compute_north: local_task_id = entrypoints[1];
const compute_south: local_task_id = entrypoints[2];
const tx_north: local_task_id = entrypoints[3];
const tx_south: local_task_id = entrypoints[4];
const rx_north: local_task_id = entrypoints[5];
const rx_south: local_task_id = entrypoints[6];
const rx_east: local_task_id = entrypoints[7];
const rx_west: local_task_id = entrypoints[8];
const tx_west: local_task_id = entrypoints[9];
const tx_east: local_task_id = entrypoints[10];
const compute_local: local_task_id = entrypoints[11];
const curr_rx_north_done: local_task_id = entrypoints[12];
const curr_rx_south_done: local_task_id = entrypoints[13];

// invariant parameters (same on every PE)
const invariants = .{
    .prows = height,
    .pcols = width,
    .init = init,
    // column compute
    .compute_north = compute_north,
    .compute_south = compute_south,
    .compute_local = compute_local,
    .curr_rx_north_done = curr_rx_north_done,
    .curr_rx_south_done = curr_rx_south_done,
    .tx_north = tx_north,
    .tx_south = tx_south,
    .rx_north = rx_north,
    .rx_south = rx_south,
    // reduction
    .rx_west = rx_west,
    .rx_east = rx_east,
    .tx_west = tx_west,
    .tx_east = tx_east,
};


fn get_west_east_train_colors(col_id: u16) comptime_struct {
    if (col_id % 2 == 0) {
        // even col
        // return even_col_colors;
        return .{
            .rx_west_train = c5,
            .rx_east_train = c6,
            .tx_west_train = c7,
            .tx_east_train = c4,
        };
    } else {
        // odd col
        // return odd_col_colors;
        return .{
            .rx_west_train = c7,
            .rx_east_train = c4,
            .tx_west_train = c5,
            .tx_east_train = c6,
        };
    }
}

fn get_north_south_train_colors(row_id: u16) comptime_struct {
    return .{
        .north_train = c0,
        .south_train = c1,
    };
}

fn get_params(px:i16, py:i16) comptime_struct {

    //         --> px = pcol_id
    //          pcols
    //       +----------+
    // prows |  core    |  | py = prow_id
    //       |          |  V
    //       +----------+
    //
    const pcols: i16 = width;
    const prows: i16 = height;
    const pcol_id: i16 = px;
    const prow_id: i16 = py;

    const col_params = @concat_structs(invariants, get_west_east_train_colors(pcol_id));
        
    const common_params = @concat_structs(
        col_params,
        .{
            .is_first_col = pcol_id == 0,
            .is_last_col = pcol_id == (pcols - 1),
        }
    );

    const row_colors = @concat_structs(
        get_north_south_train_colors(prow_id),
        .{
            .is_first_row = prow_id == 0,
            .is_second_row = prow_id == 1,
            .is_last_row = prow_id == (prows - 1),
            .is_second_last_row = prow_id == (prows - 2),
        }
    );

    const row_params = @concat_structs(common_params, row_colors);

    return row_params;
}

hypersparse_spmv/pe.csl

// To continue next command, f_callback = sys_mod.unblock_cmd_stream
param f_callback : fn ()void;

param input_queues:[4]u16;
param output_queues:[2]u16;

// explicit DSR allocation
param dest_dsr_ids:[6]u16;
param src1_dsr_ids:[6]u16;


// fabric colors
param north_train: color;
param south_train: color;
param rx_west_train: color;
param rx_east_train: color;
param tx_west_train: color;
param tx_east_train: color;

// local colors
param init: local_task_id;
param compute_north: local_task_id;
param compute_south: local_task_id;
param compute_local: local_task_id;
param curr_rx_north_done: local_task_id;
param curr_rx_south_done: local_task_id;
param tx_north: local_task_id;
param tx_south: local_task_id;
param rx_north: local_task_id;
param rx_south: local_task_id;
param tx_west: local_task_id;
param tx_east: local_task_id;
param rx_west: local_task_id;
param rx_east: local_task_id;

// in matrix params
param ncols: u32;       // full matrix size x
param nrows: u32;       // full matrix size y
param max_local_nnz: u16;   // local number of nonzeros
param max_local_nnz_cols: u16;  // max of local number of nnz cols
param max_local_nnz_rows: u16;  // max of local number of nnz rows
// in vector params
param local_vec_sz: u16;            // local vector block size per PE (padded)
// out vector params
param local_out_vec_sz: u16;        // local vector block size for final output vector (padded)
param y_pad_start_row_idx: u16;     // local row idx where padding starts in y

// fabric params
param pcols: u16;   // fab size x
param prows: u16;   // fab size y

param is_first_col: bool; // pcol_id == 0
param is_last_col: bool;  // pcol_id == (pcols - 1)
param is_first_row: bool; // prow_id == 0
param is_second_row: bool; // prow_id == 1
param is_last_row: bool;   // prow_id == (prows - 1)
param is_second_last_row: bool; // prow_id == (prows - 2)

// The structure of the matrix prepared by the caller
// The content can be modified for each spmv() call
// data buffers
// input matrix
param mat_vals_buf : *[max_local_nnz]f32;      // in matrix values (sparse): 4B

param mat_rows_buf : *[max_local_nnz]u16;      // in matrix relative row offsets: 2B
                                                // need this in preprocessing: 2B
param mat_col_idx_buf : *[max_local_nnz_cols]u16;   // column idx of nnz cols (max possible size is nnz)
param mat_col_loc_buf : *[max_local_nnz_cols]u16;   // col location in mat_vals_buf and mat_rows_buf (max nnz)
param mat_col_len_buf : *[max_local_nnz_cols]u16;   // col length (nnz rows in a col)
// precomputed output vector (sparse format) local rows index information
param y_rows_init_buf : *[max_local_nnz_rows]u16;       // init -- this should not be modified

param local_nnz : *[1]u16;         // actual local number of nonzeros
param local_nnz_cols : *[1]u16;    // actual local number of nnz cols
param local_nnz_rows : *[1]u16;    // actual local number of nnz rows

// input vector: for north-going and south-going trains
// buffer storing data for tx
var x_tx_buf : *[local_vec_sz]f32;       // in vector values (dense): 4B

// final reduced local output vector (dense)
// WARNING: the pointer of unknown size cannot pass the compilation
//     The declaration "var y_local_buf : [*]f32;" emits the following error
//  dereferencing a pointer to an unknown number of elements ('[*]f32') is illegal
//
// If we want to ping-pong the spmv by
//    spmv(x, y) // y = A*x
//    spmv(y, x) // x = A*y
// we need to make sure local_vec_sz = local_out_vec_sz, otherwise compilation fails
// because of mismatch of the dimensions
//
var y_local_buf : *[local_out_vec_sz]f32;

// The coordinate (px, py) is decided at runtime
// px = pcol_id
// py = prow_id
var pcol_id: u16 = 0;
var prow_id: u16 = 0;

const fabric = @import_module("<layout>");
fn get_x_coord() u16 {
    return fabric.get_x_coord();
}
fn get_y_coord() u16 {
    return fabric.get_y_coord();
}

// tsc library
const timestamp = @import_module("<time>");
var tsc_reduce_start_buffer = @zeros([timestamp.tsc_size_words]u16);
var tsc_reduce_end_buffer = @zeros([timestamp.tsc_size_words]u16);

// These magical addresses from the ISA are in words, and we need bytes, so we
// multiply by 2.
// for UT priority
const ce_inpq0_cfg: u16 = 0x7F58 * 2;
const ce_inpq1_cfg: u16 = 0x7F59 * 2;
const ce_inpq2_cfg: u16 = 0x7F5A * 2;
const ce_inpq3_cfg: u16 = 0x7F5B * 2;
const ce_inpq4_cfg: u16 = 0x7F5C * 2;
const ce_inpq5_cfg: u16 = 0x7F5D * 2;
const ce_inpq6_cfg: u16 = 0x7F5E * 2;
const ce_inpq7_cfg: u16 = 0x7F5F * 2;
const UT_MED_PRI: u16 = 0x100;  // bit 8
const UT_HI_PRI: u16 = 0x200;   // bit 9

// for task priority
// const task_pri_cfg: u16 = 0x7E09 * 2;
// const COMP_LO_PRI: u16 = 0x73FD;    // all comp tasks at 17, 26, 27 at task_pri=0, else at 1

////// HACKY WAY ////////////
// some value to let everyone load their elves first.
// tsc value is u48, it is assumed that tssync will be performed beforehand, otherwise this is futile
// NOTE: This is only needed to measure performance reliably. Functionality is not affected.
// var TSC_VALUE_TO_WAIT_UNTIL = [3]u16 { 0xa000, 0x21db, 0x5d };    // 400B cycles
// var TSC_VALUE_TO_WAIT_UNTIL = [3]u16 { 0x2c00, 0x7da0, 0x51 };    // 350B cycles
// var TSC_VALUE_TO_WAIT_UNTIL = [3]u16 { 0xb800, 0xd964, 0x45 };    // 300B cycles
// var TSC_VALUE_TO_WAIT_UNTIL = [3]u16 { 0x4400, 0x3529, 0x3a };    // 250B cycles
// var TSC_VALUE_TO_WAIT_UNTIL = [3]u16 { 0xd000, 0x90ed, 0x2e };    // 200B cycles
// var TSC_VALUE_TO_WAIT_UNTIL = [3]u16 { 0x4000, 0x40be, 0x25 };    // 160B cycles
// var TSC_VALUE_TO_WAIT_UNTIL = [3]u16 { 0xe800, 0x4876, 0x17 };    // 100B cycles
// var TSC_VALUE_TO_WAIT_UNTIL = [3]u16 { 0x0900, 0x3d0, 0x0 };      // 4M cycles
// var TSC_VALUE_TO_WAIT_UNTIL = [3]u16 { 0x9c40, 0x0, 0x0 };        // 40K cycles
var TSC_VALUE_TO_WAIT_UNTIL = [3]u16 { 0x3e8, 0x0, 0x0 };         // 1K cycles

// WARNING: reserve input/output queue 0 for memcpy module
// uthreads for fabric data movement
const RX_NORTH_Q: u16 = input_queues[0];
const RX_SOUTH_Q: u16 = input_queues[1];
const TX_NORTH_Q: u16 = output_queues[0];
const TX_SOUTH_Q: u16 = output_queues[1];
// reduction trains, corresponding rx and tx are not active simultaneously
// NOTE: the two phases are exclusive, so uthreads can actually be reused from north-south
const TX_WEST_Q: u16 = output_queues[0];
const TX_EAST_Q: u16 = output_queues[1];
const RX_WEST_Q: u16 = input_queues[2];
const RX_EAST_Q: u16 = input_queues[3];

const dummy_f32 = @zeros([1]f32);

// double buffers storing rx data
var x_north_buf0 = @zeros([local_vec_sz]f32);   // in vector values (dense): 4B
var x_south_buf0 = @zeros([local_vec_sz]f32);   // in vector values (dense): 4B
var x_north_buf1 = @zeros([local_vec_sz]f32);   // in vector values (dense): 4B
var x_south_buf1 = @zeros([local_vec_sz]f32);   // in vector values (dense): 4B

// output vector (sparse): to store partial computed output vectors for north and south trains
var y_vals_north_buf = @zeros([max_local_nnz_rows]f32);       // 4B
var y_rows_north_buf = @zeros([max_local_nnz_rows]u16);       // 2B
var y_vals_south_buf = @zeros([max_local_nnz_rows]f32);       // 4B
var y_rows_south_buf = @zeros([max_local_nnz_rows]u16);       // 2B

// buffers for east and west trains
// NOTE: north and south buffers are reused for double buffering
var y_vals_west_buf = @zeros([max_local_nnz_rows]f32);    // rx/tx vals on west-train during reduction (sparse): 4B
var y_rows_west_buf = @zeros([max_local_nnz_rows]u16);    // rx/tx rows on west-train during reduction (sparse): 4B
var y_vals_east_buf = @zeros([max_local_nnz_rows]f32);    // rx/tx vals on east-train during reduction (sparse): 4B
var y_rows_east_buf = @zeros([max_local_nnz_rows]u16);    // rx/tx rows on east-train during reduction (sparse): 4B


// need to keep track of both, the size to use, and the start offset into the buffers.
// var rx_east_buf_len = @zeros([2, 1]u16);
var rx_east_buf0_len = @zeros([1]u16);
var rx_east_buf1_len = @zeros([1]u16);
// var tx_west_buf_len = @zeros([2, 1]u16);
var tx_west_buf0_len = @zeros([1]u16);
var tx_west_buf1_len = @zeros([1]u16);
var tx_west_buf_off = @zeros([2]i16);

var rx_west_buf0_len = @zeros([1]u16);
var rx_west_buf1_len = @zeros([1]u16);
var tx_east_buf0_len = @zeros([1]u16);
var tx_east_buf1_len = @zeros([1]u16);
var tx_east_buf_off = @zeros([2]i16);


// NOTE: fabin fabout DSDs should be comptime consts for async!!
// This means we cannot use variable lengths transfers,
// which means that the functions set_dsd_length, set_dsd_addr, increment_dsd_offset, etc are all useless here.

//-------------- allocation DSR explicitly for fabin/fabout
// src1: 1, 2, 3, 4, 6, 7
// dest: 1, 2, 3, 4, 5, 6
// RX_NORTH_Q --> rx_north_dsd fabin (queue 4)
const rx_north_dsr = @get_dsr(dsr_src1, src1_dsr_ids[0]);
// RX_SOUTH_Q --> rx_south_dsd fabin (queue 1)
const rx_south_dsr = @get_dsr(dsr_src1, src1_dsr_ids[1]);
// RX_WEST_Q --> rx_west_dsd fabin (queue 6)
const rx_west_dsr = @get_dsr(dsr_src1, src1_dsr_ids[2]);
// RX_EAST_Q --> rx_east_dsd fabin (queue 7)
const rx_east_dsr = @get_dsr(dsr_src1, src1_dsr_ids[3]);

const mem1d_rx_east_dsr = @get_dsr(dsr_dest, dest_dsr_ids[0]);
const mem1d_rx_west_dsr = @get_dsr(dsr_dest, dest_dsr_ids[1]);
const mem1d_rx_south_dsr = @get_dsr(dsr_dest, dest_dsr_ids[2]);
const mem1d_rx_north_dsr = @get_dsr(dsr_dest, dest_dsr_ids[3]);

// TX_NORTH_Q --> tx_north_dsd fabout (queue 2)
//                tx_north_ctrl_adv_dsd
//                tx_north_ctrl_rst_dsd
const tx_north_dsr = @get_dsr(dsr_dest, dest_dsr_ids[4]);
// TX_SOUTH_Q --> tx_south_dsd fabout (queue 3)
//                tx_south_ctrl_adv_dsd
//                tx_south_ctrl_rst_dsd
const tx_south_dsr = @get_dsr(dsr_dest, dest_dsr_ids[5]);
// TX_WEST_Q --> tx_west_dsd fabout (queue 2
const tx_west_dsr = @get_dsr(dsr_dest, dest_dsr_ids[4]);
// TX_EAST_Q --> tx_east_dsd fabout (queue 3)
const tx_east_dsr = @get_dsr(dsr_dest, dest_dsr_ids[5]);

// idea: TX_SOUTH_Q and TX_EAST_Q use the same output queue 3
// so these two operations must be nonoverlapping, we can use
// the same DSR for mem1d.
//    TX_SOUTH_Q used in tx_south_task()
//     TX_EAST_Q used in tx_east_task()
//
// the same holds for TX_NORTH_Q and TX_WEST_Q
//    TX_NORTH_Q used in tx_north_task()
//     TX_WEST_Q used in tx_west_task()

const mem1d_south_dsr = @get_dsr(dsr_src1, src1_dsr_ids[4]);
const mem1d_east_dsr = @get_dsr(dsr_src1, src1_dsr_ids[4]);

const mem1d_north_dsr = @get_dsr(dsr_src1, src1_dsr_ids[5]);
const mem1d_west_dsr = @get_dsr(dsr_src1, src1_dsr_ids[5]);

//--------------

// fab DSDs

// 1. compute phase: north and south trains for input vector
const rx_north_dsd = @get_dsd(fabin_dsd, .{
    .extent = local_vec_sz,                 // fp32 => 1 per wavelet
    .fabric_color = south_train,
    .input_queue = @get_input_queue(RX_NORTH_Q),
});
const rx_south_dsd = @get_dsd(fabin_dsd, .{
    .extent = local_vec_sz,
    .fabric_color = north_train,
    .input_queue = @get_input_queue(RX_SOUTH_Q),
});
const tx_north_dsd = @get_dsd(fabout_dsd, .{
    .extent = local_vec_sz,                 // fp32 => 1 per wavelet
    .fabric_color = north_train,
    .output_queue = @get_output_queue(TX_NORTH_Q),
});
const tx_south_dsd = @get_dsd(fabout_dsd, .{
    .extent = local_vec_sz,
    .fabric_color = south_train,
    .output_queue = @get_output_queue(TX_SOUTH_Q),
});
const tx_north_ctrl_adv_dsd = @get_dsd(fabout_dsd, .{
    .extent = 2,    // two switch wavelets
    .control = true,
    .fabric_color = north_train,
    .output_queue = @get_output_queue(TX_NORTH_Q),
});
const tx_south_ctrl_adv_dsd = @get_dsd(fabout_dsd, .{
    .extent = 2,    // two switch wavelets
    .control = true,
    .fabric_color = south_train,
    .output_queue = @get_output_queue(TX_SOUTH_Q),
});
const tx_north_ctrl_rst_dsd = @get_dsd(fabout_dsd, .{
    .extent = 1,    // two switch wavelets
    .control = true,
    .fabric_color = north_train,
    .output_queue = @get_output_queue(TX_NORTH_Q),
});
const tx_south_ctrl_rst_dsd = @get_dsd(fabout_dsd, .{
    .extent = 1,    // two switch wavelets
    .control = true,
    .fabric_color = south_train,
    .output_queue = @get_output_queue(TX_SOUTH_Q),
});

// 2. reduce phase: west and east trains for partial output vectors (sparse: vals + rows)
const tx_west_dsd = @get_dsd(fabout_dsd, .{
    .extent = 1,
    .fabric_color = tx_west_train,
    .output_queue = @get_output_queue(TX_WEST_Q),
});

const tx_east_dsd = @get_dsd(fabout_dsd, .{
    .extent = 1,
    .fabric_color = tx_east_train,
    .output_queue = @get_output_queue(TX_EAST_Q),
});

// fabin DSDs

// 2. reduce phase: west and east trains for partial output vectors (sparse: vals + rows)
const rx_west_dsd = @get_dsd(fabin_dsd, .{
    .extent = 1,
    .fabric_color = rx_east_train,      // east-train, rx from west
    .input_queue = @get_input_queue(RX_WEST_Q),
});
const rx_east_dsd = @get_dsd(fabin_dsd, .{
    .extent = 1,
    .fabric_color = rx_west_train,      // west-train, rx from east
    .input_queue = @get_input_queue(RX_EAST_Q),
});

// input vector buf mem DSDs for north and south trains
// local x segment (used for local compute + tx)
// NOTE: this array should not be updated
var x_tx_buf_dsd = @get_dsd(mem1d_dsd, .{
    .tensor_access = |i|{local_vec_sz} -> dummy_f32[i],
});
// incoming x segments from north and south (double buffers)
const north_buf0_dsd = @get_dsd(mem1d_dsd, .{
    .tensor_access = |i|{local_vec_sz} -> x_north_buf0[i],
});
const north_buf1_dsd = @get_dsd(mem1d_dsd, .{
    .tensor_access = |i|{local_vec_sz} -> x_north_buf1[i],
});
const south_buf0_dsd = @get_dsd(mem1d_dsd, .{
    .tensor_access = |i|{local_vec_sz} -> x_south_buf0[i],
});
const south_buf1_dsd = @get_dsd(mem1d_dsd, .{
    .tensor_access = |i|{local_vec_sz} -> x_south_buf1[i],
});

// reduction trains
// partial output vector buf mem DSDs

// west train

// rx size of west train coach (from east)
const y_rx_east_buf0_len_dsd = @get_dsd(mem1d_dsd, .{
    .tensor_access = |i|{1} -> rx_east_buf0_len[i],
});
const y_rx_east_buf1_len_dsd = @get_dsd(mem1d_dsd, .{
    .tensor_access = |i|{1} -> rx_east_buf1_len[i],
});

// tx size of west train coach
const y_tx_west_buf0_len_dsd = @get_dsd(mem1d_dsd, .{
    .tensor_access = |i|{1} -> tx_west_buf0_len[i],
});
const y_tx_west_buf1_len_dsd = @get_dsd(mem1d_dsd, .{
    .tensor_access = |i|{1} -> tx_west_buf1_len[i],
});

// dsd for init y bufs (this array should not be modified)
const y_rows_init_buf_dsd = @get_dsd(mem1d_dsd, .{
    .tensor_access = |i|{max_local_nnz_rows} -> y_rows_init_buf[i],
});

// west train partial output buffers: reuse the north buffers for double buffering
const y_rows_west_buf0_dsd = @get_dsd(mem1d_dsd, .{
    .tensor_access = |i|{max_local_nnz_rows} -> y_rows_west_buf[i],
});
const y_vals_west_buf0_dsd = @get_dsd(mem1d_dsd, .{
    .tensor_access = |i|{max_local_nnz_rows} -> y_vals_west_buf[i],
});
const y_rows_west_buf1_dsd = @get_dsd(mem1d_dsd, .{
    .tensor_access = |i|{max_local_nnz_rows} -> y_rows_north_buf[i],
});
const y_vals_west_buf1_dsd = @get_dsd(mem1d_dsd, .{
    .tensor_access = |i|{max_local_nnz_rows} -> y_vals_north_buf[i],
});

// east train

// rx size of east train coach (from west)
const y_rx_west_buf0_len_dsd = @get_dsd(mem1d_dsd, .{
    .tensor_access = |i|{1} -> rx_west_buf0_len[i],
});
const y_rx_west_buf1_len_dsd = @get_dsd(mem1d_dsd, .{
    .tensor_access = |i|{1} -> rx_west_buf1_len[i],
});

// tx size of east train coach
const y_tx_east_buf0_len_dsd = @get_dsd(mem1d_dsd, .{
    .tensor_access = |i|{1} -> tx_east_buf0_len[i],
});
const y_tx_east_buf1_len_dsd = @get_dsd(mem1d_dsd, .{
    .tensor_access = |i|{1} -> tx_east_buf1_len[i],
});

// east train partial output buffers: reuse the south buffers for double buffering
const y_rows_east_buf0_dsd = @get_dsd(mem1d_dsd, .{
    .tensor_access = |i|{max_local_nnz_rows} -> y_rows_east_buf[i],
});
const y_vals_east_buf0_dsd = @get_dsd(mem1d_dsd, .{
    .tensor_access = |i|{max_local_nnz_rows} -> y_vals_east_buf[i],
});
const y_rows_east_buf1_dsd = @get_dsd(mem1d_dsd, .{
    .tensor_access = |i|{max_local_nnz_rows} -> y_rows_south_buf[i],
});
const y_vals_east_buf1_dsd = @get_dsd(mem1d_dsd, .{
    .tensor_access = |i|{max_local_nnz_rows} -> y_vals_south_buf[i],
});


// misc vars

// status of each rx, tx and train
var is_rx_north_done: bool = false;
var is_tx_north_done: bool = false;
var is_rx_south_done: bool = false;
var is_tx_south_done: bool = false;
var is_compute_north_done: bool = false;
var is_compute_south_done: bool = false;
var is_local_compute_done: bool = false;
var is_compute_north_started: bool = false;
var is_compute_south_started: bool = false;
var is_north_train_running: bool = false;
var is_south_train_running: bool = false;

// NORTH TRAIN: buffer and status
var curr_rx_south_buf: u16 = 0;         // buffer to use for rx (current rx)
var curr_rx_south_compute_buf: u16 = 0; // buffer to use for compute (current compute)
// free = true => rx can be performed
var rx_south_buf_free = [2]bool { false, false };
// ready = true => rx done, compute can be performed
var rx_south_buf_ready = [2]bool { false, false };

// SOUTH TRAIN: buffer and status
var curr_rx_north_buf: u16 = 0;
var curr_rx_north_compute_buf: u16 = 0;
// free = true => rx can be performed
var rx_north_buf_free = [2]bool { false, false };
// ready = true => rx done, compute can be performed
var rx_north_buf_ready = [2]bool { false, false };

// tx task state
var tx_north_task_state: u16 = 0;
var tx_south_task_state: u16 = 0;

// west-east trains
var is_west_train_running: bool = false;
var is_east_train_running: bool = false;
var is_rx_east_done: bool = false;
var is_rx_west_done: bool = false;
var is_tx_west_done: bool = false;
var is_tx_east_done: bool = false;

// whether rx or tx is active (i.e. corresponding uthread may be active)
var is_tx_west_active: bool = false;
var is_tx_east_active: bool = false;
var is_rx_east_active: bool = false;
var is_rx_west_active: bool = false;

// which task to execute for tx/rx (0: data length, 1: row indices, 2: vals, 3: done)
var tx_west_task_state: u16 = 0;
var tx_east_task_state: u16 = 0;
var rx_west_task_state: u16 = 0;
var rx_east_task_state: u16 = 0;

// the current buffer for west train
var curr_tx_west_buf: u16 = 0;  // should start with 0
var curr_rx_east_buf: u16 = 1;  // should start with 1
var tx_west_buf_ready = [2]bool { false, false };   // if the bufs are ready for tx
var rx_east_buf_avail = [2]bool { false, false };   // if the bufs are avail for rx

// the current buffer for east train
var curr_tx_east_buf: u16 = 0;  // should start with 0
var curr_rx_west_buf: u16 = 1;  // should start with 1
var tx_east_buf_ready = [2]bool { false, false };   // if the bufs are ready for tx
var rx_west_buf_avail = [2]bool { false, false };   // if the bufs are avail for rx

// north train counts
var tx_north_count: u16 = 0;
var rx_south_count: u16 = 0;
var compute_north_count: u16 = 0;

// south train counts
var tx_south_count: u16 = 0;
var rx_north_count: u16 = 0;
var compute_south_count: u16 = 0;

// west train counts
var tx_west_count: u16 = 0;
var rx_east_count: u16 = 0;

// east train counts
var tx_east_count: u16 = 0;
var rx_west_count: u16 = 0;

// initialize the x indices for current local chunks of x vector
var north_train_x_low_idx: u16 = 0;
var north_train_x_high_idx: u16 = 0;
var south_train_x_low_idx: u16 = 0;
var south_train_x_high_idx: u16 = 0;

// index into the mat col array for the next segment of x to process
var north_train_start_idx: u16 = 0;
var south_train_start_idx: u16 = 0;

// calculate the local low and high row index values in dense output for this PE
var y_local_low: u16 = 0;
var y_local_high: u16 = 0;

var iter_counter: i16 = 0;

// WARNING: iter_count is set by RPC
var iter_count: i16;


// switch related:

const SWITCH_NOP = 0;   // nop
const SWITCH_ADV = 1;   // advance
const SWITCH_RST = 2;   // reset
const SWITCH_TRD = 3;   // teardown

// two advance commands to switch both rx and tx
var ctrl_adv = @constants([2]u32, switch_wavelet(advance_switch_cmd()));
const ctrl_adv_dsd = @get_dsd(mem1d_dsd, .{
    .tensor_access = |i|{2} -> ctrl_adv[i]
});
// reset command
var ctrl_rst = @constants([1]u32, switch_wavelet(reset_switch_cmd()));
const ctrl_rst_dsd = @get_dsd(mem1d_dsd, .{
    .tensor_access = |i|{1} -> ctrl_rst[i]
});


// construct control wavelet using command list
fn switch_wavelet(cmds: [8]u16) u32 {
    const noce = true;
    const null_color = 31;

    var ctrl_wavelet: u32 = 0;
    ctrl_wavelet |= (null_color & 0x1f) << 16;

    var curr_bit_pos: u16 = 22;
    var i: u16 = 0;
    while (i < 8): (i += 1) {
        ctrl_wavelet |= @as(u32, cmds[i] & 0x3) << curr_bit_pos;
        curr_bit_pos = (curr_bit_pos + 2) % 30;
        ctrl_wavelet |= @as(u32, noce) << curr_bit_pos;
        curr_bit_pos = curr_bit_pos + 1;
    }
    return ctrl_wavelet;
}

// command to reset current PE's switch to steady state
fn advance_switch_cmd() [8]u16 {
    var cmds = @constants([8]u16, SWITCH_NOP);
    cmds[0] = SWITCH_ADV;
    return cmds;
}

// command to reset current PE's switch to steady state
fn reset_switch_cmd() [8]u16 {
    var cmds = @constants([8]u16, SWITCH_NOP);
    cmds[0] = SWITCH_RST;
    return cmds;
}


// perform local compute using buffer for input x-vec (north)
fn compute_north_fn(y_vals_buf: *[max_local_nnz_rows]f32,
                    y_rows_buf: *[max_local_nnz_rows]u16,
                    x_buf: *[local_vec_sz]f32,
                    x_low_idx: u16,
                    x_high_idx: u16) void {
    // mat_bufs are stationary, y_bufs (sparse) are stationary, x_buf is shifting
    // 1. locate local col idx >= xlow in mat_col_idx
    // 2. until local col idx < xhigh, compute into local y_vals_buf
    var i: u16 = north_train_start_idx;
    // start_idx is initialized to `local_nnz_cols`, and goes in reverse
    while (i >= 0 and (mat_col_idx_buf.*)[i] >= x_low_idx and (mat_col_idx_buf.*)[i] < x_high_idx) : (i -= 1) {
        var local_col_idx = (mat_col_idx_buf.*)[i]; // local col idx
        var local_col_len = (mat_col_len_buf.*)[i]; // num nz per corresponding col
        var local_col_loc = (mat_col_loc_buf.*)[i]; // loc into mat_vals_bufs
        var x_val = (x_buf.*)[local_col_idx - x_low_idx];
        var j: u16 = 0;
        while (j < local_col_len and local_col_loc + j < (local_nnz.*)[0]) : (j += 1) {
            var y_idx = (mat_rows_buf.*)[local_col_loc + j];    // index into y_vals
            (y_vals_buf.*)[y_idx] += x_val * (mat_vals_buf.*)[local_col_loc + j];
            // NOTE: y_rows is constructed in preprocessing
        }
    }
    north_train_start_idx = i;    // store it for next call of the same train (note: north and south maintain separate states)
} // compute_north_fn()

// perform local compute using buffer for input x-vec (south)
fn compute_south_fn(y_vals_buf: *[max_local_nnz_rows]f32,
                    y_rows_buf: *[max_local_nnz_rows]u16,
                    x_buf: *[local_vec_sz]f32,
                    x_low_idx: u16,
                    x_high_idx: u16) void {
    // mat_bufs are stationary, y_bufs (sparse) are stationary, x_buf is shifting
    // 1. locate local col idx >= xlow in mat_col_idx
    // 2. until local col idx < xhigh, compute into local y_vals_buf
    var i: u16 = south_train_start_idx;
    // it is guaranteed that mat_col_idx_buf[i] >= x_low_idx
    while (i < (local_nnz_cols.*)[0] and (mat_col_idx_buf.*)[i] < x_high_idx) : (i += 1) {
        var local_col_idx = (mat_col_idx_buf.*)[i]; // local col idx
        var local_col_len = (mat_col_len_buf.*)[i]; // num nz per corresponding col
        var local_col_loc = (mat_col_loc_buf.*)[i]; // loc into mat_vals_bufs
        var x_val = (x_buf.*)[local_col_idx - x_low_idx];
        var j: u16 = 0;
        while (j < local_col_len and local_col_loc + j < (local_nnz.*)[0]) : (j += 1) {
            var y_idx = (mat_rows_buf.*)[local_col_loc + j];    // index into y_vals
            (y_vals_buf.*)[y_idx] += x_val * (mat_vals_buf.*)[local_col_loc + j];
            // NOTE: y_rows is constructed in preprocessing
        }
    }
    south_train_start_idx = i;    // store it for next call of the same train (note: north and south maintain separate states)
} // compute_south_fn()

// if something to do after reduction is finished, put it here.
fn post_processing() void {
    // reduce is done. record the time.
    timestamp.get_timestamp(&tsc_reduce_end_buffer);

    // count this iter
    iter_counter -= 1;
    if (iter_counter > 0) {
        // run it again
        // NOTE: With multiple iterations, the tsc_reduce_start and tsc_reduce_end will only
        // contain the last iteration's timestamps.
        @activate(init);
    } else {

        // spmv is done, continue next command
        //sys_mod.unblock_cmd_stream();
        f_callback();
    } // if iter_counter > 0
}

fn rx_north_done() void {
    is_rx_north_done = true;
    // start the tx of local x segment and
    // perform local compute (if not already done by the other train.)
    @activate(tx_south);
    @activate(compute_local);
}

fn rx_south_done() void {
    is_rx_south_done = true;
    // start the tx of local x segment and
    // perform local compute (if not already done by the other train.)
    @activate(tx_north);
    @activate(compute_local);
}

fn tx_north_done() void {
    if (!is_tx_north_done) {
        is_tx_north_done = true;
        // reset the switch for any subsequent spmv operations to work
        if (is_first_row or is_last_row) {
            // first and last rows do not need to switch
            @activate(tx_north);
        } else {
            @load_to_dsr(tx_north_dsr, tx_north_ctrl_rst_dsd, .{ .async = true, .activate = tx_north_task });
            @load_to_dsr(mem1d_north_dsr, ctrl_rst_dsd);
            @mov32(tx_north_dsr, mem1d_north_dsr, .{ .async = true });
            @bitcast(*u16, ce_inpq2_cfg).* |= UT_HI_PRI;  
        }
    } else {
        // tx is started only after rx is done
        if (is_compute_north_done) {
            is_north_train_running = false;
            compute_done();
        }
    }
}

fn tx_south_done() void {
    if (!is_tx_south_done) {
        is_tx_south_done = true;
        // reset the switch for any subsequent spmv operations to work
        if (is_last_row or is_first_row) {
            // last and first rows do not need to switch
            @activate(tx_south);
        } else {
            @load_to_dsr(tx_south_dsr, tx_south_ctrl_rst_dsd, .{ .async = true, .activate = tx_south_task });
            @load_to_dsr(mem1d_south_dsr, ctrl_rst_dsd);
            @mov32(tx_south_dsr, mem1d_south_dsr, .{ .async = true });
            @bitcast(*u16, ce_inpq3_cfg).* |= UT_HI_PRI;
        }
    } else {
        // tx is started only after rx is done
        if (is_compute_south_done) {
            is_south_train_running = false;
            compute_done(); 
       }
   }
}

// when a train ends, this is activated
// check if both trains have completely passed, then activate reduction process
fn compute_done() void {
    if (!is_north_train_running and !is_south_train_running and is_local_compute_done) {
        start_reduce();
    }
}

fn start_reduce() void {
    // Record the value of the timestamp counter
    timestamp.get_timestamp(&tsc_reduce_start_buffer);

    // construct west and east buffers using the north and south buffers
    // this assumes the data in north and south buffers are sorted w.r.t rows
    var west_idx: u16 = 0;
    var east_idx: u16 = 0;
    var north_idx: u16 = 0;
    var south_idx: u16 = 0;
    var curr_row: u16 = 0;
    var curr_val: f32 = 0.0;
    while (north_idx < (local_nnz_rows.*)[0] and south_idx < (local_nnz_rows.*)[0]) {
        curr_row = y_rows_north_buf[north_idx];
        if (curr_row >= y_pad_start_row_idx) {
            // if curr_row is a padding row, we can break because
            // all subsequent row indices will be higher
            break;
        }
        curr_val = y_vals_north_buf[north_idx] + y_vals_south_buf[south_idx];
        north_idx += 1;
        south_idx += 1;
        // now that we have the current reduced data, put it in its right place
        // TODO: Not necessary to compare every time since the row indices are sorted.
        if (curr_row < y_local_low) {
            // this goes to west
            y_vals_west_buf[west_idx] = curr_val;
            y_rows_west_buf[west_idx] = curr_row;
            west_idx += 1;
        } else if (curr_row >= y_local_high) {
            // this goes to east
            y_vals_east_buf[east_idx] = curr_val;
            y_rows_east_buf[east_idx] = curr_row;
            east_idx += 1;
        } else {
            // this is local data (dense)
            (y_local_buf.*)[curr_row - y_local_low] += curr_val;
        }   // if-else
    }   // while

    // reverse the west data so that scanning can be avoided in subsequent PEs to locate local data.
    // with rows sorted reverse, the local data for a PE will always be at the beginning of the valid data (with offset.)
    // use buf1 for this. Hence, west train will start tx from buf1, unlike east train.
    // note: padding is not sent, so no need to check for it in rx task
    var i: u16 = 0;
    while (i < west_idx) : (i += 1) {
        y_vals_north_buf[i] = y_vals_west_buf[west_idx - 1 - i];
        y_rows_north_buf[i] = y_rows_west_buf[west_idx - 1 - i];
    }

    // buf0: west and east buffers
    // buf1: north and south buffers (reuse)

    // store the tx sizes for east and west trains, and start offsets (init 0)
    // tx starts with buf0 for east train and buf1 for west train
    tx_east_buf0_len[0] = east_idx;  // valid size
    tx_east_buf_off[0] = 0;         // start offset
    tx_west_buf1_len[0] = west_idx;  // valid size
    tx_west_buf_off[1] = 0;         // start offset

    // NOTE: following flags assume everything is initialized to false
    // west train: tx west from buf1, rx east into buf0
    // corner case: first col does not need to tx west, so both available for rx east
    curr_tx_west_buf = 1;
    curr_rx_east_buf = 0;
    rx_east_buf_avail[curr_rx_east_buf] = true;
    tx_west_buf_ready[curr_tx_west_buf] = true;
    if (is_first_col) {
        // edge case: since there is no tx west, buf is also avail for rx
        rx_east_buf_avail[curr_tx_west_buf] = true;
    }

    // east train: tx east, rx west
    // last col does not need to tx east, so both available for rx west
    curr_tx_east_buf = 0;
    curr_rx_west_buf = 1;
    rx_west_buf_avail[curr_rx_west_buf] = true;
    tx_east_buf_ready[curr_tx_east_buf] = true;
    if (is_last_col) {
        // edge case: since there is no tx east, buf is also avail for rx
        rx_west_buf_avail[curr_tx_east_buf] = true;
    }

    // // DEBUGGING...
    // post_processing();

    // all data for a particular train is now uni-directional
    // (i.e. for west-bound train, a PE will only receive data for local or PEs to west, never east)
    // start west-bound and east-bound trains
    // meanwhile the north-south set of buffers are available to rx the data
    is_west_train_running = true;
    @activate(tx_west);
    @activate(rx_east);
    is_east_train_running = true;
    @activate(tx_east);
    @activate(rx_west);
}

fn reduce_done() void {
    if (!is_west_train_running and !is_east_train_running) {
        post_processing();
    }
}

// west-bound train local reduction
fn reduce_local_west() void {
    // NOTE: vals, rows are in reverse order to avoid scans to locate local data
    // all local data will be at the beginning of the buffer's valid data segment
    // west-train is unidirectional, nothing goes east from here
    var y_rows_buf: *[max_local_nnz_rows]u16 = &y_rows_west_buf;    // need to initialize to something, actual set later
    var y_vals_buf: *[max_local_nnz_rows]f32 = &y_vals_west_buf;
    var length: u16 = 0;
    if (curr_rx_east_buf == 0) {    // buf0 uses "west"
        length = rx_east_buf0_len[0];
        y_rows_buf = &y_rows_west_buf;
        y_vals_buf = &y_vals_west_buf;
    } else {                        // buf1 uses "north"
        length = rx_east_buf1_len[0];
        y_rows_buf = &y_rows_north_buf;
        y_vals_buf = &y_vals_north_buf;
    }
    var i: u16 = 0;
    while (i < length and
            i < max_local_nnz_rows and
            (y_rows_buf.*)[i] >= y_local_low) : (i += 1) {
        // skip if this element is padding
        // note: no need to check for padding as it is not sent in variable sized transfers
        // if ((y_rows_buf.*)[i] < y_pad_start_row_idx) {
            // reduce into local output dense vector
            (y_local_buf.*)[(y_rows_buf.*)[i] - y_local_low] += (y_vals_buf.*)[i];
        // }
    }
    // now we have reached the index where y_rows_buf[i] < y_local_low, hence,
    // the data [i:] needs to go on the west-train, make a note of that
    if (curr_rx_east_buf == 0) {
        tx_west_buf0_len[0] = length - i;
    } else {
        tx_west_buf1_len[0] = length - i;
    }
    tx_west_buf_off[curr_rx_east_buf] = @as(i16, i);     // offset for west train valid data
}

// east-bound train local reduction
fn reduce_local_east() void {
    // for east-train, local data (if any) starts at index 0
    // east-train is unidirectional, nothing goes west from here
    var y_rows_buf: *[max_local_nnz_rows]u16 = &y_rows_east_buf;    // need to initialize to something, actual set later
    var y_vals_buf: *[max_local_nnz_rows]f32 = &y_vals_east_buf;
    var length: u16 = 0;
    if (curr_rx_west_buf == 0) {    // buf0 uses "east"
        length = rx_west_buf0_len[0];
        y_rows_buf = &y_rows_east_buf;
        y_vals_buf = &y_vals_east_buf;
    } else {    // buf1 uses "south"
        length = rx_west_buf1_len[0];
        y_rows_buf = &y_rows_south_buf;
        y_vals_buf = &y_vals_south_buf;
    }
    var i: u16 = 0;
    // reduce local data, lowest row index is >= y_local_low
    while (i < length and
            i < max_local_nnz_rows and
            (y_rows_buf.*)[i] < y_local_high) : (i += 1) {
            // note: padding is not sent in variable size transfers, no need to check
            // (y_rows_buf.*)[i] < y_pad_start_row_idx) : (i += 1) {
        (y_local_buf.*)[(y_rows_buf.*)[i] - y_local_low] += (y_vals_buf.*)[i];
    }
    // starting at index i, all remaining data needs to go further east
    if (curr_rx_west_buf == 0) {
        tx_east_buf0_len[0] = length - i;
    } else {
        tx_east_buf1_len[0] = length - i;
    }
    tx_east_buf_off[curr_rx_west_buf] = @as(i16, i);
}


fn rx_east_done() void {
    is_rx_east_done = true;
    if (is_tx_west_done) {
        is_west_train_running = false;
    }
    reduce_done();
}

fn rx_west_done() void {
    is_rx_west_done = true;
    if (is_tx_east_done) {
        is_east_train_running = false;
    }
    reduce_done();
}

fn tx_east_done() void {
    is_tx_east_done = true;
    if (is_rx_west_done) {
        // both rx and tx are done. stop the train
        is_east_train_running = false;
    }
    reduce_done();
}

fn tx_west_done() void {
    is_tx_west_done = true;
    if (is_rx_east_done) {
        is_west_train_running = false;
    }
    reduce_done();
}

// tasks

// rx tasks

// SOUTH train
task curr_rx_north_done_task() void {
    // rx into curr rx buf is done, mark it ready for compute and activate compute and also next rx
    rx_north_buf_ready[curr_rx_north_buf] = true;
    // flip the rx buffer
    curr_rx_north_buf = 1 - curr_rx_north_buf;

    if (!is_compute_south_started) {
        is_compute_south_started = true;
        @activate(compute_south);   // need to ensure compute chain hasn't started yet
    }
    @activate(rx_north);        // guaranteed that rx is not running, so start next right away
}

// NORTH train
task curr_rx_south_done_task() void {
    // rx into curr rx buf is done, mark it ready for compute and activate compute and also next rx
    rx_south_buf_ready[curr_rx_south_buf] = true;
    // flip the rx buffer
    curr_rx_south_buf = 1 - curr_rx_south_buf;

    if (!is_compute_north_started) {
        is_compute_north_started = true;
        @activate(compute_north);   // need to ensure any previous compute is done first
    }
    @activate(rx_south);        // guaranteed that rx is not running, so start next right away
}

// this rx's from the south-bound train (coming from north)
task rx_north_task() void {
    if (rx_north_count == 0) {
        // nothing more to rx
        rx_north_done();
        return;
    }
    // rx from north into south-bound buf
    if (rx_north_buf_free[curr_rx_north_buf]) {
        // lock the buffer
        rx_north_buf_free[curr_rx_north_buf] = false;
        rx_north_count -= 1;
        @load_to_dsr(rx_north_dsr, rx_north_dsd, .{ .async = true, .activate = curr_rx_north_done_task } );
        if (curr_rx_north_buf == 0) {
            @load_to_dsr(mem1d_rx_north_dsr, south_buf0_dsd);
            @mov32(mem1d_rx_north_dsr, rx_north_dsr, .{ .async = true });
        } else {
            @load_to_dsr(mem1d_rx_north_dsr, south_buf1_dsd);
            @mov32(mem1d_rx_north_dsr, rx_north_dsr, .{ .async = true });
        }
    	// set queue to be higher priority than MT
    	// input queue config for UT0
        @bitcast(*u16, ce_inpq0_cfg).* |= UT_MED_PRI;
    } else {
        // wait until avail
        @activate(rx_north);
    }
}

// this rx's from the north-bound train (coming from south)
task rx_south_task() void {
    if (rx_south_count == 0) {
        // nothing more to rx
        rx_south_done();
        return;
    }
    // rx south into north-bound buf: ensure not to overwrite the tx buffer, which is sent last, after all rx is done.
    if (rx_south_buf_free[curr_rx_south_buf]) {
        // lock the buffer
        rx_south_buf_free[curr_rx_south_buf] = false;
        rx_south_count -= 1;
        @load_to_dsr(rx_south_dsr, rx_south_dsd, .{ .async = true, .activate = curr_rx_south_done_task });
        if (curr_rx_south_buf == 0) {
            @load_to_dsr(mem1d_rx_south_dsr, north_buf0_dsd);
            @mov32(mem1d_rx_south_dsr, rx_south_dsr, .{ .async = true });
        } else {
            @load_to_dsr(mem1d_rx_south_dsr, north_buf1_dsd);
            @mov32(mem1d_rx_south_dsr, rx_south_dsr, .{ .async = true });
        }
	    @bitcast(*u16, ce_inpq1_cfg).* |= UT_MED_PRI;
    } else {
        // wait until available
        @activate(rx_south);
    }
}

// rx for west-bound train from east
task rx_east_task() void {
    if (rx_east_count < 1) {
        // nothing more to rx
        rx_east_done();
        return;
    }
    if (rx_east_task_state == 0) {
        @assert(!is_rx_east_active);
        @assert(rx_east_count > 0);
        if (rx_east_buf_avail[curr_rx_east_buf]) {
            rx_east_task_state = 1;
            is_rx_east_active = true;
            // segment 1: rx size
            @load_to_dsr(rx_east_dsr, rx_east_dsd, .{ .async = true, .activate = rx_east_task });
            if (curr_rx_east_buf == 0) {
                @load_to_dsr(mem1d_rx_east_dsr, y_rx_east_buf0_len_dsd);
                @mov16(mem1d_rx_east_dsr, rx_east_dsr, .{ .async = true });
            } else {
                @load_to_dsr(mem1d_rx_east_dsr, y_rx_east_buf1_len_dsd);
                @mov16(mem1d_rx_east_dsr, rx_east_dsr, .{ .async = true });
            }
	        @bitcast(*u16, ce_inpq7_cfg).* |= UT_MED_PRI;
        } else {
            // buffer is not avail to rx next coach, try again
            @activate(rx_east);
        }
    } else if (rx_east_task_state == 1) {
        rx_east_task_state = 2;
        // segment 2: rx rows data, terminated by sen
        if (curr_rx_east_buf == 0) {
            if (rx_east_buf0_len[0] == 0) {
                @activate(rx_east);
                return;
            }
            const rx_east_data_dsd = @set_dsd_length(rx_east_dsd, rx_east_buf0_len[0]);
            @load_to_dsr(rx_east_dsr, rx_east_data_dsd, .{ .async = true, .activate = rx_east_task });
            @load_to_dsr(mem1d_rx_east_dsr, y_rows_west_buf0_dsd);
            @mov16(mem1d_rx_east_dsr, rx_east_dsr, .{ .async = true });
        } else {
            if (rx_east_buf1_len[0] == 0) {
                @activate(rx_east);
                return;
            }
            const rx_east_data_dsd = @set_dsd_length(rx_east_dsd, rx_east_buf1_len[0]);
            @load_to_dsr(rx_east_dsr, rx_east_data_dsd, .{ .async = true, .activate = rx_east_task });
            @load_to_dsr(mem1d_rx_east_dsr, y_rows_west_buf1_dsd);
            @mov16(mem1d_rx_east_dsr, rx_east_dsr, .{ .async = true });
        }
        @bitcast(*u16, ce_inpq7_cfg).* |= UT_MED_PRI;
    } else if (rx_east_task_state == 2) {
        rx_east_task_state = 3;
        // segment 3: rx values data, terminated by eos
        if (curr_rx_east_buf == 0) {
            if (rx_east_buf0_len[0] == 0) {
                @activate(rx_east);
                return;
            }
            const rx_east_data_dsd = @set_dsd_length(rx_east_dsd, rx_east_buf0_len[0]);
            @load_to_dsr(rx_east_dsr, rx_east_data_dsd, .{ .async = true, .activate = rx_east_task });
            @load_to_dsr(mem1d_rx_east_dsr, y_vals_west_buf0_dsd);
            @mov32(mem1d_rx_east_dsr, rx_east_dsr, .{ .async = true });
        } else {
            if (rx_east_buf1_len[0] == 0) {
                @activate(rx_east);
                return;
            }
            const rx_east_data_dsd = @set_dsd_length(rx_east_dsd, rx_east_buf1_len[0]);
            @load_to_dsr(rx_east_dsr, rx_east_data_dsd, .{ .async = true, .activate = rx_east_task });
            @load_to_dsr(mem1d_rx_east_dsr, y_vals_west_buf1_dsd);
            @mov32(mem1d_rx_east_dsr, rx_east_dsr, .{ .async = true });
        }
        @bitcast(*u16, ce_inpq7_cfg).* |= UT_MED_PRI;
    } else if (rx_east_task_state == 3) {
        rx_east_task_state = 0;     // reset state
        // rx has completed
        is_rx_east_active = false;
        rx_east_count -= 1;
        // current buf can now be used to perform local reduction
        reduce_local_west();

        // rx'd data has been used and is now ready for tx
        tx_west_buf_ready[curr_rx_east_buf] = true;
        // corner case: if this is first col, there is nothing to tx west,
        // so rx east can immediately rx the next coach
        if (is_first_col) {
            rx_east_buf_avail[curr_rx_east_buf] = true;
        } else {
            rx_east_buf_avail[curr_rx_east_buf] = false;
        }
        // flip the curr rx buf
        curr_rx_east_buf = 1 - curr_rx_east_buf;

        // reduction is done, start next rx_east in case there's more (rx_east_count > 0)
        @activate(rx_east);
    }
}


// rx for east-bound train from west
task rx_west_task() void {
    if (rx_west_count < 1) {
        // nothing more remains to rx
        rx_west_done();
        return;
    }
    if (rx_west_task_state == 0) {
        @assert(!is_rx_west_active);
        @assert(rx_west_count > 0);
        if (rx_west_buf_avail[curr_rx_west_buf]) {
            rx_west_task_state = 1;
            is_rx_west_active = true;
            // segment 1: rx size
            @load_to_dsr(rx_west_dsr, rx_west_dsd, .{ .async = true, .activate = rx_west_task });
            if (curr_rx_west_buf == 0) {
                @load_to_dsr(mem1d_rx_west_dsr, y_rx_west_buf0_len_dsd);
                @mov16(mem1d_rx_west_dsr, rx_west_dsr, .{ .async = true });
            } else {
                @load_to_dsr(mem1d_rx_west_dsr, y_rx_west_buf1_len_dsd);
                @mov16(mem1d_rx_west_dsr, rx_west_dsr, .{ .async = true });
            }
            @bitcast(*u16, ce_inpq6_cfg).* |= UT_MED_PRI;
        } else {
            // buffer is not avail for rx next coach, try again
            @activate(rx_west);
        }
    } else if (rx_west_task_state == 1) {
        rx_west_task_state = 2;
        // segment 2: rx rows data, terminated by sen
        if (curr_rx_west_buf == 0) {
            if (rx_west_buf0_len[0] == 0) {
                @activate(rx_west);
                return;
            }
            const rx_west_data_dsd = @set_dsd_length(rx_west_dsd, rx_west_buf0_len[0]);
            @load_to_dsr(rx_west_dsr, rx_west_data_dsd, .{ .async = true, .activate = rx_west_task });
            @load_to_dsr(mem1d_rx_west_dsr, y_rows_east_buf0_dsd);
            @mov16(mem1d_rx_west_dsr, rx_west_dsr, .{ .async = true });
        } else {
            if (rx_west_buf1_len[0] == 0) {
                @activate(rx_west);
                return;
            }
            const rx_west_data_dsd = @set_dsd_length(rx_west_dsd, rx_west_buf1_len[0]);
            @load_to_dsr(rx_west_dsr, rx_west_data_dsd, .{ .async = true, .activate = rx_west_task });
            @load_to_dsr(mem1d_rx_west_dsr, y_rows_east_buf1_dsd);
            @mov16(mem1d_rx_west_dsr, rx_west_dsr, .{ .async = true });
        }
        @bitcast(*u16, ce_inpq6_cfg).* |= UT_MED_PRI;
    } else if (rx_west_task_state == 2) {
        rx_west_task_state = 3;
        // segment 3: rx values data, terminated by eos
        if (curr_rx_west_buf == 0) {
            if (rx_west_buf0_len[0] == 0) {
                @activate(rx_west);
                return;
            }
            const rx_west_data_dsd = @set_dsd_length(rx_west_dsd, rx_west_buf0_len[0]);
            @load_to_dsr(rx_west_dsr, rx_west_data_dsd, .{ .async = true, .activate = rx_west_task });
            @load_to_dsr(mem1d_rx_west_dsr, y_vals_east_buf0_dsd);
            @mov32(mem1d_rx_west_dsr, rx_west_dsr, .{ .async = true });
        } else {
            if (rx_west_buf1_len[0] == 0) {
                @activate(rx_west);
                return;
            }
            const rx_west_data_dsd = @set_dsd_length(rx_west_dsd, rx_west_buf1_len[0]);
            @load_to_dsr(rx_west_dsr, rx_west_data_dsd, .{ .async = true, .activate = rx_west_task });
            @load_to_dsr(mem1d_rx_west_dsr, y_vals_east_buf1_dsd);
            @mov32(mem1d_rx_west_dsr, rx_west_dsr, .{ .async = true });
        }
        @bitcast(*u16, ce_inpq6_cfg).* |= UT_MED_PRI;
    } else if (rx_west_task_state == 3) {
        rx_west_task_state = 0;     // reset state
        // rx has completed
        is_rx_west_active = false;
        rx_west_count -= 1;
        // current buf can now be used to perform local reduction
        reduce_local_east();

        // rx'd data has been used and is now ready for tx
        tx_east_buf_ready[curr_rx_west_buf] = true;
        // corner case: if this is the last col, there is nothing to tx east,
        // so rx west can immediately rx the next coach.
        if (is_last_col) {
            rx_west_buf_avail[curr_rx_west_buf] = true;
        } else {
            rx_west_buf_avail[curr_rx_west_buf] = false;
        }
        // flip the curr rx buf
        curr_rx_west_buf = 1 - curr_rx_west_buf;

        // activate next rx incase there's more (rx_west_count > 0)
        @activate(rx_west);
    }
}


// tx tasks

// this is performed once and starts the north-bound train
task tx_north_task() void {
    if (tx_north_count == 0) {
        // all tx has been completed
        tx_north_done();
        return;
    }
    // send local data from north buf 0 onto north-bound train,
    if (tx_north_task_state == 0) {
        // tx buf is always ready at init
        tx_north_task_state = 1;
        @load_to_dsr(tx_north_dsr, tx_north_dsd, .{ .async = true, .activate = tx_north_task });
        @load_to_dsr(mem1d_north_dsr, x_tx_buf_dsd);
        @mov32(tx_north_dsr, mem1d_north_dsr, .{ .async = true });
        @bitcast(*u16, ce_inpq2_cfg).* |= UT_HI_PRI;
    } else if (tx_north_task_state == 1) {
        // tx north uthread has finished, send out two advance commands for next PE
        tx_north_task_state = 2;
        if (is_first_row or is_second_row) {
            // top two rows do not send out switch wavelets
            @activate(tx_north);
        } else {
            @load_to_dsr(tx_north_dsr, tx_north_ctrl_adv_dsd, .{ .async = true, .activate = tx_north_task });
            @load_to_dsr(mem1d_north_dsr, ctrl_adv_dsd);
            @mov32(tx_north_dsr, mem1d_north_dsr, .{ .async = true });
            @bitcast(*u16, ce_inpq2_cfg).* |= UT_HI_PRI;  
    	}
    } else if (tx_north_task_state == 2) {
        tx_north_task_state = 0;    // reset
        // tx north uthread has finished
        // 1. decrement count, 2. mark as not ready for tx, 3. mark as avail for rx
        tx_north_count -= 1;
        @activate(tx_north);
    }
}

// this starts the south-bound train
task tx_south_task() void {
    if (tx_south_count == 0) {
        // all tx has been completed
        tx_south_done();
        return;
    }
    // send local data from south buf 0 onto south-bound train,
    if (tx_south_task_state == 0) {
        // tx buf is always ready at init
        tx_south_task_state = 1;
        @load_to_dsr(tx_south_dsr, tx_south_dsd, .{ .async = true, .activate = tx_south_task });
        @load_to_dsr(mem1d_south_dsr, x_tx_buf_dsd);
        @mov32(tx_south_dsr, mem1d_south_dsr, .{ .async = true });
        @bitcast(*u16, ce_inpq3_cfg).* |= UT_HI_PRI;
    } else if (tx_south_task_state == 1) {
        // tx south uthread has finished, send out reset command to put the switch into steady state
        tx_south_task_state = 2;
        if (is_last_row or is_second_last_row) {
            // bottom two rows do not send out switch wavelets
            @activate(tx_south);
        } else {
            @load_to_dsr(tx_south_dsr, tx_south_ctrl_adv_dsd, .{ .async = true, .activate = tx_south_task });
            @load_to_dsr(mem1d_south_dsr, ctrl_adv_dsd);
            @mov32(tx_south_dsr, mem1d_south_dsr, .{ .async = true });
            @bitcast(*u16, ce_inpq3_cfg).* |= UT_HI_PRI;
        }
    } else if (tx_south_task_state == 2) {
        tx_south_task_state = 0;    // reset
        // tx south uthread has finished
        // 1. decrement count, 2. mark as not ready for tx, 3. mark as avail for rx
        tx_south_count -= 1;
        @activate(tx_south);
    }
}

// west-bound train data tx
task tx_west_task() void {
    if (tx_west_count < 1) {
        // all tx has been completed, nothing more remains
        tx_west_done();
        return;
    }
    if (tx_west_task_state == 0) {
        @assert(!is_tx_west_active);
        @assert(tx_west_count > 0);
        // start the tx chain if buffer is ready for tx
        if (tx_west_buf_ready[curr_tx_west_buf]) {
            tx_west_task_state = 1;
            is_tx_west_active = true;
            // segment 1: send the valid size first
            @load_to_dsr(tx_west_dsr, tx_west_dsd, .{ .async = true, .activate = tx_west_task });
            if (curr_tx_west_buf == 0) {
                @load_to_dsr(mem1d_west_dsr, y_tx_west_buf0_len_dsd);
                @mov16(tx_west_dsr, mem1d_west_dsr, .{ .async = true });
            } else {
                @load_to_dsr(mem1d_west_dsr, y_tx_west_buf1_len_dsd);
                @mov16(tx_west_dsr, mem1d_west_dsr, .{ .async = true });
            }
            @bitcast(*u16, ce_inpq2_cfg).* |= UT_HI_PRI;
        } else {
            // curr tx west buf not ready for tx, try again
            @activate(tx_west);
        }
    } else if (tx_west_task_state == 1) {
        tx_west_task_state = 2;
        // segment 2: set the base addr and valid data size to tx the y_rows data
        // note: padding is not sent
        if (curr_tx_west_buf == 0) {
            if (tx_west_buf0_len[0] == 0) {
                @activate(tx_west);
                return;
            }
            if (tx_west_buf_off[0] == 0) {
                const y_rows_west_dsd = @set_dsd_length(y_rows_west_buf0_dsd, tx_west_buf0_len[0]);
                @load_to_dsr(tx_west_dsr, tx_west_dsd, .{ .async = true, .activate = tx_west_task });
                @load_to_dsr(mem1d_west_dsr, y_rows_west_dsd);
                @mov16(tx_west_dsr, mem1d_west_dsr, .{ .async = true });
            } else {
                const y_rows_west_dsd = @increment_dsd_offset(@set_dsd_length(y_rows_west_buf0_dsd, tx_west_buf0_len[0]), tx_west_buf_off[0], u16);
                @load_to_dsr(tx_west_dsr, tx_west_dsd, .{ .async = true, .activate = tx_west_task });
                @load_to_dsr(mem1d_west_dsr, y_rows_west_dsd);
                @mov16(tx_west_dsr, mem1d_west_dsr, .{ .async = true });
            }
        } else {
            if (tx_west_buf1_len[0] == 0) {
                @activate(tx_west);
                return;
            }
            if (tx_west_buf_off[1] == 0) {
                const y_rows_west_dsd = @set_dsd_length(y_rows_west_buf1_dsd, tx_west_buf1_len[0]);
                @load_to_dsr(tx_west_dsr, tx_west_dsd, .{ .async = true, .activate = tx_west_task });
                @load_to_dsr(mem1d_west_dsr, y_rows_west_dsd);
                @mov16(tx_west_dsr, mem1d_west_dsr, .{ .async = true });
            } else {
                const y_rows_west_dsd = @increment_dsd_offset(@set_dsd_length(y_rows_west_buf1_dsd, tx_west_buf1_len[0]), tx_west_buf_off[1], u16);
                @load_to_dsr(tx_west_dsr, tx_west_dsd, .{ .async = true, .activate = tx_west_task });
                @load_to_dsr(mem1d_west_dsr, y_rows_west_dsd);
                @mov16(tx_west_dsr, mem1d_west_dsr, .{ .async = true });
            }
        }
        @bitcast(*u16, ce_inpq2_cfg).* |= UT_HI_PRI;
    } else if (tx_west_task_state == 2) {
        tx_west_task_state = 4;
        // segment 3: send vals data
        if (curr_tx_west_buf == 0) {
            if (tx_west_buf0_len[0] == 0) {
                @activate(tx_west);
                return;
            }
            if (tx_west_buf_off[0] == 0) {
                const y_vals_west_dsd = @set_dsd_length(y_vals_west_buf0_dsd, tx_west_buf0_len[0]);
                @load_to_dsr(tx_west_dsr, tx_west_dsd, .{ .async = true, .activate = tx_west_task });
                @load_to_dsr(mem1d_west_dsr, y_vals_west_dsd);
                @mov32(tx_west_dsr, mem1d_west_dsr, .{ .async = true });
            } else {
                const y_vals_west_dsd = @increment_dsd_offset(@set_dsd_length(y_vals_west_buf0_dsd, tx_west_buf0_len[0]), tx_west_buf_off[0], f32);
                @load_to_dsr(tx_west_dsr, tx_west_dsd, .{ .async = true, .activate = tx_west_task });
                @load_to_dsr(mem1d_west_dsr, y_vals_west_dsd);
                @mov32(tx_west_dsr, mem1d_west_dsr, .{ .async = true });
            }
        } else {
            if (tx_west_buf1_len[0] == 0) {
                @activate(tx_west);
                return;
            }
            if (tx_west_buf_off[1] == 0) {
                const y_vals_west_dsd = @set_dsd_length(y_vals_west_buf1_dsd, tx_west_buf1_len[0]);
                @load_to_dsr(tx_west_dsr, tx_west_dsd, .{ .async = true, .activate = tx_west_task });
                @load_to_dsr(mem1d_west_dsr, y_vals_west_dsd);
                @mov32(tx_west_dsr, mem1d_west_dsr, .{ .async = true });
            } else {
                const y_vals_west_dsd = @increment_dsd_offset(@set_dsd_length(y_vals_west_buf1_dsd, tx_west_buf1_len[0]), tx_west_buf_off[1], f32);
                @load_to_dsr(tx_west_dsr, tx_west_dsd, .{ .async = true, .activate = tx_west_task });
                @load_to_dsr(mem1d_west_dsr, y_vals_west_dsd);
                @mov32(tx_west_dsr, mem1d_west_dsr, .{ .async = true });
            }
        }
        @bitcast(*u16, ce_inpq2_cfg).* |= UT_HI_PRI;
    } else if (tx_west_task_state == 4) {
        tx_west_task_state = 0; // reset tx state
        // current tx west has completed
        is_tx_west_active = false;
        tx_west_count -= 1;
        // curr_tx_west_buf is now available for rx of next coach.
        rx_east_buf_avail[curr_tx_west_buf] = true;
        // while its not ready for tx
        tx_west_buf_ready[curr_tx_west_buf] = false;
        // corner case: if this is the last col, there is nothing to rx east
        // in this case, the count will be 0 here.

        // flip the curr tx west buf
        curr_tx_west_buf = 1 - curr_tx_west_buf;
    
        // start next tx incase there is more (tx_west_count > 0)
        @activate(tx_west);
    }
}

// east-bound train data tx
task tx_east_task() void {
    if (tx_east_count < 1) {
        // all tx has been completed
        tx_east_done();
        return;
    }
    if (tx_east_task_state == 0) {
        @assert(!is_tx_east_active);
        @assert(tx_east_count > 0);
        // start the tx chain if buffer is ready for tx, and there is tx count remaining
        if (tx_east_buf_ready[curr_tx_east_buf]) {
            tx_east_task_state = 1;
            is_tx_east_active = true;
            // segment 1: send the valid size first
            @load_to_dsr(tx_east_dsr, tx_east_dsd, .{ .async = true, .activate = tx_east_task });
            if (curr_tx_east_buf == 0) {
                @load_to_dsr(mem1d_east_dsr, y_tx_east_buf0_len_dsd);
                @mov16(tx_east_dsr, mem1d_east_dsr, .{ .async = true });
            } else {
                @load_to_dsr(mem1d_east_dsr, y_tx_east_buf1_len_dsd);
                @mov16(tx_east_dsr, mem1d_east_dsr, .{ .async = true });
            }
            @bitcast(*u16, ce_inpq3_cfg).* |= UT_HI_PRI;
        } else {
            // curr tx east buf is not ready, try again
            @activate(tx_east);
        }
    } else if (tx_east_task_state == 1) {
        tx_east_task_state = 2;
        // segment 2: set the base addr and valid data size to tx the y_rows data
        // note: padding is not sent, so no need to check for it in rx task
        if (curr_tx_east_buf == 0) {
            if (tx_east_buf0_len[0] == 0) {
                @activate(tx_east);
                return;
            }
            if (tx_east_buf_off[0] == 0) {
                const y_rows_east_dsd = @set_dsd_length(y_rows_east_buf0_dsd, tx_east_buf0_len[0]);
                @load_to_dsr(tx_east_dsr, tx_east_dsd, .{ .async = true, .activate = tx_east_task });
                @load_to_dsr(mem1d_east_dsr, y_rows_east_dsd);
                @mov16(tx_east_dsr, mem1d_east_dsr, .{ .async = true });
            } else {
                const y_rows_east_dsd = @increment_dsd_offset(@set_dsd_length(y_rows_east_buf0_dsd, tx_east_buf0_len[0]), tx_east_buf_off[0], u16);
                @load_to_dsr(tx_east_dsr, tx_east_dsd, .{ .async = true, .activate = tx_east_task });
                @load_to_dsr(mem1d_east_dsr, y_rows_east_dsd);
                @mov16(tx_east_dsr, mem1d_east_dsr, .{ .async = true });
            }
        } else {
            if (tx_east_buf1_len[0] == 0) {
                @activate(tx_east);
                return;
            }
            if (tx_east_buf_off[1] == 0) {
                const y_rows_east_dsd = @set_dsd_length(y_rows_east_buf1_dsd, tx_east_buf1_len[0]);
                @load_to_dsr(tx_east_dsr, tx_east_dsd, .{ .async = true, .activate = tx_east_task });
                @load_to_dsr(mem1d_east_dsr, y_rows_east_dsd);
                @mov16(tx_east_dsr, mem1d_east_dsr, .{ .async = true });
            } else {
                const y_rows_east_dsd = @increment_dsd_offset(@set_dsd_length(y_rows_east_buf1_dsd, tx_east_buf1_len[0]), tx_east_buf_off[1], u16);
                @load_to_dsr(tx_east_dsr, tx_east_dsd, .{ .async = true, .activate = tx_east_task });
                @load_to_dsr(mem1d_east_dsr, y_rows_east_dsd);
                @mov16(tx_east_dsr, mem1d_east_dsr, .{ .async = true });
            }
        }
        @bitcast(*u16, ce_inpq3_cfg).* |= UT_HI_PRI;
    } else if (tx_east_task_state == 2) {
        tx_east_task_state = 4;
        // segment 2: send vals data
        if (curr_tx_east_buf == 0) {
            if (tx_east_buf0_len[0] == 0) {
                @activate(tx_east);
                return;
            }
            if (tx_east_buf_off[0] == 0) {
                const y_vals_east_dsd = @set_dsd_length(y_vals_east_buf0_dsd, tx_east_buf0_len[0]);
                @load_to_dsr(tx_east_dsr, tx_east_dsd, .{ .async = true, .activate = tx_east_task });
                @load_to_dsr(mem1d_east_dsr, y_vals_east_dsd);
                @mov32(tx_east_dsr, mem1d_east_dsr, .{ .async = true });
            } else {
                const y_vals_east_dsd = @increment_dsd_offset(@set_dsd_length(y_vals_east_buf0_dsd, tx_east_buf0_len[0]), tx_east_buf_off[0], f32);
                @load_to_dsr(tx_east_dsr, tx_east_dsd, .{ .async = true, .activate = tx_east_task });
                @load_to_dsr(mem1d_east_dsr, y_vals_east_dsd);
                @mov32(tx_east_dsr, mem1d_east_dsr, .{ .async = true });
            }
        } else {
            if (tx_east_buf1_len[0] == 0) {
                @activate(tx_east);
                return;
            }
            if (tx_east_buf_off[1] == 0) {
                const y_vals_east_dsd = @set_dsd_length(y_vals_east_buf1_dsd, tx_east_buf1_len[0]);
                @load_to_dsr(tx_east_dsr, tx_east_dsd, .{ .async = true, .activate = tx_east_task });
                @load_to_dsr(mem1d_east_dsr, y_vals_east_dsd);
                @mov32(tx_east_dsr, mem1d_east_dsr, .{ .async = true });
            } else {
                const y_vals_east_dsd = @increment_dsd_offset(@set_dsd_length(y_vals_east_buf1_dsd, tx_east_buf1_len[0]), tx_east_buf_off[1], f32);
                @load_to_dsr(tx_east_dsr, tx_east_dsd, .{ .async = true, .activate = tx_east_task });
                @load_to_dsr(mem1d_east_dsr, y_vals_east_dsd);
                @mov32(tx_east_dsr, mem1d_east_dsr, .{ .async = true });
            }
        }
        @bitcast(*u16, ce_inpq3_cfg).* |= UT_HI_PRI;
    } else if (tx_east_task_state == 4) {
        tx_east_task_state = 0; // reset tx state
        // current tx east has completed
        is_tx_east_active = false;
        tx_east_count -= 1;
        // curr_tx_east_buf is now available for rx of next coach.
        rx_west_buf_avail[curr_tx_east_buf] = true;
        // while the tx buffer us not ready
        tx_east_buf_ready[curr_tx_east_buf] = false;
        // corner case: if this is the first col, there is nothing to rx west
        // in this case, the count will be 0 here.

        // flip the curr tx east buf
        curr_tx_east_buf = 1 - curr_tx_east_buf;
    
        // start next tx incase there is more (tx_east_count > 0)
        @activate(tx_east);
    }
}

// local compute tasks

task compute_local_task() void {
    if (is_local_compute_done) {
        // if the other train did it, do not do again.
        return;
    }
    is_local_compute_done = true;
    var local_x_low_idx = prow_id * local_vec_sz;
    var local_x_high_idx = local_x_low_idx + local_vec_sz;
    if (is_rx_south_done) {     // north train has already ended, so can use its current index value
        // giving preference to north buffer
        // use the north train function
        compute_north_fn(&y_vals_north_buf, &y_rows_north_buf,
                            x_tx_buf,  // use local buffer
                            local_x_low_idx, local_x_high_idx);
    } else {                    // south train has already ended. so can use its current index value
        // use the south train function
        compute_south_fn(&y_vals_south_buf, &y_rows_south_buf,
                            x_tx_buf,  // use local buffer
                            local_x_low_idx, local_x_high_idx);
    }
    compute_done();
} // compute_local_task()

fn compute_north_done_fn() void {
    // tx and local compute are already activated when rx finished
    // this function will strictly be executed only after all rx are done.
    is_compute_north_done = true;
    // either tx done or compute done will be the last things to execute for phase1
    if (is_rx_south_done and is_tx_north_done) {
        is_north_train_running = false;
        compute_done();
    }
}

fn compute_south_done_fn() void {
    // tx and local compute are already activated when rx finished
    // this function will strictly be executed only after all rx are done.
    is_compute_south_done = true;
    // either tx done or compute done will be the last things to execute for phase1
    if (is_rx_north_done and is_tx_south_done) {
        is_south_train_running = false;
        compute_done();
    }
}

task compute_north_task() void {
    if (compute_north_count == 0) {
        compute_north_done_fn();
        return;
    }
    if (rx_south_buf_ready[curr_rx_south_compute_buf]) {
        // mark it as not ready as it is now being used
        rx_south_buf_ready[curr_rx_south_compute_buf] = false;
        // decrement compute count
        compute_north_count -= 1;
        // data ready, use north-bound buffer for compute
        if (curr_rx_south_compute_buf == 0) {
            compute_north_fn(&y_vals_north_buf, &y_rows_north_buf,
                                &x_north_buf0,
                                north_train_x_low_idx, north_train_x_high_idx);
        } else {
            compute_north_fn(&y_vals_north_buf, &y_rows_north_buf,
                                &x_north_buf1,
                                north_train_x_low_idx, north_train_x_high_idx);
        }
        // this starts with data from last row and moves up to first row
        // hence decrement for the next segment
        north_train_x_low_idx -= local_vec_sz;
        north_train_x_high_idx -= local_vec_sz;
        // data has been used, buffer is now free to rx next segment, unlock
        rx_south_buf_free[curr_rx_south_compute_buf] = true;
        // and flip to next buffer for compute
        curr_rx_south_compute_buf = 1 - curr_rx_south_compute_buf;
    }
    // and start next compute
    @activate(compute_north);
} // compute_north_task()

task compute_south_task() void {
    if (compute_south_count == 0) {
        compute_south_done_fn();
        return;
    }
    if (rx_north_buf_ready[curr_rx_north_compute_buf]) {
        rx_north_buf_ready[curr_rx_north_compute_buf] = false;
        compute_south_count -= 1;
        // rx is done use south-bound buffer for compute
        if (curr_rx_north_compute_buf == 0) {
            compute_south_fn(&y_vals_south_buf, &y_rows_south_buf,
                                &x_south_buf0,
                                south_train_x_low_idx, south_train_x_high_idx);
        } else {
            compute_south_fn(&y_vals_south_buf, &y_rows_south_buf,
                                &x_south_buf1,
                                south_train_x_low_idx, south_train_x_high_idx);
        }
        // this start with data from the first row and moved down to last row
        // hence increment for the next chunk
        south_train_x_low_idx += local_vec_sz;
        south_train_x_high_idx += local_vec_sz;
        // data has been used, buffer us now free to rx next segment, unlock
        rx_north_buf_free[curr_rx_north_compute_buf] = true;
        // and flip to next buffer as curr_rx_south_buf
        curr_rx_north_compute_buf = 1 - curr_rx_north_compute_buf;
    }    
    // start next compute
    @activate(compute_south);
} // compute_north_task()



fn is_less_than(aval: *[3]u16, bval: *[3]u16) bool {
    if ((aval.*)[2] < (bval.*)[2]) {
        return true;
    } else if ((aval.*)[2] == (bval.*)[2]) {
        if ((aval.*)[1] < (bval.*)[1]) {
            return true;
        } else if ((aval.*)[1] == (bval.*)[1]) {
            if ((aval.*)[0] < (bval.*)[0]) {
                return true;
            }
        }
    }
    return false;
}

fn zero_out_flags_and_data() void {
    // init all flags to false
    is_rx_north_done = false;
    is_tx_north_done = false;
    is_rx_south_done = false;
    is_tx_south_done = false;
    is_compute_north_done = false;
    is_compute_south_done = false;
    is_local_compute_done = false;
    is_compute_north_started = false;
    is_compute_south_started = false;
    is_north_train_running = false;
    is_south_train_running = false;
    rx_south_buf_free[0] = false;
    rx_south_buf_free[1] = false;
    rx_south_buf_ready[0] = false;
    rx_south_buf_ready[1] = false;
    rx_north_buf_free[0] = false;
    rx_north_buf_free[1] = false;
    rx_north_buf_ready[0] = false;
    rx_north_buf_ready[1] = false;
    is_west_train_running = false;
    is_east_train_running = false;
    is_rx_east_done = false;
    is_rx_west_done = false;
    is_tx_west_done = false;
    is_tx_east_done = false;
    is_tx_west_active = false;
    is_tx_east_active = false;
    is_rx_east_active = false;
    is_rx_west_active = false;
    tx_west_buf_ready[0] = false;
    tx_west_buf_ready[1] = false;
    rx_east_buf_avail[0] = false;
    rx_east_buf_avail[1] = false;
    tx_east_buf_ready[0] = false;
    tx_east_buf_ready[1] = false;
    rx_west_buf_avail[0] = false;
    rx_west_buf_avail[1] = false;

    // first copy data from the y_*_init_buf into the working buffer y_*_north_buf (y_*_west_buf1)
    // called by init(), whole spmv is done, can reuse any DSR
    @load_to_dsr(tx_south_dsr, y_rows_west_buf1_dsd);
    @load_to_dsr(mem1d_south_dsr, y_rows_init_buf_dsd);
    @mov16(tx_south_dsr, mem1d_south_dsr);

    // zero out the local and north/south buffers for output vector
    var i: u16 = 0;
    while (i < local_out_vec_sz) : (i += 1) {
        (y_local_buf.*)[i] = 0.0;
    }
    i = 0;
    while (i < (local_nnz_rows.*)[0]) : (i += 1) {
        y_vals_north_buf[i] = 0.0;
        y_vals_south_buf[i] = 0.0;
    }
}

task init_task() void {

    zero_out_flags_and_data();

    // each train starts with every PE tx'ing out its local data once.
    // this is followed by rx'ing each segmenet and processing it.
    // router will forward it to the outgoing direction

    // north train counts
    if (is_first_row) {
        tx_north_count = 0;
    } else {
        tx_north_count = 1;
    }
    rx_south_count = prows - prow_id - 1;
    compute_north_count = rx_south_count;   // should be equal
    if (compute_north_count == 0) {
        // if there is nothing to rx for compute, mark it as done
        is_compute_north_done = true;
    }

    // south train counts
    if (is_last_row) {
        tx_south_count = 0;
    } else {
        tx_south_count = 1;
    }
    rx_north_count = prow_id;
    compute_south_count = rx_north_count;   // should be equal
    if (compute_south_count == 0) {
        // if there is nothing to rx for compute, mark it as done
        is_compute_south_done = true;
    }

    // west train counts
    tx_west_count = (pcols - pcol_id) % pcols;
    rx_east_count = pcols - pcol_id - 1;

    // east train counts
    tx_east_count = (pcol_id + 1) % pcols;
    rx_west_count = pcol_id;

    // calculate the local low and high row index values in dense output for this PE
    // NOTE: these are indices into the fully padded output vector
    y_local_low = local_out_vec_sz * pcol_id;          // incl
    y_local_high = local_out_vec_sz * (pcol_id + 1);   // excl
    // all row idx > y_pad_start_idx are padding, need to be left as 0

    // first preprocess row indices:
    // 1. extract local row indices from mat_rows,
    // 2. sort them,
    // 3. and then put them in y_rows,
    // 4. while updating mat_rows to index into y_rows
    // NOTE: This is now done outside of the kernel as a preprocessing step

    // initialize the x indices to start from one end
    // (north train starts with data from last row, south train start from first row)
    north_train_x_low_idx = (prows - 1) * local_vec_sz;             // north train's start low index
    north_train_x_high_idx = north_train_x_low_idx + local_vec_sz;  // north train's start high index
    south_train_x_low_idx = 0;                                      // south train's start low index
    south_train_x_high_idx = south_train_x_low_idx + local_vec_sz;  // south train's start high index

    north_train_start_idx = (local_nnz_cols.*)[0] - 1;     // this will be in reverse (decreasing) order
    south_train_start_idx = 0;                      // this will be in increasing order

    // both rx buffers are initially available (north train, rx from south)
    rx_south_buf_free[0] = true;
    rx_south_buf_free[1] = true;
    // neither is ready for compute yet
    rx_south_buf_ready[0] = false;
    rx_south_buf_ready[1] = false;
    curr_rx_south_buf = 0;      // start with 0
    curr_rx_south_compute_buf = 0; // buffer to use for compute (current compute)

    // both rx buffers are initially available (south train, rx from north)
    rx_north_buf_free[0] = true;
    rx_north_buf_free[1] = true;
    // neither is ready for compute yet
    rx_north_buf_ready[0] = false;
    rx_north_buf_ready[1] = false;
    curr_rx_north_buf = 0;      // start with 0
    curr_rx_north_compute_buf = 0; // buffer to use for compute (current compute)

    // start the north-moving train
    // starts with the last row PE sending out its data, followed by ctrl to switch router on PE above.
    // All PEs receive the data in the same order: from prows - 1, prows - 2, prows - 3 ... until 1.
    // Top two PEs do not need to send out any ctrl, and the first row PE never needs to switch.
    
    // debug: uncomment to disable north train
    // rx_south_count = 0;
    // tx_north_count = 0;
    // is_compute_north_done = true;

    is_north_train_running = true;
    @activate(rx_south);

    // start the south-moving train
    
    // debug: uncomment to disable south train
    // rx_north_count = 0;
    // tx_south_count = 0;
    // is_compute_south_done = true;

    is_south_train_running = true;
    @activate(rx_north);
}


// compute y = A*x
fn spmv(x : *[local_vec_sz]f32, y: *[local_out_vec_sz]f32) void {

    // setup x and y
    x_tx_buf = x;
    y_local_buf = y;
    // reset the base of the DSD
    x_tx_buf_dsd = @set_dsd_base_addr(x_tx_buf_dsd, x);

    // (px, py) = (pcol_id, prow_id) is decided at runtime
    // spmv kernel does not have a task to initialize (px, py), so
    // setup coordinate (px, py) for every spmv() call
    pcol_id = get_x_coord();
    prow_id = get_y_coord();

    // start spmv once
    // WARNING: if the user wants to measure the spmv couple of times,
    // just replace "1" by "number of iterations".
    iter_count = @as(i16, 1);
    // set initial counter because init_task is called "iter_count" times
    iter_counter = iter_count;
    @activate(init);
    // post_processing() is the last task of spmv, it triggers f_callback()
    // when spmv is done
}


// comptime

comptime {

    // bind tasks
    @bind_local_task(init_task, init);

    @bind_local_task(rx_north_task, rx_north);
    @bind_local_task(rx_south_task, rx_south);
    @bind_local_task(tx_north_task, tx_north);
    @bind_local_task(tx_south_task, tx_south);
    @bind_local_task(compute_north_task, compute_north);
    @bind_local_task(compute_south_task, compute_south);
    @bind_local_task(compute_local_task, compute_local);
    @bind_local_task(curr_rx_north_done_task, curr_rx_north_done);
    @bind_local_task(curr_rx_south_done_task, curr_rx_south_done);

    @bind_local_task(rx_west_task, rx_west);
    @bind_local_task(rx_east_task, rx_east);
    @bind_local_task(tx_west_task, tx_west);
    @bind_local_task(tx_east_task, tx_east);

} // comptime


// binding a color to an input queue.
// This is necessary when an explicit DSR binds to a fabin DSD because
// the compiler no longer can generate the instruction to set up the
// config register of input queue.
comptime {
    // color south_train maps to RX_NORTH_Q: u16 = 4;
    // color north_train maps to RX_SOUTH_Q: u16 = 1;
    // color rx_east_train maps to RX_WEST_Q: u16 = 6;
    // color rx_west_train maps to RX_EAST_Q: u16 = 7;
    @initialize_queue(@get_input_queue(RX_NORTH_Q), .{.color = south_train});
    @initialize_queue(@get_input_queue(RX_SOUTH_Q), .{.color = north_train});
    @initialize_queue(@get_input_queue(RX_WEST_Q), .{.color = rx_east_train});
    @initialize_queue(@get_input_queue(RX_EAST_Q), .{.color = rx_west_train});
}

comptime {

    const north_train_route = .{
      .routes= .{
        .rx = .{ SOUTH },
        .tx = .{ RAMP, NORTH },
      },
      .switches=.{
        .pos1 = .{ .tx = NORTH },   // first change tx to just north
        .pos2 = .{ .rx = RAMP },    // then change rx from ramp
        .pos3 = .{ .invalid = true },
        .ring_mode = false,
        .current_switch_pos = 0,
        .pop_mode = .{ .pop_on_advance = true },
       },
    };

    const north_train_route_first_row = .{  // no switching
      .routes= .{
        .rx = .{ SOUTH },
        .tx = .{ RAMP },
       },
    };
    const north_train_route_last_row = .{  // no switching
      .routes= .{
        .rx = .{ RAMP },
        .tx = .{ NORTH },
       },
    };

    const south_train_route = .{
      .routes= .{
        .rx = .{ NORTH },
        .tx = .{ RAMP, SOUTH },
       },
      .switches=.{
        .pos1 = .{ .tx = SOUTH },   // first change tx to just north
        .pos2 = .{ .rx = RAMP },    // then change rx from ramp
        .pos3 = .{ .invalid = true },
        .ring_mode = false,
        .current_switch_pos = 0,
        .pop_mode = .{ .pop_on_advance = true },
       },
    };
    const south_train_route_first_row = .{  // no switching
      .routes= .{
        .rx = .{ RAMP },
        .tx = .{ SOUTH },
       },
    };

    const south_train_route_last_row = .{   // no switching
      .routes= .{
        .rx = .{ NORTH },
        .tx = .{ RAMP },
       },
    };

    const west_train_in = .{ .routes= .{ .rx = .{ EAST }, .tx = .{ RAMP } } };
    const west_train_out = .{ .routes= .{ .rx = .{ RAMP }, .tx = .{ WEST } } };
    const east_train_in = .{ .routes= .{ .rx = .{ WEST }, .tx = .{ RAMP } } };
    const east_train_out = .{ .routes= .{ .rx = .{ RAMP }, .tx = .{ EAST } } };

    if (is_first_row) {
        // first row, (prow_id == 0)
        @set_local_color_config(north_train, north_train_route_first_row);
        @set_local_color_config(south_train, south_train_route_first_row);
    } else if (is_last_row) {
        // last row, (prow_id == prows - 1)
        @set_local_color_config(north_train, north_train_route_last_row);
        @set_local_color_config(south_train, south_train_route_last_row);
    } else {
        // all middle rows
        @set_local_color_config(north_train, north_train_route);
        @set_local_color_config(south_train, south_train_route);
    }

    if (!is_last_col) {
        // all but last col, (pcol_id < pcols - 1)
        @set_local_color_config(rx_west_train, west_train_in);
        @set_local_color_config(tx_east_train, east_train_out);
    }
    if (!is_first_col) {
        // all but first col, (pcol_id > 0)
        @set_local_color_config(rx_east_train, east_train_in);
        @set_local_color_config(tx_west_train, west_train_out);
    }
}

allreduce2R1E/layout.csl

param colors: [2]color;
param entrypoints: [1]local_task_id;
param width: i16 ;   // width of the core
param height: i16 ;  // height of the core


const C0: color = colors[0];
const C1: color = colors[1];

// entrypoints of allreduce module
// LOCK runs only if teardown is received and the operation is done
// LOCK performs the state transition
// teardown handler activates LOCK
// the operation blocks LOCK in the beginning and unblocks it when it finishes
const C_LOCK: local_task_id = entrypoints[0];

fn get_params(px:i16, py:i16) comptime_struct {

    var first_py: bool = (0 == py);
    var last_py: bool = ((height-1) == py);

    var first_px: bool = (0 == px);
    var last_px: bool = ((width-1) == px);

    return .{
        .first_px = first_px,
        .last_px = last_px,
        .first_py = first_py,
        .last_py = last_py,
        .C_ROUTE = C0,
        .C_DISPATCH = C1,
        .C_LOCK = C_LOCK,
        .width = width,
        .height = height
    };
}

allreduce2R1E/pe.csl

// TODO [perf]: if MAX_ZDIM = 1, no need to update length of DSDs

// allreduce module has the following three operations
//  - row reduction
//  - column reduction
//  - broadcasting
//
// It only uses a single routable color, three entrypoints and a single
// input/output queue. Any user's kernel can combine this module with other
// modules without running out of resources.
//
// 1. row reduction
//   The reduction is from left to right. The last PE (px=width-1) receives
//   all data from the neighbors one by one. The result is stored in the last
//   PE, other PEs do not change the content.
//
// 2. column reduction
//   The reduction is from north to south. The last PE (py=height-1) receives
//   all data from the neighbors one by one. The result is stored in the last
//   PE, other PEs do not change the content.
//
// 3. broadcast
//   The right-bottom PE (px=width-1, py=height-1) broadcasts the data upwards to
//   whole column, then each PE in this column broadcasts data to its west neighbors.
//

// How to assign explicit DSRs
//
// reduction:
//  last PE: f_send_data --> @fadds(mem_x_buf_dsd, mem_x_buf_dsd, fab_recv_wdsd, .{.async=true, .activate=f_send_data} );
//                 ^                          |
//                 |--------------------------+
//  others: f_send_data --> @mov32(fab_trans_x_wdsd, mem_x_buf_dsd, .{.async=true, .activate=f_send_ctrl} );
//          --> @mov32(fab_trans_ctrl_wdsd, mem_ctrl_buf_dsd, .{.async=true, .activate=f_send_data } );
//          --> f_send_data
//          1st PE: @mov32(fab_trans_ctrl_wdsd, mem_buf_td_dsd, .{.async=true} );
//
// bcast:
//  right-bottom PE: @mov32(fab_trans_x_wdsd, mem_x_buf_dsd, .{.async=true, .activate=f_send_ctrl} );
//                   --> @mov32(fab_trans_ctrl_wdsd, mem_buf_td_dsd, .{.async=true} );
//  others: @mov32(mem_x_buf_dsd, fab_recv_wdsd, .{.async=true} );
//
// Only one dest DSR, one src0 DSR and one src1 DSR are enough because
// - the teardown separates different operations
// - when TD arrives, sender has sent out the data/ctrl
//   the receiver has received all data because there is only one color
// - all DSD operations are serialized
//
// For example:
//   dest_dsr = @get_dsr(dsr_dest, 1);
//   src0_dsr = @get_dsr(dsr_src0, 1);
//   src1_dsr = @get_dsr(dsr_src1, 1);
//
// WARNING: dest_dsr and src0_dsr must be a valid pair, for example (7,1) is invalid

// The sequence of LOCK of { row_reduce, col_reduce, bcast}
//
//  row_reduce blocks LOCK
//  T29 activates LOCK
//  row_reduce unblocks LOCK when it finishes
//
//  LOCK goes to next state
//
//  col_reduce blocks LOCK
//  T29 activates LOCK
//  col_reduce unblocks LOCK when it finishes
//
//  LOCK goes to next state
//
//  bcast blocks LOCK
//  T29 activates LOCK
//  bcast unblocks LOCK when it finishes
//
//  LOCK goes to next state (done)
//

param C_ROUTE: color;
param C_DISPATCH: color; // routable color to trigger control-wavelet triggered task
                         // the routing is R --> R

// C_DISPATCH guides the control-wavelet triggered tasks, it
// does not bind to a wavelet-triggered task, so it does not
// bind to an input queue.
const LOCAL_DISPATCH: local_task_id = @get_local_task_id(@get_int(C_DISPATCH));

// LOCK runs only if teardown is received and the operation is done
// LOCK performs the state transition
// teardown handler activates LOCK
// the operation blocks LOCK in the beginning and unblocks it when it finishes
param C_LOCK: local_task_id;

param first_px: bool; // (0 == px)
param last_px: bool;  // ((width-1) == px)
param first_py: bool; // (0 == py)
param last_py: bool;  // ((height-1) == py)

// row reduction needs to receive width-1 neighbors
// column reduction needs to receive height-1 neighbors
param width: i16;
param height: i16;

// f_callback = sys_mod.unblock_cmd_stream, to continue next command
param f_callback: fn ()void;

// last PE uses this ID as the input queue
// others use this ID as the output queue
param queues:[1]u16;

// explicit DSR allocation
param dest_dsr_ids: [1]u16;
param src0_dsr_ids: [1]u16;
param src1_dsr_ids: [1]u16;

param MAX_ZDIM: i16; // maximum size of reduced buffer

const C_SEND_CTRL_ID: u16 = 40;
const C_SEND_DATA_ID: u16 = 41;
const C_STATE_ENTRY_ID: u16 = 42;

const C_SEND_CTRL: control_task_id = @get_control_task_id(C_SEND_CTRL_ID);  // send switch advance
const C_SEND_DATA: control_task_id = @get_control_task_id(C_SEND_DATA_ID);  // send data
const C_STATE_ENTRY: control_task_id = @get_control_task_id(C_STATE_ENTRY_ID); // state machine

const timestamp = @import_module("<time>");

// tsc_size_words = 3
var tscRefBuffer = @zeros([timestamp.tsc_size_words]u16);

////////////////////////////////////////////////////////////////////////////////
// Main memory (48KB)
////////////////////////////////////////////////////////////////////////////////

var x: *[MAX_ZDIM]f32;

const STATE_ROW_REDUCE: i16 = 0;
const STATE_COL_REDUCE: i16 = 1;
const STATE_BCAST: i16 = 2;
const STATE_DONE: i16 = 3;

// At most 4 states, "+1" is to avoid out-of-bound if
// STATE_DONE also dereference state_seq[4]
var state_seq = @zeros([4+1]i16);
var state_idx: i16 = 0;
var cur_state: i16 = 0;
var next_state: i16 = 0;

// record the reduction length from the caller
var wvlts_per_pe: u16 = 0;

// number of PEs involed in the reduction: last PE needs to count number of received neighbors
// WARNING: reduce_pes only records number of received PEs
//   row reduction: width-1
//   column reduction: height-1
// If reduce_pes is wrong, simfab shows re-entry error of UT when row reduction and col reduction
// are combined because row reduction has extra UT1 waiting for wavelets
var reduce_pes: i16 = 0;
// 1st PE during the reduction: send TD to others
//   row reduction: {px = 0}
//   column reduction: {py = 0}
var reduce_first_pe: bool;
// last PE during the reduction: receive data from w-1 or h-1 neighbors
//   row reduction: {px = w-1}
//   column reduction: {py = h-1}
var reduce_last_pe: bool;

// last PE uses count_recv_or_send to receive data from w-1 neighbors
// other PEs use count_recv_or_send to send data and control
var count_recv_or_send: i16 = 0;


const dest_dsr = @get_dsr(dsr_dest, dest_dsr_ids[0]);
const src0_dsr = @get_dsr(dsr_src0, src0_dsr_ids[0]);
const src1_dsr = @get_dsr(dsr_src1, src1_dsr_ids[0]);

const STATE_DISPATCH_SEND_DATA: i16 = 0;
const STATE_DISPATCH_SEND_CTRL: i16 = 1;
const STATE_DISPATCH_STATE_ENTRY: i16 = 2;

var state_dispatch: i16 = -1;


// The portal function of allreduce module
//
// How to use:
//  reduce_mod = = @import_module( "<allreduce/pe>");
//  reduce_mod.allreduce(n, x);
//  The user has to prepare input vector x.
//
//  When allreduce() finishes, it will call user's callback.
//
// case 1: row reduction
//   state_seq = {STATE_ROW_REDUCE, STATE_DONE}
// case 2: column reduction
//   state_seq = {STATE_COL_REDUCE, STATE_DONE}
// case 3: row + column reduction
//   state_seq = {STATE_ROW_REDUCE, STATE_COL_REDUCE, STATE_DONE}
// case 4: broadcast
//   state_seq = {STATE_BCAST, STATE_DONE}
//
fn allreduce( n: i16, in_tensor: *[MAX_ZDIM]f32 ) void {

   x = in_tensor;

   @assert(n <= MAX_ZDIM);
   @assert(n > 0);

   wvlts_per_pe = @bitcast(u16, n);

   // setup state sequence
   state_seq[0] = STATE_ROW_REDUCE;
   state_seq[1] = STATE_COL_REDUCE;
   state_seq[2] = STATE_BCAST;
   state_seq[3] = STATE_DONE;

   state_idx = 0;
   cur_state = state_seq[0];
   //@activate(C_STATE_ENTRY);
   state_dispatch = STATE_DISPATCH_STATE_ENTRY;
   @activate(LOCAL_DISPATCH);
}

//--------------------- system utility for teardown

// ref: old monolith/src/ucode/kernels/lib/pe_address_map.casm
// const TAMAP_FAB_MAP_START_ADDR       = 0x7f20
// ref: monolith/src/ucode/kernels/lib/pe_addr_map_ein.casm
// const TAMAP_FAB_TRAFFIC_MAP_START_ADDR = 0x7f20
const TAMAP_FAB_MAP_START_ADDR : u16 = 0x7f20;
const D2H_COLOR_CONFIG_ADDR : u16 = TAMAP_FAB_MAP_START_ADDR + @get_int(C_ROUTE);

// mask out bit 0:9, including input/output pos0
// keep bit 10:15, including
//  bit 15: point to point
//  bit 14: TIP
//  bit 12-13: control wavelet pop mode
//  bit 11: color swap for E,W inputs
//  bit 10: color swap for N,S inputs
//
// The teardown clears bit 14 (TIP)
// bit 12-13 is comptime decided, only last PE uses pop_always, others pop_on_advance
const MASK_INPUT_OUTPUT_POS0: u16 = 0xfc00;

// bits 0:4 define the initial output switch position
const OUTPUT_WEST: u16  = 0x1;  // bit 0: west output mask
const OUTPUT_EAST: u16  = 0x2;  // bit 1: east output mask
const OUTPUT_SOUTH: u16 = 0x4;  // bit 2: south output mask
const OUTPUT_NORTH: u16 = 0x8;  // bit 3: north output mask
const OUTPUT_RAMP: u16  = 0x10; // bit 4: offramp output mask

// bits 5:9 define the initial input switch position
const INPUT_WEST: u16  = 0x20;  // bit 5: west input mask
const INPUT_EAST: u16  = 0x40;  // bit 6: east input mask
const INPUT_SOUTH: u16 = 0x80;  // bit 7: south input mask
const INPUT_NORTH: u16 = 0x100; // bit 8: north input mask
const INPUT_RAMP: u16  = 0x200; // bit 9: onramp input mask

// Fabric switch configuration
// 0x7f40 - 0x7f57 - colors 0-23. Each address is for a single color
// Bits 14:13 Current Switch position (writes both input and output switch position; reads input position)
// Bits 12 Ring mode (1) (Switch movements Stop on last valid setting if ring mode is 0.)
// Bit 11 Switch position 3 switch select (1=input; 0 = output)
// Bits 10:8 Switch position 3 (5 = INVALID; 4 = CE; 3 = N; 2 = S; 1 = E; 0 = W)
// Bit 7 Switch position 2 switch select (1=input; 0 = output)
// Bits 6:4 Switch position 2 (5 = INVALID; 4 = CE; 3 = N; 2 = S; 1 = E; 0 = W)
// Bit 3 Switch position 1 switch select (1=input; 0 = output)
// Bits 2:0 Switch position 1 (5 = INVALID; 4 = CE; 3 = N; 2 = S; 1 = E; 0 = W)
//
// ref: monolith/src/ucode/kernels/lib/pe_addr_map_fyn.casm
// .const TAMAP_FAB_SWITCH_CFG_START_ADDR = 0x7f40
const TAMAP_FAB_SWITCH_CFG_START_ADDR: u16 = 0x7f40;
const D2H_SWITCH_CONFIG_ADDR: u16 = TAMAP_FAB_SWITCH_CFG_START_ADDR + @get_int(C_ROUTE);
// mask bits 14:13
// masking with MASK_SWITCH_RESET_POS0 is equivalent to set bits14:13 to zero (i.e. back to pos0)
const MASK_SWITCH_RESET_POS0: u16 = 0x9fff;

// To clear setting of pos1, set bits 2:0 to zero, but keep others unchanged
const MASK_SWITCH_CLEAR_POS1: u16 = 0xfff8;
// Bit 3 is always 1 because "pos1 = {.rx = RAMP}" implies position 1 switch select is "1=input"
const SWITCH_POS1_INVALID: u16 = 0x5;
const SWITCH_POS1_RAMP: u16 = 0x4;

fn translate_word_to_bytes( addr: u16 ) u16 {
    var addr_bytes = addr * 2 ;
    return addr_bytes;
}


////////////////////////////////////////////////////////////////////////////////
// DSDs
// data-structure descriptors (DSDs), loaded into data-structure registers (DSRs)
//
// Queues 0,1: input depth 6 wavelets
// Queues 2,3: input depth 4 wavelets
// Queues 4-7: input depth 2 wavelets
//
// queues 0,1: output depth 2 wavelets
// queues 2,3: output depth 6 wavelets
// queues 4,5: output depth 2 wavelets
//
// Length of an operand:
// The length of all other types of DSRs is specified by the length field of its DSD. When
// the bits encoding the length are 0x7fff, the length is infinite.
//
// Length of the vector instruction is then determined in the following order:
// 1. If src0 has a non-zero length, that length is used
// 2. If src1 has a non-zero length, that length is used
// 3. If dst has a non-zero length, that length is used
// 4. if no operands have length (all operands are GPR), length = 1
////////////////////////////////////////////////////////////////////////////////

const dummy = @zeros([1]i16);

// rowReduce() binds mem_x_buf_dsd to pointer x and resets its length to MAX_ZDIM
// Last PE adds data from neighbors to mem_x_buf_dsd
// Other PEs send mem_x_buf_dsd to the east
var mem_x_buf_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{1} -> dummy[i] });

// other PE (not last PE) uses this DSD to send x
var fab_trans_x_wdsd = @get_dsd(fabout_dsd, .{
    .extent = MAX_ZDIM,
    .fabric_color = C_ROUTE,
    .output_queue = @get_output_queue(queues[0])
});

// WARNING: control wavelet must be sent with the same microthread, via the same output buffer,
// otherwise, we may see only one data wavelet, then 2nd is the control wavelet, then
// the remaining data cannot be sent out because the routing is back to {.rx=WEST, .tx=EAST},
// there is no path from RAMP to the router.
const fab_trans_ctrl_wdsd = @get_dsd(fabout_dsd, .{
    .extent = 1,
    .control = true,
    .fabric_color = C_ROUTE,
    .output_queue = @get_output_queue(queues[0]),
});


// row reduction: the last PE receives the data from its w-1 neighbors,
// the receiving sequence is p0, p1, ..., p{w-1}.
// It uses the same queue ID because it does not send, only receives.
// It does not receive ctrl wavelets because of NOCE.
// f_send_data() receives data (w-1) times
//
var fab_recv_wdsd =  @get_dsd(fabin_dsd, .{
   .extent = MAX_ZDIM,
   .fabric_color = C_ROUTE,
   .input_queue = @get_input_queue(queues[0])
});


var mem_cmd_buf = @zeros([1]u32);

const mem_cmd_buf_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{1} -> mem_cmd_buf[i] });

// send a control wavelet to trigger either f_send_data or f_send_ctrl
const fab_trans_cmd_wdsd = @get_dsd(fabout_dsd, .{
    .extent = 1,
    .fabric_color = C_DISPATCH,
    .control = true,
    .output_queue = @get_output_queue(queues[0])
});

////////////////////////////////////////////////////////////////////////////////
// Tasks
// syntax
//     task_begin(name, entrypoint, color)
////////////////////////////////////////////////////////////////////////////////


const switches = @import_module("<memcpy/memcpy_switches>");

// The following arrays define values for control wavelets, which update the
// switch position at the recipient PEs.
// All are comptime constants
//
// ctrl_11 is for other PEs which changes switch of two consecutive PEs
var ctrl_11 = [1]u32 { switches.ctrl(switches.switch_cmd_11()) };

var mem_ctrl_buf_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{1} -> ctrl_11[i] });

// teardown from casm
// teardown_buf[0] = (1 << 31) | (0b111 << 22)
//
// teardown from csl
// from cslang/test-e2e/dynamic_filters/sender.csl
//  31=0x1f = no entrypoint
//
// teardown wavelet = 0x1df 0000
//const teardown_buf = [1]u32{(31 << 16) | 0b111 << 22};
// teardown wavelet = 0x9df 9249
const teardown_buf = [1]u32 { switches.ctrl(switches.teardown_cmd_1()) };

const mem_buf_td_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{1} -> teardown_buf[i] });

// the color C_ROUTE is in teardown mode at comptime by specifying {.teardown = true}
// reduce() and broadcast_to_all() block LOCK in the beginning and unblock it
// when the operation finishes
//
// WARNING: we don't block LOCK in the xxx_configure() because it is sequential, and
// it is more intuitive to unblock LOCK in reduce() and broadcast_to_all()
//
task f_state_entry() void {

    if (STATE_ROW_REDUCE == cur_state){
        // rowReduce_configure() reconfigures the pos0, pos1 and clears TIP
        rowReduce_configure();
        // perform row reduction and store the result to last PE
        // 1st PE will send TD to turn C_ROUTE back to teardown mode
        reduce_pes = width-1;
        reduce_first_pe = first_px;
        reduce_last_pe = last_px;
        reduce( wvlts_per_pe );

        // prefetch next state which will be copied into cur_state in teardown handler
        state_idx += 1;
        next_state = state_seq[state_idx];

    }else if (STATE_COL_REDUCE == cur_state){
        // colReduce_configure() reconfigures the pos0, pos1 and clears TIP
        colReduce_configure();
        // perform column reduction and store the result to last PE
        // 1st PE will send TD to turn C_ROUTE back to teardown mode
        reduce_pes = height-1;
        reduce_first_pe = first_py;
        reduce_last_pe = last_py;
        reduce( wvlts_per_pe );

        // prefetch next state which will be copied into cur_state in teardown handler
        state_idx += 1;
        next_state = state_seq[state_idx];

    }else if (STATE_BCAST == cur_state){
        // bcast_configure() reconfigures pos0 and disables pos1
        bcast_configure();
        // right-bottom PE broadcasts the data to others and also sends TD
        broadcast_to_all( wvlts_per_pe );

        // prefetch next state which will be copied into cur_state in teardown handler
        state_idx += 1;
        next_state = state_seq[state_idx];

    }else{
        // state machine is done, return control back to the caller
        timestamp.get_timestamp(&tscRefBuffer);

        f_callback();
    }
}

fn reduce( n: u16 ) void {

    // WARNING: block LOCK in the beginning and only
    // unblock LOCK when "reduce" finishes
    @block(C_LOCK);

    count_recv_or_send = 0;

    // changes switch of of itself and its neighbor
    // The last PE does not call f_send_ctrl(), so this op is DON'T care
    mem_ctrl_buf_dsd =  @set_dsd_base_addr(mem_ctrl_buf_dsd, ctrl_11);

    mem_x_buf_dsd = @set_dsd_base_addr(mem_x_buf_dsd, x);
    mem_x_buf_dsd = @set_dsd_length(mem_x_buf_dsd, n);

    fab_recv_wdsd = @set_dsd_length(fab_recv_wdsd, n);
    fab_trans_x_wdsd = @set_dsd_length(fab_trans_x_wdsd, n);

    // last PE receives data from w-1 neighbors
    // other PEs send data and control to the east/south
    //@activate( C_SEND_DATA );  // triggers f_send_data
    state_dispatch = STATE_DISPATCH_SEND_DATA;
    @activate(LOCAL_DISPATCH);
}

fn broadcast_to_all( n: u16 ) void {

    // WARNING: block LOCK in the beginning and only
    // unblock LOCK when "broadcast" finishes
    @block(C_LOCK);

    // No PE sends switch advance
    // mem_ctrl_buf_dsd =  @set_dsd_base_addr(mem_ctrl_buf_dsd, ctrl_11);

    mem_x_buf_dsd = @set_dsd_base_addr(mem_x_buf_dsd, x);
    mem_x_buf_dsd = @set_dsd_length(mem_x_buf_dsd, n);
    fab_recv_wdsd = @set_dsd_length(fab_recv_wdsd, n);
    fab_trans_x_wdsd = @set_dsd_length(fab_trans_x_wdsd, n);

    if ( last_px and last_py ){
        // Pw-1,h-1 sends data and then f_send_ctrl sends a TD
        // f_send_ctrl() will unblock LOCK
        //@mov32(fab_trans_x_wdsd, mem_x_buf_dsd, .{.async=true, .activate=f_send_ctrl} );
        state_dispatch = STATE_DISPATCH_SEND_CTRL;
        @load_to_dsr(dest_dsr, fab_trans_x_wdsd, .{.async=true, .activate=f_dispatch} );
        @load_to_dsr(src1_dsr, mem_x_buf_dsd);
        @mov32(dest_dsr, src1_dsr, .{.async=true} );
    }else{
        // other PEs receive data and wait for TD
        // unblock LOCK after data is received, T29 will activate LOCK
        //@mov32(mem_x_buf_dsd, fab_recv_wdsd, .{.async=true, .unblock=C_LOCK} );
        @load_to_dsr(dest_dsr, mem_x_buf_dsd);
        @load_to_dsr(src1_dsr, fab_recv_wdsd, .{.async=true, .unblock=C_LOCK} );
        @mov32(dest_dsr, src1_dsr, .{.async=true} );
    }
}

// last PE does not send data, it only receives data
// row-reduce sequence: f_send_data() --> f_send_ctrl()
//                      ^                  |
//                      |------------------+
//
// f_send_data() is the last call when the reduction finishes
// unblock LOCK here when the operation is done
task f_send_data() void {
    if (reduce_last_pe){
        // last PE receives data from reduce_pes neighbors
        if (count_recv_or_send < reduce_pes){
            //@fadds(mem_x_buf_dsd, mem_x_buf_dsd, fab_recv_wdsd, .{.async=true, .activate=f_send_data} );
            state_dispatch = STATE_DISPATCH_SEND_DATA;
            @load_to_dsr(src1_dsr, fab_recv_wdsd, .{.async=true, .activate=f_dispatch} );
            @load_to_dsr(src0_dsr, mem_x_buf_dsd);
            @load_to_dsr(dest_dsr, mem_x_buf_dsd);
            @fadds(dest_dsr, src0_dsr, src1_dsr, .{.async=true} );
            count_recv_or_send += 1;
        }else{
            // last PE has received all data from the reduce_pes neighbors
            // wait for TD from 1st PE
            // unblock LOCK, T29 will activate LOCK
            @unblock(C_LOCK);
        }
    }else{
        // other PE (not last PE) sends data and control
        if (count_recv_or_send < 1){
            //@mov32(fab_trans_x_wdsd, mem_x_buf_dsd, .{.async=true, .activate=f_send_ctrl} );
            state_dispatch = STATE_DISPATCH_SEND_CTRL;
            @load_to_dsr(dest_dsr, fab_trans_x_wdsd, .{.async=true, .activate=f_dispatch} );
            @load_to_dsr(src1_dsr, mem_x_buf_dsd);
            @mov32(dest_dsr, src1_dsr, .{.async=true} );
            count_recv_or_send += 1; 
        }else{
            // sending is done (including data wavelets and control wavelets)
            @unblock(C_LOCK);
            // only 1st PE sends TD to other PEs
            // T29 will activate LOCK
            if (reduce_first_pe){
                //@mov32(fab_trans_ctrl_wdsd, mem_buf_td_dsd, .{.async=true} );
                @load_to_dsr(dest_dsr, fab_trans_ctrl_wdsd, .{.async=true} );
                @load_to_dsr(src1_dsr, mem_buf_td_dsd);
                @mov32(dest_dsr, src1_dsr, .{.async=true} );
            }
        }
    }
}


task f_send_ctrl() void{
    if (STATE_BCAST == cur_state){
        //broadcast: Pw-1,h-1 only sends the TD
        // unblock LOCK after TD is sent out
        //@mov32(fab_trans_ctrl_wdsd, mem_buf_td_dsd, .{.async=true, .unblock=C_LOCK} );
        @load_to_dsr(dest_dsr, fab_trans_ctrl_wdsd, .{.async=true, .unblock=C_LOCK} );
        @load_to_dsr(src1_dsr, mem_buf_td_dsd);
        @mov32(dest_dsr, src1_dsr, .{.async=true} );
    }else{
        // reduction: other PEs (not last PE) sends switch advance
        //   last PE does not trigger f_send_ctrl because it only receives data
        // f_send_data() will unblock LOCK
        //@mov32(fab_trans_ctrl_wdsd, mem_ctrl_buf_dsd, .{.async=true, .activate=f_send_data } );
        state_dispatch = STATE_DISPATCH_SEND_DATA;
        @load_to_dsr(dest_dsr, fab_trans_ctrl_wdsd, .{.async=true, .activate=f_dispatch } );
        @load_to_dsr(src1_dsr, mem_ctrl_buf_dsd);
        @mov32(dest_dsr, src1_dsr, .{.async=true} );
    }
}


// LOCK runs only if TD is received and the operation (*) finishes
//
// Here is the sequence
// - the operation blocks LOCK in the beginning
// - teardown handler activates LOCK
// - the operation unblocks LOCK when it finishes
// - LOCK is picked by the HW scheduler to perform the state transition
//
// (*) operation is {row_reduce, col_reduce, bcast}
//
task f_lock() void {
    cur_state = next_state; // go to next state
    //@activate(C_STATE_ENTRY);
    state_dispatch = STATE_DISPATCH_STATE_ENTRY;
    @activate(LOCAL_DISPATCH);
}


// entrypoint to trigger f_send_data or f_send_ctrl by a control wavelet
task f_dispatch() void {

    // the state must be {STATE_DISPATCH_SEND_DATA, STATE_DISPATCH_SEND_CTRL, STATE_DISPATCH_STATE_ENTRY}
    @assert( (0 <= state_dispatch) and (2 >= state_dispatch) );

    var ctrl_index: u16; // index field of a control wavelet
    if (STATE_DISPATCH_SEND_DATA == state_dispatch){
        ctrl_index = C_SEND_DATA_ID;
    }else if (STATE_DISPATCH_SEND_CTRL == state_dispatch){
        ctrl_index = C_SEND_CTRL_ID;
    }else{ // STATE_DISPATCH_STATE_ENTRY == state_dispatch
        ctrl_index = C_STATE_ENTRY_ID;
    }
    mem_cmd_buf[0] = (@as(u32, ctrl_index) << 16);
    //@mov32(fab_trans_cmd_wdsd, mem_cmd_buf_dsd, .{.async=true} );
    @load_to_dsr(dest_dsr, fab_trans_cmd_wdsd, .{.async=true} );
    @load_to_dsr(src1_dsr, mem_cmd_buf_dsd);
    @mov32(dest_dsr, src1_dsr, .{.async=true} );
}


comptime {
    @bind_control_task( f_send_ctrl, C_SEND_CTRL);
    @bind_control_task( f_send_data, C_SEND_DATA);
    @bind_control_task( f_state_entry, C_STATE_ENTRY);
    @bind_local_task( f_lock, C_LOCK);

    @bind_local_task( f_dispatch, LOCAL_DISPATCH);
}


//----------------- the following is the routing of C_ROUTE

const tile_config = @import_module("<tile_config>");

fn rowReduce_configure() void {

    // (1) setup switch according to config parameters
    // 1. pos0 (color config reg)
    // 2. pos1 (switch config reg)
    //    pos1 = {.rx = RAMP} for all PEs except last PE

    // reset switch position to pos0
    // WARNING: if switch config register does not reset the switch position back to pos0,
    // it is possible that some PE is at pos1 after the switch is reconfigured and the sending
    // pattern is messed up, for example, the PE sends data first, then forwards the data from
    // the west.
    var r_switch_state : u16 = @bitcast(*u16, translate_word_to_bytes(D2H_SWITCH_CONFIG_ADDR) ).* ;
    // mask bits 14:13 to reset input&output position to pos0
    r_switch_state = r_switch_state & MASK_SWITCH_RESET_POS0;
    // WARNING: all PEs are configured by
    //   - ".pos1 = .{ .rx = RAMP }"  --> bit 3 is 1
    //   - ".ring_mode = true"  --> bit 12 is 1
    //   - ".pop_mode = .{ .pop_on_advance = true }" --> bits 13:12 of fabric per-color config
    // mask bits 2:0 to clear setting of pos1
    r_switch_state = r_switch_state & MASK_SWITCH_CLEAR_POS1;
    if (last_px){
        // last PE does not have pos1
        r_switch_state = r_switch_state | SWITCH_POS1_INVALID;
    }else{
        // others have ".pos1 = .{ .rx = RAMP }"
        r_switch_state = r_switch_state | SWITCH_POS1_RAMP;
    }
    @bitcast(*u16, translate_word_to_bytes(D2H_SWITCH_CONFIG_ADDR) ).* = r_switch_state;

    var r_state : u16 = @bitcast(*u16, translate_word_to_bytes(D2H_COLOR_CONFIG_ADDR) ).* ;
    // clear input/output switch pos0
    r_state = r_state & MASK_INPUT_OUTPUT_POS0 ;

    if (first_px){
        // 1st PE must has {rx = RAMP} to send out the data
        // .rx = .{ RAMP },.tx = .{ EAST },
        r_state = r_state | INPUT_RAMP | OUTPUT_EAST;
    }else if (last_px){
        // last PE only receives data
        // .rx = .{ WEST }, .tx = .{ RAMP },
        r_state = r_state | INPUT_WEST | OUTPUT_RAMP;
    }else{ 
        // 0 < px < width-1
        // .rx = .{ WEST }, .tx = .{ EAST },
        r_state = r_state | INPUT_WEST | OUTPUT_EAST;
    }
    // update the switch pos0
    @bitcast(*u16, translate_word_to_bytes(D2H_COLOR_CONFIG_ADDR) ).* = r_state;

    // (2) clear teardown-in-progress bit
    // config_reg[c] ^= mask where mask = 1 << 14
    tile_config.teardown.exit(C_ROUTE);
}


fn colReduce_configure() void {

    // (1) setup switch according to config parameters
    // 1. pos0 (color config reg)
    // 2. pos1 (switch config reg)
    //    pos1 = {.rx = RAMP} for all PEs except last PE

    // reset switch position to pos0
    // WARNING: if switch config register does not reset the switch position back to pos0,
    // it is possible that some PE is at pos1 after the switch is reconfigured and the sending
    // pattern is messed up, for example, the PE sends data first, then forwards the data from
    // the west.
    var r_switch_state : u16 = @bitcast(*u16, translate_word_to_bytes(D2H_SWITCH_CONFIG_ADDR) ).* ;
    // mask bits 14:13 to reset input&output position to pos0
    r_switch_state = r_switch_state & MASK_SWITCH_RESET_POS0;
    // WARNING: all PEs are configured by
    //   - ".pos1 = .{ .rx = RAMP }"  --> bit 3 is 1
    //   - ".ring_mode = true"  --> bit 12 is 1
    //   - ".pop_mode = .{ .pop_on_advance = true }" --> bits 13:12 of fabric per-color config
    // mask bits 2:0 to clear setting of pos1
    r_switch_state = r_switch_state & MASK_SWITCH_CLEAR_POS1;
    if (last_py){
        // last PE does not have pos1
        r_switch_state = r_switch_state | SWITCH_POS1_INVALID;
    }else{
        // others have ".pos1 = .{ .rx = RAMP }"
        r_switch_state = r_switch_state | SWITCH_POS1_RAMP;
    }
    @bitcast(*u16, translate_word_to_bytes(D2H_SWITCH_CONFIG_ADDR) ).* = r_switch_state;

    var r_state : u16 = @bitcast(*u16, translate_word_to_bytes(D2H_COLOR_CONFIG_ADDR) ).* ;
    // clear input/output switch pos0
    r_state = r_state & MASK_INPUT_OUTPUT_POS0 ;

    if (first_py){
        // 1st PE must has {rx = RAMP} to send out the data
        // .rx = .{ RAMP },.tx = .{ SOUTH },
        r_state = r_state | INPUT_RAMP | OUTPUT_SOUTH;
    }else if (last_py){
        // last PE only receives data
        // .rx = .{ NORTH }, .tx = .{ RAMP },
        r_state = r_state | INPUT_NORTH | OUTPUT_RAMP;
    }else{
        // 0 < py < width-1
        // .rx = .{ NORTH }, .tx = .{ SOUTH },
        r_state = r_state | INPUT_NORTH | OUTPUT_SOUTH;
    }
    // update the switch pos0
    @bitcast(*u16, translate_word_to_bytes(D2H_COLOR_CONFIG_ADDR) ).* = r_state;

    // (2) clear teardown-in-progress bit
    // config_reg[c] ^= mask where mask = 1 << 14
    tile_config.teardown.exit(C_ROUTE);
}


// w > 1 and h > 1
//  x <-- x <-- x
//              ^
//              |
//  x <-- x <-- x
//              ^
//              |
//  x <-- x <-- x
//
fn bcast_configure() void {

    // (1) setup switch according to config parameters
    // 1. pos0 (color config reg)
    // 2. pos1 (switch config reg)
    //    pos1 = {invalid} for all PEs

    // reset switch position to pos0
    // WARNING: if switch config register does not reset the switch position back to pos0,
    // it is possible that some PE is at pos1 after the switch is reconfigured and the sending
    // pattern is messed up, for example, the PE sends data first, then forwards the data from
    // the west.
    var r_switch_state : u16 = @bitcast(*u16, translate_word_to_bytes(D2H_SWITCH_CONFIG_ADDR) ).* ;
    // mask bits 14:13 to reset input&output position to pos0
    r_switch_state = r_switch_state & MASK_SWITCH_RESET_POS0;
    // WARNING: all PEs have pos0 only, so disable pos1
    //   no change for ring_mode and pop_mode
    //   - ".ring_mode = true"  --> bit 12 is 1
    //   - ".pop_mode = .{ .pop_on_advance = true }" --> bits 13:12 of fabric per-color config
    // mask bits 2:0 to clear setting of pos1
    r_switch_state = r_switch_state & MASK_SWITCH_CLEAR_POS1;
    r_switch_state = r_switch_state | SWITCH_POS1_INVALID;
    @bitcast(*u16, translate_word_to_bytes(D2H_SWITCH_CONFIG_ADDR) ).* = r_switch_state;

    var r_state : u16 = @bitcast(*u16, translate_word_to_bytes(D2H_COLOR_CONFIG_ADDR) ).* ;
    // clear input/output switch pos0
    r_state = r_state & MASK_INPUT_OUTPUT_POS0 ;

    if (last_px){
        // px = w-1
        if (last_py){
            // Pw-1,h-1: send to west and north, { .rx = .{RAMP}, .tx = .{WEST, NOTH} } }
            r_state = r_state | INPUT_RAMP | OUTPUT_WEST | OUTPUT_NORTH;
        }else{
            if (first_py){
                // Pw-1,0: { .rx = .{SOUTH}, .tx = .{WEST, RAMP} }
                r_state = r_state | INPUT_SOUTH | OUTPUT_WEST | OUTPUT_RAMP;
            }else{
                // Pw-1,py: 0 < py < h-1, { .rx = .{SOUTH}, .tx = .{WEST, RAMP, NORTH} }
                r_state = r_state | INPUT_SOUTH | OUTPUT_WEST | OUTPUT_RAMP | OUTPUT_NORTH;
            }
        }
    }else{
        if (first_px){
            // px = 0, {.rx = .{EAST}, .tx = .{RAMP}}
            r_state = r_state | INPUT_EAST | OUTPUT_RAMP;
        }else{
            // 0 < px < w-1, { .rx = .{EAST}, .tx = .{WEST, RAMP} }
            r_state = r_state | INPUT_EAST | OUTPUT_RAMP | OUTPUT_WEST;
        }
    }

    // update the switch pos0
    @bitcast(*u16, translate_word_to_bytes(D2H_COLOR_CONFIG_ADDR) ).* = r_state;

    // (2) clear teardown-in-progress bit
    // config_reg[c] ^= mask where mask = 1 << 14
    tile_config.teardown.exit(C_ROUTE);
}

// state 1: row-reduce
// state 2: col-reduce
// state 3: bcast
//
fn teardown_allreduce() void {
    // turn C_ROUTE back to teardown mode
    // LOCK can be picked only when the operation finishes
    @activate(C_LOCK);
}

comptime {
    @set_teardown_handler(teardown_allreduce, C_ROUTE);
}

//
// routing of C_ROUTE (send data to west, from leftmost)
//    -->   --->-->   -->-->
//    ^
//    |
//   sw_adv
//    -->      -->   -->-->
//    ^        ^
//    |        |
//   data     data
//            sw_adv
//    -->     --> -->     -->
//    ^                   ^
//    |                   |
//   sw_adv              data
//                      sw_adv
//    -->       -->    --> -->
//    ^         ^
//    |         |
//             data
//             sw_adv
//
comptime {

    // The switch must work for different operations, including
    //   - row reduction
    //   - column reduction
    //   - broadcasting
    // The initial setting must be universal so we can reconfigure the
    // switch for these three operations at runtime
    //
    // We need to set invariant parameters at comptime (runtime does not alter):
    // 1. teardown mode at comptime
    //   {.teardown = true} implies color is in teardown mode at comptime
    // 2. ring mode
    //   ".ring_mode = true"  --> fabric switch config reg sets bit 12 as 1
    // 3. pop on advance
    //   ".pop_mode = .{ .pop_on_advance = true }" --> fabric per-color config reg sets bits 13:12
    // 4. position 1
    //   ".pos1 = .{ .rx = RAMP }"  --> fabric switch config reg sets bit 3 as 1
    //
    // The following (last) PEs do not have position 1:
    //   - "px = width" for row reduction
    //   - "py = height" for column reduction
    //   - all for broadcasting
    // The runtime resets position 1 (bits 2:0 of fabric switch config) to either
    //   SWITCH_POS1_INVALID to disable position 1 or
    //   SWITCH_POS1_RAMP to reset position 1 back to ".pos1 = .{ .rx = RAMP }"
    // The bit 3 of fabric switch config is always 1 (position 1 switch select is "1=input")
    // If position 1 is disabled, bit 3 is don't care
    // If position 1 is disabled, pop mode is also don't care because of NOCE
    // If position 1 is disabled, ring mode is also don't care
    //
    // Remark: we don't use ".pop_mode = .{ .always_pop = true }" because there is no need
    // to propagate the TD to mux. All PEs have a teardown handler to deal with this TD, so
    // we do not need to pop out an instruction in TD wavelet, for example
    //     0x9df 9249 --> 0x91f 9249
    // (The instruction is NOT teardown 0b111, but 0b100 (NOCE, NOP))
    // (teardown = 0x9df,9249  (cmd1=teardown+NOCE, others=NOP+NOCE))
    //
    // The original setting of row reduction
    // 1st PE: px = 0
    //   .pos0 = .{ .rx = .{ RAMP }, .tx = .{ EAST }}
    //   .pop_mode = .{ .pop_on_advance = true }
    //   .pos1 = .{ .rx = RAMP }
    //   .ring_mode = true
    //   .teardown = true
    // middle: 1st PE < px < last PE
    //   .pos0 = .{ .rx = .{ WEST }, .tx = .{ EAST }}
    //   .pop_mode = .{ .pop_on_advance = true }
    //   .pos1 = .{ .rx = RAMP }
    //   .ring_mode = true
    //   .teardown = true
    // last PE: px = w-1
    //   .pos0 = .{ .rx = .{ WEST }, .tx = .{ RAMP }}
    //   .teardown = true
    //
    // The original setting of column reduction
    // 1st PE: py = 0
    //   .pos0 = .{ .rx = .{ RAMP }, .tx = .{ SOUTH }}
    //   .pop_mode = .{ .pop_on_advance = true }
    //   .pos1 = .{ .rx = RAMP }
    //   .ring_mode = true
    //   .teardown = true
    // middle: 1st PE < py < last PE
    //   .pos0 = .{ .rx = .{ NORTH }, .tx = .{ SOUTH }}
    //   .pop_mode = .{ .pop_on_advance = true }
    //   .pos1 = .{ .rx = RAMP }
    //   .ring_mode = true
    //   .teardown = true
    // last PE: py = h-1
    //   .pos0 = .{ .rx = .{ NORTH }, .tx = .{ RAMP }}
    //   .teardown = true
    //
    const universalConfig = .{
        .routes= .{
            .rx = .{ WEST },
            .tx = .{ EAST },
        },
        .switches=.{
            .pos1 = .{ .rx = RAMP },
            .ring_mode = true,
            .pop_mode = .{ .pop_on_advance = true },
        },
        .teardown = true
    };

    if (1 == width){
        @comptime_assert(1 < width);
    }else{
        @set_local_color_config(C_ROUTE, universalConfig);
    }
}

comptime {
    @set_local_color_config(C_DISPATCH, .{ .routes = .{ .rx = .{RAMP}, .tx = .{RAMP} } } );
}

// binding a color to an input queue.
// This is necessary when an explicit DSR binds to a fabin DSD because
// the compiler no longer can generate the instruction to set up the
// config register of input queue.
comptime {
    @initialize_queue(@get_input_queue(queues[0]), .{.color = C_ROUTE} );
}

commands.sh

#!/usr/bin/env bash

set -e

cslc ./layout.csl --arch wse2 --fabric-dims=11,6 --fabric-offsets=4,1 \
--params=ncols:16 --params=nrows:16 --params=pcols:4 --params=prows:4 --params=max_local_nnz:8 \
--params=max_local_nnz_cols:4 --params=max_local_nnz_rows:4 --params=local_vec_sz:1 \
--params=local_out_vec_sz:1 --params=y_pad_start_row_idx:4 -o=out \
--memcpy --channels=1 --width-west-buf=0 --width-east-buf=0
cs_python ./run.py --num_pe_cols=4 --num_pe_rows=4  --latestlink out --channels=1 --width-west-buf=0 \
--width-east-buf=0 --is_weight_one --run-only --infile_mtx=./data/rmat4.4x4.lb.mtx