Shortcuts

Source code for xformers.ops.fmha.triton

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


from dataclasses import replace
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple

import torch

from ... import _is_triton_available
from ..common import register_operator

# This implementation needs pre-MLIR triton
# The BW pass is not stable/well tested
# And also does not have the latest improvements
if TYPE_CHECKING or (False and _is_triton_available()):
    try:
        from flash_attn.flash_attn_triton import (
            _flash_attn_backward,
            _flash_attn_forward,
        )
    except ImportError:
        import importlib
        import pathlib
        import sys
        import types

        def import_module_from_path(path: str) -> types.ModuleType:
            """Import a module from the given path, w/o __init__.py"""
            module_path = pathlib.Path(path).resolve()
            module_name = module_path.stem  # 'path/x.py' -> 'x'
            spec = importlib.util.spec_from_file_location(module_name, module_path)  # type: ignore
            assert isinstance(spec, importlib.machinery.ModuleSpec)
            module = importlib.util.module_from_spec(spec)  # type: ignore
            sys.modules[module_name] = module
            assert isinstance(spec.loader, importlib.abc.Loader)
            spec.loader.exec_module(module)
            return module

        flash_attn = import_module_from_path(
            "third_party/flash-attention/flash_attn/flash_attn_triton.py"
        )
        _flash_attn_backward = flash_attn._flash_attn_backward
        _flash_attn_forward = flash_attn._flash_attn_forward

    triton_flash_backward = _flash_attn_backward
    triton_flash_forward = _flash_attn_forward
else:
    triton_flash_backward = None
    triton_flash_forward = None

from .attn_bias import LowerTriangularMask
from .common import (
    AttentionBwOpBase,
    AttentionFwOpBase,
    Context,
    Gradients,
    Inputs,
    check_lastdim_alignment_stride1,
)


def _prepare_inputs(inp: Inputs) -> Inputs:
    attn_bias = inp.attn_bias
    if isinstance(attn_bias, torch.Tensor) and attn_bias.ndim == 3:
        B = inp.query.shape[0]
        h = attn_bias.shape[0] // B
        attn_bias = attn_bias.reshape(B, h, attn_bias.shape[1], attn_bias.shape[2])

    # Make sure that the last dimension is contiguous
    query, key, value = [
        x if x.stride(-1) == 1 else x.contiguous()
        for x in [inp.query, inp.key, inp.value]
    ]
    return replace(inp, attn_bias=attn_bias, query=query, key=key, value=value)


[docs]@register_operator class FwOp(AttentionFwOpBase): """Operator that computes memory-efficient attention using \ `Tri Dao's <https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py>`_ \ implementation, based on `Phil Tillet's code <https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py>`_ """ OPERATOR = triton_flash_forward SUPPORTED_DEVICES = {"cuda"} CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) SUPPORTED_DTYPES = {torch.half, torch.bfloat16} SUPPORTED_MAX_K = 128 SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = { type(None), LowerTriangularMask, # TODO: backwards accuracy is failing for a few cases, perhaps we want to disable this for now. # torch.Tensor, } SUPPORTS_DROPOUT = False SUPPORTS_CUSTOM_SCALE = True NAME = "tritonflashattF" @classmethod def not_supported_reasons(cls, d: Inputs) -> List[str]: reasons = super(FwOp, cls).not_supported_reasons(d) check_lastdim_alignment_stride1(reasons, "query", d.query, 8) check_lastdim_alignment_stride1(reasons, "key", d.key, 8) check_lastdim_alignment_stride1(reasons, "value", d.value, 8) if cls.OPERATOR is None: reasons.append("triton is not available") if d.device.type == "cuda": # Has only been tested on 8.0 / 9.0. # Fails on 7.5 with illegal memory access if torch.cuda.get_device_capability(d.device) < (8, 0): reasons.append( "requires GPU with sm80 minimum compute capacity, e.g., A100/H100/L4" ) if _is_triton_available(): import triton if triton.__version__ > "2.0.0": reasons.append("Only work on pre-MLIR triton for now") return reasons @classmethod def apply( cls, inp: Inputs, needs_gradient: bool ) -> Tuple[torch.Tensor, Optional[Context]]: inp = _prepare_inputs(inp) out, lse, softmax_scale = triton_flash_forward( q=inp.query, k=inp.key, v=inp.value, bias=inp.attn_bias if isinstance(inp.attn_bias, torch.Tensor) else None, softmax_scale=inp.scale_float, causal=isinstance(inp.attn_bias, LowerTriangularMask), ) return out, Context(lse=lse, out=out)
[docs]@register_operator class BwOp(AttentionBwOpBase): __doc__ = FwOp.__doc__ OPERATOR = triton_flash_backward SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES CUDA_MINIMUM_COMPUTE_CAPABILITY = FwOp.CUDA_MINIMUM_COMPUTE_CAPABILITY SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K SUPPORTED_ATTN_BIAS_TYPES = FwOp.SUPPORTED_ATTN_BIAS_TYPES SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED NAME = "tritonflashattB" @classmethod def not_supported_reasons(cls, d: Inputs) -> List[str]: reasons = super(BwOp, cls).not_supported_reasons(d) check_lastdim_alignment_stride1(reasons, "query", d.query, 8) check_lastdim_alignment_stride1(reasons, "key", d.key, 8) check_lastdim_alignment_stride1(reasons, "value", d.value, 8) if cls.OPERATOR is None: reasons.append("triton is not available") if d.device.type == "cuda": if torch.cuda.get_device_capability(d.device) != (8, 0): reasons.append("requires A100 GPU") if _is_triton_available(): import triton if triton.__version__ > "2.0.0": reasons.append("Only work on pre-MLIR triton for now") return reasons @classmethod def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: inp = _prepare_inputs(inp) # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. with torch.inference_mode(): grads = Gradients( dq=torch.empty_like(inp.query), dk=torch.empty_like(inp.key), dv=torch.empty_like(inp.value), ) cls.OPERATOR( grad, inp.query, inp.key, inp.value, ctx.out, ctx.get_padded_lse(128), grads.dq, grads.dk, grads.dv, bias=inp.attn_bias if isinstance(inp.attn_bias, torch.Tensor) else None, softmax_scale=inp.scale_float, causal=isinstance(inp.attn_bias, LowerTriangularMask), ) return grads