Reading Note: Triton
date
Feb 2, 2025
slug
triton
status
Published
tags
MLSys
summary
type
Post
Overview

- 传统深度学习库(如 cuBLAS、cuDNN)在为常用算子提供高度优化的实现时,依赖手写微内核,而新型或“非标准”算子(例如某些新型卷积变体、稀疏算子)则往往缺乏高效实现。为了解决这一问题,需要一个既能表达灵活性又能达到高性能的编程抽象和编译系统。
- 基于 Tile 的抽象:Triton 的基本设计思想就是围绕“tile”(即具有静态形状的多维子数组)这一概念构建:
- 表达层面:通过 Triton-C 语言以类似 C(或 CUDA-C)的语法表达张量程序,但程序中的基本操作单位是“tile”,而非传统意义上的标量或单纯的数组。
- 中间表示层:基于 LLVM 构建 Triton-IR,该 IR 扩展了标准 LLVM-IR,以支持 tile 级别的数据流和控制流分析,从而能够捕捉和优化 tile 内部的运算和数据传输。
- 生成与优化层:Triton-JIT 利用一系列 tile 级别的优化传递(例如分层 tiling、内存合并、共享内存分配与同步等),将 Triton-IR 编译为高效的 GPU 代码。
- Triton 的目标:通过引入 tile 这一抽象和相应的编译优化,Triton 希望让用户(或上层 DSL)能够以简单、直观的方式描述深度学习算子,同时自动生成与手写内核性能相当甚至更优的 GPU 实现,并保持良好的可移植性。
Core Components
Triton-C 前端语言
- 语法特点
- Triton-C 扩展了 ANSI C(特别是 CUDA-C)的语法,增加了用于声明多维 tile 的语法(如
int tile[16, 16]
),以及对参数化 tile 的支持(通过 tunable 关键字)。 - 内置了一些专门的函数和操作符,如
dot
(矩阵乘法)、trans
(转置)、get_global_range
等,用于生成 tile 级操作。
- 编程模型
- 采用一种类似于 SPMD(Single Program Multiple Data)的编程模型,但每个 kernel 本身是单线程的,由编译器自动并行化。
- 内置了 Numpy 风格的广播语义,允许对 tile 进行自动扩展,使得数组间运算符合形状匹配规则。
Triton-IR
- 中间表示的设计
- 基于 LLVM 的 IR,但扩展了数据类型,支持类似
i32<8, 8>
这样的 tile 类型。 - 提供了专门的指令用于 tile 操作,如 reshape、broadcast、以及专用的矩阵乘法(dot)和转置(trans)指令。
- 数据流与控制流分析
- 为了处理 tile 级操作,Triton-IR 对控制流做了扩展,引入了基于 Predicated SSA(PSSA)形式的机制,能够在 tile 内对不同分支进行 predication,从而实现安全的 tile 级条件计算。
Triton-JIT 与优化传递
- 编译流程
- 将 Triton-C 编译为 Triton-IR,然后经过一系列优化传递后生成高效的 LLVM bitcode,最终生成 GPU 可执行代码。
- 机器无关优化
- 预取(Pre-Fetching):自动检测循环中的 tile 级内存访问,插入预取指令以隐藏内存延迟。
- Peephole 优化:针对 tile 操作链(例如连续转置操作)的局部优化,利用代数等价关系简化计算。
- 机器相关优化
- 分层 Tiling(Hierarchical Tiling):在硬件层面,利用多级 tiling(tile → micro-tile → nano-tile)适应 GPU 的计算单元、寄存器、共享内存以及 DRAM 的访问粒度,实现最佳的利用率。
- 内存合并(Memory Coalescing):通过对线程内排列进行调整,保证多个线程访问连续内存位置,从而减少内存事务数。
- 共享内存分配与同步:自动判断哪些 tile 内的数据应存入共享内存,并插入必要的同步原语(barrier)以保证数据一致性。
- 自动调优
- Triton-JIT 自身能够提取各个优化传递的元参数(如不同 tiling 参数),并通过穷举搜索或更高级的自动调优机制,在给定的参数空间中选择最优配置。论文中以 tiling 参数(tile 大小、micro-tile 大小、nano-tile 大小)的组合为例进行了自动调优。
Summary
Triton 作为一种抽象层级介于 cuda 和高级深度学习编译器之间的中间语言,在运行效率和开发效率之间取得了一个平衡。