OpenAI/Triton MLIR 第四章: ROCm-triton配置
本文首发于GiantPandaCV,未经作者允许不得转载
最近在整理python-based的benchmark代码,反过来在NV的GPU上又把Triton装了一遍,发现Triton的github repo已经给出了对应的llvm的commit id以及对应的编译细节,然后跟着走了一遍,也顺利的安装成功,只需要按照如下方式即可完成NV GPU上的安装,
代码语言:javascript复制1. git clone https://github.com/openai/triton.git;
2. cd triton;
3. cd $HOME/llvm-project # your clone of LLVM.
4. git checkout 49af6502
5. mkdir build
6. cd build
7. cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=ON ../llvm -DLLVM_ENABLE_PROJECTS="mlir;llvm"
8. ninja -j8
export LLVM_BUILD_DIR=$HOME/llvm-project/build
cd <triton install>
LLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include
LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib
LLVM_SYSPATH=$LLVM_BUILD_DIR
pip install -e python
添加图片注释,不超过 140 字(可选)
出现3.0.0说明triton已经安装成功了,装完triton后一定要安装Torch,为个人使用的是CUDA 12.1版本,按照下面的命令无脑安装即可。
代码语言:javascript复制pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu121
NV GPU上triton的安装和使用其实已经轻车熟路了,接下来,让我们来探索一下AMD GPU上如何安装和配置triton。
0x00 软件安装
关于triton amd的backend,虽然triton的官方将其作为third-party来进行支持,但是我还是推荐大家使用AMD专门维护的一套triton版本,因为在最开始的官方triton的main分支下,开启 TRITON_CODEGEN_AMD_HIP_BACKEND=1 没有正确完成编译。所以找到了
按照对应的安装流程进行安装即可,我推荐使用如下命令进行安装,亲测有效
代码语言:javascript复制1. git clone https://github.com/ROCmSoftwarePlatform/triton.git
2. cd triton
3. git checkout triton-mlir
这里已经准备好了需要编译的triton,但是triton后端是基于LLVM的,所以要想借助triton去生成可以跑在对应设备上的代码,我们还需要对LLVM进行编译,本教程中将会手动编译LLVM,当然如果你选择直接编译好的LLVM也是没有问题的。关于LLVM,由于triton是基于b1115f8c这个commit id进行开发的,那么我们只需要将LLVM clone下来后,checkout到对应的commit id,然后按照如下完整命令进行编译即可。
代码语言:javascript复制1. git clone https://github.com/llvm/llvm-project
2. git checkout b1115f8c
3. cd llvm-project
4. mkdir build
5. cd build
6. cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=ON ../llvm -DLLVM_ENABLE_PROJECTS="mlir;llvm"
7. ninja -j8
等LLVM全部装好后,就可以去将当前这个LLVM的路径写入到你的bashrc下
代码语言:javascript复制export PATH=/home/llvm-project/build/bin:$PATH
然后进入到一开始clone下来的triton目录下进行如下命令
代码语言:javascript复制1. cd triton
2. vim CMakeLists.txt (option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" ON))
3. mkdir build
4. cd build
5. cmake ..
6. make -j8
在编译完全正确后,就会在当前的 build 目录下产生一个 libtriton.so 文件。那么接下来只要将
libtriton.so 文件移动到 triton/python/triton/_C 目录下,将 triton 的 python 路径下入 bashrc
代码语言:javascript复制export TRITON_HOME=/home/Documents/compiler/triton
export PYTHONPATH=$TRITON_HOME/python:${PYTHONPATH}
如果在编译的过程中出现 goolge test 找不到的情况,按照如下命令进行安装:
代码语言:javascript复制1. git clone https://github.com/google/googletest
2. cd googletest
3. cmake CMakeLists.txt
4. make -j8
5. cp ./lib/libgtest*.a /usr/lib
6. cd googletest
7. cp –a include/gtest /usr/include
如果在编译的过程中出现 pybind11 找不到的情况,按照如下命令进行按照:
代码语言:javascript复制1. pip install pytest
2. git clone https://github.com/pybind/pybind11.git
3. cd pybind11
4. mkdir build
5. cd build
6. cmake ..
7. make check -j 8
8. sudo make instal
关于 在AMD GPU上的pytorch 一定要去安装适配 ROCM 版本的 pytorch,由于我的机器使用的是5.6版本的ROCm,所以我的安装的命令如下,仅供参考:
代码语言:javascript复制pip3 install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url
https://download.pytorch.org/whl/rocm5.6
关于 ROCM 版本可以通过如下命令进行查询:
代码语言:javascript复制dpkg -l | grep rocm
这里要记住,pytorch在AMD GPU上的使用和在NV GPU上的使用非常相似,也是用.cuda()来指定变量所在位置。
0x01 GEMM代码示例
全部编译好后,就可以通过执行下面的代码得到对应的 GEMM 在 AMD 显卡上针对 Triton和 rocBLAS 的 benchmark 了。
代码语言:javascript复制import torch
import triton
import triton.language as tl
import sys
import argparse
import pytest
# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
# - A list of `triton.Config` objects that define different configurations of
# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try
# - An auto-tuning *key* whose change in values will trigger evaluation of all the
# provided configs
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
num_warps=2),
] if torch.version.hip is None else [
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
num_warps=4, num_stages=0),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2},
num_warps=8, num_stages=0),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
num_warps=8, num_stages=0),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3},
num_warps=4, num_stages=0),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8},
num_warps=4, num_stages=0),
],
key=['M', 'N', 'K'],
)
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % args['BLOCK_SIZE_K'] == 0,
})
@triton.jit
def matmul_kernel(
# Pointers to matrices
a_ptr, b_ptr, c_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
EVEN_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
ACTIVATION: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
if GROUP_SIZE_M == 1:
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
else:
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetics` section for details
offs_k = tl.arange(0, BLOCK_SIZE_K)
offs_am = (pid_m * BLOCK_SIZE_M tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N tl.arange(0, BLOCK_SIZE_N)) % N
a_ptrs = a_ptr (offs_am[:, None] * stride_am offs_k[None, :] * stride_ak)
b_ptrs = b_ptr (offs_k[:, None] * stride_bk offs_bn[None, :] * stride_bn)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
if EVEN_K:
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
else:
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
accumulator = tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs = BLOCK_SIZE_K * stride_ak
b_ptrs = BLOCK_SIZE_K * stride_bk
# You can fuse arbitrary activation functions here
# while the accumulator is still in FP32!
if ACTIVATION == "leaky_relu":
accumulator = leaky_relu(accumulator)
c = accumulator.to(tl.float16)
# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr stride_cm * offs_cm[:, None] stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`.
@triton.jit
def leaky_relu(x):
x = x 1
return tl.where(x >= 0, x, 0.01 * x)
# %%
# We can now create a convenience wrapper function that only takes two input tensors,
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel.
def matmul(a, b, activation=""):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
assert b.is_contiguous(), "Matrix B must be contiguous"
M, K = a.shape
K, N = b.shape
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
matmul_kernel[grid](
a, b, c, #
M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
ACTIVATION=activation #
)
return c
# %%
# Unit Test
# ---------
#
# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS).
@pytest.mark.parametrize("M, N, K, in_dtype, out_dtype",
[ (*shape, in_dtype, out_dtype)
for shape in [(128, 256, 32), (128, 16, 32), (32, 128, 64),
(128, 128, 64), (64, 128, 128), (32, 128, 64),
(64, 64, 32), (32, 32, 128), (128, 128, 64),
(64, 128, 128), (512, 512, 512), (1024, 1024, 1024)]
for in_dtype, out_dtype in [('int8', 'int8'),
('float16', 'float16'),
('bfloat16', 'bfloat16'),
('float16', 'float32'),
('float32', 'float32')]]
)
def test_correctness(M, N, K, in_dtype, out_dtype):
torch.manual_seed(0)
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
triton_output = matmul(a, b)
torch_output = torch.matmul(a, b)
print(f"triton_output={triton_output}")
print(f"torch_output={torch_output}")
rtol = 0 if torch.version.hip is None else 1e-2
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol):
print("✅ Triton and Torch match")
else:
print("❌ Triton and Torch differ")
assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol)
# %%
# Benchmark
# ---------
#
# Square Matrix Performance
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# We can now compare the performance of our kernel against that of cuBLAS. Here we focus on square matrices,
# but feel free to arrange this script as you wish to benchmark any other matrix shape.
global verbose
verbose = False
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['M', 'N', 'K'], # Argument names to use as an x-axis for the plot
x_vals=[
(1024, 1024, 1024),
(2048, 2048, 2048),
(4096, 4096, 4096),
(8192, 8192, 8192),
(9728, 8192, 65536)
], # Different possible values for `x_name`
line_arg='provider', # Argument name whose value corresponds to a different line in the plot
# Possible values for `line_arg`
line_vals=['rocblas', 'triton'],
# Label name for the lines
line_names=["rocBLAS", "Triton"],
# Line styles
styles=[('green', '-'), ('blue', '-')],
ylabel="TFLOPS", # Label name for the y-axis
plot_name="matmul-performance", # Name for the plot, used also as a file name for saving the plot.
args={},
))
def benchmark(M, N, K, provider):
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
quantiles = [0.5, 0.2, 0.8]
if provider == 'rocblas':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles)
global verbose
if verbose:
print(f'SIZE: {M},{N},{K} Best tuning config: ({matmul_kernel.get_best_config()})')
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)
def parse_args():
parser = argparse.ArgumentParser(
prog="GEMM tutorial example",
allow_abbrev=False,
)
parser.add_argument("-v", action='store_true', default=False, help="Print out the best tuning config")
args = parser.parse_args()
return args
def main():
# assign to a global verbose var to indicate whether print
# best tuning config
global verbose
args = parse_args()
verbose=args.v
benchmark.run(show_plots=True, print_data=True)
if __name__ == '__main__':
sys.exit(main())
0x10 GEMM代码详细解读
首先是对于搜索空间的定义,这里
代码语言:javascript复制@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
num_warps=2),
] if torch.version.hip is None else [
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
num_warps=4, num_stages=0),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2},
num_warps=8, num_stages=0),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
num_warps=8, num_stages=0),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3},
num_warps=4, num_stages=0),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8},
num_warps=4, num_stages=0),
],
key=['M', 'N', 'K'],
)
其中的torch.version.hip走的就是AMD GPU所对应的搜索空间,我们看到其对应的可以tuning的knob,有最常规的BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M外,还有了一个新的wave_per_eu,我一开始看到这个概念的时候也很陌生,随后和AMD的技术人员请教了下,总结下来就是:
AMD GPU由计算单元(CU)组成,这相当于NVIDIA GPU上的流处理器(SM)。在每个CU中,有4个SIMD单元(也称执行引擎或EU)。你可以把SIMD单元看成是一个矢量执行单元,它具有执行计算所需的一定数量的寄存器和ALUs。当你发起一个计算网格时,工作组(相当于NVIDIA GPU上的线程块)会安排在CU上运行。在CU中,波前(相当于NVIDIA GPU上的波纹)会安排在SIMD单元上运行。这里提出了occupancy的概念,它表示每个SIMD单元上可同时运行的波前数。这取决于每个波前需要的资源量和每个SIMD单元的资源量。waves_per_eu参数重点关注寄存器使用情况。例如,每个SIMD(EU)有512个寄存器。如果每个波前需要256个寄存器,那么occupancy为2。但如果我们设置waves_per_eu=3,编译器会试图将每个波前的寄存器使用量减少到170,这样occupancy就可以是3了。但是提高waves_per_eu存在寄存器溢出的风险和性能下降。所以增加waves_per_eu可能会增加occupancy,但不一定能提高性能。
然后是具体的kernel定义,这部分的定义其实和NV GPU上的写法没有本质区别
代码语言:javascript复制@triton.jit
def matmul_kernel(
# Pointers to matrices
a_ptr, b_ptr, c_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
EVEN_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
ACTIVATION: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
if GROUP_SIZE_M == 1:
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
else:
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetics` section for details
offs_k = tl.arange(0, BLOCK_SIZE_K)
offs_am = (pid_m * BLOCK_SIZE_M tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N tl.arange(0, BLOCK_SIZE_N)) % N
a_ptrs = a_ptr (offs_am[:, None] * stride_am offs_k[None, :] * stride_ak)
b_ptrs = b_ptr (offs_k[:, None] * stride_bk offs_bn[None, :] * stride_bn)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
if EVEN_K:
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
else:
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
accumulator = tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs = BLOCK_SIZE_K * stride_ak
b_ptrs = BLOCK_SIZE_K * stride_bk
# You can fuse arbitrary activation functions here
# while the accumulator is still in FP32!
if ACTIVATION == "leaky_relu":
accumulator = leaky_relu(accumulator)
c = accumulator.to(tl.float16)
# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr stride_cm * offs_cm[:, None] stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
接下来是单元测试,用来说明triton的输出结果和torch的输出结果必须是相同的
代码语言:javascript复制def test_correctness(M, N, K, in_dtype, out_dtype):
torch.manual_seed(0)
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
triton_output = matmul(a, b)
torch_output = torch.matmul(a, b)
print(f"triton_output={triton_output}")
print(f"torch_output={torch_output}")
rtol = 0 if torch.version.hip is None else 1e-2
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol):
print("✅ Triton and Torch match")
else:
print("❌ Triton and Torch differ")
assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol)
接下来你只需要指定好对应的GEMM的尺寸,我们的默认输入顺序还是以M,N,K为主,剩下都是中规中局的操作了。
代码语言:javascript复制@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['M', 'N', 'K'], # Argument names to use as an x-axis for the plot
x_vals=[
(1024, 1024, 1024),
(2048, 2048, 2048),
(4096, 4096, 4096),
(8192, 8192, 8192),
(9728, 8192, 65536)
], # Different possible values for `x_name`
line_arg='provider', # Argument name whose value corresponds to a different line in the plot
# Possible values for `line_arg`
line_vals=['rocblas', 'triton'],
# Label name for the lines
line_names=["rocBLAS", "Triton"],
# Line styles
styles=[('green', '-'), ('blue', '-')],
ylabel="TFLOPS", # Label name for the y-axis
plot_name="matmul-performance", # Name for the plot, used also as a file name for saving the plot.
args={},
))
def benchmark(M, N, K, provider):
a = torch.randn((M, K), device='cuda', dtype=torch.float16)
b = torch.randn((K, N), device='cuda', dtype=torch.float16)
quantiles = [0.5, 0.2, 0.8]
if provider == 'rocblas':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles)
global verbose
if verbose:
print(f'SIZE: {M},{N},{K} Best tuning config: ({matmul_kernel.get_best_config()})')
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)
def parse_args():
parser = argparse.ArgumentParser(
prog="GEMM tutorial example",
allow_abbrev=False,
)
parser.add_argument("-v", action='store_true', default=False, help="Print out the best tuning config")
args = parser.parse_args()
return args
def main():
# assign to a global verbose var to indicate whether print
# best tuning config
global verbose
args = parse_args()
verbose=args.v
benchmark.run(show_plots=True, print_data=True)
if __name__ == '__main__':
sys.exit(main())
关于在AMD GPU上更加自动化的GEMM benchmark调优脚本,我们将在后面的章节中来为大家进行解读。