GEMV 7: Routes and Fabric DSDs, Part II

GEMV 7: Routes and Fabric DSDs, Part II

Continuing from the previous example, we now break up a single GEMV computation among a 2 x 2 square of PEs.

The host program copies b into the y tensor of the left column of PEs, with PE (0, 0) getting the first M/2 values and PE (0, 1) getting the last M/2 values.

Each PE also gets a corresponding chunk of A. The left PEs get the left N/2 columns and the right PEs get the right N/2 columns, while the upper PEs get the upper M/2 rows and the lower PEs get the lower M/2 rows. In other words, the northwest PE gets the northwest quadrant of A, the northeast PE gets the northeast quadrant of A, and so on.

The host program also copies x into the upper row of PEs, with PE (0, 0) getting the first N/2 values and the PE (1, 0) gettin the last N/2 values.

When the compute function is launched, the PEs in the top row begin sending their respective elements of x to their routers, along the color x_color. These PEs send the elements of x both to the SOUTH and back down their own RAMP.

On all four PEs, receiving a wavelet along x_color activates recv_x. This task is a wavelet-triggered task (WTT): the wavelet’s data is fed in as an argument to recv_x.

When a PE receives an element of x in the recv_x task, it computes the corresponding piece of Ax and adds it to its local y tensor. When a PE has received all corresponding elements of x along x_color, (i.e., the first N/2 values of x for the two left PEs, and the last N/2 values of x for the two right PEs), it has finished computing its local contribution to y.

At this point, the local task reduce is activated. The left column of PEs send their partial y result along the color ax_color to the EAST, and the right column of PEs receives these partial y results, and increments their y tensors by the received values. At this point, the right column of PEs contain the final result y, with the first M/2 elements in PE (1, 0) and the last M/2 elements in PE (1, 1).

Last, the host copies y from the right column of PEs, and checks that the result is correct.

Note that in this program’s layout file, we no longer assign a pe_id as a compile-time parameter. Instead, we use the <layout> module in pe_program.csl to determine the coordinates of the PE at runtime. This can reduce compilation time by reducing the number of unique PE programs that need to be compiled. Specifically, by parameterizing a PE’s code (i.e., passing parameters through @set_tile_code) we are creating more unique PE programs as opposed to relying on runtime-evaluated values.

layout.csl

// total matrix dimensions
param M: i16;
param N: i16;

// Colors
const ax_color: color = @get_color(0); // sends/recvs partial result Ax EAST
const x_color:  color = @get_color(1); // sends/recvs elems x

// This example uses 2x2 PEs
const memcpy = @import_module("<memcpy/get_params>", .{
  .width = 2,
  .height = 2
});

layout {
  // PE coordinates are (column, row)
  @set_rectangle(2, 2);

  for (@range(i16, 2)) |pe_x| {
    for (@range(i16, 2)) |pe_y| {
      @set_tile_code(pe_x, pe_y, "pe_program.csl", .{
        .memcpy_params = memcpy.get_params(pe_x),
        .M_per_PE = M / 2,
        .N_per_PE = N / 2,
        .ax_color = ax_color,
        .x_color = x_color
      });
    }
  }

  // Top left PE (0, 0)
  @set_color_config(0, 0, ax_color, .{.routes = .{ .rx = .{RAMP}, .tx = .{EAST}  }});
  @set_color_config(0, 0, x_color,  .{.routes = .{ .rx = .{RAMP}, .tx = .{RAMP, SOUTH} }});

  // Top right PE (1, 0)
  @set_color_config(1, 0, ax_color, .{.routes = .{ .rx = .{WEST},  .tx = .{RAMP} }});
  @set_color_config(1, 0, x_color,  .{.routes = .{ .rx = .{RAMP}, .tx = .{RAMP, SOUTH} }});

  // Bottom left PE (0, 1)
  @set_color_config(0, 1, ax_color, .{.routes = .{ .rx = .{RAMP}, .tx = .{EAST} }});
  @set_color_config(0, 1, x_color,  .{.routes = .{ .rx = .{NORTH}, .tx = .{RAMP} }});

  // Bottom right PE (1, 1)
  @set_color_config(1, 1, ax_color, .{.routes = .{ .rx = .{WEST}, .tx = .{RAMP} }});
  @set_color_config(1, 1, x_color,  .{.routes = .{ .rx = .{NORTH}, .tx = .{RAMP} }});

  // export symbol names
  @export_name("A", [*]f32, true);
  @export_name("x", [*]f32, true);
  @export_name("y", [*]f32, true);
  @export_name("compute", fn()void);
}

pe_program.csl

param memcpy_params: comptime_struct;

// Matrix dimensions
param M_per_PE: i16;
param N_per_PE: i16;

// Colors
param ax_color: color; // sends partial result Ax EAST
param x_color:  color; // sends elems x SOUTH/ recvs elems x from NORTH

// Queue IDs
const ax_color_oq: output_queue = @get_output_queue(2);
const ax_color_iq: input_queue  = @get_input_queue(2);
const x_color_oq:  output_queue = @get_output_queue(3);
const x_color_iq:  input_queue  = @get_input_queue(3);

// Task ID used by exit task to unblock cmd stream
const exit_task_id:   local_task_id = @get_local_task_id(9);

// Task ID used by reduce task
const reduce_task_id: local_task_id = @get_local_task_id(10);

// Data task ID for task recv_x, consumes x_color wlts
// On WSE-2, data task IDs are created from colors; on WSE-3, from input queues
const recv_x_task_id: data_task_id =
  if      (@is_arch("wse2")) @get_data_task_id(x_color)
  else if (@is_arch("wse3")) @get_data_task_id(x_color_iq);


// memcpy module provides infrastructure for copying data
// and launching functions from the host
const sys_mod = @import_module("<memcpy/memcpy>", memcpy_params);

// layout module provides PE coordinates at runtime
const layout_mod = @import_module("<layout>");


// 48 kB of global memory contain A, x, b, y
var A: [M_per_PE*N_per_PE]f32; // A is stored column major
var x: [N_per_PE]f32;
var y: [M_per_PE]f32;

// DSDs for accessing A, x, y
// A_dsd accesses column of A
var A_dsd = @get_dsd(mem1d_dsd, .{ .base_address = &A, .extent = M_per_PE });
var x_dsd = @get_dsd(mem1d_dsd, .{ .base_address = &x, .extent = N_per_PE });
var y_dsd = @get_dsd(mem1d_dsd, .{ .base_address = &y, .extent = M_per_PE });

// ptrs to A, x, b, y will be advertised as symbols to host
var A_ptr: [*]f32 = &A;
var x_ptr: [*]f32 = &x;
var y_ptr: [*]f32 = &y;


fn is_top_row() bool {
  return (layout_mod.get_y_coord() == 0);
}

fn is_left_col() bool {
  return (layout_mod.get_x_coord() == 0);
}


task reduce() void {
  if (is_left_col()) {
    const out_dsd = @get_dsd(fabout_dsd, .{
                      .fabric_color = ax_color, .extent = M_per_PE,
                      .output_queue = ax_color_oq
                    });
    // After fmovs is done, activate exit_task to unblock cmd stream
    @fmovs(out_dsd, y_dsd, .{ .async = true, .activate = exit_task_id });

  } else {
    const in_dsd = @get_dsd(fabin_dsd, .{
                     .fabric_color = ax_color, .extent = M_per_PE,
                     .input_queue  = ax_color_iq
                   });
    // After fadds is done, activate exit_task to unblock cmd stream
    @fadds(y_dsd, y_dsd, in_dsd, .{ .async = true, .activate = exit_task_id });
  }
}


// Use to keep track of # of invocations of recv_x task
// when num_recv_x == N_per_PE, we are done receiving x elements
var num_recv_x: i16 = 0;

task recv_x(x_val: f32) void {
  @fmacs(y_dsd, y_dsd, A_dsd, x_val);
  A_dsd = @increment_dsd_offset(A_dsd, M_per_PE, f32);

  num_recv_x += 1;
  if (num_recv_x == N_per_PE) {
    @activate(reduce_task_id);
  }
}


// The top row sends x values along x_color to launch recv_x
fn compute() void {
  if (is_top_row()) {
    const send_x_dsd = @get_dsd(fabout_dsd, .{
                         .fabric_color = x_color, .extent = N_per_PE,
                         .output_queue = x_color_oq
                       });
    @fmovs(send_x_dsd, x_dsd, .{ .async = true });
  }
}


task exit_task() void {
  sys_mod.unblock_cmd_stream();
}


comptime {
  // When exit_task_id is activated, exit_task will execute
  @bind_local_task(exit_task, exit_task_id);

  // reduce is local task activated by ID reduce_task_ID
  @bind_local_task(reduce, reduce_task_id);

  // recv_x is wavelet-triggered task (WTT) activated by receiving
  // wavelets along color x_color, which corresponds to recv_x_task_id
  // On WSE-3, these wavelets are received in input queue x_color_iq
  @bind_data_task(recv_x, recv_x_task_id);

  // On WSE-3, we must explicitly initialize input and output queues
  if (@is_arch("wse3")) {
    @initialize_queue(ax_color_oq, .{ .color = ax_color });
    @initialize_queue(ax_color_iq, .{ .color = ax_color });
    @initialize_queue(x_color_oq,  .{ .color = x_color });
    @initialize_queue(x_color_iq,  .{ .color = x_color });
  }

  @export_symbol(A_ptr, "A");
  @export_symbol(x_ptr, "x");
  @export_symbol(y_ptr, "y");
  @export_symbol(compute);
}

run.py

#!/usr/bin/env cs_python

import argparse
import json
import numpy as np

from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime, MemcpyDataType, MemcpyOrder # pylint: disable=no-name-in-module

# Read arguments
parser = argparse.ArgumentParser()
parser.add_argument('--name', help="the test compile output dir")
parser.add_argument('--cmaddr', help="IP:port for CS system")
args = parser.parse_args()

# Get matrix dimensions from compile metadata
with open(f"{args.name}/out.json", encoding='utf-8') as json_file:
  compile_data = json.load(json_file)

# Matrix dimensions
N = int(compile_data['params']['N'])
M = int(compile_data['params']['M'])

# Construct A, x, b
A = np.arange(M*N, dtype=np.float32).reshape(M,N)
x = np.full(shape=N, fill_value=1.0, dtype=np.float32)
b = np.full(shape=M, fill_value=2.0, dtype=np.float32)

# Calculate expected y
y_expected = A@x + b

# Size of N dimension on each PE
N_per_PE = N // 2
M_per_PE = M // 2

# Construct a runner using SdkRuntime
runner = SdkRuntime(args.name, cmaddr=args.cmaddr)

# Get symbols for A, x, y on device
A_symbol = runner.get_id('A')
x_symbol = runner.get_id('x')
y_symbol = runner.get_id('y')

# Load and run the program
runner.load()
runner.run()

# Copy b into y of PEs (0, 0) and (0, 1)
# PE (0, 0) gets first M/2 elements; PE (0, 1) gets last M/2 elements
runner.memcpy_h2d(y_symbol, b, 0, 0, 1, 2, M_per_PE, streaming=False,
  order=MemcpyOrder.ROW_MAJOR, data_type=MemcpyDataType.MEMCPY_32BIT, nonblock=False)

# Copy chunks of A into all PEs
# PE (0, 0) gets A[0:M/2,0:N/2], PE (1, 0) gets A[0:M/2][N/2:N]
# PE (0, 1) gets A[M/2:M,0:N/2], PE (1, 1) gets A[M/2:M][N/2:N]
# Each chunk on each PE is stored column major
A_prepared = A.reshape(2, M_per_PE, 2, N_per_PE).transpose(0, 2, 3, 1).ravel()
runner.memcpy_h2d(A_symbol, A_prepared, 0, 0, 2, 2, M_per_PE*N_per_PE, streaming=False,
  order=MemcpyOrder.ROW_MAJOR, data_type=MemcpyDataType.MEMCPY_32BIT, nonblock=False)

# Copy x into PEs (0, 0) and (1, 0)
# PE (0, 0) gets first N/2 elements; PE (1, 0) gets last N/2 elements
runner.memcpy_h2d(x_symbol, x, 0, 0, 2, 1, N_per_PE, streaming=False,
  order=MemcpyOrder.ROW_MAJOR, data_type=MemcpyDataType.MEMCPY_32BIT, nonblock=False)

# Launch the compute function on device
runner.launch('compute', nonblock=False)

# Copy y back from PEs (1, 0) and (1, 1)
# First M/2 elements from PE (1, 0); Last M/2 elements from PE (1, 1)
y_result = np.zeros([M], dtype=np.float32)
runner.memcpy_d2h(y_result, y_symbol, 1, 0, 1, 2, M_per_PE, streaming=False,
  order=MemcpyOrder.ROW_MAJOR, data_type=MemcpyDataType.MEMCPY_32BIT, nonblock=False)

# Stop the program
runner.stop()

# Ensure that the result matches our expectation
np.testing.assert_allclose(y_result, y_expected)
print("SUCCESS!")

commands.sh

#!/usr/bin/env bash

set -e

cslc --arch=wse3 ./layout.csl --fabric-dims=9,4 \
--fabric-offsets=4,1 --params=M:4,N:6 -o out --memcpy --channels 1
cs_python run.py --name out