Stay humble. Stay hungry. Stay foolish.

Triton Notes (1)

Written in

by

This note focuses on interfaces and workflow without going into implementation details.

User Frontends

Write Python functions

Arguments

  • Function arguments are either pointers (no type annotation) or constexprs (type annotation as triton.language.constexpr).

  • Data types are embedded at compilation through providing argument types (for example, *fp32, fp32, etc.).

Returns

  • Can return tensors or None.

triton.language

tensor

To represent N-D of pointers or values with common operator overloadings.

tensor.to

Perform data typecast.

program_id

To distinguish threads through 3-dimensional program ids.

arange / zeros

To make vectors/matrices/tensors. All operators will work on vectors/matrices/tensors with multi-threading.

load / store

To load from / store to DRAM.

  • load(pointer, mask=None, ...).
  • store(pointer, value, mask=None, ...).

The pointer argument could be

  • A single element pointer to handle scalar.
  • An element-wise tensor of pointers (from arange) to handle tensors.
  • A block pointer (from make_block_ptr) to handle tensors.

dot / where / maximum / exp / sum / …

Mathmatical operations.

Compilation Workflow

  1. Use triton.jit to wrap a python function as a triton.JITFunction.
@triton.jit
def mha(...):
  pass
  1. Use triton.compile to compile a triton.JITFunction into a CUDA kernel.
# 2.1 Lowers python AST to TTIR (Triton IR) in MLIR.
ttir = triton.compiler.ast_to_ttir(...)

# 2.2 Lowers TTIR to TTGIR (Triton GPU IR) in MLIR.
ttgir = triton.compiler.ttir_to_ttgir(...)

# 2.3 Lowers TTGIR to LLIR (LLVM IR) in MLIR.
llir = triton.compiler.ttgir_to_llir(...)

# 2.4 Lowers LLIR to PTX in LLVM.
ptx = triton.compiler.llir_to_ptx(...)

# 2.5 Lowers LLVM to Cubin (Machine code) using NVRTC.
cubin = triton.compiler.ptx_to_cubin(...)
  • numWarps (default 4) and threadPerWarps (default 32) are configurable to impact performance.
    • Warp is a hardware concept between block and thread. A block is consists of warps of threads. All threads inside a warp executes the same instruction.
    • Context switching transfers controls between different warps. Programmers write code managed to do memory access coalesce and control flow divergence.
  • numStages (default 3) is configurable to optimize TTGIR overlapping/hiding memory latencies using software pipelines.

Serving Workflow

Customized serving solution

  1. Serialize cubin into text file.

  2. Embedded text files as cc data.

  3. Op constructor to load kernels from embedded files;

  4. Op functor to handle inputs/outputs and launch kernels.

Tags

Leave a comment