GEMM with Collective Operations

GEMM with Collective Operations

This program implements the SUMMA matrix multiplication algorithm and serves as an example of using the collectives_2d library together with SdkRuntime and the memcpy framework.

The host code first copies tiles of A and B onto their corresponding PEs. It then uses the remote procedure call (RPC) mechanism to launch the function main, at which point the GEMM computation begins.

We perform GEMM in P many steps on a grid of P x P processors. At each step i, PEs in the ith column broadcast their home tiles of A to other PEs in their row, and PEs in the ith row broadcast their home tiles of B to other PEs in their column. Once both broadcasts are complete as determined by x_done() and y_done() both being activated, each PE computes C_tile += Ap * Bp where Ap and Bp are pointers to either the PE’s home tile or the tile it received through broadcasts.

When computation is complete the host copies back the resulting tiles of C from the device.

layout.csl

// Color/ task ID map
//
//  ID var              ID var              ID var                ID var
//   0 c2d_x_color_0     9 c2d_x_entrypt_1  18                    27 reserved (memcpy)
//   1 c2d_x_color_1    10 c2d_y_entrypt_0  19                    28 reserved (memcpy)
//   2                  11 c2d_y_entrypt_1  20                    29 reserved
//   3                  12 EXIT             21 reserved (memcpy)  30 reserved (memcpy)
//   4 c2d_y_color_0    13 compute_task_id  22 reserved (memcpy)  31 reserved
//   5 c2d_y_color_1    14 x_task_id        23 reserved (memcpy)  32
//   6                  15 y_task_id        24                    33
//   7                  16                  25                    34
//   8 c2d_x_entrypt_0  17                  26                    35

// Program rectangle is P x P
param P: u16;

// Matrix dimensions on one PE
param Mt: u16;
param Kt: u16;
param Nt: u16;

const memcpy = @import_module("<memcpy/get_params>", .{
  .width = P,
  .height = P
});

const c2d = @import_module("<collectives_2d/params>");

layout {
  @set_rectangle(P, P);

  var Px: u16 = 0;
  while (Px < P) : (Px += 1) {
    var Py: u16 = 0;
    const memcpy_params = memcpy.get_params(Px);
    while (Py < P) : (Py += 1) {
      const c2d_params = c2d.get_params(Px, Py, .{
        .x_colors      = .{ @get_color(0),         @get_color(1) },
        .x_entrypoints = .{ @get_local_task_id(8), @get_local_task_id(9) },
        .y_colors      = .{ @get_color(4),         @get_color(5) },
        .y_entrypoints = .{ @get_local_task_id(10), @get_local_task_id(11) },
      });
      @set_tile_code(Px, Py, "pe.csl", .{
        .memcpy_params = memcpy_params,
        .c2d_params = c2d_params,
        .Mt = Mt, .Kt = Kt, .Nt = Nt,
      });
    }
  }

  // export symbol names
  @export_name("A", [*]f32, true);
  @export_name("B", [*]f32, true);
  @export_name("C", [*]f32, true);
  @export_name("main", fn()void);
}

pe.csl

// This program implements the SUMMA matrix multiplication algorithm and is
// written as an example to show how to use the `collectives_2d` library.

// We perform GEMM in `P` many steps on a grid of `P x P` processors.
// At each step `i`, PEs in the `i`th column broadcast their home tiles of `A`
// to other PEs in their row, and PEs in the `i`th row broadcast their home
// tiles of `B` to other PEs in their column. Once both broadcasts are complete
// as determined by `x_done()` and `y_done()` both being activated,
// each PE computes `C_tile += Ap * Bp` where `Ap` and `Bp` are pointers to
// either the PE's home tile or the tile it received through broadcasts.

param c2d_params: comptime_struct;
param memcpy_params: comptime_struct;

// Matrix size params
param Mt: i16;
param Kt: i16;
param Nt: i16;

// Task IDs
const EXIT:            local_task_id = @get_local_task_id(12);
const compute_task_id: local_task_id = @get_local_task_id(13);
const x_task_id:       local_task_id = @get_local_task_id(14);
const y_task_id:       local_task_id = @get_local_task_id(15);

const mpi_x = @import_module("<collectives_2d/pe>", .{
    .dim_params = c2d_params.x,
    .queues = [2]u16{2,4},
    .dest_dsr_ids = [1]u16{1},
    .src0_dsr_ids = [1]u16{1},
    .src1_dsr_ids = [1]u16{1}
    });
const mpi_y = @import_module("<collectives_2d/pe>", .{
    .dim_params = c2d_params.y,
    .queues = [2]u16{3,5},
    .dest_dsr_ids = [1]u16{2},
    .src0_dsr_ids = [1]u16{2},
    .src1_dsr_ids = [1]u16{2}
    });

// On WSE-2, memcpy uses input/output queue 0
// On WSE-3, memcpy uses input/output queues 0 and 1
const sys_mod = @import_module("<memcpy/memcpy>", memcpy_params);

const P = @get_rectangle().width;

// This PE's home tile of A, B, C
// `A_tile` and `B_tile` will be populated with initial values by run.py
// These arrays are stored in a column major format.
var A_tile = @zeros([Mt*Kt]f32);
var B_tile = @zeros([Kt*Nt]f32);
var C_tile = @zeros([Mt*Nt]f32);

var ptr_A : [*]f32 = &A_tile;
var ptr_B : [*]f32 = &B_tile;
var ptr_C : [*]f32 = &C_tile;

// Temporary buffers for storing in-flight tiles of A and B
var A_buffer = @zeros([Mt*Kt]f32);
var B_buffer = @zeros([Kt*Nt]f32);

var px: u16;
var py: u16;

task x_done() void {
  @activate(compute_task_id);
}

task y_done() void {
  @unblock(compute_task_id);
}

var step: u16 = 0;
fn main() void {
  @assert(step < P);

  // The first time through we need to initialize our state
  if (step == 0) {
    mpi_x.init();
    mpi_y.init();
    px = mpi_x.pe_id;
    py = mpi_y.pe_id;
  }

  // Communicate along both rows and columns
  const Ap = if (px == step) &A_tile else &A_buffer;
  const Bp = if (py == step) &B_tile else &B_buffer;
  mpi_x.broadcast(step, @ptrcast([*]u32, Ap), Mt * Kt, x_task_id);
  mpi_y.broadcast(step, @ptrcast([*]u32, Bp), Kt * Nt, y_task_id);
}

task compute() void {
  const Ap = if (px == step) &A_tile else &A_buffer;
  const Bp = if (py == step) &B_tile else &B_buffer;

  // Do an fmacs based local GEMM
  var A_dsd  = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{Mt} -> A_tile[i] });
  A_dsd = @set_dsd_base_addr(A_dsd, Ap);

  for (@range(i16, Kt)) |k| {
    var C_dsd = @get_dsd(mem1d_dsd, .{ .tensor_access = |i|{Mt} -> C_tile[i] });

    for (@range(i16, Nt)) |j| {
      const b = Bp.*[j*Kt + k];
      @fmacs(C_dsd, C_dsd, A_dsd, b);
      C_dsd = @increment_dsd_offset(C_dsd, Mt, f32);
    }
    A_dsd = @increment_dsd_offset(A_dsd, Mt, f32);
  }

  step += 1;
  @block(compute_task_id);

  if (step != P) {
    main();
  } else {
    @activate(EXIT);
  }
}

task f_exit() void {
  // the user must unblock cmd color for every PE
  sys_mod.unblock_cmd_stream();
}

comptime {
  @bind_local_task(f_exit, EXIT);
  @bind_local_task(compute, compute_task_id);
  @bind_local_task(x_done, x_task_id);
  @bind_local_task(y_done, y_task_id);
  @block(compute_task_id);

  @export_symbol(ptr_A, "A");
  @export_symbol(ptr_B, "B");
  @export_symbol(ptr_C, "C");
  @export_symbol(main);
}

run.py

#!/usr/bin/env cs_python

import argparse
import json
import numpy as np

from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime     # pylint: disable=no-name-in-module
from cerebras.sdk.runtime.sdkruntimepybind import MemcpyDataType # pylint: disable=no-name-in-module
from cerebras.sdk.runtime.sdkruntimepybind import MemcpyOrder    # pylint: disable=no-name-in-module

parser = argparse.ArgumentParser()
parser.add_argument("--name", help="the test name")
parser.add_argument("--cmaddr", help="IP:port for CS system")
args = parser.parse_args()

# Get params from compile metadata
with open(f"{args.name}/out.json", encoding='utf-8') as json_file:
  compile_data = json.load(json_file)

# Kernel rectangle and per-PE matrix dimensions
P = int(compile_data['params']['P'])
Mt = int(compile_data['params']['Mt'])
Kt = int(compile_data['params']['Kt'])
Nt = int(compile_data['params']['Nt'])

# Full matrix dimensions
# A is M x K, B is K x N, C is M x N
M = Mt * P
K = Kt * P
N = Nt * P

memcpy_dtype = MemcpyDataType.MEMCPY_32BIT
memcpy_order = MemcpyOrder.ROW_MAJOR

# Use a deterministic seed so that CI results are predictable
np.random.seed(seed=7)

A = np.random.rand(M, K).astype(np.float32)
B = np.random.rand(K, N).astype(np.float32)

runner = SdkRuntime(args.name, cmaddr=args.cmaddr)

sym_A = runner.get_id("A")
sym_B = runner.get_id("B")
sym_C = runner.get_id("C")

runner.load()
runner.run()

w = P # number of columns PEs in the core rectangle
h = P # number of row PEs in the core rectangle

# How to transform a 2-D tensor into a cliff distribution with
# column-major local tensor
#
# Example: w=2, h=2, A is 4-by-4 (lh-by-lw)
# A = |  0  1  2  3 |
#     |  4  5  6  7 |
#     |  8  9 10 11 |
#     | 12 13 14 15 |
# A1 = A.reshape(2,2,2,2) of the form (h,lh,w,lw)
# A1 = | | 0  1|  | 4  5| |
#      | | 2  3|, | 6  7| |
#      |                  |
#      | | 8  9|  |12 13| |
#      | |10 11|, |14 15| |
# A2 = A1.transpose(0, 2, 3, 1) of the form (h, w, lw, lh)
# so the local tensor lh-by-lw is col-major
# A2 = | | 0  4|  | 2  6| |
#      | | 1  5|, | 3  7| |
#      |                  |
#      | | 8 12|  |10 14| |
#      | | 9 13|, |11 15| |
# A3 = A2.reshape(2,2,4)
# A3 = |  0  4  1  5 |
#      |  2  6  3  7 |
#      |  8 12  9 13 |
#      | 10 14 11 15 |
# A3 is h-w-l

A1 = A.reshape(h, Mt, w, Kt)
A2 = A1.transpose(0, 2, 3, 1)
A3 = A2.reshape(h, w, Mt*Kt)
runner.memcpy_h2d(sym_A, A3.ravel(), 0, 0, w, h, Mt*Kt, \
    streaming=False, data_type=memcpy_dtype, order=MemcpyOrder.ROW_MAJOR, nonblock=True)

B1 = B.reshape(h, Kt, w, Nt)
B2 = B1.transpose(0, 2, 3, 1)
B3 = B2.reshape(h, w, Kt*Nt)
runner.memcpy_h2d(sym_B, B3.ravel(), 0, 0, w, h, Kt*Nt, \
    streaming=False, data_type=memcpy_dtype, order=MemcpyOrder.ROW_MAJOR, nonblock=True)

runner.launch("main", nonblock=False)

C3_1d_u32 = np.zeros(h*w*Mt*Nt, np.uint32)
runner.memcpy_d2h(C3_1d_u32, sym_C, 0, 0, w, h, Mt*Nt, \
    streaming=False, data_type=memcpy_dtype, order=MemcpyOrder.ROW_MAJOR, nonblock=False)
# C3 is h-by-w-l or
# C3 is of the form (h, w, Nt, Mt) where local tensor Mt-by-Nt is column-major
C3 = C3_1d_u32.reshape((h, w, Nt, Mt))
# C2 is of the form (h, Mt, w, Nt)
C2 = C3.transpose(0, 3, 1, 2)
# C1 is of the form (M, N)
C1 = C2.reshape(M, N)
# C has the correct data type
C = C1.view(np.float32)

runner.stop()

# Check the result
C_expected = np.dot(A, B)

# absolute(a - b) <= (atol + rtol * absolute(b))
np.testing.assert_allclose(C_expected, C, rtol=1e-05, atol=1e-06)

print("SUCCESS")

commands.sh

#!/usr/bin/env bash

set -e

cslc --arch=wse3 ./layout.csl --fabric-dims=11,6 --fabric-offsets=4,1 \
--params=P:4,Mt:14,Kt:14,Nt:14 \
--memcpy --channels=1 -o out
cs_python run.py --name out