Operations#

GEMM#

Ease-of-use interface for constructing, compiling, and running GEMMs.

The Gemm interface is meant to allow one to easily instantiate, compile, and run GEMM operations in CUTLASS via Python, without specifying many configuration parameters. Under the hood, the interface will select sensible default parameters for the many template parameters for CUTLASS GEMMs.

Note: optimal performance is not to be expected from this interface. To achieve optimal performance, one should specify and tune each configuration parameter.

The simplest example of using this interface is the following:

# A, B, C, and D are torch/numpy/cupy tensor objects
plan = cutlass.op.Gemm(A, B, C, D)
plan.run()

One can also use the interface by specifying data types of operands at construction and using different tensor objects with these data types at runtime:

# The following is shorthand for:
#        cutlass.op.Gemm(element_A=torch.float32, element_B=torch.float32,
#                        element_C=torch.float32, element_D=torch.float32,
#                        element_accumulator=torch.float32,
#                        layout=cutlass.LayoutType.RowMajor)
plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass.LayoutType.RowMajor)

A0 = torch.rand((128, 256), device='cuda')
B0 = torch.rand((256, 64), device='cuda')
C0 = torch.zeros((128, 64), device='cuda')
D0 = torch.zeros((128, 64), device.'cuda')
plan.run(A0, B0, C0, D0)

A = torch.rand((32, 128), device='cuda')
B = torch.rand((128, 256), device='cuda')
C = torch.zeros((32, 256), device='cuda')
D = torch.zeros((32, 256), device.'cuda')
plan.run(A1, B1, C1, D1)

The interface additionally enables one to decouple the compilation of the underlying CUTLASS kernel from its execution:

plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor)
plan.compile()

# Do other work...

plan.run(A0, B0, C0, D0)

# Do other work...

plan.run(A1, B1, C1, D1)

Elementwise activation functions are easily fused to the GEMM via the interface:

plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor)
plan.activation = cutlass.epilogue.relu

Operations can also be run asynchronously:

plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor)
args = plan.run()

# Do other work...

args.sync()
class cutlass.op.gemm.Gemm(A=None, B=None, C=None, D=None, alpha=1.0, beta=0.0, element_accumulator=None, element=None, layout=None, element_A=None, element_B=None, element_C=None, element_D=None, layout_A=None, layout_B=None, layout_C=None, cc=None, kernel_cc=None)[source]#

Bases: OperationBase

Constructs a Gemm object.

The data types and layouts of operands A, B, and C, along with the data type of output D and that used for accumulation, are bound to the Gemm object throughout its lifetime – these are not to be changed after a Gemm has been constructed.

The constructor has optional parameters for flexibly setting these parameters. The following constructors are equivalent:

# Use F32 for A, B, C, D, and accumulation. All operands are row major.

# Use the generic ``element`` and ``layout`` parameters to concisely set all data types and layouts
# for operands to the same values.
Gemm(element=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor)

# Explicitly specify the data types to use for A, B, C, and D. Use the generic ``layout``.
Gemm(element_A=cutlass.DataType.f32, element_B=cutlass.DataType.f32, element_C=cutlass.DataType.f32,
    element_D=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor)

# Set the data types and elements from existing tensors. Note that one can use different tensors when
# executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must
# have the same data type and layout as those passed in here).
# A, B, C, and D are row-major torch.Tensor objects of type torch.float32
Gemm(A=A, B=B, C=C, D=D)

# Use the generic ``element`` and explicitly specify the layouts to use for A, B, and C (layout of D is
# the same as that for D, at present)
Gemm(element=cutlass.DataType.f32, layout_A=cutlass.LayoutType.RowMajor,
    layout_B=cutlass.LayoutType.RowMajor, layout_C=cutlass.LayoutType.RowMajor)

# Explicitly specify the data type and layout for only some of A, B, C, and D. Unspecified data types
# and layouts will inherit those passed in via the generic ``element`` and ``layout``
Gemm(element_A=cutlass.DataType.f32, layout_B=cutlass.LayoutType.RowMajor,
    element=cutlass.DataType.f32, layout=cutlass.LayoutType.RowMajor)
The order of precedence for the setting of the data type and layout for a given operand/output is as follows:
  1. If the tensor type is specified (e.g., A), use the data type and layout inferred from this tensor

  2. Otherwise, if the data type/layout (e.g., element_A, layout_A) is specified, use those

  3. Otherwise, use the generic values (e.g., element, layout)

Parameters:
  • cc (int) – compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90

  • kernel_cc (int) – compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80

  • A – tensor representing data type and layout of operand A

  • B – tensor representing data type and layout of operand B

  • C – tensor representing data type and layout of operand C

  • D – tensor representing data type and layout of operand D

  • alpha – scalar paramter alpha from GEMM computation that scales the product of operands A and B

  • beta – scalar parameter beta from GEMM operation that scales operand C

  • element_accumulator (cutlass.DataType) – data type to be used in accumulation of the product of operands A and B

  • element (cutlass.DataType) – generic data type to be used for operands A, B, C, D, as well as the accumulation data type

  • layout (cutlass.LayoutType) – generic layout type to be used for operands A, B, C, and D

  • element_A (cutlass.DataType) – data type to be used for operand A

  • element_B (cutlass.DataType) – data type to be used for operand B

  • element_C (cutlass.DataType) – data type to be used for operand C

  • element_D (cutlass.DataType) – data type to be used for operand D

  • layout_A (layout of operand A) – cutlass.LayoutType

  • layout_B (layout of operand B) – cutlass.LayoutType

  • layout_C (layout of operand C) – cutlass.LayoutType

  • layout_D (layout of operand D) – cutlass.LayoutType

property activation#

Returns the type of the current activation function used

compile(tile_description=None, alignment_A=None, alignment_B=None, alignment_C=None, print_module=False)[source]#

Emits and compiles the kernel currently specified. If tile_description and any of the alignment parameters are set, the kernel will be chosen using this tile description and alignments. Otherwise, a default tile description and alignment will be used.

Parameters:
  • tile_description (cutlass.backend.TileDescription) – tile description specifying shapes and operand types to use in the kernel

  • alignment_A (int) – alignment of operand A

  • alignment_B (int) – alignment of operand B

  • alignment_C (int) – alignment of operand C

  • print_module (bool) – whether to print the emitted C++ code

Returns:

operation that was compiled

Return type:

cutlass.backend.GemmOperationUniversal

construct(tile_description=None, alignment_A=None, alignment_B=None, alignment_C=None)[source]#

Constructs a cutlass.backend.GemmUniversalOperation based on the input parameters and current kernel specification of the Gemm object.

Parameters:
  • tile_description (cutlass.backend.TileDescription) – tile description specifying shapes and operand types to use in the kernel

  • alignment_A (int) – alignment of operand A

  • alignment_B (int) – alignment of operand B

  • alignment_C (int) – alignment of operand C

Returns:

operation that was constructed

Return type:

cutlass.backend.GemmOperationUniversal

property opclass: OpcodeClass#

Returns the opcode class currently in use by the GEMM

Returns:

opcode class currently in use

Return type:

cutlass.OpcodeClass

run(A=None, B=None, C=None, D=None, alpha=None, beta=None, batch_count=1, sync=True, print_module=False)[source]#

Runs the kernel currently specified. If it has not already been, the kernel is emitted and compiled. Tensors holding operands and outputs of the kernel are sourced either from the A, B, C, D, alpha, and beta parameters provided in this call, or from those passed in on the construction of this object – one of the two must be specified.

By default, this call returns only once the kernel has completed. To launch the kernel and immediately return, set sync=False. In this case, it is the responsibility of the caller to syncrhonize the results of the kernel before attempting to access outputs by calling sync() on the arguments returned from this call.

Parameters:
  • A – tensor representing data type and layout of operand A

  • B – tensor representing data type and layout of operand B

  • C – tensor representing data type and layout of operand C

  • D – tensor representing data type and layout of operand D

  • alpha – scalar paramter alpha from GEMM computation that scales the product of operands A and B

  • beta – scalar parameter beta from GEMM operation that scales operand C

  • batch_count (int) – number of GEMMs in the batch

  • sync (bool) – whether the call should wait for the kernel to complete before returning

  • print_module (bool) – whether to print the emitted C++ code

Returns:

arguments passed in to the kernel

Return type:

cutlass.backend.GemmArguments

property swizzling_functor#

Returns the type of the swizzling functor currently being used by the GEMM

Returns:

swizzing functor type

tile_descriptions()[source]#

Returns a list of valid tile descriptions for the operations

Returns:

list of valid tile descriptions for the operations

Return type:

list

Grouped GEMM#

Ease-of-use interface for constructing, compiling, and running GEMMs.

The GroupedGemm interface is meant to allow one to easily instantiate, compile, and run grouped GEMM operations in CUTLASS via Python, without specifying many configuration parameters. Under the hood, the interface will select sensible default parameters for the many template parameters for CUTLASS grouped GEMMs.

Note: optimal performance is not to be expected from this interface. To achieve optimal performance, one should specify and tune each configuration parameter.

The simplest example of using this interface is the following:

# As, Bs, Cs, and Ds are torch/numpy/cupy tensor objects
plan = cutlass.op.GroupedGemm(element=cutlass.DataType.f16, layout=cutlass.LayoutType.RowMajor)
plan.run([A0, A1], [B0, B1], [C0, C1], [D0, D1])
class cutlass.op.gemm_grouped.GroupedGemm(A=None, B=None, C=None, D=None, alpha=1.0, beta=0.0, element_accumulator=None, element=None, layout=None, element_A=None, element_B=None, element_C=None, element_D=None, layout_A=None, layout_B=None, layout_C=None, cc=None)[source]#

Bases: Gemm

Constructs a GroupedGemm object.

The data types and layouts of operands A, B, and C, along with the data type of output D and that used for accumulation, are bound to the GroupedGemm object throughout its lifetime – these are not to be changed after a GroupedGemm has been constructed.

The constructor has optional parameters for flexibly setting these parameters. Please see the constructor for Gemm for examples of these.

Parameters:
  • cc (int) – compute capability of device to generate kernels for

  • A – tensor representing data type and layout of operands A

  • B – tensor representing data type and layout of operands B

  • C – tensor representing data type and layout of operands C

  • D – tensor representing data type and layout of operands D

  • alpha – scalar paramter alpha from GEMM computation that scales the product of operands A and B

  • beta – scalar parameter beta from GEMM operation that scales operand C

  • element_accumulator (cutlass.DataType) – data type to be used in accumulation of the product of operands A and B

  • element (cutlass.DataType) – generic data type to be used for operands A, B, C, D, as well as the accumulation data type

  • layout (cutlass.LayoutType) – generic layout type to be used for operands A, B, C, and D

  • element_A (cutlass.DataType) – data type to be used for operand A

  • element_B (cutlass.DataType) – data type to be used for operand B

  • element_C (cutlass.DataType) – data type to be used for operand C

  • element_D (cutlass.DataType) – data type to be used for operand D

  • layout_A (layout of operand A) – cutlass.LayoutType

  • layout_B (layout of operand B) – cutlass.LayoutType

  • layout_C (layout of operand C) – cutlass.LayoutType

  • layout_D (layout of operand D) – cutlass.LayoutType

construct(tile_description=None, alignment_A=None, alignment_B=None, alignment_C=None)[source]#

Constructs a cutlass.backend.GemmOperationGrouped based on the input parameters and current kernel specification of the Gemm object.

Parameters:
  • tile_description (cutlass.backend.TileDescription) – tile description specifying shapes and operand types to use in the kernel

  • alignment_A (int) – alignment of operand A

  • alignment_B (int) – alignment of operand B

  • alignment_C (int) – alignment of operand C

Returns:

operation that was constructed

Return type:

cutlass.backend.GemmOperationGrouped

run(A, B, C, D, alpha=None, beta=None, sync=True, print_module=False)[source]#

Runs the kernel currently specified.

By default, this call returns only once the kernel has completed. To launch the kernel and immediately return, set sync=False. In this case, it is the responsibility of the caller to syncrhonize the results of the kernel before attempting to access outputs by calling sync() on the arguments returned from this call.

Parameters:
  • A (list) – list of tensors representing data type and layout of operand A

  • B (list) – list of tensors representing data type and layout of operand B

  • C (list) – list of tensors representing data type and layout of operand C

  • D (list) – list of tensors representing data type and layout of operand D

  • alpha – scalar paramter alpha from GEMM computation that scales the product of operands A and B

  • beta – scalar parameter beta from GEMM operation that scales operand C

  • sync (bool) – whether the call should wait for the kernel to complete before returning

  • print_module (bool) – whether to print the emitted C++ code

Returns:

arguments passed in to the kernel

Return type:

cutlass.backend.GemmGroupedArguments

property swizzling_functor#

Returns the type of the swizzling functor currently being used by the GEMM

Returns:

swizzing functor type

Operation#

Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d)

class cutlass.op.op.OperationBase(cc=None, kernel_cc=None)[source]#

Bases: object

Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d)

Parameters:
  • cc (int) –

  • kernel_cc (int) –

activations()[source]#

Returns possible activation functions that can be used

Returns:

list of activation functions that can be used

Return type:

list

swizzling_functors()[source]#

Returns possible swizzling functions that can be used

Returns:

list of swizzling functions that can be used

Return type:

list