
当MoE遇上TPU:一次让推理效率翻倍的技术跃迁
在大语言模型的世界里,效率就是生命线。
2026年的今天,稀疏专家混合(MoE)架构已经成为支撑千亿参数模型的事实标准——它用"只激活部分专家"的稀疏性换来了参数量级的膨胀,却始终面临一个尴尬的现实:数据移动的开销像一只蛀虫,不断蚕食着计算资源的效率。
直到最近,SGLang团队在TPU v7x上完成了一次"外科手术式"的优化。他们用Fused MoE V2——一个融合了scatter、专家FFN和gather的Pallas核——将MoE预填充延迟一口气砍掉了53%。这意味着什么?原来需要5.16毫秒完成的预填充计算,现在2.42毫秒就能搞定。更惊人的是,在SGLang解码基准测试中,16块TPU v7x芯片的吞吐量已经达到了16块H200 GPU的1.77倍。
这不只是数字的胜利,更像是TPU在AI推理战场上的又一次宣言。
MoE的隐痛:被数据移动拖累的计算
要理解这次优化的意义,我们得先搞清楚MoE架构为什么会"慢"。
以inclusionAI发布的Ling-2.6-1T为例——这是目前最具代表性的稀疏MoE模型之一。它拥有1万亿总参数,但每次前向传播只激活63亿参数;它有256个路由专家,采用top-8路由加共享专家的混合策略。理论上,这种设计应该让计算效率远超同等规模的全密集模型。
然而现实是,MoE的计算过程天然存在一个效率黑洞:每个token需要被分散路由到不同的专家模块进行处理,然后再收集汇总。这个"分散-计算-聚合"的过程涉及大量的数据移动——token要从全局内存搬运到各个专家的本地计算单元,处理完成后又要聚合回来。当专家数量众多、路由策略复杂时,数据移动的开销甚至会超过实际计算本身。
打个不太准确但形象的比喻:就像一个快递仓库,商品本身可能只需要5分钟分拣,但装卸搬运却要花45分钟。这就是MoE架构的隐痛——算力够用,但搬运成了瓶颈。
传统的解法是逐个优化scatter(分散)和gather(聚合)操作,或者优化专家FFN的计算密度。但这种"各自为战"的优化思路存在天花板——数据移动和计算之间存在难以掩盖的间隙,总是有一部分硬件资源在等待。
SGLang团队决定换一种思路。
Fused MoE V2:让数据移动"消失"在计算中
这次发布的Fused MoE V2核,本质上是一个"三合一"的融合算子——它把scatter、专家FFN、gather三个操作融合成单一的内核,在Pallas框架下实现。
这听起来像是工程上的小改进,但背后的逻辑非常精妙。
传统流水线中,scatter完成后要等所有token都分散到位,计算单元才开始工作;gather完成后要等数据聚合完毕,下一步才能启动。但在Fused MoE V2中,这三个操作被编织成一张"流水线网"——当第一批token完成scatter进入FFN计算时,第二批token的scatter已经启动;当第一批token完成FFN进入gather阶段时,它的计算结果会立即参与聚合,而不必等待所有token都处理完毕。
换句话说,数据移动不再"打断"计算,而是被巧妙地"隐藏"在计算的间隙中——或者说,计算被巧妙地填充进了数据移动的间隙。这是一种以空间换时间、以流水线换串行的思路,但实现起来需要对TPU的硬件特性和Pallas编程模型有极深的理解。
我查阅了Pallas的相关资料,它是Google为TPU量身定制的稀疏计算编程框架,核心思想就是通过融合算子减少内存访问、提升数据局部性。Fused MoE V2正是这一理念的极致实践——它不只是优化了单个操作,而是重新设计了整个数据流。
这个核还支持混合KV/循环内存池,结合GLA(线性注意力)机制和单控制器数据并行能力,构成了一套完整的推理优化方案。但最核心的突破,依然是MoE数据移动的隐藏。
性能数据:53%的延迟降幅是如何炼成的
让我们把镜头拉近,看看具体的数据。
预填充阶段(Prefill,处理输入prompt的计算):
- 优化前延迟:5.16ms
- 优化后延迟:2.42ms
- 降幅:53%
这个数字足够震撼。预填充是LLM推理中计算最密集的阶段,53%的延迟降低意味着模型可以更快地开始生成第一个token,用户感知的"响应延迟"会大幅缩短。如果仅替换MoE核,预填充吞吐量就能提升24.8%——这是一个不需要改动模型结构、不需要重新训练的"即插即用"式优化。
解码阶段(Decoding,逐token生成的计算):
- 优化前核延迟:0.249ms
- 优化后核延迟:0.211ms
- 降幅:约15%
- 吞吐量提升:18.5%-35.3%
解码阶段的优化幅度相对较小,这是可以理解的——解码时每个token的计算量本身就较小,MoE的占比也相对较低。但15%的延迟降低和最高35.3%的吞吐量提升,对于实际部署中需要每秒处理成千上万次生成请求的场景,依然意义重大。
终极对比:在SGLang官方解码基准测试中,16块TPU v7x芯片对阵16块H200 GPU:
- 当mc=128(中等并发)时,TPU吞吐量是H200的1.29倍
- 当mc=512(高并发)时,TPU吞吐量是H200的1.77倍
这个结果很有意思——并发越高,TPU的优势越明显。这说明Fused MoE V2不仅优化了单核效率,还通过单控制器数据并行等技术提升了整体的扩展性。当负载增加时,TPU能够更充分地利用融合算子带来的效率优势。
行业启示:推理战场的格局正在改写
这次发布让我想到一个更大的图景。
过去几年,NVIDIA的GPU凭借CUDA生态的先发优势和HBM内存的高带宽,几乎垄断了AI训练和推理市场。Google的TPU虽然在内部使用广泛,但一直被视为"专有武器",难以在更广泛的市场形成竞争。
然而,JAX生态的崛起正在改变游戏规则。
SGLang最初是作为LLM推理服务框架出现的,最早支持的是PyTorch生态。但这次SGLang-JAX的发布表明,JAX/Pallas已经具备了与主流框架正面竞争的能力。Fused MoE V2这样的优化表明,TPU的软件栈正在快速成熟,不再是那个"需要深度定制才能用"的异类。
从更宏观的视角看,这次优化的成功也验证了一个趋势:在AI基础设施层面,软硬协同优化正在成为提升效率的主要途径。Google的TPU团队+Pallas框架+SGLang团队的协作,产出的是一个跨越多个技术栈的系统级优化。这种整合能力,是单一厂商难以复制的。
当然,GPU生态也不会坐以待毙。NVIDIA的Blackwell架构、CUDA Graph优化、TensorRT-LLM的持续迭代,都在提升MoE推理效率。但至少在TPU这个赛道上,我们看到了一种可能性的打开:当硬件特性能被正确理解,当编程模型能提供足够的灵活度,当融合算子能真正消除瓶颈——效率的提升,往往比我们想象的更剧烈。
---
说实话,这次SGLang-JAX的更新可能不会像某些"重磅发布"那样引发社交媒体的狂欢。但对于真正在大规模部署MoE模型的团队来说,这可能是一个改变游戏规则的更新。
53%的预填充延迟降低、1.77倍的吞吐量对比——这些数字背后是TPU生态在AI推理领域的一次重要证明。我个人判断,随着JAX/Pallas生态的进一步完善,我们会看到更多类似的"效率跃迁"出现在TPU上。
这不是终点,而是开始。
