返回博客列表

vLLM 与 NVIDIA 加速库:从 CUTLASS 到全栈 Kernel 选择的深度剖析

·AI

vLLM 与 NVIDIA 加速库:从 CUTLASS 到全栈 Kernel 选择的深度剖析

前两篇 vLLM 文章分别剖析了 PagedAttention 和 fused_moe kernel。但如果你深入 vLLM 的 csrc/ 目录,会发现一个更庞大的世界:数十万行 C++/CUDA 代码,调用了 CUTLASS、cuBLAS、FlashInfer、Triton、cuSPARSELt 等多个 NVIDIA 加速库,针对不同的 GPU 架构(Turing → Ampere → Ada → Hopper → Blackwell)分别实现了不同的 kernel。

这篇文章的目标是:把 vLLM 的 kernel 全景图讲清楚。从计算瓶颈出发,搞明白每个加速库在 vLLM 中的角色、调用场景、选择逻辑,然后深入 CUTLASS 的源码级实现。


一、LLM 推理的计算瓶颈

1.1 两个阶段,两种瓶颈

LLM 推理分为 Prefill 和 Decode 两个阶段,它们的计算特征截然不同:

Prefill 阶段(处理完整 prompt):
  矩阵形状:[M=prompt_len, K=hidden_dim] × [K, N=hidden_dim]
  M 可能很大(数百到数千),是典型的 compute-bound GEMM
  瓶颈:GPU Tensor Core 算力

Decode 阶段(逐 token 生成):
  矩阵形状:[M=1~batch_size, K=hidden_dim] × [K, N=hidden_dim]
  M 很小(通常 1-64),GEMM 退化为矩阵-向量乘
  瓶颈:HBM 内存带宽(权重加载主导)

一个 70B 参数的 BF16 模型,权重占 140GB。Decode 阶段每生成一个 token 需要读取全部权重——在 H100(3.35 TB/s HBM 带宽)上,仅权重加载就需要 ~42ms。这就是为什么量化(FP8/INT8/INT4)如此重要:把权重压缩到一半或四分之一,直接提升 2-4 倍的带宽利用率。

1.2 三类核心计算

LLM 推理中需要加速的计算可以归为三类:

计算类型 典型操作 特征 主要库
密集 GEMM Linear 层(QKV/FFN) 大矩阵乘,高算力需求 cuBLAS, CUTLASS
量化 GEMM FP8/INT8/INT4 矩阵乘 需要 fused dequantize + scale CUTLASS, Triton
注意力计算 Self-Attention IO-bound,需要特殊优化 FlashAttention, FlashInfer

关键洞察:不同的计算类型需要不同的加速库,没有一个库能通吃所有场景。 这就是 vLLM 需要集成多个加速库的根本原因。


二、加速库全景

2.1 vLLM 的 Kernel 技术栈

vLLM Kernel 全景图:

┌─────────────────────────────────────────────────────────┐
│                    vLLM Python Layer                      │
│  model_executor/layers/                                  │
│  ├── linear.py          → GEMM dispatch                  │
│  ├── quantization/      → 量化后端选择                    │
│  ├── attention/         → 注意力后端选择                  │
│  └── fused_moe/         → MoE kernel 选择                │
└─────────────┬───────────────────────────────────────────┘
              │ torch.ops / custom ops
              ▼
┌─────────────────────────────────────────────────────────┐
│                  vLLM C++/CUDA Layer                      │
│  csrc/                                                   │
│  ├── quantization/                                       │
│  │   ├── w8a8/cutlass/     → CUTLASS FP8/INT8 GEMM      │
│  │   ├── cutlass_w4a8/     → CUTLASS W4A8 混合精度       │
│  │   ├── gptq_marlin/      → Marlin W4A16 kernel         │
│  │   └── fp8/              → FP8 工具函数                 │
│  ├── sparse/cutlass/       → CUTLASS 2:4 稀疏 GEMM       │
│  ├── attention/            → 注意力 kernel                │
│  │   ├── mla/              → CUTLASS MLA (Blackwell)     │
│  │   └── paged/            → PagedAttention kernel       │
│  └── moe/                  → MoE CUTLASS/MXFP8 kernel    │
└─────────────┬───────────────────────────────────────────┘
              │
              ▼
┌─────────────────────────────────────────────────────────┐
│              底层加速库                                    │
│                                                          │
│  ┌──────────┐ ┌──────────┐ ┌───────────┐ ┌───────────┐ │
│  │ CUTLASS  │ │  cuBLAS  │ │ FlashInfer│ │  Triton   │ │
│  │ (开源    │ │ (闭源    │ │ (JIT 编译 │ │ (Python   │ │
│  │  模板库) │ │  运行时) │ │  注意力)  │ │  DSL)     │ │
│  └──────────┘ └──────────┘ └───────────┘ └───────────┘ │
│                                                          │
│  ┌──────────┐ ┌──────────┐ ┌───────────┐               │
│  │cuSPARSELt│ │ TRT-LLM  │ │  Marlin   │               │
│  │(稀疏运算)│ │ (NVIDIA  │ │ (手写PTX  │               │
│  │          │ │  kernel) │ │  W4 kernel)│               │
│  └──────────┘ └──────────┘ └───────────┘               │
└─────────────────────────────────────────────────────────┘

2.2 各库的角色定位

CUTLASS(CUDA Templates for Linear Algebra Subroutines)

NVIDIA 开源的 C++ 模板库,vLLM 中最核心的 kernel 后端。与 cuBLAS 的关键区别是:CUTLASS 是编译期特化——你在 C++ 模板参数中指定 tile 大小、数据类型、调度策略,编译器生成专用 kernel。这使得 epilogue fusion(在 GEMM 输出阶段融合 scale、bias、activation)成为可能。

在 vLLM 中负责:

  • FP8 W8A8 量化 GEMM(Hopper/Blackwell)
  • INT8 W8A8 量化 GEMM(Ampere/Hopper)
  • W4A8 混合精度 GEMM
  • Block-wise FP8 量化(DeepSeek V3 方案)
  • 2:4 结构化稀疏 GEMM
  • MoE Grouped GEMM(FP8/MXFP8)
  • MLA 注意力(Blackwell)

cuBLAS

NVIDIA 闭源的 BLAS 运行时库,通过 torch.mm / torch.matmul 间接调用。不需要编译、不需要配置,但无法做 epilogue fusion——GEMM 结果必须写回 HBM 后再做后处理。

在 vLLM 中负责:

  • 未量化的 BF16/FP16 密集 GEMM(这是默认路径)
  • 当 CUTLASS 不适用时的 fallback

FlashInfer

MLSys 2025 最佳论文,模块化的注意力引擎。通过 JIT 编译和 Jinja 模板生成特化 kernel,支持多种后端(FlashAttention、cuDNN、CUTLASS、TRT-LLM)。

在 vLLM 中负责:

  • 主要的注意力计算后端
  • Paged KV Cache 的高效访问
  • Blackwell 上的 FP8 scaled_mm(通过 CUTLASS)
  • BF16 MoE kernel(通过 TRT-LLM kernel)
  • CUTLASS MoE 的封装(FP8/MXFP4/NVFP4)

Triton

OpenAI 开发的 Python DSL,编译为 GPU kernel。写起来像 NumPy,跑起来接近 CUDA。

在 vLLM 中负责:

  • fused_moe kernel(MoE 模型的核心路径)
  • 当 CUTLASS 不适用时的 MoE fallback(如 Blackwell 上 batch_size ≤ 8)
  • 自定义 fused kernel(如 silu_and_mulrotary_embedding
  • Flash Attention 的部分变体

Marlin

独立的手写 CUDA kernel(使用内联 PTX 汇编),专门针对 W4A16(4-bit 权重、16-bit 激活)场景优化。

在 vLLM 中负责:

  • GPTQ 量化模型的推理
  • AWQ 量化模型的推理
  • W4A16 场景下通常比 CUTLASS 更快(因为是手工针对这个场景极致优化的)

三、CUTLASS 核心概念速览

在深入 vLLM 的 CUTLASS 代码之前,先理解 CUTLASS 的关键抽象。

3.1 分层 Tile 分解

CUTLASS 将 GEMM 分解为三层 tile 结构,映射到 GPU 的线程层级:

GEMM: C[M, N] = A[M, K] × B[K, N]

            ┌─────────── N ───────────┐
            │  ┌───┐ ┌───┐ ┌───┐     │
        M   │  │CTA│ │CTA│ │CTA│ ... │  ← Thread Block Tile(CTA 级别)
            │  │Tile│ │Tile│ │Tile│   │    128×128, 从 GMEM 加载到 SMEM
            │  └───┘ └───┘ └───┘     │
            │  ┌───┐ ┌───┐ ┌───┐     │
            │  │   │ │   │ │   │     │
            └─────────────────────────┘

CTA Tile 内部:
┌──────────── 128 ────────────┐
│  ┌────┐ ┌────┐              │
│  │Warp│ │Warp│  ...         │  ← Warp Tile(Warp 级别)
│  │ 0  │ │ 1  │              │    64×64, 从 SMEM 加载到 Register
│  └────┘ └────┘              │
│  ┌────┐ ┌────┐              │
│  │Warp│ │Warp│  ...         │
│  │ 2  │ │ 3  │              │
│  └────┘ └────┘              │
└─────────────────────────────┘

Warp Tile 内部:
  每个线程持有小块 register tile
  通过 Tensor Core (MMA/WGMMA) 执行矩阵乘加

3.2 三代 Tensor Core 指令

指令 架构 执行单元 特点
WMMA Volta/Turing 32 threads (1 warp) 16×16×16, 高层 API
MMA (mma.sync) Ampere (SM80) 32 threads (1 warp) 混合精度, 同步执行
WGMMA Hopper (SM90) 128 threads (4 warps) 异步执行, 操作数 B 可直接从 SMEM 读取, 最小 64×64 output

Hopper 的 WGMMA 是一个质变:操作数 B 不需要加载到 register,直接用 64-bit 描述符从 SMEM 访问。这大幅降低了 register 压力,让更多 register 留给累加器。

3.3 TMA(Tensor Memory Accelerator)

Hopper 引入的硬件单元,专门负责 GMEM ↔ SMEM 的数据搬运:

Ampere (SM80) 的数据搬运:
  32 个线程各自计算地址 → 各自发起 cp.async → 各自等待完成
  线程数据地址计算开销大,占用大量指令发射槽

Hopper (SM90) 的 TMA:
  1 个线程创建 copy descriptor(描述 tensor 形状、stride、layout)
  → 硬件自动完成地址生成 + 数据搬运 + layout 转换
  → 其余 127 个线程继续做计算

  本质:把数据搬运从「软件驱动」变成「硬件驱动」

TMA 使得 Warp Specialization 成为可能:专门的 producer warp 负责数据搬运,consumer warp 专注计算,两者异步流水线化。

3.4 Epilogue Fusion

这是 CUTLASS 相比 cuBLAS 最大的优势:

cuBLAS 的执行流程:
  GEMM kernel → 写 C 到 HBM → 读 C → Scale kernel → 写到 HBM → 读 → Bias kernel → ...
  每一步都有 HBM 往返,带宽浪费巨大

CUTLASS Epilogue Fusion:
  GEMM kernel → (累加器在 register 中) → Scale → Bias → Activation → Cast → 写到 HBM
  只有一次 HBM 写入,中间结果全在 register 中

vLLM 在 CUTLASS 中使用了 Epilogue Visitor Tree (EVT),将后处理操作组织成 DAG:

// vLLM 的 ScaledEpilogueBias:D = a_scale * b_scale * (A_q × B_q) + bias
//
// EVT 节点链:
//   AccFetch → Multiply(b_scale) → Multiply(a_scale) → Add(bias) → Store
//
// 所有操作在 GEMM 的最后一步、结果还在 register 中时完成

这对量化推理至关重要:FP8 GEMM 的反量化(乘 scale)和 bias 加法在一个 kernel 内完成,不需要额外的 HBM 往返。


四、CUTLASS 在 vLLM 中的源码剖析

4.1 架构适配:一套代码,多代 GPU

vLLM 的 CUTLASS kernel 覆盖从 Turing (SM75) 到 Blackwell (SM120) 的六代 GPU:

// csrc/cutlass_extensions/common.hpp
// 架构守卫模板:确保 kernel 只在正确的 GPU 上编译和运行

template <typename T>
using enable_sm75_to_sm80 = ...;  // Turing (RTX 2000)

template <typename T>
using enable_sm80_to_sm89 = ...;  // Ampere (A100)

template <typename T>
using enable_sm89_to_sm90 = ...;  // Ada Lovelace (RTX 4090, L40S)

template <typename T>
using enable_sm90_or_later = ...; // Hopper (H100) 及以上

template <typename T>
using enable_sm100f_only = ...;   // Blackwell 数据中心 (B100/B200)

template <typename T>
using enable_sm120_only = ...;    // RTX Blackwell (RTX 5090)

运行时通过 SM 版本号分发到对应的 kernel:

// csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu
void cutlass_scaled_mm(torch::Tensor& c, ...) {
    int32_t version_num = get_sm_version_num();

    if (version_num >= 120)     cutlass_scaled_mm_sm120(...);
    else if (version_num >= 100) cutlass_scaled_mm_sm100(...);
    else if (version_num >= 90)  cutlass_scaled_mm_sm90(...);
    else if (version_num == 89)  cutlass_scaled_mm_sm89(...);
    else if (version_num >= 80)  cutlass_scaled_mm_sm80(...);
    else if (version_num >= 75)  cutlass_scaled_mm_sm75(...);
}

4.2 FP8 W8A8 GEMM(Hopper SM90)

这是 vLLM 中最常用的 CUTLASS kernel 路径。核心在于根据矩阵形状选择不同的 tile 配置:

// csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8_dispatch.cuh
// 7 种 FP8 kernel 配置,按 M 大小分发

// M > 128(大 batch, Prefill 阶段)
TileShape = Shape<_128, _128, _128>;
ClusterShape = Shape<_2, _1, _1>;
KernelSchedule = KernelTmaWarpSpecializedPingpongFP8FastAccum;

// M >= 8192 && K >= 6144(超大 batch, 协作调度)
TileShape = Shape<_256, _128, _128>;
ClusterShape = Shape<_2, _1, _1>;
KernelSchedule = KernelTmaWarpSpecializedCooperativeFP8FastAccum;

// M ∈ (64, 128]
TileShape = Shape<_64, _128, _128>;
ClusterShape = Shape<_2, _1, _1>;
KernelSchedule = KernelTmaWarpSpecializedPingpongFP8FastAccum;

// M ∈ (16, 64], N ≤ 1280(小 batch, swap_ab=true)
TileShape = Shape<_64, _16, _256>;
ClusterShape = Shape<_1, _4, _1>;
KernelSchedule = KernelTmaWarpSpecializedFP8FastAccum;

// M ∈ (16, 64], N > 1280(小 batch, swap_ab=true)
TileShape = Shape<_64, _64, _256>;
ClusterShape = Shape<_1, _1, _1>;

// M ∈ [1, 16], N ≤ 1280(极小 batch / Decode, swap_ab=true)
TileShape = Shape<_64, _16, _256>;
ClusterShape = Shape<_1, _2, _1>;

// M ∈ [1, 16], N > 1280
TileShape = Shape<_64, _16, _256>;
ClusterShape = Shape<_1, _1, _1>;

几个关键设计决策:

① swap_ab 优化。 当 M 很小时(Decode 阶段),GEMM 矩阵 A 的行数远小于列数。直接计算会导致大量 padding 浪费。swap_ab 转置 A 和 B:C = A × B → C^T = B^T × A^T,让短的维度变成 N(输出维度),减少 thread block 中的空闲线程。

② Pingpong vs Cooperative。 Pingpong 调度让两个 consumer warp group 交替执行 GEMM 和 epilogue,最大化 Tensor Core 利用率。Cooperative 调度让所有 warp group 协作完成同一个 tile,适合超大 batch(M >= 8192)。

③ ClusterShape。 Hopper 引入了 Thread Block Cluster——多个 CTA 可以通过分布式共享内存协作。Shape<_2, _1, _1> 表示 M 方向 2 个 CTA 组成一个 cluster,共享数据加载。

4.3 INT8 W8A8 GEMM(Ampere SM80)

Ampere 使用 CUTLASS 2.x API,没有 TMA 和 WGMMA,但有 cp.asyncmma.sync

// csrc/quantization/w8a8/cutlass/scaled_mm_c2x_sm80_dispatch.cuh

// M > 128
TileShape = GemmShape<128, 128, 64>;
WarpShape = GemmShape<64, 64, 64>;
InstructionShape = GemmShape<16, 8, 32>;  // INT8 Tensor Core 指令
Stages = 5;  // 5 级流水线
SharedMemory = 81920;  // 80KB SMEM

// M ∈ (32, 64]
TileShape = GemmShape<64, 128, 128>;
SharedMemory = 122880;  // 120KB SMEM

// M ∈ (16, 32]
TileShape = GemmShape<32, 64, 128>;
SharedMemory = 61440;

// M ∈ [1, 16](Decode)
TileShape = GemmShape<16, 64, 128>;
SharedMemory = 51200;

注意 Stages 的选择:5 级流水线意味着 SMEM 中同时有 5 个 K-tile 的数据在飞行,最大化计算和访存的重叠。代价是更高的 SMEM 占用。

4.4 Block-wise FP8(DeepSeek V3 方案,SM90)

DeepSeek V3 的量化方案是每 128×128 权重子矩阵一个 scale,每 1×128 激活子向量一个 scale。CUTLASS kernel 需要在 GEMM 内部处理这些细粒度 scale:

// csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh

template <typename OutType>
void cutlass_gemm_blockwise_sm90_fp8_dispatch(...) {
    // ScaleGranularity: M=1, N=128, K=128
    // 即:激活按 [1, 128] 分块, 权重按 [128, 128] 分块
    cutlass_gemm_caller_blockwise<
        cutlass_3x_gemm_fp8_blockwise<
            OutType,
            1,    // scale_m_granularity = 1 (per-token)
            128,  // scale_n_granularity = 128
            128,  // scale_k_granularity = 128
            Shape<_128, _128, _128>,  // TileShape
            Shape<_1, _2, _1>,        // ClusterShape
            cutlass::epilogue::TmaWarpSpecializedCooperative,
            cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum
        >>(out, a, b, a_scales, b_scales);
}

FP8BlockScaledAccum 是关键:它在 GEMM 的 K-维度迭代中,每处理 128 个元素就从 scale tensor 中加载对应的 scale 并乘入累加器,而不是在循环结束后统一乘。这与 Triton fused_moe kernel 中的 block-wise FP8 逻辑对应。

4.5 W4A8 混合精度 GEMM

4-bit 权重 × FP8 激活,基于 CUTLASS example 55(Hopper mixed-dtype GEMM):

// csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu

using MmaType = cutlass::float_e4m3_t;   // FP8 E4M3 激活
using QuantType = cutlass::int4b_t;       // INT4 权重(压缩存储)
using PackFactor = 8;  // 8 个 INT4 打包进一个 int32

// 10 种 tile 配置,按 M/N/K 分发
// 核心启发式:
if (M <= 16) schedule = "128x16_1x1x1";
else if (M <= 32) schedule = "128x32_1x1x1";
else if (M <= 64) schedule = "128x64_1x1x1";
else if (M <= 128) schedule = "128x128_1x1x1";
else if (M <= 256) schedule = "128x256_1x1x1";
else schedule = "128x256_2x1x1";  // 最大 batch 用 cluster

INT4 权重在 kernel 内部被实时解压为 FP8,与 FP8 激活做 Tensor Core 乘加。权重需要预处理为 LayoutAtomQuant 格式以适配 Tensor Core 的访问模式。

4.6 2:4 结构化稀疏 GEMM

2:4 稀疏是 Ampere 引入的硬件特性:每 4 个元素中有 2 个为零,Sparse Tensor Core 可以跳过零值的计算,理论吞吐翻倍。

// csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh

using CollectiveMainloop = cutlass::gemm::collective::CollectiveBuilder<
    cutlass::arch::Sm90,
    cutlass::arch::OpClassSparseTensorOp,  // ← 使用 Sparse Tensor Core
    ElementAB, cutlass::layout::RowMajor, AlignmentAB,
    ElementAB, cutlass::layout::ColumnMajor, AlignmentAB,
    ElementAcc, TileShape, ClusterShape, Stages,
    KernelSchedule
>::CollectiveOp;

稀疏 kernel 配置(FP8,SM90):

M 范围 TileShape ClusterShape 调度策略
M > 256 128×128×128 1×2×1 WarpSpecializedFP8FastAccum
M ∈ (128, 256] 128×128×256 1×1×1 CooperativeFP8FastAccum
M ∈ (64, 128] 64×128×256 1×1×1 PingpongFP8FastAccum
M ≤ 64 64×64×256 1×1×1 WarpSpecializedFP8FastAccum

vLLM 内置了 StructuredSparseCompressor,将密集矩阵转换为 2:4 压缩格式(非零值 + 索引元数据),压缩开销不到总执行时间的 5%。

4.7 MoE Grouped GEMM(SM90/SM100)

MoE 模型需要在一个 kernel 中执行多个不同大小的 GEMM(每个 expert 的 token 数量不同)。CUTLASS 的 GemmUniversalMode::kGrouped 天然支持这种场景:

// csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x.cuh

template <typename ElementAB_, typename ElementC_, typename ArchTag_,
          template <typename, typename, typename> typename Epilogue_,
          typename TileShape, typename ClusterShape,
          typename KernelSchedule, typename EpilogueSchedule,
          bool swap_ab_ = false>
struct cutlass_3x_group_gemm {
    // 使用指针数组(pointer array)指定每个 group 的 A、B、C 矩阵
    using ProblemShape = GroupProblemShape<Shape<int, int, int>>;
    // 每个 group 可以有不同的 M(token 数),但 N、K 相同

    using KernelType = cutlass::gemm::kernel::GemmUniversal<
        ProblemShape, CollectiveMainloop, CollectiveEpilogue>;
};

Python 层的 MoE CUTLASS 调度(cutlass_moe.py):

class CutlassExpertsFp8Base:
    # 支持的量化方案:
    SUPPORTED_W_A = [
        (kFp8StaticChannelSym, kFp8DynamicTokenSym),   # per-channel 权重 + per-token 激活
        (kFp8StaticTensorSym, kFp8DynamicTensorSym),   # per-tensor 权重 + per-tensor 激活
        (kFp8StaticTensorSym, kFp8StaticTensorSym),    # 全静态
    ]

    def apply(self, hidden_states, w1, w2, topk_weights, topk_ids, ...):
        # 1. 量化输入到 FP8
        a1q, a1_scale = quantize_input(hidden_states)
        # 2. CUTLASS Grouped GEMM #1: gate_up projection
        ops.cutlass_moe_mm(a1q, w1, ...)
        # 3. 激活函数 (SiLU)
        silu_and_mul(intermediate)
        # 4. 量化中间结果到 FP8
        a2q, a2_scale = quantize_input(intermediate)
        # 5. CUTLASS Grouped GEMM #2: down projection
        ops.cutlass_moe_mm(a2q, w2, ...)
        # 6. Unpermute 输出

    # swap_ab 优化:小 batch 时转置
    swap_ab = a1q.size(0) <= 64

4.8 CUTLASS MLA(Multi-head Latent Attention,Blackwell)

DeepSeek V3 的 MLA 注意力机制在 Blackwell 上有专用的 CUTLASS 实现:

// csrc/attention/mla/sm100_cutlass_mla_kernel.cu

using TileShape = Shape<_128, _128, Shape<_512, _64>>;
// H=128 heads, K=128 block size, D=(512 latent + 64 rope)

using FmhaKernel = Sm100FmhaMlaKernelTmaWarpspecialized<
    TileShape, Element, ElementAcc, ElementOut, ElementAcc,
    TileScheduler, /*kIsCpAsync=*/!IsPaged128>;

using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;

这是 CUTLASS 在注意力领域的扩展——不仅做 GEMM,还做 fused multi-head attention。

4.9 Epilogue 体系

vLLM 在 CUTLASS epilogue 框架上实现了 6 种自定义 epilogue:

Epilogue 类型 公式 使用场景
TrivialEpilogue D = acc 无量化
ScaledEpilogue D = s_a × s_b × acc 对称量化,无 bias
ScaledEpilogueBias D = s_a × s_b × acc + C 对称量化 + bias
ScaledEpilogueColumnBias 同上,列方向 bias swap_ab 场景
ScaledEpilogueBiasAzp D = s_a × s_b × (acc - azp_adj) + C 非对称量化
ScaledEpilogueBiasAzpToken D = s_a × s_b × (acc - azp×azp_adj) + C 逐 token 非对称
ScaledEpilogueArray 指针数组 scale MoE Grouped GEMM

所有 epilogue 使用 CUTLASS 的 EVT 系统,通过 Sm90EVT 节点链式组合 multiply、subtract、add 操作。


五、Kernel 选择逻辑:什么时候用什么

5.1 总体决策树

                      输入请求
                         │
              ┌──────────┴──────────┐
              │ 量化 GEMM?          │
              │ (有 scale 参数)      │
              └──────────┬──────────┘
                    ╱          ╲
                  是              否
                  │               │
          ┌───────┴───────┐     cuBLAS
          │ SM 版本检查    │   (torch.mm)
          └───────┬───────┘
     ┌─────┬──────┼──────┬─────┐
   SM75  SM80   SM89   SM90  SM100+
     │     │      │      │      │
  CUTLASS CUTLASS CUTLASS CUTLASS CUTLASS
   2.x    2.x    2.x    3.x    3.x
   INT8   INT8   FP8   FP8    FP8
          INT8   INT8   INT8   Block FP8
                        Block  W4A8
                        W4A8   Sparse
                        Sparse MXFP8
                        MoE    MoE+MLA

5.2 MoE Kernel 选择

MoE 场景的 kernel 选择更复杂,因为有多个可行后端:

# vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py

class TritonOrCutlassExperts(FallbackExperts):
    """Blackwell 上的 MoE kernel 选择策略"""

    def _select_experts_impl(self, hidden_states, w1, w2):
        if self.is_sm100 and hidden_states.shape[0] <= 8:
            # Blackwell 小 batch:CUTLASS grouped GEMM 的 launch 开销太大
            # 回退到 Triton(单 kernel,更低 launch overhead)
            return self.fallback_experts  # TritonExperts
        else:
            # 其他情况:CUTLASS grouped GEMM 更快
            return self.experts  # CutlassExpertsFp8

完整的 MoE Expert 后端层级:

后端 数据格式 量化 最佳场景
TritonExperts Standard [M×top_k, K] BF16/FP16 通用,无量化
CutlassExpertsFp8 Standard FP8 per-tensor/channel SM90+ 大 batch
BatchedTritonExperts Batched [E, max_tokens, K] BF16/FP16 EP 通信后
DeepGemmExperts Standard FP8 block-wise (128) N > 512,SM90+
BatchedDeepGemmExperts Batched FP8 block-wise EP + FP8
CutlassExpertsW4A8Fp8 Standard W4A8 INT4 权重 + FP8 激活
FlashInferExperts - FP8/MXFP4/NVFP4/BF16 SM90+ 全能后端

5.3 FP8 能力检查

// 编译时和运行时的能力检查

bool cutlass_scaled_mm_supports_fp8(int64_t capability) {
    if (capability >= 90) return CUDA_VERSION >= 12000;  // Hopper: CUDA 12.0+
    if (capability >= 89) return CUDA_VERSION >= 12040;  // Ada: CUDA 12.4+
    return false;
}

bool cutlass_scaled_mm_supports_block_fp8(int64_t capability) {
    if (capability >= 100) return CUDA_VERSION >= 12080;  // Blackwell
    if (capability >= 90) return CUDA_VERSION >= 12000;   // Hopper
    return false;  // Block FP8 仅 SM90+
}

bool cutlass_group_gemm_supported(int64_t capability) {
    if (capability >= 100) return CUDA_VERSION >= 12080;
    if (capability >= 90) return CUDA_VERSION >= 12030;
    return false;  // Grouped GEMM (MoE) 仅 SM90+
}

六、性能对比:CUTLASS vs cuBLAS vs Triton

6.1 各库的性能特征

                    小 M (Decode)              大 M (Prefill)
                    M = 1~16                   M = 256~8192
                    ┌───────────┐              ┌───────────┐
cuBLAS (BF16)       │ ■■■■      │ 基准          │ ■■■■■■■■  │ 基准
                    │           │              │           │
CUTLASS FP8         │ ■■■■■■■■  │ 1.5-1.9x     │ ■■■■■■■■■ │ 1.3-1.9x
(含 dequant)       │           │ (带宽节省)    │           │ (计算+带宽)
                    │           │              │           │
Triton FP8          │ ■■■■■     │ ~1.2x        │ ■■■■■■    │ ~1.0x
(fused_moe)        │           │              │           │ (接近 cuBLAS)
                    │           │              │           │
Marlin W4A16        │ ■■■■■■■■■ │ 最快          │ ■■■■■■    │ 一般
(手写 PTX)         │           │ (极致带宽优化)│           │
                    └───────────┘              └───────────┘

注:性能倍数为相对 cuBLAS BF16 的近似值,随硬件和矩阵大小变化

6.2 CUTLASS vs cuBLAS

维度 CUTLASS cuBLAS
开源 是(C++ 模板) 否(闭源运行时)
优化方式 编译期特化 运行时 kernel 选择
Epilogue Fusion 支持任意融合 不支持
量化 GEMM 原生支持 FP8/INT8/INT4 + scale 有限(需额外 kernel 做 dequant)
BF16 Dense 与 cuBLAS 持平 通常最优(有时用 CUTLASS 生成的 kernel)
定制化 完全可控 tile/schedule/epilogue 固定 API
编译开销 大(大量模板实例化) 无(预编译)

关键结论:cuBLAS 在未量化 BF16 Dense GEMM 上通常是最优选择(这也是 PyTorch 默认用它的原因)。CUTLASS 的优势在于量化 GEMM(需要 fused dequantize)和自定义操作(如 MoE grouped GEMM)。

实际上,cuBLAS 内部也使用 CUTLASS 生成的 kernel。两者不是对立关系,而是抽象层级不同:cuBLAS 提供易用的运行时接口,CUTLASS 提供底层的编译期定制能力。

6.3 CUTLASS vs Triton

维度 CUTLASS Triton
语言 C++ 模板 Python DSL
开发效率 低(需要深度 GPU 知识) 高(类 NumPy 语法)
峰值性能 更高(直接映射硬件指令) 接近但通常稍低
硬件适配 每代 GPU 专用优化 编译器自动适配
典型性能差距 - FP8 GEMM: CUTLASS 比 Triton 快 ~2.3x
适用场景 生产级 kernel、极致性能 原型开发、自定义 fused kernel

vLLM 中两者互补使用:

  • CUTLASS:量化 GEMM(性能关键路径)、稀疏 GEMM、MoE Grouped GEMM
  • Triton:fused_moe kernel(token routing + GEMM 融合)、自定义 activation kernel、FlashAttention 变体

6.4 实测性能数据

场景 性能 来源
CUTLASS FP8 GEMM 190.49 TFLOPS(vs Triton 82.65 TFLOPS,2.3x) vLLM benchmark
FP8 CUTLASS vs PyTorch BF16 最高 1.9x 加速 vLLM 官方
2:4 稀疏 + FP8 额外 ~30% 延迟降低 vLLM 文档
Blackwell cuBLAS nvjet vs FlashInfer CUTLASS (BF16) cuBLAS 快 1.7x vLLM 博客
FlashInfer+CUTLASS+TRT-LLM 联合优化 +38% 吞吐, +13% 延迟改善 vLLM InferenceMAX

七、实际模型案例

7.1 DeepSeek V3 / R1(256 experts, FP8 block-wise)

DeepSeek V3 是 vLLM CUTLASS 集成的最极端测试场景:

推理路径:
Linear 层 → CUTLASS block-wise FP8 GEMM (SM90)
  ├── ScaleGranularity: M=1, N=128, K=128
  ├── KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum
  └── 与 Triton fused_moe 的 block-wise FP8 对应

MoE 层 → CUTLASS Grouped GEMM FP8 或 DeepGEMM
  ├── 256 experts, top-8 routing
  ├── GroupedTopk 路由
  ├── Expert Parallel (Wide-EP) 跨 32-64 GPU
  └── CUTLASS MoE 或 FlashInfer+CUTLASS hybrid

MLA 注意力 → CUTLASS MLA (Blackwell) 或 FlashInfer
  ├── 512-dim latent + 64-dim rope
  └── Paged KV Cache (block_size=128)

在 H200 Wide-EP 部署下,vLLM 实现了 2,200 tok/s/GPU 的吞吐。

7.2 Llama 3 70B(Dense, BF16/FP8)

BF16 推理:
  所有 Linear 层 → cuBLAS (torch.mm)
  Attention → FlashInfer / FlashAttention
  无量化,无 CUTLASS

FP8 推理(W8A8):
  Linear 层 → CUTLASS FP8 scaled_mm (SM90)
    ├── Prefill: TileShape<128,128,128>, Pingpong
    └── Decode:  TileShape<64,16,256>, swap_ab
  Attention → FlashInfer
  性能提升 1.5-1.9x(vs BF16, 取决于 batch size)

INT8 推理:
  Linear 层 → CUTLASS INT8 scaled_mm (SM80)
    ├── 5-stage pipeline
    └── SplitK + StreamK for load balancing
  Attention → FlashAttention 2

W4A16 推理(GPTQ/AWQ):
  Linear 层 → Marlin kernel(手写 PTX)
  不使用 CUTLASS

7.3 Mixtral 8x7B / 8x22B(MoE, 8 experts)

BF16 MoE:
  Linear 层 → cuBLAS
  MoE 层 → Triton fused_moe kernel
    ├── moe_align_block_size → fused_moe_kernel(W1) → SiLU → fused_moe_kernel(W2)
    └── 8 experts, top-2, SwiGLU activation

FP8 MoE:
  Linear 层 → CUTLASS FP8 scaled_mm
  MoE 层 → CUTLASS Grouped GEMM FP8 (SM90+)
    ├── CutlassExpertsFp8: per-tensor/per-channel scale
    └── 或 Triton fused_moe with FP8 (fallback)

Blackwell (SM100) FP8 MoE:
  MoE 层 → TritonOrCutlassExperts
    ├── batch_size > 8: CutlassExpertsFp8
    └── batch_size ≤ 8: TritonExperts (fallback, 更低 launch overhead)

7.4 Sparse Llama(2:4 稀疏 + FP8)

Sparse 推理路径:
  Linear 层 → CUTLASS Sparse FP8 GEMM
    ├── OpClassSparseTensorOp (Sparse Tensor Core)
    ├── 权重预压缩为 2:4 格式(非零值 + 索引元数据)
    ├── 理论吞吐翻倍(跳过 50% 零值计算)
    └── 实测:30% 吞吐提升 + 20% 延迟降低

精度表现:
  2:4 Sparse Llama FP8 在 Open LLM Leaderboard V1 上恢复了 98.4% 的准确率

八、Know-How:扩展与调优

8.1 添加新的 CUTLASS Kernel

以添加一个新的量化 GEMM 为例:

Step 1:定义 kernel 配置

// csrc/quantization/my_quant/cutlass/my_kernel.cuh

#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal.h"

template <typename OutType, typename TileShape, typename ClusterShape>
struct my_cutlass_kernel {
    // 定义输入类型
    using ElementA = cutlass::float_e4m3_t;
    using ElementB = cutlass::float_e4m3_t;
    using ElementAccumulator = float;

    // 选择调度策略
    using KernelSchedule =
        cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;

    // 构建 Collective Mainloop
    using CollectiveMainloop =
        typename cutlass::gemm::collective::CollectiveBuilder<
            cutlass::arch::Sm90,
            cutlass::arch::OpClassTensorOp,
            ElementA, cutlass::layout::RowMajor, 16,
            ElementB, cutlass::layout::ColumnMajor, 16,
            ElementAccumulator,
            TileShape, ClusterShape,
            cutlass::gemm::collective::StageCountAutoCarveout<
                sizeof(typename CollectiveEpilogue::SharedStorage)>,
            KernelSchedule
        >::CollectiveOp;

    // 定义 Epilogue(自定义 scale + bias)
    using CollectiveEpilogue = /* ... ScaledEpilogueBias ... */;

    // 组装 GemmUniversal
    using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
        cutlass::gemm::GroupProblemShape<Shape<int,int,int>>,
        CollectiveMainloop, CollectiveEpilogue>;

    using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
};

Step 2:添加 M-based dispatch

// csrc/quantization/my_quant/cutlass/my_dispatch.cuh

void my_kernel_dispatch(torch::Tensor& out, torch::Tensor const& a,
                        torch::Tensor const& b, ...) {
    uint32_t m = a.size(0);

    if (m > 128) {
        run_kernel<Shape<_128, _128, _128>, Shape<_2, _1, _1>>(...);
    } else if (m > 64) {
        run_kernel<Shape<_64, _128, _128>, Shape<_2, _1, _1>>(...);
    } else if (m > 16) {
        run_kernel<Shape<_64, _64, _256>, Shape<_1, _1, _1>, /*swap_ab=*/true>(...);
    } else {
        run_kernel<Shape<_64, _16, _256>, Shape<_1, _2, _1>, /*swap_ab=*/true>(...);
    }
}

Step 3:注册 PyTorch custom op

// csrc/quantization/my_quant/my_entry.cu

void my_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
                  torch::Tensor const& b, torch::Tensor const& a_scales,
                  torch::Tensor const& b_scales) {
    TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
    my_kernel_dispatch(c, a, b, a_scales, b_scales);
}

// 在 ops.h 中注册
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
    ops.def("my_scaled_mm", &my_scaled_mm);
}

8.2 Kernel 调优指南

① Tile 大小选择

核心原则:大 M 用大 tile(compute-bound),小 M 用小 M-tile + 大 K-tile(memory-bound)

M > 128:   TileShape<128, 128, 128>  — 大 tile 填满 Tensor Core
M  (16,64]: TileShape<64, 64, 256>  — 小 M-tile 减少 padding,大 K-tile 增加计算密度
M  16:    TileShape<64, 16, 256>   — 极小 M-tile + swap_ab

② Cluster 大小

Hopper 的 Thread Block Cluster 让相邻 CTA 可以通过分布式共享内存协作:

  • Shape<_2, _1, _1>:M 方向 2 个 CTA 一组,共享 A 矩阵的加载
  • Shape<_1, _4, _1>:N 方向 4 个 CTA 一组,适合 N 很大的场景
  • Shape<_1, _1, _1>:无 cluster,适合小矩阵或调试

③ Schedule 选择

Schedule 特点 适用场景
Pingpong 两组 consumer 交替 GEMM+epilogue 中等 batch, 最大化 TC 利用率
Cooperative 所有 warp group 协作同一 tile 大 batch (M>8192)
WarpSpecialized 基础版 producer-consumer 小 batch, 低 launch overhead
FP8FastAccum 后缀,启用 FP8 快速累加 所有 FP8 场景
BlockScaledAccum Block-wise FP8 scale DeepSeek V3 方案

④ swap_ab 优化

当 M ≤ 64 时,考虑 swap_ab。它的效果是把 M 从 row 维度挪到 column 维度,让 thread block 的 M-tile 不需要很大就能覆盖所有行。代价是 epilogue 中的 bias 方向需要从 row 变成 column(ScaledEpilogueBiasScaledEpilogueColumnBias)。

⑤ 编译优化

CUTLASS 的模板实例化会产生巨大的编译时间和二进制大小。vLLM 的做法:

  • 只实例化需要的配置(不是所有组合)
  • 按 SM 版本分拆编译单元(sm80、sm90、sm100 各自独立)
  • 使用 CUTLASS_NVCC_ARCHS 控制编译的架构目标

8.3 性能调试

# vLLM 自带的 CUTLASS benchmark
python benchmarks/cutlass_benchmarks/w8a8_benchmarks.py \
    --dtype fp8 --m 1 16 64 128 256 1024 \
    --n 4096 8192 --k 4096

# MoE CUTLASS benchmark
python benchmarks/kernels/benchmark_cutlass_moe_fp8.py

# Sparse benchmark
python benchmarks/cutlass_benchmarks/sparse_benchmarks.py

# 使用 NSight Compute 分析 kernel 性能
ncu --set full \
    --target-processes all \
    python -m vllm.entrypoints.openai.api_server \
    --model meta-llama/Llama-3-8B --quantization fp8

关键指标:

  • Compute throughput:Tensor Core 利用率,FP8 理论峰值 ~3958 TFLOPS (H100)
  • Memory bandwidth:HBM 利用率,H100 理论峰值 3.35 TB/s
  • Occupancy:SM 占用率,影响 warp 调度和延迟隐藏
  • L2 cache hit rate:影响 grouped ordering 的效果

九、总结

vLLM 的 kernel 技术栈是一个精心设计的多层次选择系统

cuBLAS   未量化 BF16 密集 GEMM(默认最稳定)
CUTLASS  量化 GEMM + epilogue fusion(FP8/INT8/W4A8/Sparse/MoE)
Triton   自定义 fused kernel(fused_moeactivation)
FlashInfer  注意力计算(JIT 编译多后端)
Marlin   W4A16 权重量化(手写 PTX极致优化)

不存在「最优的单一后端」,只有针对特定场景的最优选择。 vLLM 的价值在于它把这些后端的选择逻辑封装起来:你不需要知道什么时候该用 CUTLASS 的 Pingpong schedule、什么时候该 swap_ab、什么时候该回退到 Triton——vLLM 根据 GPU 架构、矩阵形状、量化类型自动选择。

从更高的视角看,vLLM 的 CUTLASS 集成体现了 LLM 推理优化的一个核心趋势:从通用计算库到领域专用 kernel 的演进。 cuBLAS 是通用的 BLAS 实现;CUTLASS 允许为特定的量化方案和矩阵形状生成专用 kernel;Marlin 进一步为特定的量化格式手写 PTX。每一步都在通用性和性能之间做更极端的取舍。

对于 AI Infra 工程师,理解这个 kernel 栈的意义不仅是「知道 vLLM 怎么工作」,更是理解了 GPU 编程中抽象层级与性能之间的根本张力——这个张力不会随着硬件更新而消失,只会在每一代新架构上以新的形式重现。


参考资源