Code Now

【模型推理】浅谈CUDA Graph(WIP)

· LuYanFCP

CUDA Graph 算是在推理优化里被提到最多的特性之一了。原理倒是不复杂:把原本一次次 launch 的 kernel 打包成一张静态图,一次性扔给 GPU,省掉中间的调度开销。

但这个"打包"具体怎么做?PyTorch 里怎么用?以及最重要的是——在 H200 上到底能快多少?

这篇文章就是把这几个问题摸清楚的过程记录。

1. LLM 推理在忙什么

Graph 能减少 kernel launch 的开销,但为什么 LLM 推理特别需要这个?先简单说下背景。

老生常谈的问题,LLM 推理分两个阶段:

Prefill

  • 输入是用户的一整段 prompt
  • 一次性做完整的 self-attention 计算,然后讲产生的KVCache写入显存或者其他存储设备,从而在Decode阶段使用。
  • 特点是计算密集,kernel消耗的时间比较多,GPU 利用率相对高,同时对内存IO要要求比较低。总之就compute-bound,算力是决定他的速度的关键。

Decode

  • 自回归过程,每次只生成一个 token
  • 要用到之前计算的 KV Cache,只做必要的更新,对于单个请求推理的时候我们输入model的Tensor维度为[1, D], 其中D是Embedding的维度。
  • 特点是memory-bound, 单个请求计算量少,但是需要load KVCache巨大,而且推理的时间越长,KVCache load的压力就越大。

对于这个推理过程,我们可以认为CPU step2step的将模型每一层计算所需要的kernel相关数据和元数据准备好,调用launch通知GPU执行,然后异步的准备下一次计算所需要的信息的过程。

此时因为不同Prefill和Decode不同特点,CPU与GPU直接的协作和等待有不同的表现

  1. Prefill 由于Kernel计算时间比较长,因此CPU侧的开销其实基本可以被GPU的运算Overlap掉,因此关注点核心应该是GPU Kernel如何跑更快。
Image
  1. Decode阶段,其实就是两个地方了,首先Kernel计算量比较小,因此开销时间很少,这种情况就出现了GPU等CPU等情况,这也是就是为什么我们常说Decode阶段需要开CUDA Graph的,原因,CUDA Graph的目标就是消除GPU等CPU的问题,让GPU计算更高效。
Image

如今其实也有主流的推理框架也支持Prefill阶段的CUDA Graph,这个主要是因为在推理过程中模型有大量的小kernel,比如说Norm这类的Kernel,这种其实也类似于Decode阶段的样子,GPU也需要等待CPU,因此CUDA Graph也是可以计算的,至于细节后面我会再写一篇专门介绍Prefill Piecewise CUDA Graph的一些细节。


2. Pytorch在推理的时候CPU纠结在做什么?

因为像Pytorch这种框架一般对于一个算子支持很多Kernel的实现,因此为了支持这么多Kernel的实现就回有大量的dispatch的代码用来运行时自动路由到不同的kernel上,同时也会有很多准备工作,所有工作完成之后就调用cudaLaunchKernel交给CUDA的runtime层做一些准备工作并发送给GPU。在GPU运行的时候其实也会有一些操作,比如说GPU侧做Kernel环境的准备,在执行后也会有一些清理状态的后处理。 总结来说就如图,一个Kernel从Pytorch发起执行到做完运行有这几个步骤:

  1. CPU:
    1. Pytorch C++ Dispatch
    2. Kernel Launcher Prepare
    3. cudaLaunchKernel
  2. GPU:
    1. Command Processor
    2. Kernel Execute
    3. Post-Process
Image

2.1. Pytorch Dispatch & Kernel Prepare

Pytorch Dispatch: 这个其实是所有支持多种Kernel+使用Python实现执行前端框架的通病。

  1. 跨语言开销:调用过程中必须包含多语言的之间的走bind或者其他机制从Python到C++到调用开销。
  2. 检查并找到需要执行的Kernel函数指针:这个过程不一定发生在C++层,比如说Sglang的很多Kernel的Dispatch就发生在Python层。
  3. 准备数据和元数据:Dispatch的时候由于不同Kernel要求输入的内存布局/参数以及元数据的不同,因此这里框架也需要一个转化。

Kernel Prepare:主要是执行Kernel之前做一些提取裸指针/Grid/Block 维度计算/计算ShareMemory的过程。这些步骤在现代推理引擎中一般由比较专业的Kernel库来完成,比如说Flashinfer/Cutlass/Triton来完成。

  1. 取裸指针:CUDA Kernel 是不认识 PyTorch 的 Tensor 对象的。准备阶段需要调用 .data_ptr(),把 q, k_cache, v_cache 的底层显存物理地址(指针)全抽出来。
  2. Grid/Block 维度计算:比如说推理的时候,每个请求的序列长度都不一样,而且 大部分框架都是用了 Paged KV Cache。因此,在这个阶段,CPU 必须根据当前的 Batch Size、每个 sequence 的真实长度、KV Cache Block 的映射表(Block Tables)来动态计算:我需要启动多少个 CUDA Block?每个 Block 分配多少个线程(Threads)去处理哪个 token? 当前因为现在很多Kernel都是用一些JIT手段提前编译好了这些参数,因此这个阶段也经常去做动态找到Match输入条件的这些预编译的Kernel。
  3. 计算ShareMemory/Stage这些参数。

完成这些步骤之后就完全交给cudaLaunchKernel去准备和执行Kernel。

2.2. Kernel Launch

在代码中Kernel Launch其实就是这样一个调用kernel<<<grid, block>>>(...) 这个语法背后发生的很多事,主要就是给GPU发执行指令之前runtime的准备,比如说:

  1. 参数准备:把 kernel 参数拷到 constant memory
  2. 配置解析:grid、block、shared memory、stream 等配置
  3. 命令入队:生成一个 kernel launch command,插入到 stream 的 command queue
  4. 可能的发送:如果 queue 满了或需要同步,通过 driver 发给 GPU

在推理中,这些步骤每次都要走一遍,无论 kernel 本身多简单。对于计算时间长的kernel(比如 matmul),launch 开销占比可以忽略;但对于 tiny kernel(比如 element-wise 的 rms_norm),launch 时间和执行时间可能差不多长。

这就好比你去餐厅点餐:

  • 点一份牛排(计算时间长的kernel),等餐 30 分钟,点餐 1 分钟,占比很小
  • 点一杯水(tiny kernel),倒水只要 5 秒,但点餐还是要 1 分钟,占比巨大。

2.3. Command Process & Post Process

Command Processor: 负责接收指令、解析grid/dim、并在SM上分配任务。CPU 发来一个 Kernel A,Command Processor 接收、解析、调度。执行完后,它必须停下来,等待 CPU 通过 PCIe 总线发来 Kernel B 的指令。即使 CPU 发得很快,Command Processor 每次处理独立指令、评估 Stream 依赖关系都是有硬件周期开销的。

Cuda Graph就像我们C++编程中使用很多预编译的模版一样,相当于把动态Process过程转化为静态的了,完成Kernel A就直接做Kernel B相当于完全静态的过程,不需要调度。

Post Process:主要是更新状态机和尤其是同步点的信息。

2.4. 这些开销究竟占用了多少时间

已我当前的环境在H200+Qwen3-0.6B作为测试,这边我使用nsysPytorch Profile ClaudeCode总结后的数据,在Batch=1的情况下eager和graph状态耗时分别是15.56ms和4.44ms,总加速到3.5x1。后续的在H200上测试会详细的讲解如何测量这些数据

阶段 位置 Eager Graph 节省/kernel 数据源
PyTorch Dispatch + Kernel Launcher Prepare CPU 4-37 µs/call 取决于 op
(aten::mul/add ~5 µs, aten::mm 12 µs, aten::cudnn_attention_forward 37 µs)
~0 (只在 replay 最外层) ~99% [x] torch.profiler self_cpu_time 实测
cudaLaunchKernel() CPU 3.5 µs/call (nsys) / 4.0 µs/call (profiler) ~0 (graph 只有 1 次 cudaGraphLaunch 0.51 ms) 99%+ [x] [x] nsys cuda_api_sum + torch.profiler 双重验证
Command processor GPU 属于 gap 近零 大部分 -
Kernel Execute GPU 2.06 µs 2.24 µs -0.18 µs [x] nsys cuda_gpu_kern_sum 实测 (graph 略贵因为有 node metadata 开销)
Post Process GPU 属于 gap 近零 大部分 -
cudaLaunchKernel+Post Process GPU 7.22 µs/gap 0.31 µs/gap 6.91 µs [x] 间接实测: (total_gpu - kernel_exec) / kernel_count

3. CUDA Graph

CUDA Graph是什么,其实上面的很多讨论已经回答了这个问题,用一句话总结来说就是:CUDA Graph 是一种将一系列 GPU 操作(如内核执行、内存拷贝)及其前后依赖关系预先定义为一张有向无环图(DAG),然后通过单次指令将其整体打包提交给 GPU 自动调度执行的工作模型,旨在彻底消除 CPU 频繁下达指令所带来的调度开销。

Image

是DAG自然就有节点和边,在CUDA Graph中节点是每个GPU的操作比如说:Kernel launches/内存操作(D2H/H2D)/内存管理(malloc/free)/Host Func/空节点(用于同步和协调)。边其实没有实际的操作,只是代表一种依赖关系,比如并行/依赖。

3.1. Graph vs Stream:从"事件驱动"到"DAG"

以前写 CUDA 代码,默认就是用 stream。stream 本质上是个队列——你按顺序塞进去的操作,GPU 按顺序执行。简单粗暴,但问题也很明显:

// Stream 模型:线性队列,CPU 逐条提交
cudaMemcpyAsync(d_in, h_in, size, H2D, stream);
kernelA<<<grid, block, 0, stream>>>(d_in, d_tmp);  // A
kernelB<<<grid, block, 0, stream>>>(d_tmp, d_out); // B,必须等 A 完事

如果 B 和 C 其实没有依赖关系,可以并行跑,怎么办?传统做法是多开几个 stream,然后用 event 做同步:

cudaStream_t s1, s2;
cudaEvent_t ev;
// ... 创建 stream 和 event

kernelA<<<grid, block, 0, s1>>>();
cudaEventRecord(ev, s1);           // A 完了发个信号

kernelB<<<grid, block, 0, s2>>>();
cudaStreamWaitEvent(s2, ev);       // B 等 A
kernelC<<<grid, block, 0, s2>>>(); // C 也等 A,但 B 和 C 可以并行
Image

代码马上变得又臭又长。而且这只是两个并行分支,实际推理里有几十个 kernel,手动管 event就非常麻烦

Graph 的做法完全不同——你直接画一张依赖图:

Image

不需要手动管 event,runtime 看到图结构就知道 B 和 C 可以一起跑,D 必须等它俩都完。这是声明式(declarative)vs 命令式(imperative)的区别。

图是静态的因此编译器会静态的分配好所有资源,且不需要调度。这也是双刃剑,静态意味着丧失了灵活性,这也是他主要用在推理上的原因。

3.2. Graph 里能放什么?

DAG 自然就有节点和边。节点是实际干活的操作,边只是依赖关系(没有实际计算)。

CUDA Graph 支持的节点类型比我想象的多:

节点类型 实际对应 备注
Kernel kernel<<<...>>>() 标准执行Kernel的动作
Memcpy cudaMemcpyAsync H2D/D2H/D2D 数据传输
Memset cudaMemsetAsync 清零 buffer
Host host callback function CPU 端回调,慎用(会阻塞 GPU)
Child Graph 嵌套子图 可以图里套图,模块化
Empty 空操作 纯粹作为同步点用

对于 LLM 推理来说,绝大部分节点是 Kernel,偶尔有 Memcpy(比如从 host 拷 input_ids)。Host callback 在性能敏感路径上基本不用,因为它会强制 sync,把 GPU 卡住等 CPU 执行 callback。

边(edge)的概念更简单——就是"谁先谁后"的关系。A → B 表示 B 必须等 A 完事。如果两个节点之间没有路径相连,runtime 就会自动让它们并行。

3.3. 三个阶段:Define → Instantiate → Launch

CUDA Graph的生命周期分三步:

1. Define(定义)

构建图的拓扑结构——加节点、连边。这个阶段产物是 cudaGraph_t,只是个"图纸",可以修改(比如加节点、删边)。

2. Instantiate(实例化)

把图纸编译成可执行文件。CUDA 会验证拓扑合法性(比如有没有环)、预分配内部资源、生成执行计划。产物是 cudaGraphExec_t,这一步之后图就"冻结"了,不能直接改结构。

3. Launch(启动)

把实例化的图提交到一个 stream 上执行。一次 launch = 整个 DAG 的所有操作一次性扔给 GPU。

为什么要分三步?核心原因是开销分离。Define 和 Instantiate 的代价只付一次,之后每次推理只需要付 Launch 的极小代价。如果每次都要重新 Define + Instantiate,那 Graph 的优势就没了。

另外,Instantiate 之后也不是完全不能改——可以通过 cudaGraphExecKernelNodeSetParams() 更新节点参数(比如换输入 buffer 的指针),不需要重新 Instantiate。这点对推理很重要,因为每次输入数据会变,但图的结构不变。

3.4. 怎么构建 Graph?两种方式

方式一:Explicit API(手动拼)

直接调用 CUDA API 添加节点和边:

cudaGraph_t graph;
cudaGraphCreate(&graph);

// 添加一个 kernel 节点
cudaGraphNode_t kernelNode;
cudaKernelNodeParams params = {...};
cudaGraphAddKernelNode(&kernelNode, graph, NULL, 0, &params);

// 添加依赖边
cudaGraphAddDependencies(graph, &nodeA, &nodeB, 1);  // A → B

优点是控制精细,缺点是代码啰嗦。一般只在需要极致优化或者图结构动态变化时使用。

方式二:Stream Capture(自动录)

更常用的做法——在你现有的 stream 代码外包一层 capture,CUDA 自动帮你生成图:

cudaGraph_t graph;
cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);

// 这里写你原来的代码,会被自动录下来
kernelA<<<grid, block, 0, stream>>>();
kernelB<<<grid, block, 0, stream>>>();
// ...

cudaStreamEndCapture(stream, &graph);

PyTorch 的 torch.cuda.CUDAGraph 就是基于 Stream Capture 封装的。这种方式对现有代码侵入小,是推理优化的首选。


4. PyTorch上如何使用CUDA Graph

前面讲的都是 CUDA C++ 层面的 API,但实际项目里我用的是 PyTorch。PyTorch 提供了 torch.cuda.CUDAGraph,把底层的 capture、instantiate、launch 包装成了几行 Python 代码。尤其是对于capture模式来说,官方文档给的example其实就这几行2

import torch

# Enable CUDA Graph mode in PyTorch
model = YourModel().cuda()
static_input = torch.randn(32, 3, 224, 224, device='cuda')

# Warmup
for _ in range(3):
    _ = model(static_input)

# Capture the graph
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
    static_output = model(static_input)

# Training loop - replay the graph
for data in dataloader:
    static_input.copy_(data)  # Update input in-place
    g.replay()                # Execute captured operations

代码中其实也涉及到几个核心问题:

  1. 静态的输入和输出
  2. warmup
  3. capture限制

当前Pytorch CUDA Graph API只支持Capture,因为本身pytorch使用的是动态图(eager mode) 和声明式的静态图有一定的理念冲突。如果需要使用python+Explicit API 推荐直接使用cuda-python

4.1. 为什么是静态

Graph 要求"静态"的核心原因:它录制的是指针地址,而不是数据内容

录制时:kernel 从 d_input=0x7f00_1000 读,往 d_output=0x7f00_2000 写
         ↓
图中记录的就是这两个具体地址
         ↓
每次 replay:GPU 去 0x7f00_1000 读,往 0x7f00_2000 写

如果中间你把 d_input 给 free 了,重新 malloc 一个,地址变成 0x7f00_5000,但图还去读 0x7f00_1000——直接段错误。

所以 Graph 的"静态"不是指数据不能变,而是地址不能变。实际推理时的模式是:

# 1. 预分配固定 buffer(整个生命周期不释放)
static_input = torch.empty(max_batch, max_seq_len, device='cuda')

# 2. 每次推理:往固定地址填新数据
static_input.copy_(new_input)

# 3. replay:图从同一个地址读
g.replay()

# 4. 读结果:从固定地址拷出来
result = static_output.clone()

指针不变,数据变。 这是 Graph 能work的前提。

PyTorch 的 static_input.copy_(data)static_output.clone() 这对操作看着有点别扭——一个往里写,一个往外读——但目的都是保护图中的静态地址不被破坏。

除了输入输出,Graph 对模型内部的操作也有不少限制3

内存分配相关(最严格的限制)

  • tensor.resize_() / tensor.resize_as_() —— 可能改变底层存储地址
  • 任何导致 tensor 重新分配内存的操作 —— 比如 cat 后赋值给原变量
  • 动态创建新 tensor —— capture 期间分配的内存地址被钉住,但如果在图执行期间再分配,地址不稳定

同步操作

  • torch.cuda.synchronize() —— capture 期间操作不执行,sync 会死锁
  • tensor.item() / tensor.tolist() —— 把 GPU 数据拷到 CPU,隐含同步
  • print(tensor) —— 同上,要读值就必须 sync

数据依赖的控制流

  • if tensor.sum() > 0: —— 条件依赖 tensor 的值,但 capture 期间值不存在
  • for i in range(tensor.size(0)): —— 如果 size 是动态的,每次可能不一样
  • 任何需要知道 tensor 具体值才能决定走哪条分支的代码

一些安全替代方案

# 不安全
if x.sum() > 0:
    y = x * 2
else:
    y = x + 1

# 安全 —— 用 mask 代替分支
mask = (x > 0).float()
y = mask * (x * 2) + (1 - mask) * (x + 1)

# 不安全
for i in range(batch_size):  # batch_size 是 tensor 的值
    process(x[i])

# 安全 —— 用固定次数的循环
for i in range(MAX_BATCH):  # MAX_BATCH 是常量
    if i < batch_size:
        process(x[i])

简单来说:图的拓扑结构必须是固定的,不能依赖运行时数据做任何决策。

4.2. 为什么 Warmup 是必须的?

一开始我以为 warmup 只是"让性能稳定",后来发现没有它根本 capture 不对。

PyTorch 很多操作是 lazy 的:

  • cuDNN 卷积第一次会 benchmark 不同算法
  • JIT 编译的 kernel 第一次调用时才编译
  • cuBLAS workspace 第一次 matmul 时才分配

这些"第一次"会产生额外的内存分配和 kernel launch。如果在 capture 期间触发,会被录进图里——之后每次 replay 都重新选算法/编译,完全不对。

Warmup 的作用:提前触发所有 lazy 初始化,让 capture 只录到"稳态"的操作序列。

为什么通常跑 3 次?第 1 次触发 cuDNN benchmark 和 JIT 编译,第 2 次某些 allocator 策略可能调整,第 3 次确认真的稳态了。

4.3. Capture 期间的限制

capture期间其实类似我们开发软件的时候的MOCK执行,很多行为都是被mock掉的,而不是真实执行的,只是通过mock+capture快速的从算子流中捕获整体的图,因此除了之前的静态的限制还有如下限制:

  1. 同步操作会死锁
# 错误:capture 期间 sync 会卡住
with torch.cuda.graph(g):
    x = model(input)
    torch.cuda.synchronize()  # 死锁!操作根本没执行,你在等一个永远不会完成的事

sync 要等操作完成,但 capture 期间操作是 mock 的,没真跑,所以 sync 永远等不到。

  1. 不能读 tensor 的值
# 错误:capture 期间 tensor 的值不存在
with torch.cuda.graph(g):
    x = model(input)
    print(x[0])       # 报错:无法读取值
    if x.sum() > 0:   # 报错:条件依赖值
        ...

这限制比想象中严。比如你想打个 log 看中间结果,在 capture 期间是完全做不到的。只能先 capture 完,replay 的时候再想办法看。

  1. 控制流必须静态
# 错误:依赖 tensor 值的控制流
with torch.cuda.graph(g):
    for i in range(seq_len):  # seq_len 如果是 tensor 的值,不行
        process(x[i])

# 正确:用固定次数 + mask
with torch.cuda.graph(g):
    for i in range(MAX_LEN):  # MAX_LEN 是 Python 常量
        mask = (i < seq_len).float()
        out += mask * process(x[i])

实际推理里 seq_len 每次可能不一样,但图结构必须固定。解决办法是用最大长度循环,然后用 mask 把无效位置置零。

5. 真实推理引擎的中的CUDA Graph:

推理引擎中的CUDA Graph使用比较广泛,但是对于推理server,因为请求有多有少,因此在decode推理的时候对静态CUDA Graph也是一个挑战。现代的推理引擎比如说vLLM、Sglang、TensorRT-LLM 通常使用分桶的方式,通过在初始化的时候捕获一组不同输入Batch的CUDA Graph在推理的时候重放,逻辑如图:

Image

这也是sglang参数--cuda-graph-bs 设置参数的含义。

6. CUDA Graph究竟能加速多少

这里我使用Qwen3-0.6B 在H200上进行测试,代码如下


from __future__ import annotations

import json
import statistics
import time

import torch
from transformers import StaticCache

MODEL_ID = "Qwen/Qwen3-0.6B"
DEVICE = "cuda:0"
DTYPE = torch.bfloat16

# 测量参数
PROMPT = "The key difference between CUDA streams and CUDA graphs is"
NUM_DECODE_TOKENS = 128
WARMUP_STEPS = 10

RESULTS_DIR = Path(__file__).resolve().parent.parent / "results"


@dataclass
class LoadedModel:
    model: torch.nn.Module
    tokenizer: AutoTokenizer


def load_model() -> LoadedModel:
    """加载 Qwen3-0.6B。eval + no_grad 在调用侧处理。"""
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=DTYPE,
        attn_implementation="sdpa",
    ).to(DEVICE)
    model.eval()
    return LoadedModel(model=model, tokenizer=tokenizer)


def encode_prompt(tokenizer, prompt: str = PROMPT) -> torch.Tensor:
    """Prompt → input_ids on device. Shape: [1, prompt_len]."""
    return tokenizer(prompt, return_tensors="pt").input_ids.to(DEVICE)


def _prefill(model, input_ids, cache):
    """Prefill 走 eager. 返回首个 decode token 和 next cache_position."""
    prompt_len = input_ids.shape[1]
    cache_position = torch.arange(prompt_len, dtype=torch.int64, device=DEVICE)
    out = model(
        input_ids=input_ids,
        past_key_values=cache,
        cache_position=cache_position,
        use_cache=True,
    )
    next_tokens = out.logits[:, -1:, :].argmax(dim=-1)
    next_cache_position = cache_position[-1:] + 1
    return next_tokens, next_cache_position


def _eager_decode_step(model, cache, input_ids, cache_position):
    """跑一步 decode. 用于 warmup / 为 capture 准备 cuBLAS workspace."""
    out = model(
        input_ids=input_ids,
        past_key_values=cache,
        cache_position=cache_position,
        use_cache=True,
    )
    return out


@torch.no_grad()
def run_graphed() -> dict:
    loaded = load_model()
    model, tokenizer = loaded.model, loaded.tokenizer
    input_ids = encode_prompt(tokenizer)
    prompt_len = input_ids.shape[1]
    max_cache_len = prompt_len + NUM_DECODE_TOKENS + 16

    # ─────────────────────────────────────────────
    # 一次性 cache: 从头到尾只建一个
    # 原因: graph 捕获到的 cache 指针必须在 replay 时依然有效
    # ─────────────────────────────────────────────
    cache = StaticCache(
        config=model.config, max_cache_len=max_cache_len, device=DEVICE, dtype=DTYPE,
    )

    print(f"[setup] prefill + {WARMUP_STEPS} eager warmup decode steps ...")
    next_tokens, next_cache_position = _prefill(model, input_ids, cache)
    out = None
    for _ in range(WARMUP_STEPS):
        out = _eager_decode_step(model, cache, next_tokens, next_cache_position)
        next_tokens = out.logits[:, -1:, :].argmax(dim=-1)
        next_cache_position = next_cache_position + 1
    torch.cuda.synchronize()

	# 静态输入输出
    static_input_ids = next_tokens.clone()                # [1, 1]
    static_cache_position = next_cache_position.clone()   # [1]

    print("[capture] recording decode graph ...")
    
    s = torch.cuda.Stream()
    g = torch.cuda.CUDAGraph()
    
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        for _ in range(3):
            out = model(
                input_ids=static_input_ids,
                past_key_values=cache,
                cache_position=static_cache_position,
                use_cache=True,
            )
    torch.cuda.current_stream().wait_stream(s)
    
    
    with torch.cuda.graph(g):
        out = model(
                input_ids=static_input_ids,
                past_key_values=cache,
                cache_position=static_cache_position,
                use_cache=True,
            )
    static_output_logits = out.logits 

    # 重置 cache 状态 + 重做 prefill
    # 关键: StaticCache 内部有一个 cumulative_length 计数器, 每次 forward 都 +=
    # input_len (不管 cache_position 是否复用). 如果不 reset, 它会一路涨到超过
    # max_cache_len, 导致某处 index_copy_ 触发 "index out of bounds".
    # cache.reset() 会把 keys/values 清零, 并把 cumulative_length.zero_().
    # 注意: 这只是零内容, 地址不变——这正是 graph 依赖的前提, 所以 reset 不影响 g.
    cache.reset()

    next_tokens, next_cache_position = _prefill(model, input_ids, cache)
    static_input_ids.copy_(next_tokens)
    static_cache_position.copy_(next_cache_position)
    torch.cuda.synchronize()


    # 测量: 用 graph.replay() 代替 model() 做 decode loop
    events_pair: list[tuple[torch.cuda.Event, torch.cuda.Event]] = []

    t0 = time.perf_counter()
    for step in range(NUM_DECODE_TOKENS - 1):  # -1 因为 prefill 已经产生了第一个 token
        step_start = torch.cuda.Event(enable_timing=True)
        step_end = torch.cuda.Event(enable_timing=True)

        step_start.record()
        g.replay()
        step_end.record()
        events_pair.append((step_start, step_end))

        # 用 replay 的输出产生下一步输入. copy_() 原地更新静态 buffer
        next_tokens = static_output_logits[:, -1:, :].argmax(dim=-1)
        static_input_ids.copy_(next_tokens)
        static_cache_position.add_(1)  # 原地 +1

    torch.cuda.synchronize()
    t1 = time.perf_counter()
    total_wall_ms = (t1 - t0) * 1000
    per_step_gpu_ms = [s.elapsed_time(e) for s, e in events_pair]

	
	# 统计
    decode_ms = per_step_gpu_ms  # 这里全部是 decode (prefill 在外面)
    result = {
        "model": MODEL_ID,
        "dtype": "bfloat16",
        "cache": "static",
        "attention": "sdpa",
        "mode": "graphed",
        "max_cache_len": max_cache_len,
        "prompt_len": prompt_len,
        "num_decode_tokens": len(decode_ms),
        "decode_per_step_ms": decode_ms,
        "decode_stats": {
            "mean": statistics.mean(decode_ms),
            "median": statistics.median(decode_ms),
            "stdev": statistics.stdev(decode_ms) if len(decode_ms) > 1 else 0.0,
            "min": min(decode_ms),
            "max": max(decode_ms),
        },
        "total_wall_ms": total_wall_ms,
        "total_gpu_ms": sum(per_step_gpu_ms),
        "cpu_overhead_ms": total_wall_ms - sum(per_step_gpu_ms),
    }


    print(f"\n[graphed] decode mean:  {result['decode_stats']['mean']:.3f} ms/step")
    print(f"[graphed] decode stdev: {result['decode_stats']['stdev']:.3f} ms")
    print(f"[graphed] total wall:   {total_wall_ms:.1f} ms")
    print(f"[graphed] CPU overhead: {result['cpu_overhead_ms']:.1f} ms "
          f"({result['cpu_overhead_ms']/total_wall_ms*100:.2f}%)")
    print(f"[graphed] saved → {out_path}")
    return result


if __name__ == "__main__":
    run_graphed()

6.1. 结果

指标 eager (StaticCache) graph 倍率
decode mean 16.43 ms 4.50 ms 3.65×
decode stdev 0.106 ms 0.072 ms 1.47×
decode min 16.26 ms 4.49 ms
decode max 17.00 ms 5.30 ms
total wall 2108 ms 574 ms 3.67×

分桶的结果:

Image

6.2. nsys观察到的现象

所有的Kernel操作不需要等待合并为一个Graph

Image

Eager:

Image

7. 总结

捋完这一圈,关于 CUDA Graph 使用大概有这几个结论:

什么场景值得用

  • Decode 阶段收益最明显。Prefill 的 kernel 大,launch 开销占比小;Decode 全是 tiny kernel,CPU 提交成了瓶颈。实测 Qwen3-0.6B 在 H200 上 decode 能快 3.6 倍,这个差距在更小模型上会更夸张。
  • Batch size 固定或者变化范围小的服务场景。动态 shape 是 Graph 的天敌,要么预分配最大 buffer 浪费内存,要么维护一堆不同尺寸的 graph。
  • 已经做完算子融合的模型。如果 kernel 数量还是上百个,Graph 的收益会被摊薄。

主要的代价和限制

  • 内存。StaticCache + Graph 的静态 buffer,意味着要按最大可能长度预分配,显存占用比动态分配高。
  • 灵活性。控制流必须用 mask 代替分支,调试时不能 print 中间结果,调试特别困哪
  • 首次捕获的 overhead。Warmup + Capture 要花几十到几百毫秒,对冷启动敏感的场景是个问题。

8. 参考


查看原始Issue


  1. nvidia的文档提到的范围是之前硬件+低版本Pytorch上的开销,每一代CUDA都在优化cudaLaunchKernel的时间。对现代 ML 框架(PyTorch eager, TensorFlow, JAX),Python transition成本很低。因此实际数据与占比分布可能与官方数据有差距。 ↩︎

  2. https://docs.nvidia.com/dl-cuda-graph/latest/index.html#quick-start ↩︎

  3. Writing CUDA Graph-Compatible Code ↩︎