Optimizing a Rust GPU matmul kernel
I read the excellent post Optimizing a WebGPU Matmul Kernel for 1TFLOP+ Performance by Zach Nussbaum and thought it might be fun to reimplement it with Rust GPU.
We'll follow Zach's original post closely, comparing and contrasting using Rust vs the WGSL and Typescript from his post.
At the end, I'll show some unique benefits of using Rust on the GPU.
A big thank you to Zach for allowing me to reimplement his blog post!
The complete runnable code can be found on GitHub.
What is Rust GPU?
Rust GPU is a project that allows you to write code for GPUs using the Rust programming language. GPUs are typically programmed using specialized languages like WGSL, GLSL, MSL, or HLSL. Rust GPU changes this by letting you use Rust to write GPU programs (often called "shaders" or "kernels").
These Rust GPU programs are then compiled into SPIR-V, a low-level format that most GPUs understand. Since SPIR-V is the format Vulkan uses, Rust GPU makes it possible to integrate Rust-based GPU programs into any Vulkan-compatible workflow1.
For more details, check out the Rust GPU website or the GitHub repository.
How does Rust GPU work?
Rust GPU focuses purely on compiling your Rust code into SPIR-V. This compiled code is what the GPU executes. However, Rust GPU doesn't dictate how you handle CPU-to-GPU communication or data transfer. You're free to choose a host CPU library written in whatever language that fits your project. Some popular options in Rust include:
- ash: Low-level Vulkan bindings for Rust, providing maximum control over Vulkan operations.
- vulkano: A higher-level Vulkan library that simplifies common tasks.
- wgpu: A cross-platform library that abstracts GPU operations across Vulkan, DirectX, Metal, and WebGPU.
But again, you don't have to use Rust for the CPU-side when using Rust on the GPU—any language will do.
What will we use?
In Zach's post, he writes his GPU programs in WGSL. These programs and their data are sent to and from the GPU via Typescript which talks to the WebGPU CPU code built into the browser.
We'll take a different approach: writing GPU programs in Rust via Rust GPU and managing everything—including the CPU-side code—in Rust. This means both the GPU programs and the code controlling them will be written in the same language. If you are familiar with web programming, what we are doing is conceptually similar to Javascript running on both the server and the client.
Using Rust for both CPU and GPU has advantages, like consistent tooling and shared code. But it also means we need to be clear about which code runs where. I've tried to make sure this distinction is easy to follow.
To handle communication between our code on the CPU and GPU, we'll use
wgpu
. wgpu
is a high-level Rust library that
implements the WebGPU API. On the web, it works directly with the browser's WebGPU
implementation. On native platforms, it translates API calls to the platform's GPU API
(Vulkan, DirectX, or Metal). This lets us run the same code on a wide range of
platforms, including Windows, Linux, macOS2, iOS3, Android, and the web4.
By using Rust GPU and wgpu
, we have a clean, portable setup with everything written in
Rust.
GPU program basics
The smallest unit of execution is a thread, which executes the GPU program.
Workgroups are groups of threads: they are grouped together and run in parallel (they’re called thread blocks in CUDA). They can access the same shared memory.
We can dispatch many of these workgroups at once. CUDA calls this a grid (which is made of thread blocks).
Workgroups and dispatching workgroups are defined in 3D. The size of a workgroup is
defined by compute(threads((x, y, z)))
where the number of threads per workgroup is
x * y * z.
Writing the kernel
Kernel 1: Naive kernel
The simplest way to compute a dot product between matrix A and B and write to matrix C is for each row in A (of shape M), iterate over the columns of A (of shape K) and multiply by the corresponding value of B.
Here, we have our first difference from Zach's post. In WGSL, you must define inputs at the top-level scope:
struct Dimensions {
M: u32,
K: u32,
N: u32,
}
@group(0) @binding(0) var<uniform> dimensions: Dimensions;
@group(0) @binding(1) var<storage, read> a: array<f32>;
@group(0) @binding(2) var<storage, read> b: array<f32>;
@group(0) @binding(3) var<storage, read_write> result: array<f32>;
And then write your kernel:
@compute @workgroup_size(1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let index = global_id.x;
let row = index / dimensions.N;
let col = index % dimensions.N;
if (index < dimensions.M * dimensions.N) {
var sum = 0.0;
for (var i: u32 = 0u; i < dimensions.K; i = i + 1u) {
sum = sum + a[row * dimensions.K + i] * b[i * dimensions.N + col];
}
result[row * dimensions.N + col] = sum;
}
}
With Rust GPU, we specify the inputs as arguments to the kernel and configure them with procedural macros:
#![no_std]
use settings::Dimensions;
use spirv_std::glam::UVec3;
use spirv_std::spirv;
#[spirv(compute(threads(1)))]
pub fn matmul(
#[spirv(global_invocation_id)] global_id: UVec3,
#[spirv(uniform, descriptor_set = 0, binding = 0)] dimensions: &Dimensions,
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] a: &[f32],
#[spirv(storage_buffer, descriptor_set = 0, binding = 2)] b: &[f32],
#[spirv(storage_buffer, descriptor_set = 0, binding = 3)] result: &mut [f32],
) {
let index = global_id.x;
let row = index / dimensions.n;
let col = index % dimensions.n;
if index < dimensions.m * dimensions.n {
let mut sum = 0.0;
for i in 0..dimensions.k {
let a_val = a[(row * dimensions.k + i) as usize];
let b_val = b[(i * dimensions.n + col) as usize];
sum += a_val * b_val;
}
result[(row * dimensions.n + col) as usize] = sum;
}
}
This code looks like normal Rust code but runs entirely on the GPU.
There are a couple of things to note about the Rust implementation:
- The kernel uses the regular Rust
#![no_std]
attribute, which is required because GPUs do not have access to Rust's standard library (std
). Instead, you rely oncore
andspirv_std
to providestd
-like functionality. - Libraries are imported via
use
. The module system works exactly the same as regular Rust. - We're importing a vendored copy of
glam
. This is the exactglam
crate from crates.io. - The inner loop (
for i in 0..dimensions.k
) uses Rust'sfor
syntax with a range. This is a higher-level abstraction compared to manually iterating with an index in other shader languages like WGSL, GLSL, or HLSL. - Read-only inputs are immutable references (
&Dimensions
/&[f32]
) and writable outputs are mutable references (&mut [f32]
). This feels very familiar to anyone used to writing Rust.
What's with all the usize
?
Rust defines usize
as the native pointer width of the hardware the code is running
on. This is important because Rust
uses usize
for indexing slices to ensure that access is properly pointer-aligned.
On most GPU hardware, usize
is effectively equivalent to u32
. But the Rust compiler
doesn't assume that. It can't, because doing so could introduce problems—like if you ran
this code on hardware where usize
is actually u64
. Rust won't let you implicitly
treat a u32
as a usize
. You have to explicitly cast it, essentially telling the
compiler "I know this is safe for my target hardware."
This explicitness might seem tedious but it is one of the ways Rust prevents subtle bugs. It forces you to think about whether your assumptions about hardware alignment and pointer sizes are correct, making your code more portable and reliable.
Matrix multiplication is a pathological case with lots of indexing and row and column
calculations. Most Rust GPU code does not have nearly as many usize
casts as these
examples.
Dispatching workgroups
Each workgroup, since it's only one thread (#[spirv(compute(threads(1)))]
), processes
one result[i, j]
.
To calculate the full matrix, we need to launch as many entries as there are in the
m * n
matrix. Here we specify that (Uvec3::new(m * n, 1, 1
) on the CPU:
impl GridComputation for Naive {
fn workgroup(&self) -> UVec3 {
UVec3::new(1, 1, 1)
}
fn dispatch_count(&self, m: u32, n: u32) -> UVec3 {
UVec3::new(m * n, 1, 1)
}
}
The dispatch_count()
function runs on the CPU and is used by the CPU-to-GPU API (in
our case wgpu
) to configure and dispatch work to the GPU:
let dispatch_count = <T as GridComputation>::dispatch_count(&self.variant, m, n);
...
compute_pass.dispatch_workgroups(dispatch_count.x, dispatch_count.y, dispatch_count.z);
This code appears more complicated than it needs to be. I abstracted the CPU-side code that talks to the GPU using generics and traits so I could easily slot in different kernels and their settings while writing the blog post.
You could just hardcode the value for simplicity.
Kernel 2: Moarrr threads!
With the first kernel, we're only able to compute small square matrices due to limits on the number of workgroups you can dispatch at once.
Since we're launching one workgroup per entry, a 256x256 matrix is larger than our limit!
Remember this line?
#[spirv(compute(threads(1)))]
We can reduce the number of dispatched workgroups by increasing the number of threads per workgroup!
If we update our GPU code
#[spirv(compute(threads(256)))]
we can reduce the number of total dispatched workgroups per dimension:
impl GridComputation for Workgroup256 {
fn workgroup(&self) -> UVec3 {
UVec3::new(256, 1, 1)
}
fn dispatch_count(&self, m: u32, n: u32) -> UVec3 {
let workgroup = self.workgroup();
let threads_needed = m * n;
// This ceil division is needed because Rust handles truncation differently than
// Typescript/Javascript so we might get 0.
// We'll also cap the value to a maximum of 65,535 to comply with hardware limits.
let x = ((threads_needed as f32 / workgroup.x as f32).ceil() as u32).min(65_535);
UVec3::new(x, 1, 1)
}
}
With these two small changes we can handle larger matrices without hitting hardware workgroup limits.
Kernel 3: Calculating with 2D workgroups
However, doing all the computation in "1 dimension" still limits the matrix size we can calculate.
Although we don't change much about our code, if we distribute our work in 2 dimensions we're able to bypass these limits and launch more workgroups that are larger. This allows us to calculate a 4096x4096 matmul.
We update our compute(threads(256)))
to compute(threads((16, 16)))
, and make the small
change to row
and col
from Zach's post to increase speed:
#![no_std]
use settings::Dimensions;
use spirv_std::glam::UVec3;
use spirv_std::spirv;
#[spirv(compute(threads(16, 16)))]
pub fn matmul(
#[spirv(global_invocation_id)] global_id: UVec3,
#[spirv(uniform, descriptor_set = 0, binding = 0)] dimensions: &Dimensions,
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] a: &[f32],
#[spirv(storage_buffer, descriptor_set = 0, binding = 2)] b: &[f32],
#[spirv(storage_buffer, descriptor_set = 0, binding = 3)] result: &mut [f32],
) {
let row = global_id.x as usize;
let col = global_id.y as usize;
if row < dimensions.m as usize && col < dimensions.n as usize {
let mut sum = 0.0;
for i in 0..dimensions.k as usize {
sum += a[row * dimensions.k as usize + i] * b[i * dimensions.n as usize + col];
}
result[row * dimensions.n as usize + col] = sum;
}
}
And we need to tweak the workgroup dispatch count calculation on the CPU as we are in 2D
now and using the y
value:
impl GridComputation for Workgroup2d {
fn workgroup(&self) -> UVec3 {
UVec3::new(16, 16, 1)
}
fn dispatch_count(&self, m: u32, n: u32) -> UVec3 {
let w = self.workgroup();
let workgroup_size = w.x + w.y;
let x = ((m as f32) / (workgroup_size as f32)).ceil() as u32;
let y = ((n as f32) / (workgroup_size as f32)).ceil() as u32;
UVec3::new(x, y, 1)
}
}
Kernel 4: Kernel tiling
Another thing to consider is how much work each thread does.
Up to now, each thread only computes one entry. But there is some overhead to launching each workgroup versus computing more than 1 element per thread!
If calculating more elements per thread is faster than the overhead to launch each workgroup, we should see a big speedup.
To do so, we calculate 4 results per thread (e.g. a 1x4 Tile).
#![no_std]
use settings::Dimensions;
use settings::TILE_SIZE;
use spirv_std::glam::UVec3;
use spirv_std::spirv;
#[spirv(compute(threads(16, 16)))]
pub fn matmul(
#[spirv(global_invocation_id)] global_id: UVec3,
#[spirv(uniform, descriptor_set = 0, binding = 0)] dimensions: &Dimensions,
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] a: &[f32],
#[spirv(storage_buffer, descriptor_set = 0, binding = 2)] b: &[f32],
#[spirv(storage_buffer, descriptor_set = 0, binding = 3)] result: &mut [f32],
) {
let row = global_id.y as usize;
let col = (global_id.x * TILE_SIZE) as usize;
if row >= dimensions.m as usize || col >= dimensions.n as usize {
return;
}
let mut sum00: f32 = 0.0;
let mut sum01: f32 = 0.0;
let mut sum02: f32 = 0.0;
let mut sum03: f32 = 0.0;
for i in 0..dimensions.k as usize {
let a_elem = a[row * dimensions.k as usize + i];
if col < dimensions.n as usize {
sum00 += a_elem * b[i * dimensions.n as usize + col];
}
if col + 1 < dimensions.n as usize {
sum01 += a_elem * b[i * dimensions.n as usize + col + 1];
}
if col + 2 < dimensions.n as usize {
sum02 += a_elem * b[i * dimensions.n as usize + col + 2];
}
if col + 3 < dimensions.n as usize {
sum03 += a_elem * b[i * dimensions.n as usize + col + 3];
}
}
if col < dimensions.n as usize {
result[row * dimensions.n as usize + col] = sum00;
}
if col + 1 < dimensions.n as usize {
result[row * dimensions.n as usize + col + 1] = sum01;
}
if col + 2 < dimensions.n as usize {
result[row * dimensions.n as usize + col + 2] = sum02;
}
if col + 3 < dimensions.n as usize {
result[row * dimensions.n as usize + col + 3] = sum03;
}
}
The kernel looks roughly the same as before except we've unrolled the computation and
are calculating TILE_SIZE
results per thread. We also need some error checking for
when our matrices don't fit nicely.
But this code is kinda gross...it looks like the opaque GPU code we are used to. Let's make it nice!
#![no_std]
use settings::Dimensions;
use settings::TILE_SIZE;
use spirv_std::glam::UVec3;
use spirv_std::spirv;
#[spirv(compute(threads(16, 16)))]
pub fn matmul(
#[spirv(global_invocation_id)] global_id: UVec3,
#[spirv(uniform, descriptor_set = 0, binding = 0)] dimensions: &Dimensions,
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] a: &[f32],
#[spirv(storage_buffer, descriptor_set = 0, binding = 2)] b: &[f32],
#[spirv(storage_buffer, descriptor_set = 0, binding = 3)] result: &mut [f32],
) {
let row = global_id.y as usize;
let col = (global_id.x * TILE_SIZE) as usize;
if row >= dimensions.m as usize || col >= dimensions.n as usize {
return;
}
// Compute sums for each offset directly
let mut sums = [0.0; TILE_SIZE as usize];
for i in 0..dimensions.k as usize {
let a_elem = a[row * dimensions.k as usize + i];
for offset in 0..TILE_SIZE as usize {
if col + offset < dimensions.n as usize {
let b_elem = b[i * dimensions.n as usize + col + offset];
sums[offset] += a_elem * b_elem;
}
}
}
// Write results back
for offset in 0..TILE_SIZE as usize {
if col + offset < dimensions.n as usize {
result[row * dimensions.n as usize + col + offset] = sums[offset];
}
}
}
Much better.
We can take this a step further and calculate 2D results per thread! Instead of calculating 4 elements per single row, we can calculate 4 elements for 4 rows (e.g. a 2D tile).
#![no_std]
use settings::Dimensions;
use settings::{TILE_M, TILE_N};
use spirv_std::glam::UVec3;
use spirv_std::spirv;
#[spirv(compute(threads(16, 16)))]
pub fn matmul(
#[spirv(global_invocation_id)] global_id: UVec3,
#[spirv(uniform, descriptor_set = 0, binding = 0)] dimensions: &Dimensions,
#[spirv(storage_buffer, descriptor_set = 0, binding = 1)] a: &[f32],
#[spirv(storage_buffer, descriptor_set = 0, binding = 2)] b: &[f32],
#[spirv(storage_buffer, descriptor_set = 0, binding = 3)] result: &mut [f32],
) {
let row = (global_id.y * TILE_M) as usize;
let col = (global_id.x * TILE_N) as usize;
// Initialize sums array to zeros
// Note: This is uglier than it needs to be to work around
// https://github.com/Rust-GPU/rust-gpu/issues/46
let mut sums: [[f32; TILE_N as usize]; TILE_M as usize] = Default::default();
// Compute the 2D tile
for k in 0..dimensions.k as usize {
for i in 0..TILE_M as usize {
let a_element = if row + i < dimensions.m as usize {
a[(row + i) * dimensions.k as usize + k]
} else {
0.0
};
for j in 0..TILE_N as usize {
let b_element = if col + j < dimensions.n as usize {
b[k * dimensions.n as usize + (col + j)]
} else {
0.0
};
sums[i][j] += a_element * b_element;
}
}
}
// Write results
for i in 0..TILE_M as usize {
for j in 0..TILE_N as usize {
let output_row = row + i;
let output_col = col + j;
if output_row < dimensions.m as usize && output_col < dimensions.n as usize {
result[output_row * dimensions.n as usize + output_col] = sums[i][j];
}
}
}
}
Each thread now calculates a 4x4 grid of the output matrix and we see a slight improvement over the last kernel.
To stay true to the spirit of Zach's original blog post, we'll wrap things up here and leave the "fancier" experiments for another time.
A note on performance
I didn't include performance numbers as I have a different machine than Zach. The
complete runnable code can be found on
GitHub
and you can run the benchmarks yourself with cargo bench
.
You can also check out real-world projects using Rust GPU such as
autograph
and
renderling
.
Reflections on porting to Rust GPU
Porting to Rust GPU went quickly, as the kernels Zach used were fairly simple. Most of my time was spent with concerns that were not specifically about writing GPU code. For example, deciding how much to abstract vs how much to make the code easy to follow, if everything should be available at runtime or if each kernel should be a compilation target, etc. The code is not great as it is still blog post code!
My background is not in GPU programming, but I do have Rust experience. I joined the Rust GPU project because I tried to use standard GPU languages and knew there must be a better way.
Writing these GPU kernels felt like writing any other Rust code (other than debugging, more on that later) which is a huge win to me. Not just the language itself, but the entire development experience.
Rust-specific party tricks
Rust lets us write code for both the CPU and GPU in ways that are often impossible—or at least less elegant—with other languages. I'm going to highlight some benefits I experienced while working on this blog post.
Shared code across GPU and CPU
In GPU programming, we often need to pass data between the CPU and GPU. For example, our
GPU kernel expects a Dimensions
struct as input:
use settings::Dimensions;
...
pub fn matmul(
...
#[spirv(uniform, descriptor_set = 0, binding = 0)] dimensions: &Dimensions,
We create an instance of Dimensions
on the CPU and send it to the GPU via wgpu
,
where the Rust kernel loads and uses it.
// This is a `uniform` buffer instead of `storage` buffer because the data is
// the same for all workgroups, it is read-only, and it is small enough to fit
// in a single buffer (`uniform` buffers are limited to 64 KB on most GPUs
// and often less on older GPUs).
let dimensions = Dimensions::new(m, k, n);
let dimensions_buffer = create_buffer_init(
&self.device,
"Dimensions Buffer",
&[dimensions],
wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
);
This means the code on the CPU and GPU need to agree on the definition of
Dimensions
!
In many GPU programming ecosystems, this would involve manually keeping the definitions in sync across different languages—one for the CPU, one for the GPU. This is tedious and error-prone.
With Rust, it's straightforward: we move the Dimensions
struct into its own crate, and
both the CPU and GPU code depend on that crate. Now, the type definition lives in one
place and both platforms use it directly.
This approach eliminates duplication and guarantees consistency. If we need to make changes, those changes propagate to both the CPU and GPU automatically, reducing the risk of mismatches and making refactoring far safer.
This kind of consistency across CPU and GPU is something you don't often see in other GPU programming ecosystems. Bespoke codegen solutions are often created to accomplish the same thing Rust has built in.
Running and debugging shaders on the CPU
GPU code can be notoriously hard to debug. While developing this kernel, I ran into a
bug I couldn't figure out. GPU debugging tools are limited and printf
-style debugging
often isn't available. But what if we could run the GPU kernel on the CPU, where we
have access to tools like standard debuggers and good ol' printf
/println
?
With Rust GPU, this was straightforward. By using standard Rust cfg()
directives I
made the GPU-specific annotations (#[spirv(...)]
) disappear when compiling for the
CPU. The result? The kernel became a regular Rust function. On the GPU, it behaves like
a shader. On the CPU, it's just a function you can call directly.
Here's what it looks like in practice using the 2D tiling kernel from before:
//! This shader can run on both the CPU and the GPU.
//!
//! The GPU-specific attributes are only used when compiling for the GPU, otherwise they
//! are stripped away and the shader entrypoint becomes a normal function that can be
//! called from the CPU.
#![no_std]
use settings::Dimensions;
use settings::{TILE_M, TILE_N};
#[cfg(target_arch = "spirv")]
use spirv_std::spirv;
#[cfg(target_arch = "spirv")]
use spirv_std::glam;
#[cfg(not(target_arch = "spirv"))]
use glam;
use glam::UVec3;
#[cfg_attr(target_arch = "spirv", spirv(compute(threads(16, 16))))]
pub fn matmul(
#[cfg_attr(target_arch = "spirv", spirv(global_invocation_id))] global_id: UVec3,
#[cfg_attr(target_arch = "spirv", spirv(uniform, descriptor_set = 0, binding = 0))]
dimensions: &Dimensions,
#[cfg_attr(
target_arch = "spirv",
spirv(storage_buffer, descriptor_set = 0, binding = 1)
)]
a: &[f32],
#[cfg_attr(
target_arch = "spirv",
spirv(storage_buffer, descriptor_set = 0, binding = 2)
)]
b: &[f32],
#[cfg_attr(
target_arch = "spirv",
spirv(storage_buffer, descriptor_set = 0, binding = 3)
)]
result: &mut [f32],
) {
let row = (global_id.y * TILE_M as u32) as usize;
let col = (global_id.x * TILE_N as u32) as usize;
// Initialize sums array to zeros
let mut sums: [[f32; TILE_N as usize]; TILE_M as usize] = Default::default();
// Compute the 2D tile
for k in 0..dimensions.k as usize {
for i in 0..TILE_M as usize {
let a_element = if row + i < dimensions.m as usize {
a[(row + i) * dimensions.k as usize + k]
} else {
0.0
};
for j in 0..TILE_N as usize {
let b_element = if col + j < dimensions.n as usize {
b[k * dimensions.n as usize + (col + j as usize)]
} else {
0.0
};
sums[i][j] += a_element * b_element;
}
}
}
// Write results
for i in 0..TILE_M as usize {
for j in 0..TILE_N as usize {
let output_row = row + i as usize;
let output_col = col + j as usize;
if output_row < dimensions.m as usize && output_col < dimensions.n as usize {
result[output_row * dimensions.n as usize + output_col] = sums[i][j];
}
}
}
}
The logic in the kernel hasn't changed, it is exactly the same as the GPU-only code from before.
You'll also notice that on the GPU it uses glam
from spirv_std
but on the CPU it
uses glam
from crates.io:
#[cfg(target_arch = "spirv")]
use spirv_std::glam;
#[cfg(not(target_arch = "spirv"))]
use glam;
This is enabled by the standard Rust ecosystem tooling around dependencies:
# Dependencies when run on either the CPU or GPU
[dependencies]
settings = { path = "../../shared/settings" }
# Dependencies when run on the CPU
[target.'cfg(not(target_arch = "spirv"))'.dependencies]
glam.workspace = true
# Dependencies when run on the GPU
[target.'cfg(target_arch = "spirv")'.dependencies]
spirv-std.workspace = true
Testing the kernel in isolation is useful, but it does not reflect how the GPU executes it with multiple invocations across workgroups and dispatches. To test the kernel end-to-end, I needed a test harness that simulated this behavior on the CPU.
Building the harness was straightforward due to Rust. By enforcing the same invariants as the GPU I could validate the kernel under the same conditions the GPU would run it:
fn multiply(
&self,
a: &[f32],
b: &[f32],
m: u32,
k: u32,
n: u32,
) -> Result<Vec<f32>, MatrixMultiplyError> {
// Initialize the result vector with zeros as that is what the GPU does.
let mut result = vec![0.0; (m * n) as usize];
// Retrieve workgroup and dispatch configurations. These tell us how to iterate.
let workgroup = <T as GridComputation>::workgroup(&self.variant);
let dispatch = <T as GridComputation>::dispatch_count(&self.variant, m, n);
// Define dimensions as (m, k, n)
let dimensions = Dimensions::new(m, k, n);
// Iterate over the dispatch grid
for gwx in 0..dispatch.x {
for gwy in 0..dispatch.y {
for wx in 0..workgroup.x {
for wy in 0..workgroup.y {
// Calculate global indices
let x = gwx * workgroup.x + wx;
let y = gwy * workgroup.y + wy;
if x < m && y < n {
// Define global id
let global_id = UVec3::new(x, y, 1);
// Perform the matmul operation for element (x, y). NOTE:
// This is the EXACT SAME CODE THAT RUNS ON THE GPU, RUNNING
// ON THE CPU. This is the power of rust-gpu.
<T as Cpu>::call(
&self.variant,
global_id,
&dimensions,
&a,
&b,
&mut result,
);
}
}
}
}
}
Ok(result)
}
Again, this code appears more complicated than it needs to be. I abstracted the CPU testing harness code using generics and traits so I could easily slot in different kernels and their settings while writing the blog post.
You could just call the kernel function directly in nested loops.
Tests
By moving the kernel code to the CPU, I could write tests that ran quickly and entirely on the CPU. This eliminated the need to serialize tests and offload them to the GPU (which is a shared and limited resource).
This approach has several benefits. First, it significantly reduced the feedback loop during development, allowing me to catch issues faster. Second, it ensured the tests could be run in any environment where the Rust toolchain is available—no GPU required. This is especiallly relevant in CI environments such as Github Actions that do not have a GPU by default.
For example, my test for a small matrix multiplication kernel running in the harness on the CPU looked like this:
#[test]
fn test_single_threaded_matmul_2x1x1() {
let m = 2;
let k = 1;
let n = 1;
let a = vec![1.0, 2.0];
let b = vec![3.0];
let expected = vec![3.0, 6.0];
let variant = crate::variants::Isomorphic;
let matrix_multiplier =
block_on(SingleThreadedMatMul::new(variant)).expect("Failed to create");
let result = matrix_multiplier
.multiply(&a, &b, m, k, n)
.expect("Matrix multiplication failed");
assert_eq!(result, expected);
}
Benchmarks
I wanted to run benchmarks similar to those in the original blog post. Because I was
using Rust, this was simple. I used
criterion with cargo bench
, just like any
other Rust project.
This required no new tools or workflows. The tools I already knew worked seamlessly.
More importantly, this approach benefits anyone working on the project. Any Rust
engineer can run these benchmarks with no additional setup—cargo bench
is a standard
part of the Rust ecosystem.
Formatting
Rust GPU code is formatted with rustfmt
, following the same standards as all Rust
code. This not only ensured my GPU code looked identical to my CPU code, it made my GPU
code consistent with the entire Rust ecosystem. Leveraging standard tools like
rustfmt
minimizes cognitive overhead and avoids the hassle of configuring third-party
formatters of varying quality.
Lint
Linting GPU code in Rust works the same way as for CPU code. Running cargo clippy
highlighted issues and enforced consistent code quality. Though I didn't have any,
custom lint configurations are applied to Rust GPU kernels as well. Lints ensure that
GPU code is held to the same high standards as the rest of the project.
Documentation
Writing doc comments and running cargo doc
generates documentation for GPU kernels,
exactly how it happens in regular Rust. While some ecosystems offer similar tools,
Rust's integration is built-in and works seamlessly for both CPU and GPU code. There's
no special setup required.
But wait, there's more!
The kernel in Zach's blog post is intentionally simple. That makes it easy to follow, but it also means the Rust code looks very similar to WGSL. While this is fine for an introductory example, it doesn't demonstrate Rust's real strengths for GPU programming. These strengths—reusing existing libraries, traits, enums, generics, and more—become much more important as projects grow in complexity.
Leverage the existing Rust ecosystem
Rust's no_std
ecosystem offers a wide array of libraries that can be used in
environments without the standard library. Traditionally this has meant embedded
devices, but a lot of the same assumptions apply to GPUs! As a consequence, you can
reuse existing no_std
& no alloc
libraries from
crates.io in your GPU code without the
authors explicitly adding GPU support. This is uniquely enabled by Rust GPU's
implementation choices and Rust's
registers. Sharing and reusing code
from the greater Rust ecosystem is a superpower when writing GPU programs that will
massively compound over time.
Traits
Traits are one of Rust's most powerful tools and they work with Rust GPU. Traits let you
define zero-cost reusable type-safe behavior. For example, if you have multiple kernels
for different matrix multiplication strategies, you can define a MatrixMultiplication
trait and implement it for each variation. This eliminates duplication and makes your
code easier to extend.
Enums and zero-sized types
GPU code is notoriously hard to read, but Rust's enums and zero-sized types (ZSTs) can make it much more understandable. Enums let you explicitly encode states or modes. For example, you can define tiling strategies or precision levels using enums instead of relying on constants or magic numbers.
ZSTs take this further by encoding configurations directly into the type system. For example, you could represent different kernel configurations as ZSTs. This approach ensures invalid configurations are impossible, improving both readability and safety.
Generics
Generics are another feature missing from this kernel but are a powerful tool in Rust
GPU. They allow you to write flexible kernels that work across different data types or
memory layouts. For instance, you can write a single function that supports both f32
and f64
without duplicating code, all while maintaining type safety and performance.
Error handling with Result
Rust GPU also supports error handling using Result
. Encoding errors in the type system
makes it clear where things can go wrong and forces you to handle those cases. This is
particularly useful for validating kernel inputs or handling the many edge cases in GPU
logic.
Iterators
Rust's iterators don't appear in this kernel, but they're another way Rust GPU simplifies complex logic. Instead of manual loops with indices, you can use iterators to express your logic more clearly.
Iterators reduce the chance of off-by-one errors and make the intent of the code much clearer.
Rust GPU's support for iterators is not complete but we are looking to improve it in the future.
Conditional compilation
While I briefly touched on it a couple of times, this kernel doesn't really show the
full power of conditional compilation. With #[cfg(...)]
and cargo
"features", you can adapt
kernels to different hardware or configurations without duplicating code. GPU languages
like WGSL or GLSL offer preprocessor directives, but these tools lack standardization
across projects. Rust GPU leverages the existing Cargo ecosystem, so conditional
compilation follows the same standards all Rust developers already know.
Come join us!
Rust GPU only recently became a community managed
project. We're eager to add more users and contributors!
We will be working on revamping the onboarding and documentation soon. To follow along
or get involved, check out the rust-gpu
repo on
GitHub.
Footnotes
-
Why not CUDA? That is covered by Rust CUDA, a related project that I am planning on rebooting soon! ↩
-
Technically
wgpu
uses MoltenVK or translates to Metal on macOS ↩ -
Technically
wgpu
uses MoltenVK or translates to Metal on iOS ↩ -
Technically
wgpu
translates SPIR-V to GLSL (WebGL) or WGSL (WebGPU) via naga on the web ↩