Topic 7: Filters

Fabric filters allow a PE to selectively accept incoming wavelets. This example shows the use of so-called range filters, which specify the wavelets to allow to be forwarded to the CE based on the upper 16 bits of the wavelet contents. Specifically, PE #0 sends all 12 wavelets to the other PEs, while each recipient PE receives and processes only a quarter of the incoming wavelets. See Filter Configuration Semantics for other possible filter configurations.

layout.csl

// color/ task ID map
//
//  ID var           ID var            ID var                ID var
//   0                9 STARTUP        18                    27 reserved (memcpy)
//   1 dataColor     10                19                    28 reserved (memcpy)
//   2 resultColor   11                20                    29 reserved
//   3 H2D           12                21 reserved (memcpy)  30 reserved (memcpy)
//   4 D2H           13                22 reserved (memcpy)  31 reserved
//   5               14                23 reserved (memcpy)  32
//   6               15                24                    33
//   7               16                25                    34
//   8 main_task_id  17                26                    35

//  +-------------+
//  | north(d2H)  |
//  +-------------+
//  | core        |
//  +-------------+
//  | south(nop)  |
//  +-------------+

// IDs for memcpy streaming colors
param MEMCPYH2D_DATA_1_ID: i16;
param MEMCPYD2H_DATA_1_ID: i16;

// Colors
const MEMCPYH2D_DATA_1: color = @get_color(MEMCPYH2D_DATA_1_ID);
const MEMCPYD2H_DATA_1: color = @get_color(MEMCPYD2H_DATA_1_ID);
const dataColor:        color = @get_color(1);
const resultColor:      color = @get_color(2);

// Task IDs
const STARTUP:      local_task_id = @get_local_task_id(9);
const main_task_id: local_task_id = @get_local_task_id(8);
const recv_task_id: data_task_id  = @get_data_task_id(dataColor);

const memcpy = @import_module( "<memcpy/get_params>", .{
    .width = 4,
    .height = 3,
    .MEMCPYH2D_1 = MEMCPYH2D_DATA_1,
    .MEMCPYD2H_1 = MEMCPYD2H_DATA_1
    });

layout {
  @set_rectangle(4, 3);

  for (@range(i16, 4)) |pe_x| {
    const memcpy_params = memcpy.get_params(pe_x);

    // north PE only runs d2h
    @set_tile_code(pe_x, 0, "memcpyEdge/north.csl", .{
      .memcpy_params = memcpy_params,
      .USER_OUT_1 = resultColor,
      .STARTUP = STARTUP,
    });
  }

  const memcpy_params_0 = memcpy.get_params(0);
  const memcpy_params_1 = memcpy.get_params(1);
  const memcpy_params_2 = memcpy.get_params(2);
  const memcpy_params_3 = memcpy.get_params(3);

  @set_tile_code(0, 1, "send.csl", .{
    .peId = 0,
    .memcpy_params = memcpy_params_0,
    .exchColor = dataColor,
    .resultColor = resultColor,
    .main_task_id = main_task_id
  });

  const recvStruct = .{ .recvColor    = dataColor,
                        .resultColor  = resultColor,
                        .recv_task_id = recv_task_id };
  @set_tile_code(1, 1, "recv.csl", @concat_structs(recvStruct, .{
    .peId = 1,
    .memcpy_params = memcpy_params_1,
  }));
  @set_tile_code(2, 1, "recv.csl", @concat_structs(recvStruct, .{
    .peId = 2,
    .memcpy_params = memcpy_params_2,
  }));
  @set_tile_code(3, 1, "recv.csl", @concat_structs(recvStruct, .{
    .peId = 3,
    .memcpy_params = memcpy_params_3,
  }));

  for (@range(i16, 4)) |pe_x| {
    const memcpy_params = memcpy.get_params(pe_x);
    // south does nothing
    @set_tile_code(pe_x, 2, "memcpyEdge/south.csl", .{
      .memcpy_params = memcpy_params,
      .STARTUP = STARTUP
    });
  }
}

send.csl

param memcpy_params: comptime_struct;

param peId: u16;

// Colors
param exchColor:        color;
param resultColor:      color;

// Task IDs
param main_task_id: local_task_id;

// ----------
// Every PE needs to import memcpy module otherwise the I/O cannot
// propagate the data to the destination.

// memcpy module reserves input queue 0 and output queue 0
const sys_mod = @import_module( "<memcpy/memcpy>", memcpy_params);
// ----------

/// Helper function to pack 16-bit index and 16-bit float value into one 32-bit
/// wavelet.
fn pack(index: u16, data: f16) u32 {
  return (@as(u32, index) << 16) | @as(u32, @bitcast(u16, data));
}

const size = 12;
const data = [size]u32 {
  pack(0, 10.0),  pack( 1, 11.0), pack( 2, 12.0),
  pack(3, 13.0),  pack( 4, 14.0), pack( 5, 15.0),
  pack(6, 16.0),  pack( 7, 17.0), pack( 8, 18.0),
  pack(9, 19.0),  pack(10, 20.0), pack(11, 21.0),
};

/// Function to send all data values to all east neighbors.
fn sendDataToEastTiles() void {
  const inDsd = @get_dsd(mem1d_dsd, .{
    .tensor_access = |i|{size} -> data[i]
  });

  const outDsd = @get_dsd(fabout_dsd, .{
    .extent = size,
    .fabric_color = exchColor,
    .output_queue = @get_output_queue(2)
  });

  // WARNING: "async" is necessary otherwise CE has no resource
  // to run memcpy kernel
  @mov32(outDsd, inDsd, .{.async=true});
}

/// Function to process (divide by 2) the first three values and send result to
/// the north neighbor (halo PE).
const num_wvlts: u16 = 3;
var buf = @zeros([num_wvlts]f16);
var ptr_buf : [*]f16 = &buf;

fn processAndSendSubset() void {
  const outDsd = @get_dsd(fabout_dsd, .{
    .extent = num_wvlts,
    .fabric_color = resultColor,
    .output_queue = @get_output_queue(1)
  });
  const bufDsd = @get_dsd(mem1d_dsd, .{
    .tensor_access = |i|{num_wvlts} -> buf[i]
  });

  var idx: u16 = 0;
  while (idx < num_wvlts) : (idx += 1) {
    const payload = @as(u16, data[idx] & 0xffff);
    const floatValue = @bitcast(f16, payload);
    buf[idx] = floatValue / 2.0;
  }
  // WARNING: nonblock is necessary otherwise CE has no resource
  // to run memcpy kernel
  @fmovh(outDsd, bufDsd, .{.async = true});
}

task mainTask() void {
  sendDataToEastTiles();
  processAndSendSubset();
}

comptime {
  @activate(main_task_id);
  @bind_local_task(mainTask, main_task_id);

  @set_local_color_config(exchColor, .{ .routes = .{ .rx = .{ RAMP }, .tx = .{ EAST } } });
  @set_local_color_config(resultColor, .{ .routes = .{ .rx = .{ RAMP }, .tx = .{ NORTH } } });
}

recv.csl

param memcpy_params: comptime_struct;

param peId: u16;

// Colors
param recvColor:        color;
param resultColor:      color;

// Task IDs
param recv_task_id: data_task_id; // data task receives data along recvColor

// ----------
// Every PE needs to import memcpy module otherwise the I/O cannot
// propagate the data to the destination.

// memcpy module reserves input queue 0 and output queue 0
const sys_mod = @import_module( "<memcpy/memcpy>", memcpy_params);
// ----------

/// The recipient simply halves the value in the incoming wavelet and sends the
/// result to the north neighbor (halo PE).
var buf = @zeros([1]f16);
task recvTask(data: f16) void {
  @block(recvColor);
  buf[0] = data / 2.0;
  const outDsd = @get_dsd(fabout_dsd, .{
    .extent = 1,
    .fabric_color = resultColor,
    .output_queue = @get_output_queue(1)
  });
  const bufDsd = @get_dsd(mem1d_dsd, .{
    .tensor_access = |i|{1} -> buf[i]
  });
  // WARNING: nonblock is necessary otherwise CE has no resource
  // to run memcpy kernel
  @fmovh(outDsd, bufDsd, .{.async = true, .unblock = recv_task_id});
}

comptime {
  @bind_data_task(recvTask, recv_task_id);

  const baseRoute = .{
    .rx = .{ WEST }
  };

  const filter = .{
      // Each PE should only accept three wavelets starting with the one whose
      // index field contains the value peId * 3.
      .kind = .{ .range = true },
      .min_idx = peId * 3,
      .max_idx = peId * 3 + 2,
    };

  if (peId == 3) {
    // This is the last PE, don't forward the wavelet further to the east.
    const txRoute = @concat_structs(baseRoute, .{ .tx = .{ RAMP } });
    @set_local_color_config(recvColor, .{.routes = txRoute, .filter = filter});
  } else {
    // Otherwise, forward incoming wavelets to both CE and to the east neighbor.
    const txRoute = @concat_structs(baseRoute, .{ .tx = .{ RAMP, EAST } });
    @set_local_color_config(recvColor, .{.routes = txRoute, .filter = filter});
  }

  // Send result wavelets to the north neighbor (i.e. the halo PEs).
  @set_local_color_config(resultColor, .{ .routes = .{ .rx = .{ RAMP }, .tx = .{ NORTH } } });
}

run.py

#!/usr/bin/env cs_python

import argparse
import json
import numpy as np

from cerebras.sdk.sdk_utils import memcpy_view
from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime, MemcpyDataType # pylint: disable=no-name-in-module
from cerebras.sdk.runtime.sdkruntimepybind import MemcpyOrder # pylint: disable=no-name-in-module

parser = argparse.ArgumentParser()
parser.add_argument('--name', help='the test name')
parser.add_argument("--cmaddr", help="IP:port for CS system")
args = parser.parse_args()
dirname = args.name

# Parse the compile metadata
with open(f"{dirname}/out.json", encoding="utf-8") as json_file:
  compile_data = json.load(json_file)
params = compile_data["params"]
MEMCPYD2H_DATA_1 = int(params["MEMCPYD2H_DATA_1_ID"])
print(f"MEMCPYD2H_DATA_1 = {MEMCPYD2H_DATA_1}")

memcpy_dtype = MemcpyDataType.MEMCPY_16BIT
runner = SdkRuntime(dirname, cmaddr=args.cmaddr)

runner.load()
runner.run()

print("step 1: streaming D2H at P0.0")
# The D2H buffer must be of type u32
out_tensors_u32 = np.zeros(4*3, np.uint32)
runner.memcpy_d2h(out_tensors_u32, MEMCPYD2H_DATA_1, 0, 0, 4, 1, 3, \
    streaming=True, data_type=memcpy_dtype, order=MemcpyOrder.ROW_MAJOR, nonblock=False)
# remove upper 16-bit of each u32
result = memcpy_view(out_tensors_u32, np.dtype(np.float16))

runner.stop()

oracle = [5, 5.5, 6, 6.5, 7, 7.5, 8, 8.5, 9, 9.5, 10, 10.5]
np.testing.assert_allclose(result, oracle, atol=0.0001, rtol=0)
print("SUCCESS!")

commands.sh

#!/usr/bin/env bash

set -e

cslc ./layout.csl --fabric-dims=11,5 --fabric-offsets=4,1 -o out \
--params=MEMCPYH2D_DATA_1_ID:3 \
--params=MEMCPYD2H_DATA_1_ID:4 \
--memcpy --channels=1 --width-west-buf=0 --width-east-buf=0
cs_python run.py --name out