Topic 5: Switches

Fabric switches permit limited runtime control of routes.

In this example, the layout block initializes the default route to receive wavelets from the ramp and forward them to the PE’s north neighbor. However, it also defines routes for switch positions 1, 2, and 3. The hardware updates the route according to the specified switch positions when it receives a so-called Control Wavelet.

For the payload of the control wavelet, the code creates a special wavelet using the helper function ctrl().

Switches can be helpful not just to change the routing configuration in limited ways at runtime, but also to save the number of colors used. For instance, this same example could be re-written to use four colors and four routes, but by using fabric switches, this example uses just one color.

layout.csl

// color/ task ID map
//
//  ID var           ID var             ID var                ID var
//   0                9 STARTUP         18                    27 reserved (memcpy)
//   1 channel       10                 19                    28 reserved (memcpy)
//   2 out           11                 20                    29 reserved
//   3               12                 21 reserved (memcpy)  30 reserved (memcpy)
//   4 D2H           13                 22 reserved (memcpy)  31 reserved
//   5               14                 23 reserved (memcpy)  32
//   6               15                 24                    33
//   7               16                 25                    34
//   8 main_task_id  17                 26                    35
//

//  +---------------+
//  | north (d2h)   |
//  +---------------+
//  | core (3-by-3) |
//  +---------------+
//  | south (nop)   |
//  +---------------+

param MEMCPYD2H_DATA_1_ID: i16; // ID for memcpy streaming color

const colorValue = 1; // ID of color used to transmit from send.csl

// Colors
const MEMCPYD2H_DATA_1: color = @get_color(MEMCPYD2H_DATA_1_ID);
const channel:          color = @get_color(colorValue);
const out:              color = @get_color(2);

// Task IDs
const main_task_id:    local_task_id = @get_local_task_id(8);
const STARTUP:         local_task_id = @get_local_task_id(9);
const channel_task_id: data_task_id  = @get_data_task_id(channel);

const memcpy = @import_module( "<memcpy/get_params>", .{
    .width = 3,
    .height = 5,
    .MEMCPYD2H_1 = MEMCPYD2H_DATA_1
    });

layout {
  @set_rectangle(3, 5);

  // north only runs D2H which receives data from pe_program
  // and forwards it to the host
  for (@range(i16, 3)) |pe_x| {
    const memcpy_params = memcpy.get_params(pe_x);
    @set_tile_code(pe_x, 0, "memcpyEdge/north.csl", .{
      .memcpy_params = memcpy_params,
      .USER_OUT_1 = out,
      .STARTUP = STARTUP,
    });
  }

  const memcpy_params_0 = memcpy.get_params(0);
  const memcpy_params_1 = memcpy.get_params(1);
  const memcpy_params_2 = memcpy.get_params(2);

  // The core has 3-by-3 PEs starting at row 1 where row 0 is "north".
  // The py coorindate of each PE is added by 1.

  // Out of the nine PEs, the PE in the center (PE #1,1) will send four
  // control wavelets to the PE's four adjacent neighbors.  These four
  // adjacent numbers are programmed to receive the control wavelets, whereas
  // all other PEs (i.e. the PEs at the corners of the rectangle) are
  // programmed to contain no instructions or routes.
  @set_tile_code(1, 1+1, "send.csl", .{
    .memcpy_params = memcpy_params_1,
    .txColor = channel,
    .main_task_id = main_task_id,
    .colorValue = colorValue
  });

  @set_tile_code(1, 0+1, "recv.csl", .{
    // Make this PE send the final message back to the host signaling completion
    .memcpy_params = memcpy_params_1,
    .rxColor = channel, .outColor = out,
    .rx_task_id = channel_task_id,
    .inDir = SOUTH, .fin = true
  });

  @set_tile_code(0, 1+1, "recv.csl", .{
    .memcpy_params = memcpy_params_0,
    .rxColor = channel, .outColor = out,
    .rx_task_id = channel_task_id,
    .inDir = EAST, .fin = false
  });

  @set_tile_code(2, 1+1, "recv.csl", .{
    .memcpy_params = memcpy_params_2,
    .rxColor = channel, .outColor = out,
    .rx_task_id = channel_task_id,
    .inDir = WEST, .fin = false
  });

  @set_tile_code(1, 2+1, "recv.csl", .{
    .memcpy_params = memcpy_params_1,
    .rxColor = channel, .outColor = out,
    .rx_task_id = channel_task_id,
    .inDir = NORTH, .fin = false
  });

  // south does NOP
  for (@range(i16, 3)) |pe_x| {
    const memcpy_params = memcpy.get_params(pe_x);
    @set_tile_code(pe_x, 4, "memcpyEdge/south.csl", .{
      .memcpy_params = memcpy_params,
      .STARTUP = STARTUP
    });
  }

  @set_tile_code(0, 0+1, "empty.csl", .{
    .memcpy_params = memcpy_params_0,
  });
  @set_tile_code(2, 0+1, "empty.csl", .{
    .memcpy_params = memcpy_params_2,
  });
  @set_tile_code(0, 2+1, "empty.csl", .{
    .memcpy_params = memcpy_params_0,
  });
  @set_tile_code(2, 2+1, "empty.csl", .{
    .memcpy_params = memcpy_params_2,
  });
}

send.csl

// Not a complete program; the top-level source file is code.csl.
param memcpy_params: comptime_struct;

param colorValue;

// Colors
param txColor:          color;

// Task IDs
param main_task_id: local_task_id;

// ----------
// Every PE needs to import memcpy module otherwise the I/O cannot
// propagate the data to the destination.

// memcpy module reserves input queue 0 and output queue 0
const sys_mod = @import_module( "<memcpy/memcpy>", memcpy_params);
// ----------

const dsd = @get_dsd(fabout_dsd, .{
  .extent = 1,
  .fabric_color = txColor,

  // Specify that this wavelet is a control wavelet
  .control = true,
});

// Opcodes for potentially updating switches
const opcode_nop = 0;
const opcode_switch_advance = 1;
const opcode_switch_reset = 2;
const opcode_teardown = 3;

// Helper function to construct the payload of the control wavelet.
// args:
//    ceFilter: a filter bit to disable transmission from the destination
//              router to the destination CE,
//    opcode: switching opcode (see comment above), and
//    data: 16-bit wavelet data
fn ctrl(ce_filter: bool, opcode: i16, data: u16) u32 {
  const six = @as(u32, 6);
  const eight = @as(u32, 8);
  const sixteen = @as(u32, 16);

  const hi_word = @as(u32, colorValue) |
                  @as(u32, opcode) << six |
                  @as(u32, ce_filter) << eight;

  const lo_word = @as(u32, data);
  return hi_word << sixteen | lo_word;
}

task mainTask() void {
  // Now we can reuse a single color to send four different values to the four
  // neighbors of this PE.  The four wavelets will be sent over four
  // consecutive cycles.

  // Send 0xaa along the first (WEST) direction
  // Since all arguments to this function are known at compile time, we make
  // this a `comptime` function call.
  @mov32(dsd, comptime ctrl(false, opcode_switch_advance, 0xaa));

  // Send 0xbb along the second (EAST) direction
  @mov32(dsd, comptime ctrl(false, opcode_switch_advance, 0xbb));

  // Send 0xcc along the third (SOUTH) direction
  @mov32(dsd, comptime ctrl(false, opcode_switch_advance, 0xcc));

  // Send 0xdd along the fourth (NORTH) direction
  @mov32(dsd, comptime ctrl(false, opcode_switch_advance, 0xdd));
}

comptime {
  @bind_local_task(mainTask, main_task_id);
  @activate(main_task_id);

  const routes = .{
    // The default route, which is to receive from ramp and send to north
    .rx = .{ RAMP },
    .tx = .{ NORTH }
  };

  const switches = .{

    // Upon a control wavelet, change the transmit direction to west
    .pos1 = .{ .tx = WEST },

    // Upon another control wavelet, change the transmit direction to east
    .pos2 = .{ .tx = EAST },

    // Upon yet another control wavelet, change the transmit direction to south
    .pos3 = .{ .tx = SOUTH },

    // Send to west PE first, then east PE, then south PE, and then north PE
    .current_switch_pos = 1,

    // Wrap around from position 3 to position 0 after receiving control wavelet
    .ring_mode = true,
  };

  @set_local_color_config(txColor, .{.routes = routes, .switches = switches});
}

recv.csl

// Not a complete program; the top-level source file is code.csl.
param memcpy_params: comptime_struct;

param fin: bool;
param inDir: direction;

// Colors
param rxColor:          color;
param outColor:         color;

// Task IDs
param rx_task_id: data_task_id; // Data task receives data along rxColor

// ----------
// Every PE needs to import memcpy module otherwise the I/O cannot
// propagate the data to the destination.

// memcpy module reserves input queue 0 and output queue 0
const sys_mod = @import_module( "<memcpy/memcpy>", memcpy_params);
// ----------

const dsd = @get_dsd(fabout_dsd, .{.fabric_color = outColor, .extent = 1});

export var global:u16 = 0;

task rxTask(data: u16) void {
  global = data;

  if (fin) {
    @mov16(dsd, 0);
  }
}

comptime {
  @bind_data_task(rxTask, rx_task_id);
  @set_local_color_config(rxColor, .{.routes = .{ .rx = .{ inDir }, .tx = .{ RAMP } } });

  const outRoute = .{ .rx = .{ RAMP }, .tx = .{ NORTH } };
  @set_local_color_config(outColor, .{.routes = outRoute});
}

empty.csl

// Not a complete program; the top-level source file is code.csl.

// ----------
// Every PE needs to import memcpy module otherwise the I/O cannot
// propagate the data to the destination.

param memcpy_params: comptime_struct;

// memcpy module reserves input queue 0 and output queue 0
const sys_mod = @import_module( "<memcpy/memcpy>", memcpy_params);
// ----------

run.py

#!/usr/bin/env cs_python

import argparse
import json
import numpy as np

from cerebras.sdk.debug.debug_util import debug_util
from cerebras.sdk.sdk_utils import memcpy_view
from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime, MemcpyDataType # pylint: disable=no-name-in-module
from cerebras.sdk.runtime.sdkruntimepybind import MemcpyOrder # pylint: disable=no-name-in-module

parser = argparse.ArgumentParser()
parser.add_argument('--name', help='the test name')
parser.add_argument("--cmaddr", help="IP:port for CS system")
args = parser.parse_args()
dirname = args.name

# Parse the compile metadata
with open(f"{dirname}/out.json", encoding="utf-8") as json_file:
  compile_data = json.load(json_file)
params = compile_data["params"]
MEMCPYD2H_DATA_1 = int(params["MEMCPYD2H_DATA_1_ID"])
print(f"MEMCPYD2H_DATA_1 = {MEMCPYD2H_DATA_1}")

memcpy_dtype = MemcpyDataType.MEMCPY_16BIT
runner = SdkRuntime(dirname, cmaddr=args.cmaddr)

runner.load()
runner.run()

print("step 1: streaming D2H at P1.0 (end of communication)")
# The D2H buffer must be of type u32
out_tensors_u32 = np.zeros(1, np.uint32)
runner.memcpy_d2h(out_tensors_u32, MEMCPYD2H_DATA_1, 1, 0, 1, 1, 1, \
    streaming=True, data_type=memcpy_dtype, order=MemcpyOrder.COL_MAJOR, nonblock=False)
# remove upper 16-bit of each u32
result_tensor = memcpy_view(out_tensors_u32, np.dtype(np.int16))

runner.stop()

debug_mod = debug_util(dirname, cmaddr=args.cmaddr)
core_offset_x = 4
core_offset_y = 1
print(f"=== core rectangle starts at {core_offset_x}, {core_offset_y}")
# sender PE is P1.1
# top PE of sender PE is P1.1
result_top = debug_mod.get_symbol(core_offset_x+1, core_offset_y+1, "global", np.uint16)
# left PE of sender PE is P0.2
result_left = debug_mod.get_symbol(core_offset_x+0, core_offset_y+2, "global", np.uint16)
# right PE of sender PE is P2.2
result_right = debug_mod.get_symbol(core_offset_x+2, core_offset_y+2, "global", np.uint16)
# bottom PE of sender PE is P1.3
result_bottom = debug_mod.get_symbol(core_offset_x+1, core_offset_y+3, "global", np.uint16)

np.testing.assert_allclose(result_top, 0xdd)
np.testing.assert_allclose(result_left, 0xaa)
np.testing.assert_allclose(result_right, 0xbb)
np.testing.assert_allclose(result_bottom, 0xcc)
print("SUCCESS!")

commands.sh

#!/usr/bin/env bash

set -e

cslc ./layout.csl --fabric-dims=10,7 --fabric-offsets=4,1 -o out \
--params=MEMCPYD2H_DATA_1_ID:4 \
--memcpy --channels=1 --width-west-buf=0 --width-east-buf=0
cs_python run.py --name out