📌 本文重點
- 長 prompt / 短 response RL 訓練會浪費 >90% 計算
- 把推理用 KV/prefix cache 思路搬進帶梯度訓練可大幅提速
- 在 Qwen3.5-4B 上實測最高約 7.5x throughput 提升
長 prompt、短 response 的 RLHF/RLAIF 任務(例如對話評分、工具調用評分)有一個非常痛的點:每個樣本都在重算同一段 prompt。對 1000-token prompt、100-token response 的場景,你實際上有 >90% 的 FLOPs 在白白重褾。這篇要講的是:如何把推理時的 KV/prefix cache 思路搬進帶梯度的 RL 訓練,在 Qwen3.5-4B 上實測最高拿到 7.5x 速度提升,並給你一套可以直接落地的工程實作方案。
💡 關鍵: 在長 prompt / 短 response 場景中,重用 prompt 前向計算可將大部分重複 FLOPs 直接省掉,帶來數倍級 throughput 提升。
重點說明
1. 為什麼 RL 訓練會浪費那麼多計算?
典型的 RLHF/RLAIF 術次資料形態:
- prompt:系統 + 多輪對話 + 任務描述(幾百到上千 tokens)
- response:模型生成或候選回答(幾十到一兩百 tokens)
多數開源 RL engine(包括許多自寫 pipeline)會:
[ prompt tokens ][ response tokens ]
T_prompt T_resp
對每一個樣本、每一次 rollout / gradient step,都從頭跑整條序列,雖然 prompt 完全相同,只是 response 不同。這會帶來幾個直接影響:
- GPU 利用率被長 prompt 綁死:
- 你以為自己 batch size 是 64,其實「有效」只有在 response 段,前面 90% 的計算是在重放。
- batch 設計被 context 長度限制:
- 1000+ token prompt 會吃掉大部份 memory,導致你無法疊大 batch,只能靠 gradient accumulation,進一步增加 step latency。
- RL 特有放大器:
- 同一個 prompt 下可能要算多個候選 response、policy/value 多頭、不同 reward function,全都從 prompt 重新 forward 一次。
因此,只要你是「長 prompt / 短 response」型任務,任何一點在 prompt 端節省的 FLOPs,都是純利潤。
2. 把 KV/prefix cache 搬進訓練:核心思路
推理時我們早就習慣用 KV cache/prefix cache:
- 先跑一次 prompt,存下每層的 key/value(或 hidden states)。
- 生成 response 時,只計算增量 token,復用前綴。
在訓練中要做到類似的事情,難點在於:
- 我們需要 完整的 computation graph(for backprop)。
- 不能只存數值(像推理那樣),還要讓 autograd 知道這些值是可導的。
- 不能打壞 attention:response 的 attention 要能看見 prompt token 的 hidden states。
一種工程上可行的做法(簡化描述):
- 把序列拆成兩段圖:prompt graph + response graph。
- prompt 部分:
- 前向一次,拿到 prompt hidden states(例如每層的
h_prompt)與最後一層的 cache-like 表示。 - 保留其 computation graph(不 detach),但不馬上 backward。
- response 部分:
- 再跑一次 LLM,但將 prompt 當成固定 prefix 傳入,使 response token 的 attention 能看到這些 prefix hidden states。
- 在 PyTorch 裡可以透過自訂 forward 函數,把 prompt hidden states 塞回 attention 模組,類似手動實作 prefix cache。
- loss 計算只對 response tokens 做(例如 policy loss、value loss),但梯度會沿著 response→prompt 的 graph 反傳,保證不破壞訓練正確性。
關鍵是:
- 只對 prompt 前向一次,但仍然讓 prompt 參與梯度更新。
- 對同一 prompt 的多個 response,重複使用一份 prompt hidden states(甚至在一個批次中共享)。
在 Qwen3.5-4B 上,reddit 實測:
- prompt : response ≈ 10:1(例如 1000:100)
- RL 任務:長對話 + 短完成
- 快取後在長 prompt/短 response 工作負載下 最高取得 ~7.5x step throughput 提升(取決於實際長度比與 IO/通信開銷)。
💡 關鍵: 當 prompt 與 response 長度比約 10:1 時,只重算 response 部分可在實測中帶來約 7.5 倍 step throughput 提升。
3. 什麼任務最吃紅利?
根據 Qwen3.5-4B 測試經驗與工作負載特性,大致可以這樣判斷:
- 長 prompt / 短 response(
T_prompt / T_resp ≥ 4) - 如:對話 RLHF 評分(用戶上下文很長,模型答覆很短)。
- 工具調用評分:所有工具 schema + log 作為 prompt,再對短 decision 進行 RL。
- 部分代碼 RL:整個大檔案為 prompt,模型只改一小段。
-
這類場景通常可以拿到 3x–7.5x 的實際提速。
-
中 prompt / 中 response(
T_prompt / T_resp ≈ 1) - 如:通用問答 RLHF(prompt 只有一兩句,回答較長)。
-
提速有限,約 1.2x–2x,且實作複雜度可能不值。
-
短 prompt / 長 response(
T_prompt / T_resp < 1) - 基本沒紅利,甚至會因複雜控制流、多段 graph 而變慢。
實務上可以用一條 thumb rule:
如果你平均的 prompt token 數是 response 的 3 倍以上,就應該認真評估導入。
💡 關鍵: 當
T_prompt至少約為T_resp的 3 倍時,引入訓練版 prompt cache 通常才有顯著性價比。
實作範例
以下示例是 PyTorch 為主,偏 pseudo code,但結構與實務工程接近。
1. 資料結構與 DataLoader 改寫
我們先把一個 RL batch 明確拆成 prompt / response:
# 每個樣本:
# prompt_ids: [T_p]
# resp_ids: [T_r]
class RLDataset(torch.utils.data.Dataset):
def __getitem__(self, idx):
item = self.data[idx]
return {
"prompt_ids": item.prompt_ids, # 長
"resp_ids": item.resp_ids, # 短
"reward": item.reward, # 或 advantage
}
def collate_fn(batch):
# padding & batch 組合
prompt_ids = pad_sequence([b["prompt_ids"] for b in batch], batch_first=True)
resp_ids = pad_sequence([b["resp_ids"] for b in batch], batch_first=True)
# 生成對應 mask
prompt_attn_mask = (prompt_ids != pad_token_id)
resp_attn_mask = (resp_ids != pad_token_id)
return {
"prompt_ids": prompt_ids,
"resp_ids": resp_ids,
"prompt_mask": prompt_attn_mask,
"resp_mask": resp_attn_mask,
"reward": torch.tensor([b["reward"] for b in batch]),
}
2. 模型 forward:拆成 prompt graph + response graph
假設你有一個可插拔的 LLM 模型 model,我們新增兩個關鍵 API:
model.forward_prompt(...):只跑 prompt,返回 hidden states(及必要 cache)。model.forward_response_with_prefix(...):給定 prefix hidden states,跑 response。
class RLPromptCacheModel(nn.Module):
def forward_prompt(self, input_ids, attention_mask):
# 返回每層的 hidden,或最後一層即可
# 重要:不要 detach,保持 grad
outputs = self.transformer(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
)
return outputs.hidden_states # list[Layer][B, T_p, H]
def forward_response_with_prefix(self,
resp_ids,
resp_mask,
prompt_hidden_states,
prompt_mask):
# 這裡需要改造 attention:
# 讓每層 self-attention 的 KV = [prompt, resp]
# 可以在每層 module 裡寫一個 hook,或實作 custom attn。
outputs = self.transformer_with_prefix(
resp_ids=resp_ids,
resp_mask=resp_mask,
prefix_hidden_states=prompt_hidden_states,
prefix_mask=prompt_mask,
)
return outputs.last_hidden_state
核心點:transformer_with_prefix 要做到:
- 對於每層的 self-attention:
- query 來自 response tokens;
- key/value 為
[prefix_hidden_states; resp_hidden]; - 這讓 response token 能正常 attend 到 prompt,並保持完整 graph。
實務上可以參考 FlashAttention / prefix-tuning 的實作方式,直接拼接 prefix hidden 作為額外 token,再控制 mask:
def transformer_with_prefix(...):
# 假設我們把 prefix & response 在 time 維度上串起來
# 注意這裡是邏輯串接,實際可用 concat + mask 控制
concat_hidden = torch.cat([prefix_hidden, resp_emb], dim=1) # [B, T_p+T_r, H]
concat_mask = torch.cat([prefix_mask, resp_mask], dim=1) # [B, T_p+T_r]
# 交給原本的 transformer 做 self-attention
outputs = self.base_transformer(
hidden_states=concat_hidden,
attention_mask=concat_mask,
)
# 只取 response 對應位置的輸出
resp_hidden_out = outputs.last_hidden_state[:, -resp_len:, :]
return resp_hidden_out
3. Loss 計算與 RL head
以 policy gradient 為例,我們只對 response token 做 loss:
prompt_hs = model.forward_prompt(batch["prompt_ids"], batch["prompt_mask"]) # list[L]
resp_logits = model.forward_response_with_prefix(
batch["resp_ids"],
batch["resp_mask"],
prompt_hs,
batch["prompt_mask"],
)
# policy head
logits = policy_head(resp_logits) # [B, T_r, V]
log_probs = F.log_softmax(logits, dim=-1)
# 只對實際採樣到的 token 做 loss
# 假設 resp_ids 是我們的 action
token_logp = log_probs.gather(-1, batch["resp_ids"].unsqueeze(-1)).squeeze(-1)
# 依 RL 演算法計算 advantage 等
loss = -(token_logp * advantage_mask).sum() / num_valid_tokens
loss.backward()
因為 prompt_hs 沒有被 detach,梯度會沿著 response 部分回傳到 prompt 部分,等效於一次走完整個序列,但 prompt 只 forward 一次。
4. 與 gradient checkpointing / mixed precision / DDP 整合
- gradient checkpointing:
- 可以只對 response graph 開啟 checkpoint,prompt graph 一般不需要再切。
-
若 prompt 特別長,可在 prompt 段也設 checkpoint,但要注意不要把 cache 給破壞(照 layer 切即可)。
-
mixed precision (
AMP/Fp16/bf16): - 保持 prompt & response forward 使用同一個
torch.cuda.amp.autocast區塊。 -
prompt cached hidden 和 response 的精度必須一致,避免 dtype mismatch。
-
DDP/FSDP: - 基本原則:prompt forward 也在每個 rank 上做一次,不要跨 rank 共用 hidden,避免額外通信。
- 對
FSDP來說,prompt hidden 是 activation,照樣會被 shard/rebuild,不需要特別處理。 - 注意 loss scale 及
no_sync()區段,確保多 step accumulation 時 prompt/response 的 backward 一致。
建議與注意事項
1. 常見坑
- 快取導致樣本 shuffle 不均
- 若你把「相同 prompt 的多個 response」綁在一起,容易造成某些 prompt 被過度訓練。
-
建議在 dataset 層維持 樣本級 shuffle,不要把 prompt 當成硬分桶,或定期重組 group。
-
mask 錯誤導致梯度泄漏
- 如果 attention mask 沒處理好,可能出現:response token 看到未來 token,或不同樣本互相看到彼此的 prompt。
-
尤其在 concat prefix 時,要確認:
- padding token 完全被 mask 掉;
- prefix 與 response 的因果 mask 正確(response 不該看到未來 response)。
-
policy / value head 不一致
- 很多 RL pipeline 會同時跑 policy head + value head。
- 如果你只對 policy 路徑用 prompt cache,而 value 還在跑 full sequence,
會導致兩邊的 feature distribution 不一致。 - 建議:兩個 head 共用同一套 prompt+response 拆圖邏輯,或至少在 feature 塊對齊。
2. 什麼時候值得導入?
你可以簡單做一個估算:
- 計算平均
T_prompt / T_resp。 - 估算你的訓練 step 中,有多少時間是花在 forward(相對於通信/IO)。
- 目標提速 ≈
T_total / (T_resp + T_prompt / cache_reuse_factor)。
若粗算下來:
- 理論加速 > 2x,且你目前的 RL 訓練被 FLOPs-bound(非 IO-bound),那導入很可能值得。
- 若你被 data loading 或 reward 模型 inference 卡住,則先優化 pipeline 再考慮這一層。
3. 實務指引(TL;DR)
- 優先導入場景:
- RLHF/RLAIF 的對話評分、工具調用評分、長上下文 code RL。
- prompt 長度是 response 的 3–10 倍。
-
使用 Qwen3.5-4B 或相近大小模型,GPU 計算是主要瓶頸。
-
預期收益:
- 實測可達 3x–7.5x throughput 提升。
- 允許你把 batch 撐大,減少 gradient accumulation,進一步提高 GPU 利用率。
-
相同 GPU 成本下,能多跑數倍 rollout 或更長訓練步數。
-
導入步驟建議:
- 先在小 batch 上實作
forward_prompt+forward_response_with_prefix,只做 sanity check。 - 確認與原 full sequence 訓練的 loss/梯度差異在可接受範圍(數值抖動為正常)。
- 再導入
DDP/FSDP+AMP,逐步拉大 batch 測 throughput。 - 監控 loss 曲線與最終 RL reward,確認沒有明顯退化。
只要你的 RL 任務落在「長 prompt / 短 response」區間,RL 訓練版 prompt cache 幾乎就是一次性的大幅成本折扣;對正在做 RLHF/RLAIF 的團隊,值得花 1–2 週工程時間好好實作一版。
🚀 你現在可以做的事
- 在現有 RLHF/RLAIF 代碼中量測平均
T_prompt / T_resp,判斷是否達到導入門檻(≥3)- 在一個小型實驗中實作
forward_prompt與forward_response_with_prefix,對比 full sequence 訓練的 loss/梯度- 在實際 Qwen3.5-4B 或現用模型上開啟 prompt cache 實驗,記錄 throughput 與成本變化,評估是否全面導入

