Exporting a CUTLASS grouped GEMM kernel to a PyTorch CUDA extension#

This notebook walks through a basic example of using the CUTLASS Python interface to declare a grouped GEMM kernel and export it as a PyTorch CUDA extension.

Open In Colab

Background on grouped GEMM#

Grouped GEMM enables one to execute a set of GEMMs (each with potentially different sizes and strides) in a single CUDA kernel. It can be thought of as a generalized version of a pointer-array GEMM, without the requirement that the sizes and strides of each GEMM be the same.

For example, if one has p GEMMs with sizes:

M_1 x N_1 x K_1
M_2 x N_2 x K_2
...
M_p x N_p x K_p

CUTLASS’s grouped GEMM will execute these in a single CUDA kernel.

Grouped GEMM is particularly beneficial for saturating the GPU with many small problems that would insufficiently utilize the device in isolation.

Declaring a grouped GEMM via the CUTLASS Python interface#

A grouped GEMM operation is declared similarly to a GEMM operation in the CUTLASS Python interface: one simply calls cutlass.op.GroupedGemm.

[1]:
import cutlass
import torch

dtype = torch.float16
plan = cutlass.op.GroupedGemm(element=dtype, layout=cutlass.LayoutType.RowMajor)
/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

We can then compile and run this operation on a group of GEMMs. We’ll first set up some utility functions to initialize GEMMs.

[2]:
import random
random.seed(2023)

# Utility function to initialize A, B, C, and D matrices corresponding to dimensions M, N, and K
def initialize(dtype, M, N, K):
    sizes = [(M, K), (K, N), (M, N), (M, N)]
    return [torch.randint(-3, 3, size, device='cuda').to(dtype) for size in sizes]

# Utility function to generate `problems` GEMMs of random sizes
def generate_problems(problems):
    valid_sizes = [128, 256, 512, 1024]
    As, Bs, Cs, Ds = [], [], [], []
    for _ in range(problems):
        M, N, K = [random.choice(valid_sizes) for _ in range(3)]
        A, B, C, D = initialize(dtype, M, N, K)
        As.append(A)
        Bs.append(B)
        Cs.append(C)
        Ds.append(D)
    return As, Bs, Cs, Ds

We’ll next run a group of 50 GEMMs via the CUTLASS Python interface and via PyTorch.

[3]:
As, Bs, Cs, Ds, = generate_problems(50)

plan.run(As, Bs, Cs, Ds, print_module=True)
Ds_torch = [a @ b for a, b in zip(As, Bs)]

for d, d_torch in zip(Ds, Ds_torch):
    assert torch.allclose(d, d_torch)

// Gemm operator cutlass_sm80_tensorop_h16x8x16gemm_grouped_1x1x1_256x128_64x3_tt_align8
using cutlass_sm80_tensorop_h16x8x16gemm_grouped_1x1x1_256x128_64x3_tt_align8_base =
  typename cutlass::gemm::kernel::DefaultGemmGrouped<
    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
    cutlass::half_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
    cutlass::half_t, cutlass::layout::RowMajor,
    cutlass::half_t,
    cutlass::arch::OpClassTensorOp,
    cutlass::arch::Sm80,
    cutlass::gemm::GemmShape<256, 128, 64>,
    cutlass::gemm::GemmShape<64, 64, 64>,
    cutlass::gemm::GemmShape<16, 8, 16>,
    cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, cutlass::half_t, cutlass::half_t>,
    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
    3,
    cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
    cutlass::arch::OpMultiplyAdd
>::GemmKernel;

// Define named type
struct cutlass_sm80_tensorop_h16x8x16gemm_grouped_1x1x1_256x128_64x3_tt_align8_type :
  public cutlass_sm80_tensorop_h16x8x16gemm_grouped_1x1x1_256x128_64x3_tt_align8_base { };

Exporting the CUTLASS kernel to a PyTorch CUDA extension#

The procedure above allows one to quickly experiment with using a CUTLASS kernels However, one might prefer to use the CUTLASS kernel via a PyTorch CUDA extension. This will avoids adding any runtime overheads associated with the Python portions of the CUTLASS Python interface.

The CUTLASS Python interface provides simple solutions for creating PyTorch CUDA extensions for a CUTLASS kernel. These extensions can either be written out for a later “ahead-of-time” compilation, or be just-in-time compiled and returned to the user.

To create a JIT-compiled module from the CUTLASS kernel we defined above, simply call the following:

[4]:
op = plan.construct()
grouped_gemm = cutlass.emit.pytorch(op, name='grouped_gemm', cc=plan.cc, sourcedir='out', jit=True)

The cutlass.emit.pytorch function emits: * out/grouped_gemm_kernel.cu: This file contains the declaration of the CUTLASS kernel and a method to call it from PyTorch tensors * out/grouped_gemm.cpp: This file contains a C++ wrapper around the aforementioned CUTLASS kernel * setup.py: This file contains the setuptools script for building and installing the generated extension

The extension can be build from within the module_output directory by running:

TORCH_CUDA_ARCH_LIST="8.0" python setup.py install

Where TORCH_ARCH_LIST is set to the compute capability of the device on which the kernel will be run.

See the PyTorch “Custom C++ and CUDA Extensions” tutorial for more details on this.

The PyTorch CUDA extension could be built for this module by running:

cd out
TORCH_CUDA_ARCH_LIST="8.0" python setup.py

(assuming that one is building for SM80)

One could then use the kernel in a later PyTorch module by running:

import torch
import grouped_gemm

grouped_gemm.run(As, Bs)

In this case, however, we set jit=True, which specifies that we would like to compile and load the PyTorch CUDA extension on the fly. Under the hood, this leverages the torch.utils.cpp_extension.load method and returns back the loaded extension.

We can then use the extension and compare its results to running the GEMMs via vanilla PyTorch GEMMs:

[5]:
Ds = grouped_gemm.run(As, Bs)
Ds_torch = [a @ b for a, b in zip(As, Bs)]
for d, d_torch in zip(Ds, Ds_torch):
    assert torch.allclose(d, d_torch)

Finally, we can profile our grouped GEMM extension:

[6]:
num_warmup = 20
num_profile = 100

# Warmup iterations
for _ in range(num_warmup):
    Ds = grouped_gemm.run(As, Bs)
    Ds_torch = [a @ b for a, b in zip(As, Bs)]
    torch.cuda.synchronize()

# Timing iterations
import time
grouped = 0
nongrouped = 0
for _ in range(num_profile):
    start = time.time()
    Ds = grouped_gemm.run(As, Bs)
    torch.cuda.synchronize()
    grouped += time.time() - start

    start = time.time()
    Ds_torch = [a @ b for a, b in zip(As, Bs)]
    torch.cuda.synchronize()
    nongrouped += time.time() - start

print('Grouped:     {:.3f} us'.format(grouped * 1e6/num_profile))
print('Non-Grouped: {:.3f} us'.format(nongrouped * 1e6/num_profile))
print('Speedup: {:.3f}'.format(nongrouped / grouped))
Grouped:     400.696 us
Non-Grouped: 646.670 us
Speedup: 1.614
[ ]: