返回博客列表

vLLM fused_moe Kernel 深度剖析:Know Why & Know How

·AI

vLLM fused_moe Kernel 深度剖析:Know Why & Know How

上一篇 vLLM 深度解析聚焦在 PagedAttention 和 Continuous Batching,这两个机制解决了 LLM 推理中「内存管理」和「调度」的核心问题。但如果你部署的是 MoE(Mixture of Experts)模型——Mixtral 8x7B、DeepSeek V3、Llama 4 Maverick——你会发现瓶颈转移了:MoE 层的计算效率成为决定吞吐量的关键因素。

vLLM 中解决这个问题的核心组件是 fused_moe kernel。这篇文章从 MoE 的本质出发,逐层拆解这个 kernel 的设计动机、Triton 实现细节、性能优化原理,以及代码级的 know-how。


一、Know Why:为什么需要 Fused Kernel

1.1 MoE 架构回顾

MoE 的核心思想是:不让每个 token 经过所有参数,而是通过一个路由器(Router/Gate)选择少量专家(Expert)进行计算。 这样模型的总参数量可以很大(知识容量大),但每个 token 的实际计算量保持在可控范围。

MoE 层的基本结构:

输入 hidden_states [num_tokens, hidden_dim]
         │
         ▼
    ┌─────────┐
    │  Router  │  线性层:hidden_dim → num_experts
    │ (Gate)   │  输出:每个 token 对每个 expert 的得分
    └────┬────┘
         │ Top-K 选择 + Softmax 归一化
         ▼
    ┌────────────────────────────────────────────┐
    │  Expert 0   Expert 1   Expert 2   ...      │
    │  (FFN)      (FFN)      (FFN)               │
    │  W1: gate_up_proj   W2: down_proj          │
    └────────────────────────────────────────────┘
         │
         ▼ 加权求和:Σ(weight_i × expert_i(x))
    输出 [num_tokens, hidden_dim]

以 Mixtral 8x7B 为例:8 个 expert,每个 token 选 top-2 个 expert。DeepSeek V3 更极端:256 个 expert,每个 token 选 top-8。

1.2 Naive 实现的问题

一个「教科书式」的 MoE 前向传播是这样的:

# Naive MoE 实现(不要在生产中使用)
def naive_moe_forward(hidden_states, router, experts, top_k):
    # Step 1: 路由计算
    router_logits = router(hidden_states)              # [M, num_experts]
    topk_weights, topk_ids = torch.topk(
        torch.softmax(router_logits, dim=-1), k=top_k  # [M, top_k]
    )

    # Step 2: 逐 expert 收集 token 并计算
    output = torch.zeros_like(hidden_states)
    for expert_id in range(num_experts):
        # 找出分配给这个 expert 的 token
        mask = (topk_ids == expert_id).any(dim=-1)      # [M]
        if not mask.any():
            continue
        expert_input = hidden_states[mask]               # 动态索引 → 不连续内存
        expert_output = experts[expert_id](expert_input) # 单独的 GEMM kernel launch

        # 加权累加回输出
        for k in range(top_k):
            token_mask = topk_ids[:, k] == expert_id
            weight = topk_weights[:, k][token_mask]
            output[token_mask] += weight.unsqueeze(-1) * expert_output[...]

    return output

这个实现有四个致命的性能问题

问题一:kernel launch 开销巨大。 每个 expert 需要独立的 GEMM 调用。Mixtral 有 8 个 expert,每层至少 16 次 GEMM(gate_up + down,每个 expert 各一次)。DeepSeek V3 有 256 个 expert,这个数字变成 512 次。GPU kernel launch 的开销大约 5-10μs,512 次就是 2.5-5ms,这在 decode 阶段(单步总时间可能只有 10-20ms)是灾难性的。

问题二:动态索引导致内存不连续。 hidden_states[mask] 产生的 tensor 在 GPU 内存中不是连续的。后续的 GEMM 无法利用 coalesced memory access,带宽利用率暴跌。

问题三:每个 expert 的 batch size 很小。 如果 8 个 expert 平均分配 token,每个 expert 只拿到 batch_size / 8 的 token。小 batch 的 GEMM 严重 underutilize GPU 的计算单元——Tensor Core 需要足够大的矩阵才能达到峰值吞吐。

问题四:多次全局内存读写。 输入数据被反复从 HBM 读取(每个 expert 读一次),中间结果写回 HBM 后又被加权累加读取。HBM 带宽成为瓶颈。

1.3 Fused Kernel 的核心动机

Fused kernel 的目标是把上述「多次独立 kernel launch + 多次全局内存往返」压缩成尽可能少的 kernel launch,并在 kernel 内部完成所有操作。

Naive 实现的 GPU 操作时间线:
┌─Router GEMM─┐┌─Softmax─┐┌─TopK─┐┌─Scatter─┐┌─Expert0 GEMM─┐┌─Expert1 GEMM─┐...┌─Gather─┐┌─Weight─┐
└──────────────┘└─────────┘└──────┘└─────────┘└──────────────┘└──────────────┘   └────────┘└────────┘
  kernel 1        kernel 2   k3       k4          k5               k6              k(N+3)    k(N+4)

Fused 实现的 GPU 操作时间线:
┌─Router─┐┌─TopK+Sort+Pad─┐┌───── Fused Expert GEMM (所有 expert 一次完成) ─────┐┌─Reduce─┐
└────────┘└────────────────┘└────────────────────────────────────────────────────┘└────────┘
 kernel 1      kernel 2                      kernel 3                               kernel 4

具体来说,vLLM 的 fused_moe 融合了以下操作:

  1. Token 排列(Permutation):按 expert 分组排列 token,消除不连续内存访问
  2. 分组 GEMM(Grouped GEMM):所有 expert 的矩阵乘法在一个 kernel 中完成
  3. 反量化(Dequantization):FP8/INT8 的 scale 应用在 GEMM 循环内完成
  4. 路由权重乘法:在 kernel 输出阶段直接乘以 routing weight

一句话总结:fused_moe 把「排列 → 乘法 → 反量化 → 加权」四个步骤融合进一个 Triton kernel,减少 kernel launch 次数 10-100 倍,消除中间结果的 HBM 往返。


二、核心数据流:从 Token 到 Expert 再到输出

在深入 kernel 实现之前,先建立完整的数据流心智模型:

输入:hidden_states [M, K]M=token数, K=hidden_dim)
  │
  ├─① Router: Linear(K → E)  → router_logits [M, E]
  │
  ├─② fused_topk()           → topk_weights [M, top_k]  (float32)
  │                             topk_ids [M, top_k]      (int32)
  │
  ├─③ moe_align_block_size() → sorted_token_ids [EM]     (排序+填充后的token索引)
  │                             expert_ids [EM/BLOCK_M]   (每个block的expert编号)
  │                             num_tokens_post_pad       (填充后的总token数)
  │
  ├─④ fused_moe_kernel(W1)   → intermediate [EM, 2N]     (gate_up projection)
  │   (Triton kernel #1)         ↕ SiLU激活: silu_and_mul
  │                             intermediate [EM, N]
  │
  ├─⑤ fused_moe_kernel(W2)   → output [EM, K]           (down projection)
  │   (Triton kernel #2)
  │
  └─⑥ moe_sum()              → final [M, K]             (top_k 专家输出加权求和)

权重形状:
  W1 (gate_up): [E, 2N, K]  → 每个expert的gate和up projection合并
  W2 (down):    [E, K, N]   → 每个expert的down projection

其中 E=专家数, N=intermediate_size, K=hidden_dim

这里有一个关键的设计选择:token 的排序(sorting)和填充(padding)在 kernel 外部完成(步骤③),而不是在 kernel 内部。 这是因为排序是一个全局操作(需要知道所有 token 的 expert 分配),而 kernel 的并行结构要求每个 thread block 独立工作。


三、Token Routing 详解

3.1 Top-K 路由:fused_topk

def fused_topk(
    hidden_states: torch.Tensor,     # [M, hidden_dim]
    gating_output: torch.Tensor,     # [M, num_experts]
    topk: int,
    renormalize: bool,
    scoring_func: str = "softmax",   # 或 "sigmoid"(DeepSeek V3 使用)
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    返回:
      topk_weights: [M, top_k]  每个token选中的expert的权重
      topk_ids:     [M, top_k]  每个token选中的expert的ID
      token_expert_indices: [M, top_k]  全局索引
    """

标准路由使用 softmax + top-k。但 DeepSeek V2/V3 使用了分组 Top-K(Grouped TopK),这是一个更复杂的两阶段路由:

# DeepSeek V3 的分组路由逻辑(简化)
class GroupedTopk:
    def forward(self, router_logits):
        scores = torch.sigmoid(router_logits)  # 注意:sigmoid 而非 softmax

        # 阶段1:选择 top_group 个专家组
        # 256个expert分成多个group,每组内取top-2求和作为组得分
        group_scores = scores.view(M, num_groups, experts_per_group)
        if e_score_correction_bias is not None:
            group_scores += bias  # 负载均衡偏置
        top_group_scores = group_scores.sum(top-2-per-group)
        selected_groups = top_group_scores.topk(topk_group)

        # 阶段2:在选中的组内选择 top_k 个expert
        mask_non_selected_groups(scores)
        topk_weights, topk_ids = scores.topk(top_k)

        # 偏置只影响选择,不影响最终权重
        return topk_weights, topk_ids

3.2 Block 对齐:moe_align_block_size

这是整个 fused_moe 数据流中最关键的预处理步骤。Triton kernel 的并行粒度是 BLOCK_SIZE_M 个 token,所以每个 expert 分到的 token 数量必须对齐到 BLOCK_SIZE_M 的倍数。

def moe_align_block_size(
    topk_ids: torch.Tensor,           # [M, top_k]
    block_size: int,                   # = BLOCK_SIZE_M
    num_experts: int,
    expert_map: torch.Tensor | None,   # EP 时的全局→本地映射
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    输入: topk_ids = [[2, 5], [0, 2], [5, 3], [2, 0]]  (4 tokens, top_2)
    block_size = 4

    处理过程(假设 6 个 expert):

    1. 统计每个 expert 的 token 数:
       Expert 0: 2 tokens (token 1,3)
       Expert 2: 3 tokens (token 0,1,3)
       Expert 3: 1 token  (token 2)
       Expert 5: 2 tokens (token 0,2)

    2. 每个 expert 的 token 数向上对齐到 block_size=4:
       Expert 0: 2 → pad to 4
       Expert 2: 3 → pad to 4
       Expert 3: 1 → pad to 4
       Expert 5: 2 → pad to 4

    3. 生成排序后的 token ID 数组(padding 用 M*top_k 填充):
       sorted_token_ids = [
         1, 3, PAD, PAD,    # Expert 0 的 token(对齐到4)
         0, 1, 3, PAD,      # Expert 2 的 token
         2, PAD, PAD, PAD,  # Expert 3 的 token
         0, 2, PAD, PAD,    # Expert 5 的 token
       ]
       注意:这里的 token ID 是 topk 展开后的索引

    4. 生成 expert_ids(每 block_size 个 token 对应一个 expert):
       expert_ids = [0, 2, 3, 5]

    5. num_tokens_post_pad = 16

    返回: (sorted_token_ids, expert_ids, num_tokens_post_pad)
    """

为什么需要 padding? Triton kernel 的每个 thread block 处理 BLOCK_SIZE_M 个 token。如果 expert 0 只分到 2 个 token 但 BLOCK_SIZE_M=16,剩下 14 个 slot 就需要填充。kernel 内部通过 token_mask = offs_token < num_valid_tokens 来跳过 padding token 的计算,但内存访问模式保持对齐。

Expert Parallel 场景: 当使用 EP(Expert Parallelism)时,expert_map 将不在当前 rank 上的 expert 映射为 -1。kernel 检测到 expert_ids[block] == -1 时直接写零并返回。


四、Triton Kernel 实现剖析

4.1 Kernel 签名

@triton.jit
def fused_moe_kernel(
    # 矩阵指针
    a_ptr,                    # 输入激活 [M, K]
    b_ptr,                    # Expert 权重 [E, K, N] 或 [E, N, K]
    c_ptr,                    # 输出 [EM, N]
    a_scale_ptr,              # 激活量化 scale
    b_scale_ptr,              # 权重量化 scale
    topk_weights_ptr,         # 路由权重
    sorted_token_ids_ptr,     # 排序后的 token 索引
    expert_ids_ptr,           # 每个 block 的 expert ID
    num_tokens_post_padded_ptr,
    # 矩阵维度
    N, K, EM, num_valid_tokens,
    # Stride 参数(14 个)
    stride_am, stride_ak,
    stride_be, stride_bk, stride_bn,
    stride_cm, stride_cn,
    stride_asm, stride_ask,
    stride_bse, stride_bsk, stride_bsn,
    # 量化分组
    group_n: tl.constexpr,
    group_k: tl.constexpr,
    # 元参数
    BLOCK_SIZE_M: tl.constexpr,   # Token 分块大小 (16-128)
    BLOCK_SIZE_N: tl.constexpr,   # 输出维度分块 (32-1024)
    BLOCK_SIZE_K: tl.constexpr,   # 归约维度分块 (32-256)
    GROUP_SIZE_M: tl.constexpr,   # L2 Cache 分组 (1-32)
    MUL_ROUTED_WEIGHT: tl.constexpr,
    top_k: tl.constexpr,
    compute_type: tl.constexpr,
    use_fp8_w8a8: tl.constexpr,
    use_int8_w8a16: tl.constexpr,
):

4.2 Grid 调度:Grouped Ordering for L2 Cache Reuse

kernel 的 grid 大小是 (num_pid_m * num_pid_n,),即一个一维 grid。关键在于 program_id 到 (pid_m, pid_n) 的映射方式

pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)       # M 方向的 block 数
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)         # N 方向的 block 数

# Grouped ordering:GROUP_SIZE_M 行为一组,组内先遍历 M 再遍历 N
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

为什么这样做? 考虑 GEMM 的数据访问模式:

标准行优先遍历(Row-major):
pid 0 → (M=0, N=0)  读 B[:, 0:64]
pid 1 → (M=0, N=1)  读 B[:, 64:128]
pid 2 → (M=0, N=2)  读 B[:, 128:192]
...
pid P → (M=1, N=0)  读 B[:, 0:64]B 的第一列要重新从 HBM 加载!

Grouped ordering(GROUP_SIZE_M=4):
pid 0 → (M=0, N=0)  读 B[:, 0:64]B 列块在 L2 中
pid 1 → (M=1, N=0)  读 B[:, 0:64]   ← L2 命中!
pid 2 → (M=2, N=0)  读 B[:, 0:64]   ← L2 命中!
pid 3 → (M=3, N=0)  读 B[:, 0:64]   ← L2 命中!
pid 4 → (M=0, N=1)  读 B[:, 64:128]
pid 5 → (M=1, N=1)  读 B[:, 64:128] ← L2 命中!
...

PyTorch 官方博客报告,这种 grouped ordering 在 A100 上提升高达 4 倍,H100 上提升 4.4 倍。对于 MoE 场景尤其重要,因为每个 expert 的 B 矩阵不同——L2 cache 命中率直接决定了是否需要反复从 HBM 加载权重。

4.3 Token 路由映射

kernel 内部如何从 sorted_token_ids 映射回原始输入?

# 加载当前 block 对应的排序后 token ID
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)

# 有效 token 掩码(排除 padding)
token_mask = offs_token < num_valid_tokens

# 加载当前 block 的 expert ID
off_experts = tl.load(expert_ids_ptr + pid_m)

# Expert Parallel:不在当前 rank 的 expert 写零并返回
if off_experts == -1:
    # write zeros to output ...
    return

# ★ 关键:从排序后的 token 索引恢复原始 token 位置
# sorted_token_ids 中每个 token 出现 top_k 次
# offs_token // top_k 得到原始 token 在 hidden_states 中的行索引
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am
                  + offs_k[None, :] * stride_ak)

# Expert 权重指针:选择当前 expert 的权重矩阵
b_ptrs = b_ptr + (off_experts * stride_be
                  + offs_k[:, None] * stride_bk
                  + offs_bn[None, :] * stride_bn)

这里 offs_token // top_k 是一个精巧的设计:因为每个 token 被选中了 top_k 个 expert,所以在 sorted_token_ids 中每个原始 token 出现了 top_k 次。除以 top_k 恢复原始索引。

4.4 GEMM 主循环

# 累加器初始化为 float32(即使输入是 FP16/BF16/FP8)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
    # 带掩码的内存加载
    a = tl.load(a_ptrs,
                mask=token_mask[:, None] &
                     (offs_k[None, :] < K - k * BLOCK_SIZE_K),
                other=0.0)
    b = tl.load(b_ptrs,
                mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
                other=0.0)

    if use_fp8_w8a8:
        if group_k > 0 and group_n > 0:
            # ★ Block-wise FP8 量化(DeepSeek V3 方案)
            # 每 group_k 行 × group_n 列一个 scale
            k_start = k * BLOCK_SIZE_K
            offs_ks = k_start // group_k
            a_scale = tl.load(a_scale_ptr + offs_ks * stride_ask, ...)
            b_scale = tl.load(b_scale_ptr + offs_ks * stride_bsk + ...)
            # scale 在循环内乘:保留精度
            accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
        else:
            # Tensor-wise FP8:循环后统一乘 scale
            accumulator = tl.dot(a, b, acc=accumulator)
    elif use_int8_w8a16:
        accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
    else:
        accumulator += tl.dot(a, b)

    # 移动到下一个 K 块
    a_ptrs += BLOCK_SIZE_K * stride_ak
    b_ptrs += BLOCK_SIZE_K * stride_bk

几个关键细节:

  1. 累加器始终是 float32。 即使输入是 FP8,tl.dot 的结果也用 float32 累加,避免精度损失。
  2. Block-wise FP8 的 scale 在循环内乘。 这是 DeepSeek V3 提出的量化方案:每 128×128 的权重子矩阵一个 scale,每 1×128 的激活子向量一个 scale。在循环内乘 scale 意味着每个 K 块都有独立的精度校正,比循环后统一乘更精确。
  3. tl.dot(a, b) 调用 Tensor Core。 Triton 编译器会将其映射到硬件的矩阵乘加指令(HMMA/IMMA),实现近峰值吞吐。

4.5 路由权重与输出

# 在 kernel 内部直接乘以路由权重(可选)
if MUL_ROUTED_WEIGHT:
    moe_weight = tl.load(topk_weights_ptr + offs_token,
                         mask=token_mask, other=0)
    accumulator = accumulator * moe_weight[:, None]

# 类型转换(FP32 → compute_type)
if use_fp8_w8a8 and not (group_k > 0 and group_n > 0):
    # Tensor-wise FP8:循环后乘 scale
    accumulator = (accumulator * a_scale * b_scale).to(compute_type)
elif use_int8_w8a16:
    accumulator = (accumulator * b_scale).to(compute_type)
else:
    accumulator = accumulator.to(compute_type)

# 写入输出
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)

MUL_ROUTED_WEIGHT 参数控制是否在 kernel 内乘权重。在 W2(down projection)阶段通常开启——这样输出已经是加权的,后续只需要对 top_k 个 expert 的输出做求和(moe_sum),不需要再单独乘权重。


五、两阶段 Expert 计算

一个完整的 MoE FFN 不是单次 GEMM,而是两阶段

def fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, ...):
    """
    Stage 1: gate_up = hidden_states × W1, 然后 SiLU 激活
    Stage 2: output = activated × W2, 然后加权求和

    W1 shape: [E, 2N, K]  (gate 和 up 合并)
    W2 shape: [E, K, N]
    """
    # ③ Token 排序 + 填充
    sorted_token_ids, expert_ids, num_tokens_post_pad = \
        moe_align_block_size(topk_ids, BLOCK_SIZE_M, num_experts, expert_map)

    # ④ Stage 1: Gate+Up Projection
    #   输入量化(可选 FP8/INT8)
    a1, a1_scale = moe_kernel_quantize_input(hidden_states, ...)
    #   Fused GEMM: [M, K] × [E, 2N, K]^T → [EM, 2N]
    fused_moe_kernel[grid](a1, w1, intermediate, ...)

    # ④ 激活函数: SiLU(gate) * up → [EM, N]
    #   对于 SwiGLU: intermediate 的前 N 列是 gate, 后 N 列是 up
    #   silu_and_mul 是一个高效的 fused CUDA op
    ops.silu_and_mul(intermediate_out, intermediate_in)

    # ⑤ Stage 2: Down Projection
    a2, a2_scale = moe_kernel_quantize_input(intermediate_out, ...)
    #   Fused GEMM: [EM, N] × [E, K, N]^T → [EM, K]
    fused_moe_kernel[grid](a2, w2, output,
                           MUL_ROUTED_WEIGHT=True, ...)  # 在 kernel 内乘权重

    # ⑥ Top-K 求和
    ops.moe_sum(output, output_final)

内存复用优化: 实现中使用 chunking 策略,将 token 分成多个 chunk 处理。不同 chunk 之间复用中间缓冲区:

# 内存复用:cache1 用完后 cache3 可以复用其空间
# "We can reuse the memory between these because
#  by the time we need cache3, we're done with cache1"
for chunk_start in range(0, M, CHUNK_SIZE):
    chunk = hidden_states[chunk_start:chunk_start + CHUNK_SIZE]
    # Stage 1 用 cache1 (intermediate)
    # 激活后结果写到 cache2
    # Stage 2 用 cache2, 输出写到 cache3(可复用 cache1 的空间)

六、性能分析:为什么快

6.1 Kernel Launch 开销对比

实现方式 Kernel Launch 次数/层 说明
Naive (逐 expert) 2E + 5 每个 expert 2 次 GEMM + routing/scatter/gather
Fused (vLLM) ~4 sort+pad, fused_gemm_w1, activation, fused_gemm_w2
DeepSeek V3 (E=256) ~517 → ~4 ~130 倍减少
Mixtral (E=8) ~21 → ~4 ~5 倍减少

Kernel launch overhead 在 decode 阶段尤为致命(batch size 小,每个 kernel 的实际计算量很少),fused kernel 的优势在小 batch 时最为显著。

6.2 内存带宽利用

Naive 实现的 HBM 访问:
   hidden_states: E 次(每个 expert 都要读)       E × M × K
   W1/W2:        2E 次(每个 expert 读自己的权重)  2E × K × N
  写中间结果:       2E                              2E × M × N
   HBM 访问  E × (2MK + 4KN + 2MN)

Fused 实现的 HBM 访问:
   hidden_states: 1 次(排序后连续读取)            M × K × top_k
   W1/W2:        2 次(L2 cache grouped ordering)  2 × E × K × N(但 L2 命中率高)
  写输出:          1                                M × K × top_k
   HBM 访问  2MK×top_k + 2EKN(有效带宽更高)

关键差异:

  • Input 只读一次:sorted_token_ids 让同一个 expert 的 token 在内存中连续,coalesced access
  • 权重 L2 复用:grouped ordering 让相邻 thread block 复用 B 矩阵的 L2 cache line
  • 无中间结果写回:gate_up 的结果直接在 register/shared memory 中做激活,不写 HBM

6.3 Compute Utilization

Naive 实现中每个 expert 的 GEMM 大小是 [M/E, K] × [K, N]。当 M=32, E=8 时,每个 expert 只有 4 个 token——这个 GEMM 太小了,Tensor Core 根本跑不满。

Fused 实现将所有 expert 的计算打包成一个大 kernel,grid size = ceil(EM/BLOCK_M) × ceil(N/BLOCK_N)。更多的 thread block 意味着更好的 SM occupancy 和 wave quantization。

6.4 实测性能数据

指标 数值 来源
DeepSeek-R1 吞吐(H200 Wide-EP) 2,200 tok/s/GPU vLLM 官方博客
vLLM vs SGLang(DeepSeek-R1 H200) +4.33%(1500 vs 1438 tok/s) 社区 benchmark
Grouped ordering 提升 A100 最高 4×,H100 最高 4.4× PyTorch 博客
共享专家融合(DeepSeek) ITL 提升最高 34% vLLM PR #15502
SplitK 分解 vs Data Parallel 18-20% 提升 PyTorch 博客

七、FP8 量化:三种模式

fused_moe kernel 原生支持三种 FP8 量化粒度,这对 MoE 模型至关重要——MoE 的参数量大,量化能显著降低显存占用和带宽需求。

7.1 Tensor-wise FP8

最粗粒度:整个 activation tensor 一个 scale,整个 weight tensor 一个 scale。

# 循环后统一 dequantize
accumulator = tl.dot(a, b, acc=accumulator)  # FP8 乘加
# ... 循环结束 ...
result = accumulator * a_scale * b_scale  # 单次 scale 乘

优点:简单、开销最低。缺点:对异常值敏感,精度损失较大。

7.2 Channel-wise FP8

Per-token activation scale + per-output-channel weight scale:

# a_scale: [M, 1]   每个 token 一个 scale
# b_scale: [1, N]   每个输出通道一个 scale
result = accumulator * a_scale[:, None] * b_scale[None, :]

精度和开销的平衡点,适合大多数场景。

7.3 Block-wise FP8(DeepSeek V3 方案)

这是 DeepSeek V3 论文提出的量化策略,也是目前精度最高的 FP8 方案:

权重量化:每 128×128 子矩阵一个 scale
激活量化:每 1×128 子向量一个 scale

Weight [K, N]:
┌────┬────┬────┬────┐
│s00 │s01 │s02 │s03 │  每个 128×128 block
├────┼────┼────┼────┤  有独立的 scale
│s10 │s11 │s12 │s13 │
├────┼────┼────┼────┤
│s20 │s21 │s22 │s23 │
└────┴────┴────┴────┘

Activation [M, K]:
每个 token 的每 128 个元素有独立的 scale
[a_s0_0, a_s0_1, a_s0_2, ...]  token 0
[a_s1_0, a_s1_1, a_s1_2, ...]  token 1

在 kernel 的 GEMM 循环内部,每处理一个 K 块就加载对应的 scale 并立即乘入累加器:

for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
    a = tl.load(a_ptrs, ...)
    b = tl.load(b_ptrs, ...)

    # 加载这个 K 块对应的 scale
    k_start = k * BLOCK_SIZE_K
    offs_ks = k_start // group_k          # group_k = 128
    a_scale = tl.load(a_scale_ptr + offs_ks * stride_ask)  # [BLOCK_M, 1]
    b_scale = tl.load(b_scale_ptr + offs_ks * stride_bsk)  # [1, BLOCK_N]

    # 在循环内乘 scale,保留精度
    accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]

为什么 block-wise 更精确? FP8 的动态范围只有 ~16(E4M3),如果一个 tensor 中既有很大的值又有很小的值,统一 scale 会让小值 underflow。128×128 的 block 粒度大幅缩小了同一 scale 内的数值范围差异。

7.4 多后端支持

当 N 较大(>512)时,vLLM 还提供了 Triton 之外的 FP8 后端:

后端 适用场景 量化格式
Triton kernel 通用,所有 FP8 模式 E4M3, per-tensor/channel/block
DeepGEMM N > 512,FP8 only E4M3, 128-block symmetric
CUTLASS FP8 + W4A8 混合 E4M3, channel/tensor scale

后端选择由 FusedMoEConfig 中的 oracle 决定,会根据矩阵大小和硬件自动选择最优实现。


八、Expert Parallel(EP)

MoE 模型的 expert 数量可以非常多(DeepSeek V3 有 256 个),单卡放不下所有 expert 的权重。Expert Parallel 将 expert 分布到多张 GPU 上。

8.1 Expert 分布策略

def determine_expert_map(ep_size, ep_rank, global_num_experts, strategy):
    """
    Linear(默认):
      Rank 0: experts [0, 1, ..., 63]
      Rank 1: experts [64, 65, ..., 127]
      Rank 2: experts [128, 129, ..., 191]
      Rank 3: experts [192, 193, ..., 255]

    Round-robin(更均匀的负载):
      Rank 0: experts [0, 4, 8, ...]
      Rank 1: experts [1, 5, 9, ...]
      Rank 2: experts [2, 6, 10, ...]
      Rank 3: experts [3, 7, 11, ...]
    """

8.2 通信模式

EP 的数据流包含两次 All-to-All 通信:

     token dispatch                expert compute             token combine
┌──────────────────┐          ┌──────────────────┐          ┌──────────────────┐
 GPU 0: 全部 token  ──A2A──→│ GPU 0: Expert 0-63 ──A2A──→│ GPU 0: 全部输出  
 GPU 1: 全部 token  ──A2A──→│ GPU 1: Expert 64+  ──A2A──→│ GPU 1: 全部输出  
 ...                         ...                         ...              
└──────────────────┘          └──────────────────┘          └──────────────────┘
  每个 GPU 有所有 token        每个 GPU 只有自己的 expert     结果聚合回所有 GPU

8.3 通信后端

vLLM 支持多种 EP 通信后端,各有适用场景:

后端 特点 最佳场景
DeepEP HT 高吞吐模式,基于 nvshmem Prefill 阶段,多节点
DeepEP LL 低延迟模式 Decode 阶段,单节点
PPLX Perplexity 开发的灵活方案 单节点(通常更快)
NCCL AllGather 标准集合通信 回退方案

8.4 动态负载均衡(EPLB)

MoE 的一个固有问题是 expert 负载不均衡——某些 expert 可能被大量 token 选中(热点),而另一些几乎没有流量。vLLM 实现了运行时动态重平衡:

# 运行时监控 expert 负载,动态调整 expert 到 GPU 的映射
def set_eplb_state(moe_layer_idx, expert_load_view,
                   logical_to_physical_map, logical_replica_count):
    """
    允许热门 expert 在多个 GPU 上有副本,
    冷门 expert 共享 GPU 空间
    """

九、模型集成实例

9.1 Mixtral 8x7B / 8x22B

Mixtral 是 fused_moe kernel 最初的目标模型。配置相对简单:

Expert 数量: 8
Top-K: 2
Hidden dim: 4096
Intermediate: 14336
激活函数: SiLU (SwiGLU)
路由: softmax + top-2 + renormalize

8 个 expert 的参数量相当于 7B × 8 = 56B(但每个 token 只激活 2 个 expert,等效计算量约 14B)。fused kernel 将 8 个 expert 的计算合并为 2 次 kernel launch(W1 + W2),相比 naive 的 16 次大幅降低了开销。

9.2 DeepSeek V3 / R1

DeepSeek V3 对 fused_moe 的要求高出一个数量级:

Expert 数量: 256 (routed) + 1 (shared)
Top-K: 8
Hidden dim: 7168
Intermediate: 2048 (每个 expert 较小,但数量多)
激活函数: SiLU (SwiGLU)
路由: GroupedTopK + sigmoid + e_score_correction_bias
量化: Block-wise FP8 (128×128 weight blocks, 1×128 activation)

DeepSeek V3 的几个特殊需求推动了 fused_moe 的重要演进:

① Grouped TopK 路由。 256 个 expert 分成多个组,先选组再选 expert。这需要 GroupedTopk 专用实现。

② Block-wise FP8。 DeepSeek V3 论文提出的量化方案,要求 kernel 在 GEMM 循环内乘 scale。这是对 kernel 内部逻辑的侵入式修改。

③ 共享专家融合(Shared Expert Fusion)。 DeepSeek V3 有一个 shared expert(所有 token 都经过)和 256 个 routed expert。SharedFusedMoE 将 shared expert 的计算与 routed expert 的计算融合,避免额外的 kernel launch。这一优化在 vLLM PR #15502 中实现,带来了最高 34% 的 ITL(Inter-Token Latency)改善。

④ Expert Parallel。 256 个 expert 不可能放在一张卡上。DeepSeek V3 的部署通常使用 Wide-EP(宽专家并行),将 expert 分散到多达 32-64 张 GPU 上。这需要高效的 All-to-All 通信和动态负载均衡。

9.3 Llama 4 Scout / Maverick

Meta 的 Llama 4 系列也采用了 MoE 架构,vLLM 为其提供了专用的 Llama4 routing method type。


十、代码级 Know-How

10.1 关键数据结构

# FusedMoE 层的核心参数
class FusedMoE(CustomOp):
    def __init__(self,
        num_experts: int,           # Expert 数量
        top_k: int,                 # 每个 token 选几个 expert
        hidden_size: int,           # 输入维度
        intermediate_size: int,     # FFN 中间维度
        tp_size: int = 1,           # Tensor Parallel size
        ep_size: int = 1,           # Expert Parallel size
        renormalize: bool = True,   # 是否重归一化路由权重
        quant_config: Any = None,   # 量化配置
        activation: str = "silu",   # 激活函数
        # ... 约 30 个参数
    ):
        # 权重:
        # w13_weight: [num_experts, 2 * intermediate_size, hidden_size]
        #   gate 和 up 合并存储(13 = W1 和 W3)
        # w2_weight:  [num_experts, hidden_size, intermediate_size]
        #   down projection

为什么 W1 和 W3 合并为 w13? 因为 SwiGLU 激活的公式是 SiLU(x × W1) * (x × W3),W1 和 W3 的输入相同(都是 x),可以合并成一个 GEMM:x × [W1; W3],输出的前半部分过 SiLU,后半部分直接相乘。这比两次单独的 GEMM 快 ~2 倍。

10.2 Kernel Config 调优

fused_moe 的性能对 block size 参数非常敏感。vLLM 提供了三层配置机制:

# 配置优先级(从高到低):
# 1. 用户自定义配置
#    VLLM_TUNED_CONFIG_FOLDER=/path/to/configs
#
# 2. 设备特定的预调优配置
#    vllm/model_executor/layers/fused_moe/configs/
#    文件命名:E={experts},N={dim},dtype={type},block_shape={shape}.json
#
# 3. 动态默认值
def get_default_config(M, E, N, K, topk, dtype, is_marlin, block_shape):
    if M <= 32:
        # 小 batch:memory-bound,用小 M tile + 大 K tile
        return {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64,
                "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1}
    elif M <= 128:
        return {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128,
                "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}
    else:
        # 大 batch:compute-bound,用大 tile
        return {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256,
                "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 32}

调优 tips:

  • BLOCK_SIZE_M 必须整除 padding 后的 token 数。太大会浪费(padding 多),太小会 launch 太多 thread block
  • BLOCK_SIZE_K 必须整除 K 维度(_ensure_block_size_k_divisible 会自动处理)
  • GROUP_SIZE_M 影响 L2 cache 复用率。对大矩阵用大值(16-32),小矩阵用 1
  • ROCm 平台 num_stages 上限为 2,超过会 OOM

10.3 调用链路

从最外层到最内层的调用链:

model.forward()
  └─ MoELayer.forward()
       └─ FusedMoE.forward()                        # layer.py
            ├─ self.router(hidden_states)            # 路由计算
            ├─ FusedMoEMethodBase.apply()             # 分发到具体实现
            │    └─ fused_experts()                   # fused_moe.py
            │         └─ fused_experts_impl()
            │              ├─ moe_align_block_size()  # Token 排序+填充
            │              ├─ fused_moe_kernel[grid]  # Triton kernel (W1)
            │              ├─ silu_and_mul()           # 激活
            │              ├─ fused_moe_kernel[grid]  # Triton kernel (W2)
            │              └─ moe_sum()               # Top-K 聚合
            └─ (如果有 shared expert) shared_expert + moe_sum

10.4 Modular Kernel 架构

vLLM 从 v0.10 开始将 fused_moe 重构为模块化架构,将计算分解为三个可替换的阶段:

# 模块化 pipeline:Prepare → Experts → Finalize
class FusedMoEModularKernel:
    prepare_finalize: FusedMoEPrepareAndFinalize  # 量化 + 通信(EP)
    experts: FusedMoEExperts                       # 实际 GEMM 计算

# 不同的 Expert 后端实现(可替换):
#   TritonExperts          — 标准 Triton kernel(通用)
#   BatchedTritonExperts   — 批处理格式 [E, max_tokens, K]
#   DeepGemmExperts        — DeepGEMM FP8 后端(N > 512)
#   CutlassExpertsFp8      — CUTLASS FP8 后端
#   CutlassExpertsW4A8Fp8  — 4-bit 权重 + 8-bit 激活

# 不同的 Prepare/Finalize 后端:
#   StandardPrepareAndFinalize     — 单卡,直接排序
#   DeepEPPrepareAndFinalize       — DeepEP All-to-All 通信
#   PPLXPrepareAndFinalize         — PPLX 通信后端

这种模块化设计让添加新的 expert 计算后端(比如新硬件的专用 kernel)或新的通信策略(比如新的 EP 方案)变得简单——只需实现对应的接口,不需要修改核心逻辑。

10.5 如何扩展

添加新的 Expert 后端:

class MyCustomExperts(FusedMoEExperts):
    def apply(self,
              output: torch.Tensor,          # 输出 buffer
              hidden_states: torch.Tensor,    # 输入激活
              w1: torch.Tensor,               # Gate+Up 权重
              w2: torch.Tensor,               # Down 权重
              topk_weights: torch.Tensor,     # 路由权重
              topk_ids: torch.Tensor,         # Expert 分配
              activation: str = "silu",
              global_num_experts: int = -1,
              expert_map: torch.Tensor | None = None,
              w1_scale: torch.Tensor | None = None,
              w2_scale: torch.Tensor | None = None,
              # ...
    ):
        # 实现你的 expert 计算逻辑
        ...

添加自定义 Tuning Config:

# 创建配置文件
mkdir -p /my_configs
cat > /my_configs/E=256,N=2048,dtype=fp8_w8a8,block_shape=128_128.json << 'EOF'
{
    "1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3},
    "32": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3},
    "64": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, "num_warps": 8, "num_stages": 4},
    "128": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4}
}
EOF

# 使用自定义配置启动
export VLLM_TUNED_CONFIG_FOLDER=/my_configs
python -m vllm.entrypoints.openai.api_server --model deepseek-ai/DeepSeek-V3 ...

配置文件的 key 是 batch size(M),value 是对应的 block size 参数。运行时根据实际 batch size 选择最近的配置。


十一、总结

fused_moe kernel 是 vLLM 服务 MoE 模型的核心性能引擎。它的设计体现了几个重要的系统优化思想:

① 减少 kernel launch。 把 O(E) 次 GEMM 压缩成 O(1) 次,通过 token 排序 + 分组 GEMM 实现。这是最直接的性能来源。

② 优化内存访问模式。 Grouped ordering 提升 L2 cache 命中率,sorted token IDs 确保 coalesced memory access,block alignment 消除 padding 的计算浪费。

③ 融合计算与量化。 FP8 的 dequantize 不是一个单独的 kernel,而是在 GEMM 循环内完成。block-wise FP8 甚至在每个 K 块内独立 dequantize,以极小的开销换取显著的精度提升。

④ 模块化与可扩展。 Prepare → Experts → Finalize 的三阶段抽象,让不同的硬件后端(Triton/DeepGEMM/CUTLASS)和通信方案(DeepEP/PPLX/NCCL)可以自由组合。

从 Mixtral 的 8 个 expert 到 DeepSeek V3 的 256 个 expert,MoE 模型的规模在快速增长。fused_moe kernel 作为这一趋势的基础设施层,其设计和优化思路值得每一个关注 AI Infra 的工程师深入理解。


参考资源