標籤: 長 prompt 短 response

  • RL 訓練版 Prompt Cache 7.5x 提速解析

    RL 訓練版 Prompt Cache 7.5x 提速解析

    📌 本文重點

    • 長 prompt / 短 response RL 訓練會浪費 >90% 計算
    • 把推理用 KV/prefix cache 思路搬進帶梯度訓練可大幅提速
    • 在 Qwen3.5-4B 上實測最高約 7.5x throughput 提升

    長 prompt、短 response 的 RLHF/RLAIF 任務(例如對話評分、工具調用評分)有一個非常痛的點:每個樣本都在重算同一段 prompt。對 1000-token prompt、100-token response 的場景,你實際上有 >90% 的 FLOPs 在白白重褾。這篇要講的是:如何把推理時的 KV/prefix cache 思路搬進帶梯度的 RL 訓練,在 Qwen3.5-4B 上實測最高拿到 7.5x 速度提升,並給你一套可以直接落地的工程實作方案。

    💡 關鍵: 在長 prompt / 短 response 場景中,重用 prompt 前向計算可將大部分重複 FLOPs 直接省掉,帶來數倍級 throughput 提升。


    重點說明

    1. 為什麼 RL 訓練會浪費那麼多計算?

    典型的 RLHF/RLAIF 術次資料形態:

    • prompt:系統 + 多輪對話 + 任務描述(幾百到上千 tokens)
    • response:模型生成或候選回答(幾十到一兩百 tokens)

    多數開源 RL engine(包括許多自寫 pipeline)會:

    [ prompt tokens ][ response tokens ]
      T_prompt           T_resp
    

    對每一個樣本、每一次 rollout / gradient step,都從頭跑整條序列,雖然 prompt 完全相同,只是 response 不同。這會帶來幾個直接影響:

    1. GPU 利用率被長 prompt 綁死
    2. 你以為自己 batch size 是 64,其實「有效」只有在 response 段,前面 90% 的計算是在重放。
    3. batch 設計被 context 長度限制
    4. 1000+ token prompt 會吃掉大部份 memory,導致你無法疊大 batch,只能靠 gradient accumulation,進一步增加 step latency。
    5. RL 特有放大器
    6. 同一個 prompt 下可能要算多個候選 response、policy/value 多頭、不同 reward function,全都從 prompt 重新 forward 一次。

    因此,只要你是「長 prompt / 短 response」型任務,任何一點在 prompt 端節省的 FLOPs,都是純利潤


    2. 把 KV/prefix cache 搬進訓練:核心思路

    推理時我們早就習慣用 KV cache/prefix cache

    1. 先跑一次 prompt,存下每層的 key/value(或 hidden states)。
    2. 生成 response 時,只計算增量 token,復用前綴。

    在訓練中要做到類似的事情,難點在於:

    • 我們需要 完整的 computation graph(for backprop)。
    • 不能只存數值(像推理那樣),還要讓 autograd 知道這些值是可導的。
    • 不能打壞 attention:response 的 attention 要能看見 prompt token 的 hidden states。

    一種工程上可行的做法(簡化描述):

    1. 把序列拆成兩段圖:prompt graph + response graph。
    2. prompt 部分:
    3. 前向一次,拿到 prompt hidden states(例如每層的 h_prompt)與最後一層的 cache-like 表示。
    4. 保留其 computation graph(不 detach),但不馬上 backward。
    5. response 部分:
    6. 再跑一次 LLM,但將 prompt 當成固定 prefix 傳入,使 response token 的 attention 能看到這些 prefix hidden states。
    7. 在 PyTorch 裡可以透過自訂 forward 函數,把 prompt hidden states 塞回 attention 模組,類似手動實作 prefix cache。
    8. loss 計算只對 response tokens 做(例如 policy loss、value loss),但梯度會沿著 response→prompt 的 graph 反傳,保證不破壞訓練正確性。

    關鍵是:

    • 只對 prompt 前向一次,但仍然讓 prompt 參與梯度更新。
    • 對同一 prompt 的多個 response,重複使用一份 prompt hidden states(甚至在一個批次中共享)。

    在 Qwen3.5-4B 上,reddit 實測:

    • prompt : response ≈ 10:1(例如 1000:100)
    • RL 任務:長對話 + 短完成
    • 快取後在長 prompt/短 response 工作負載下 最高取得 ~7.5x step throughput 提升(取決於實際長度比與 IO/通信開銷)。

    💡 關鍵: 當 prompt 與 response 長度比約 10:1 時,只重算 response 部分可在實測中帶來約 7.5 倍 step throughput 提升。


    3. 什麼任務最吃紅利?

    根據 Qwen3.5-4B 測試經驗與工作負載特性,大致可以這樣判斷:

    1. 長 prompt / 短 response(T_prompt / T_resp ≥ 4
    2. 如:對話 RLHF 評分(用戶上下文很長,模型答覆很短)。
    3. 工具調用評分:所有工具 schema + log 作為 prompt,再對短 decision 進行 RL。
    4. 部分代碼 RL:整個大檔案為 prompt,模型只改一小段。
    5. 這類場景通常可以拿到 3x–7.5x 的實際提速。

    6. 中 prompt / 中 response(T_prompt / T_resp ≈ 1

    7. 如:通用問答 RLHF(prompt 只有一兩句,回答較長)。
    8. 提速有限,約 1.2x–2x,且實作複雜度可能不值。

    9. 短 prompt / 長 response(T_prompt / T_resp < 1

    10. 基本沒紅利,甚至會因複雜控制流、多段 graph 而變慢。

    實務上可以用一條 thumb rule:

    如果你平均的 prompt token 數是 response 的 3 倍以上,就應該認真評估導入。

    💡 關鍵:T_prompt 至少約為 T_resp 的 3 倍時,引入訓練版 prompt cache 通常才有顯著性價比。


    實作範例

    以下示例是 PyTorch 為主,偏 pseudo code,但結構與實務工程接近。

    1. 資料結構與 DataLoader 改寫

    我們先把一個 RL batch 明確拆成 prompt / response:

    # 每個樣本:
    # prompt_ids: [T_p]
    # resp_ids:   [T_r]
    
    class RLDataset(torch.utils.data.Dataset):
        def __getitem__(self, idx):
            item = self.data[idx]
            return {
                "prompt_ids": item.prompt_ids,   # 長
                "resp_ids": item.resp_ids,       # 短
                "reward": item.reward,           # 或 advantage
            }
    
    
    def collate_fn(batch):
        # padding & batch 組合
        prompt_ids = pad_sequence([b["prompt_ids"] for b in batch], batch_first=True)
        resp_ids   = pad_sequence([b["resp_ids"]   for b in batch], batch_first=True)
    
        # 生成對應 mask
        prompt_attn_mask = (prompt_ids != pad_token_id)
        resp_attn_mask   = (resp_ids   != pad_token_id)
    
        return {
            "prompt_ids": prompt_ids,
            "resp_ids": resp_ids,
            "prompt_mask": prompt_attn_mask,
            "resp_mask": resp_attn_mask,
            "reward": torch.tensor([b["reward"] for b in batch]),
        }
    

    2. 模型 forward:拆成 prompt graph + response graph

    假設你有一個可插拔的 LLM 模型 model,我們新增兩個關鍵 API:

    • model.forward_prompt(...):只跑 prompt,返回 hidden states(及必要 cache)。
    • model.forward_response_with_prefix(...):給定 prefix hidden states,跑 response。
    class RLPromptCacheModel(nn.Module):
        def forward_prompt(self, input_ids, attention_mask):
            # 返回每層的 hidden,或最後一層即可
            # 重要:不要 detach,保持 grad
            outputs = self.transformer(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True,
            )
            return outputs.hidden_states  # list[Layer][B, T_p, H]
    
        def forward_response_with_prefix(self,
                                         resp_ids,
                                         resp_mask,
                                         prompt_hidden_states,
                                         prompt_mask):
            # 這裡需要改造 attention:
            # 讓每層 self-attention 的 KV = [prompt, resp]
            # 可以在每層 module 裡寫一個 hook,或實作 custom attn。
            outputs = self.transformer_with_prefix(
                resp_ids=resp_ids,
                resp_mask=resp_mask,
                prefix_hidden_states=prompt_hidden_states,
                prefix_mask=prompt_mask,
            )
            return outputs.last_hidden_state
    

    核心點:transformer_with_prefix 要做到:

    • 對於每層的 self-attention:
    • query 來自 response tokens;
    • key/value 為 [prefix_hidden_states; resp_hidden]
    • 這讓 response token 能正常 attend 到 prompt,並保持完整 graph。

    實務上可以參考 FlashAttention / prefix-tuning 的實作方式,直接拼接 prefix hidden 作為額外 token,再控制 mask:

    def transformer_with_prefix(...):
        # 假設我們把 prefix & response 在 time 維度上串起來
        # 注意這裡是邏輯串接,實際可用 concat + mask 控制
        concat_hidden = torch.cat([prefix_hidden, resp_emb], dim=1)  # [B, T_p+T_r, H]
        concat_mask   = torch.cat([prefix_mask, resp_mask], dim=1)   # [B, T_p+T_r]
    
        # 交給原本的 transformer 做 self-attention
        outputs = self.base_transformer(
            hidden_states=concat_hidden,
            attention_mask=concat_mask,
        )
        # 只取 response 對應位置的輸出
        resp_hidden_out = outputs.last_hidden_state[:, -resp_len:, :]
        return resp_hidden_out
    

    3. Loss 計算與 RL head

    以 policy gradient 為例,我們只對 response token 做 loss:

    prompt_hs = model.forward_prompt(batch["prompt_ids"], batch["prompt_mask"])  # list[L]
    
    resp_logits = model.forward_response_with_prefix(
        batch["resp_ids"],
        batch["resp_mask"],
        prompt_hs,
        batch["prompt_mask"],
    )
    
    # policy head
    logits = policy_head(resp_logits)  # [B, T_r, V]
    log_probs = F.log_softmax(logits, dim=-1)
    
    # 只對實際採樣到的 token 做 loss
    # 假設 resp_ids 是我們的 action
    token_logp = log_probs.gather(-1, batch["resp_ids"].unsqueeze(-1)).squeeze(-1)
    
    # 依 RL 演算法計算 advantage 等
    loss = -(token_logp * advantage_mask).sum() / num_valid_tokens
    loss.backward()
    

    因為 prompt_hs 沒有被 detach,梯度會沿著 response 部分回傳到 prompt 部分,等效於一次走完整個序列,但 prompt 只 forward 一次


    4. 與 gradient checkpointing / mixed precision / DDP 整合

    • gradient checkpointing
    • 可以只對 response graph 開啟 checkpoint,prompt graph 一般不需要再切。
    • 若 prompt 特別長,可在 prompt 段也設 checkpoint,但要注意不要把 cache 給破壞(照 layer 切即可)。

    • mixed precision (AMP/Fp16/bf16)

    • 保持 prompt & response forward 使用同一個 torch.cuda.amp.autocast 區塊。
    • prompt cached hidden 和 response 的精度必須一致,避免 dtype mismatch。

    • DDP/FSDP

    • 基本原則:prompt forward 也在每個 rank 上做一次,不要跨 rank 共用 hidden,避免額外通信。
    • FSDP 來說,prompt hidden 是 activation,照樣會被 shard/rebuild,不需要特別處理。
    • 注意 loss scale 及 no_sync() 區段,確保多 step accumulation 時 prompt/response 的 backward 一致。

    建議與注意事項

    1. 常見坑

    1. 快取導致樣本 shuffle 不均
    2. 若你把「相同 prompt 的多個 response」綁在一起,容易造成某些 prompt 被過度訓練。
    3. 建議在 dataset 層維持 樣本級 shuffle,不要把 prompt 當成硬分桶,或定期重組 group。

    4. mask 錯誤導致梯度泄漏

    5. 如果 attention mask 沒處理好,可能出現:response token 看到未來 token,或不同樣本互相看到彼此的 prompt。
    6. 尤其在 concat prefix 時,要確認:

      • padding token 完全被 mask 掉;
      • prefix 與 response 的因果 mask 正確(response 不該看到未來 response)。
    7. policy / value head 不一致

    8. 很多 RL pipeline 會同時跑 policy head + value head。
    9. 如果你只對 policy 路徑用 prompt cache,而 value 還在跑 full sequence,
      會導致兩邊的 feature distribution 不一致。
    10. 建議:兩個 head 共用同一套 prompt+response 拆圖邏輯,或至少在 feature 塊對齊。

    2. 什麼時候值得導入?

    你可以簡單做一個估算:

    • 計算平均 T_prompt / T_resp
    • 估算你的訓練 step 中,有多少時間是花在 forward(相對於通信/IO)。
    • 目標提速 ≈ T_total / (T_resp + T_prompt / cache_reuse_factor)

    若粗算下來:

    • 理論加速 > 2x,且你目前的 RL 訓練被 FLOPs-bound(非 IO-bound),那導入很可能值得。
    • 若你被 data loading 或 reward 模型 inference 卡住,則先優化 pipeline 再考慮這一層。

    3. 實務指引(TL;DR)

    • 優先導入場景
    • RLHF/RLAIF 的對話評分、工具調用評分、長上下文 code RL。
    • prompt 長度是 response 的 3–10 倍。
    • 使用 Qwen3.5-4B 或相近大小模型,GPU 計算是主要瓶頸。

    • 預期收益

    • 實測可達 3x–7.5x throughput 提升。
    • 允許你把 batch 撐大,減少 gradient accumulation,進一步提高 GPU 利用率。
    • 相同 GPU 成本下,能多跑數倍 rollout 或更長訓練步數。

    • 導入步驟建議

    • 先在小 batch 上實作 forward_prompt + forward_response_with_prefix,只做 sanity check。
    • 確認與原 full sequence 訓練的 loss/梯度差異在可接受範圍(數值抖動為正常)。
    • 再導入 DDP/FSDP + AMP,逐步拉大 batch 測 throughput。
    • 監控 loss 曲線與最終 RL reward,確認沒有明顯退化。

    只要你的 RL 任務落在「長 prompt / 短 response」區間,RL 訓練版 prompt cache 幾乎就是一次性的大幅成本折扣;對正在做 RLHF/RLAIF 的團隊,值得花 1–2 週工程時間好好實作一版。


    🚀 你現在可以做的事

    • 在現有 RLHF/RLAIF 代碼中量測平均 T_prompt / T_resp,判斷是否達到導入門檻(≥3)
    • 在一個小型實驗中實作 forward_promptforward_response_with_prefix,對比 full sequence 訓練的 loss/梯度
    • 在實際 Qwen3.5-4B 或現用模型上開啟 prompt cache 實驗,記錄 throughput 與成本變化,評估是否全面導入