This note focuses on interfaces and workflow without going into implementation details.
User Frontends
Write Python functions
Arguments
-
Function arguments are either pointers (no type annotation) or constexprs (type annotation as
triton.language.constexpr). -
Data types are embedded at compilation through providing argument types (for example,
*fp32,fp32, etc.).
Returns
- Can return tensors or
None.
triton.language
tensor
To represent N-D of pointers or values with common operator overloadings.
tensor.to
Perform data typecast.
program_id
To distinguish threads through 3-dimensional program ids.
arange / zeros
To make vectors/matrices/tensors. All operators will work on vectors/matrices/tensors with multi-threading.
load / store
To load from / store to DRAM.
load(pointer, mask=None, ...).store(pointer, value, mask=None, ...).
The pointer argument could be
- A single element pointer to handle scalar.
- An element-wise tensor of pointers (from
arange) to handle tensors. - A block pointer (from
make_block_ptr) to handle tensors.
dot / where / maximum / exp / sum / …
Mathmatical operations.
Compilation Workflow
- Use
triton.jitto wrap a python function as atriton.JITFunction.
@triton.jit
def mha(...):
pass
- Use
triton.compileto compile atriton.JITFunctioninto a CUDA kernel.
# 2.1 Lowers python AST to TTIR (Triton IR) in MLIR.
ttir = triton.compiler.ast_to_ttir(...)
# 2.2 Lowers TTIR to TTGIR (Triton GPU IR) in MLIR.
ttgir = triton.compiler.ttir_to_ttgir(...)
# 2.3 Lowers TTGIR to LLIR (LLVM IR) in MLIR.
llir = triton.compiler.ttgir_to_llir(...)
# 2.4 Lowers LLIR to PTX in LLVM.
ptx = triton.compiler.llir_to_ptx(...)
# 2.5 Lowers LLVM to Cubin (Machine code) using NVRTC.
cubin = triton.compiler.ptx_to_cubin(...)
numWarps(default 4) andthreadPerWarps(default 32) are configurable to impact performance.- Warp is a hardware concept between
blockandthread. A block is consists of warps of threads. All threads inside a warp executes the same instruction. - Context switching transfers controls between different warps. Programmers write code managed to do memory access coalesce and control flow divergence.
- Warp is a hardware concept between
numStages(default 3) is configurable to optimize TTGIR overlapping/hiding memory latencies using software pipelines.
Serving Workflow
Customized serving solution
-
Serialize cubin into text file.
-
Embedded text files as cc data.
-
Op constructor to load kernels from embedded files;
-
Op functor to handle inputs/outputs and launch kernels.
Leave a comment