Topic 11: Collective Communications

Topic 11: Collective Communications

The <collectives_2d> library can be used for communication between PEs in the same row or column. It mimics the capabilities provided by message passing interface (MPI) collective operations found in other programming languages.

This example showcases each of the currently available communication primitives while using the library across two indepedent dimensions. The communication tasks are executed asynchronously.

task_x uses the broadcast primitive to transmit data from the first PE in every row to every other PE in the same row. After the data is received, reduce_fadds computes the vector sum of the broadcast_recv. The result is transmitted back to the first PE in every row.

task_y operates concurrently along every column of PEs. The task first uses scatter to distribute chunk_size slices of scatter_data across every other PE in the same column. The task uses gather to collect chunk_size slices of data distributed by scatter. Because scatter is the inversion of gather, we have used collective communications to transmit the data from scatter_data to gather_recv.

layout.csl

// color/ task ID map
//
//  ID var              ID var             ID var                ID var
//   0 c2d_x_color_0     9                 18                    27 reserved (memcpy)
//   1 c2d_x_color_1    10 c2d_x_entrypt_0 19                    28 reserved (memcpy)
//   2                  11 c2d_x_entrypt_1 20                    29 reserved
//   3                  12 c2d_y_entrypt_0 21 reserved (memcpy)  30 reserved (memcpy)
//   4 c2d_y_color_0    13 c2d_y_entrypt_1 22 reserved (memcpy)  31 reserved
//   5 c2d_y_color_1    14                 23 reserved (memcpy)  32
//   6                  15 task_x_id       24                    33
//   7                  16 task_y_id       25                    34
//   8                  17                 26                    35
//

param Pw:         u16; // kernel width
param Ph:         u16; // kernel height
param chunk_size: u16; // Num elements to send/recv in collectives

// Colors
const c2d_x_color_0: color = @get_color(0);
const c2d_x_color_1: color = @get_color(1);
const c2d_y_color_0: color = @get_color(4);
const c2d_y_color_1: color = @get_color(5);

// Task IDs
const c2d_x_entrypt_0: local_task_id = @get_local_task_id(10);
const c2d_x_entrypt_1: local_task_id = @get_local_task_id(11);
const c2d_y_entrypt_0: local_task_id = @get_local_task_id(12);
const c2d_y_entrypt_1: local_task_id = @get_local_task_id(13);
const task_x_id:       local_task_id = @get_local_task_id(15);
const task_y_id:       local_task_id = @get_local_task_id(16);

const c2d = @import_module("<collectives_2d/params>");

const memcpy = @import_module( "<memcpy/get_params>", .{
  .width = Pw,
  .height = Ph
});

layout {
  @set_rectangle(Pw, Ph);

  var Px: u16 = 0;
  while (Px < Pw) : (Px += 1) {
    var Py: u16 = 0;
    while (Py < Ph) : (Py += 1) {
      const params = c2d.get_params(Px, Py, .{
        .x_colors      = .{ c2d_x_color_0,   c2d_x_color_1 },
        .x_entrypoints = .{ c2d_x_entrypt_0, c2d_x_entrypt_1 },
        .y_colors      = .{ c2d_y_color_0,   c2d_y_color_1 },
        .y_entrypoints = .{ c2d_y_entrypt_0, c2d_y_entrypt_1 },
      });
      const memcpy_params = memcpy.get_params(Px);
      @set_tile_code(Px, Py, "pe_program.csl", .{
        .memcpy_params = memcpy_params,
        .c2d_params = params,
        .chunk_size = chunk_size,
        .task_x_id = task_x_id,
        .task_y_id = task_y_id });
    }
  }

  // export symbol name
  @export_name("broadcast_data", [*]u32, true);
  @export_name("scatter_data", [*]u32, true);
  @export_name("broadcast_recv", [*]u32, true);
  @export_name("faddh_result", [*]u32, true);
  @export_name("gather_recv", [*]u32, true);

  @export_name("f_run_x", fn()void);
  @export_name("f_run_y", fn()void);
}

pe_program.csl

param c2d_params: comptime_struct;
param memcpy_params: comptime_struct;

param chunk_size: u16; // Number of elements to send/recv in collectives

// Task IDs
param task_x_id: local_task_id; // Task ID for callback for collectives in x direction
param task_y_id: local_task_id; // Task ID for callback for collectives in y direction

const sys_mod = @import_module( "<memcpy/memcpy>", memcpy_params);

const rect_height = @get_rectangle().height;
const rect_width = @get_rectangle().width;

const mpi_x = @import_module("<collectives_2d/pe>", .{
    .dim_params = c2d_params.x,
    .queues = [2]u16{2,4},
    .dest_dsr_ids = [1]u16{1},
    .src0_dsr_ids = [1]u16{1},
    .src1_dsr_ids = [1]u16{1}
    });
const mpi_y = @import_module("<collectives_2d/pe>", .{
    .dim_params = c2d_params.y,
    .queues = [2]u16{3,5},
    .dest_dsr_ids = [1]u16{2},
    .src0_dsr_ids = [1]u16{2},
    .src1_dsr_ids = [1]u16{2}
    });


const Nx = chunk_size * rect_width;
const Ny = chunk_size * rect_height;

// broadcast_data and scatter_data supplied by run.py
var broadcast_data = @zeros([Nx]u32);
var broadcast_recv = @zeros([Nx]u32);
var faddh_result = @zeros([Nx]u32);

var scatter_data = @zeros([Ny]u32);
var scatter_recv = @zeros([Ny]u32);
var gather_recv = @zeros([Ny]u32);

var ptr_broadcast_data: [*]u32 = &broadcast_data;
var ptr_scatter_data: [*]u32 = &scatter_data;
var ptr_broadcast_recv: [*]u32 = &broadcast_recv;
var ptr_faddh_result: [*]u32 = &faddh_result;
var ptr_gather_recv: [*]u32 = &gather_recv;

var task_x_state: u16 = 0;
task task_x() void {
   switch (task_x_state) {
      0 => {
         mpi_x.init();
         var send_buf = @ptrcast([*]u32, &broadcast_data);
         var recv_buf = @ptrcast([*]u32, &broadcast_recv);
         if (mpi_x.pe_id == 0) {
            mpi_x.broadcast(0, send_buf, Nx, task_x_id);
         } else {
            mpi_x.broadcast(0, recv_buf, Nx, task_x_id);
         }

         task_x_state += 1;
      },
      1 => {
         var send_buf = @ptrcast([*]f32, &broadcast_recv);
         var recv_buf = @ptrcast([*]f32, &faddh_result);

         mpi_x.reduce_fadds(0, send_buf, recv_buf, Nx, task_x_id);

         task_x_state += 1;
      },
      else => {
         // WARNING: the user must unblock cmd color for every PE
         sys_mod.unblock_cmd_stream();
         return;
      }
   }
}

var task_y_state: u16 = 0;
task task_y() void {
   switch (task_y_state) {
      0 => {
         mpi_y.init();
         var send_buf = @ptrcast([*]u32, &scatter_data);
         var recv_buf = @ptrcast([*]u32, &scatter_recv);

         mpi_y.scatter(0, send_buf, recv_buf, chunk_size, task_y_id);

         task_y_state += 1;
      },
      1 => {
         var send_buf = @ptrcast([*]u32, &scatter_recv);
         var recv_buf = @ptrcast([*]u32, &gather_recv);

         mpi_y.gather(0, send_buf, recv_buf, chunk_size, task_y_id);

         task_y_state += 1;
      },
      else => {
         // WARNING: the user must unblock cmd color for every PE
         sys_mod.unblock_cmd_stream();
         return;
      }
   }
}

comptime {
   @bind_local_task(task_x, task_x_id);
   @bind_local_task(task_y, task_y_id);
}

fn f_run_x() void {
   @activate(task_x_id);

   // terminate when task_x finishes
}

fn f_run_y() void {
   @activate(task_y_id);

   // terminate when task_y finishes
}

comptime{
  @export_symbol(ptr_broadcast_data, "broadcast_data");
  @export_symbol(ptr_scatter_data, "scatter_data");
  @export_symbol(ptr_broadcast_recv, "broadcast_recv");
  @export_symbol(ptr_faddh_result, "faddh_result");
  @export_symbol(ptr_gather_recv, "gather_recv");
  @export_symbol(f_run_x);
  @export_symbol(f_run_y);
}

run.py

#!/usr/bin/env cs_python

import argparse
import json
import numpy as np

from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime, MemcpyDataType # pylint: disable=no-name-in-module
from cerebras.sdk.runtime.sdkruntimepybind import MemcpyOrder # pylint: disable=no-name-in-module

parser = argparse.ArgumentParser()
parser.add_argument('--name', help='the test name')
parser.add_argument('--cmaddr', help='IP:port for CS system')
args = parser.parse_args()
dirname = args.name

# Parse the compile metadata
with open(f"{dirname}/out.json", encoding="utf-8") as json_file:
  compile_data = json.load(json_file)
params = compile_data["params"]
Pw = int(params["Pw"])
Ph = int(params["Ph"])
chunk_size = int(params["chunk_size"])
print(f"Pw = width of the core = {Pw}")
print(f"Ph = height of the core = {Ph}")
print(f"chunk_size = {chunk_size}")

Nx = Pw*chunk_size
Ny = Ph*chunk_size

print(f"Nx = {Nx}, Ny = {Ny}")

memcpy_dtype = MemcpyDataType.MEMCPY_32BIT
runner = SdkRuntime(dirname, cmaddr=args.cmaddr)

sym_broadcast_data = runner.get_id("broadcast_data")
sym_scatter_data = runner.get_id("scatter_data")
sym_broadcast_recv = runner.get_id("broadcast_recv")
sym_faddh_result = runner.get_id("faddh_result")
sym_gather_recv = runner.get_id("gather_recv")

runner.load()
runner.run()

print("step 1: copy mode H2D(broadcast_data) to 1st column PEs")
broadcast_data = np.ones((Ph, 1, Nx)).astype(np.float32)
runner.memcpy_h2d(sym_broadcast_data, broadcast_data.ravel(), 0, 0, 1, Ph, Nx, \
    streaming=False, data_type=memcpy_dtype, order=MemcpyOrder.ROW_MAJOR, nonblock=True)

print("step 2: copy mode H2D(scatter_data) to 1st row PEs")
scatter_data = np.ones((1, Pw, Ny)).astype(np.int32)
runner.memcpy_h2d(sym_scatter_data, scatter_data.ravel(), 0, 0, Pw, 1, Ny, \
    streaming=False, data_type=memcpy_dtype, order=MemcpyOrder.ROW_MAJOR, nonblock=True)

print("step 3: call f_run_x to test broadcast and reduction")
runner.launch("f_run_x", nonblock=False)

print("step 4: call f_run_y to test scatter and gather")
runner.launch("f_run_y", nonblock=False)

print("step 5: copy mode D2H(broadcast_recv)")
# broadcast on x: Px=0 broadcasts data to all other PEs
# broadcast_recv(y, x=0) = 0
# broadcast_recv(y, x !=0) = ones
broadcast_recv_1d = np.zeros(Ph*Pw*Nx, np.float32)
runner.memcpy_d2h(broadcast_recv_1d, sym_broadcast_recv, 0, 0, Pw, Ph, Nx, \
    streaming=False, data_type=memcpy_dtype, order=MemcpyOrder.ROW_MAJOR, nonblock=False)
broadcast_recv = broadcast_recv_1d.reshape((Ph, Pw, Nx))

print("step 6: copy mode D2H(faddh_result) from 1st column PEs")
# reduce(broadcast_recv) to Px=0
faddh_result_1d = np.zeros(Ph*Nx, np.float32)
runner.memcpy_d2h(faddh_result_1d, sym_faddh_result, 0, 0, 1, Ph, Nx, \
    streaming=False, data_type=memcpy_dtype, order=MemcpyOrder.ROW_MAJOR, nonblock=False)
faddh_result = faddh_result_1d.reshape((Ph, 1, Nx))

print("step 7: copy mode D2H(gather_recv) from 1st row PEs")
gather_recv_1d = np.zeros(Pw*Ny, np.int32)
runner.memcpy_d2h(gather_recv_1d, sym_gather_recv, 0, 0, Pw, 1, Ny, \
    streaming=False, data_type=memcpy_dtype, order=MemcpyOrder.ROW_MAJOR, nonblock=False)
gather_recv = gather_recv_1d.reshape((1, Pw, Ny))

runner.stop()

# verify broadcast on x-direction
correct_broadcast_recv = np.ones(Nx).astype(np.float32)
for y in range(Ph):
  for x in range(Pw):
    if x == 0:
      continue
    np.testing.assert_equal(broadcast_recv[y, x], correct_broadcast_recv)

# verify faddh_result at 1st column PEs
# reduce on x: reduce(broadcast_recvs) to Px=0
# where broadcast_recvs(y, x=0) = 0
#       broadcast_recvs(y, x != 0) = ones
correct_faddh_result = np.full(Nx, (Pw-1), dtype=np.float32)
for y in range(Ph):
  np.testing.assert_equal(faddh_result[y, 0], correct_faddh_result)

# verify gather_recv at 1st row PEs
correct_gather_recv = np.ones(Ny).astype(np.int32)
for x in range(Pw):
  np.testing.assert_equal(gather_recv[0, x], correct_gather_recv)

print("SUCCESS")

commands.sh

#!/usr/bin/env bash

set -e

cslc --arch=wse3 ./layout.csl --fabric-dims=22,17 --fabric-offsets=4,1 \
--params=Pw:15,Ph:15,chunk_size:3 -o out \
--memcpy --channels=1 --width-west-buf=0 --width-east-buf=0
cs_python run.py --name out