標籤: 長上下文

  • In-Place Test-Time Training 讓模型邊推理邊變強

    📌 本文重點

    • In-Place TTT 讓 LLM 在推理時就地微調
    • 只更新 MLP projection,成本可控又穩定
    • 對小模型與超長 context 任務效果特別明顯

    傳統 「先訓練後部署」 的模型,在上線後面對新 domain、新長文檔,只能靠 prompt/RAG「繞著問題走」,權重本身完全不會變。In-Place Test-Time Training(In-Place TTT)要解決的就是:

    在不重新訓練整個模型、不影響服務穩定性的前提下,讓 LLM 在推理過程中針對當前 context 做「就地微調」,即用即學。

    對開發者的直接好處:

    • 小模型 + 超長 context 的 任務表現大幅提升(paper 裡 0.4B 模型在 128k context 明顯變強)
    • 某些固定 domain(公司內知識庫、特定專案)可以在 session 內 自動做 domain adaptation,不用頻繁 retrain
    • 不需要新架構、不改 attention,只在 MLP projection 上動手腳,工程改動可控

    💡 關鍵: In-Place TTT 用少量 MLP 投影更新,就能在 128k 這種超長 context 下明顯強化 0.4B 等小模型的表現,達到「小模型大任務」的效果


    重點說明

    1. TTT 是什麼?為什麼舊做法不適合 LLM?

    Test-Time Training(TTT) 的想法很簡單:

    1. 推理時,先根據當前輸入做一點訓練(更新一小部分參數 = 快權重 fast weights
    2. 再用更新後的模型做真正的預測

    在 CV/小模型上,常見做法是加一個小 head、對自監督目標(rotation、jigsaw 等)做幾步 gradient update。

    但直接搬到 LLM 會踩坑:

    • 架構不兼容:很多 TTT 方案假設有額外 head 或特定 feature,LLM 一般沒有設計這種 TTT head
    • 計算太貴:如果動 transformer block 的多數權重,長 context 下反向傳播非常重
    • 目標不對齊:TTT 的 loss 不是 next-token prediction,會和原本的語言建模目標衝突,容易越調越壞

    In-Place TTT 的貢獻,就是解決這三個問題:

    • 不改架構,只把 MLP 最終 projection matrix 視為快權重
    • 只在這些 projection 上做反向,成本可控
    • 設計與 自回歸 NTP(next-token prediction)對齊 的訓練目標

    2. 為什麼選 MLP 最終 projection 當快權重?

    以典型 transformer block 的 FFN 為例:

    x -> W1 -> act -> W2 -> 殘差加回
    

    In-Place TTT 把 W2(MLP 最終 projection) 變成可更新快權重,原因:

    1. 局部但影響力大:W2 直接把非線性特徵投回主 hidden space,調這一層可以改變 token 的語義表徵,但不動 attention 結構
    2. 參數量適中:比動整個 block 或 embedding 便宜得多,長 context 反向成本可接受
    3. 穩定性好:不動 attention,有助於保持原模型的「語言能力」,在此之上做局部適配

    實務上的感受是:這種類似 LoRA 但只開在 MLP output projection 的調整方式,對長上下文裡的 pattern 適配非常有效,且容易控制範圍。

    💡 關鍵: 只動 MLP 的最後投影 W2,等於用最小的權重區塊,撬動整個 hidden space 的語義調整,兼顧效果與穩定性


    3. 如何設計與 next-token prediction 對齊的 TTT 目標?

    In-Place TTT 不再另外設計自監督任務,而是直接基於 自回歸 NTP

    • 你有一段長 context:[x1, x2, ..., xT]
    • 原本推理只 forward,loss 不回傳
    • 現在在部分 token 上:
    • 用原模型 logits 做 next-token cross-entropy
    • 但只對 MLP W2 回傳梯度並更新(其他權重 frozen)

    為了讓計算量可控,實作上會:

    • 把長 context 切成多個 blocks(例如每 512 或 1024 tokens)
    • 每個 block:
    • 先 forward 得到 loss
    • 只在該 block 上做一兩步梯度更新
    • 不回頭修正舊 block 的輸出(in-place)

    這樣:

    • 目標和原 pretrain 任務完全一致,不會「學歪」
    • 計算複雜度近似於「多做一輪 forward+backward」,可以和現有 KV cache/長 context 優化一起用

    4. 分塊更新如何兼容長上下文與多請求並行?

    對系統工程師來說,最大問題是:

    長 context + 反向傳播 + 多 request,會不會把 GPU 打爆?

    In-Place TTT 的分塊策略大致如下:

    • Block-wise TTT
    • Example:128k context,每 1k tokens 一塊,共 128 個 block
    • 每個 block:forward、算 loss、只更新 W2 的少量參數
    • 不需要保留整段的中間梯度,只需要當前 block 的 activations
    • 和 KV cache 相容
    • KV cache 還是照原本自回歸生成方式累積
    • TTT 僅對當前 block 的 MLP projection 做 backward,不需要對 KV 做反向
    • 多請求並行
    • 每個 request 自己帶一組「快權重 state」(或 delta)
    • 服務層可以:
      • 把 base 模型權重設為 read-only
      • 為每個 session 存一個 快權重 buffer(例如低秩 delta 或 mask 更新)

    這點和 KV cache 管理框架(vLLM, InfiniGen, H2O 等) 類似:

    • KV cache 管 context-dependent activations
    • 快權重則是 context-dependent 權重偏移

    兩者一起用時要特別小心記憶體分配和 eviction 策略:TTT 的 state 不宜無限制累積。

    💡 關鍵: 把 TTT 的更新設計成 per-session 的小型權重偏移,就能在保持多並發與穩定性的前提下,讓每個請求都「自己學自己」


    實作範例:在推理服務中加一層簡單的 In-Place TTT

    以下用 PyTorch 風格虛擬碼示意,重點是流程與邊界,而非完整訓練程式。

    1. 模型改造:標記可更新的 MLP projection

    class TTTMLP(nn.Module):
        def __init__(self, inner_dim, hidden_dim, enable_ttt=False):
            super().__init__()
            self.w1 = nn.Linear(hidden_dim, inner_dim)
            self.act = nn.GELU()
            self.w2 = nn.Linear(inner_dim, hidden_dim)
    
            # 只在 TTT 模式下允許 w2 被更新
            self.enable_ttt = enable_ttt
            for p in self.w1.parameters():
                p.requires_grad = False
            for p in self.w2.parameters():
                p.requires_grad = enable_ttt
    
        def forward(self, x):
            return self.w2(self.act(self.w1(x)))
    

    在整個 transformer 裡,只需要把原本 MLP 換成 TTTMLP(或只在指定層啟用 enable_ttt)。

    2. 推理 + In-Place TTT loop(單 request)

    def run_with_ttt(model, tokenizer, input_ids,
                     max_new_tokens=128,
                     ttt_block_size=1024,
                     ttt_steps=1,
                     ttt_lr=1e-4):
        """
        model: 已載入的 LLM,除了 MLP.w2 之外都 frozen
        input_ids: 長 context token ids
        """
        device = next(model.parameters()).device
        input_ids = input_ids.to(device)
    
        # 建一個專用 optimizer,只管 MLP.w2
        ttt_params = [p for n, p in model.named_parameters()
                      if p.requires_grad and 'mlp.w2' in n]
        optimizer = torch.optim.AdamW(ttt_params, lr=ttt_lr)
    
        model.eval()
    
        # === 1) 先在 context 上做 block-wise TTT ===
        for start in range(0, input_ids.size(1) - 1, ttt_block_size):
            end = min(start + ttt_block_size, input_ids.size(1) - 1)
            block = input_ids[:, start:end+1]  # [B, L_block+1]
    
            with torch.enable_grad():
                logits = model(block[:, :-1])  # predict next token
                target = block[:, 1:]
                loss = torch.nn.functional.cross_entropy(
                    logits.reshape(-1, logits.size(-1)),
                    target.reshape(-1),
                )
    
                optimizer.zero_grad()
                loss.backward()
                # 只會更新 MLP.w2
                optimizer.step()
    
        # === 2) 使用更新後的模型繼續生成 ===
        generated = model.generate(
            input_ids=input_ids,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            use_cache=True,
        )
        return generated
    

    重點:

    • TTT 階段
    • torch.enable_grad() 打開梯度
    • 只在 context token 上跑 1~數步更新
    • 生成階段
    • 關掉梯度,只用更新後的快權重

    3. 服務層:何時開啟/關閉 TTT?

    可以在 serving config 裡加一個 flag:

    model:
      name: my-ttt-llm
      enable_ttt: true
      ttt_block_size: 1024
      ttt_steps: 1
      ttt_lr: 1e-4
      ttt_max_tokens: 32768   # 超過就不再做 TTT,避免過度漂移
    
    policy:
      enable_ttt_for:
        - domain: "enterprise_qa"
        - domain: "long_doc_summarization"
      disable_ttt_for:
        - domain: "safety_sensitive"
        - domain: "code_generation_prod"
    

    路由層可以依「任務類型」或「租戶設定」決定是否啟用 TTT。


    建議與注意事項

    1. 哪些任務適合/不適合開 TTT?

    適合:

    • Domain adaptation:企業內知識庫 QA、專案文件解讀、特定產品 FAQ
    • 超長 context 理解:法律條文、技術規格書、長期會議紀錄
    • 單 session 可犧牲一點穩定性、換取更高表現 的任務(例如 research 助理)

    不太適合:

    • 安全敏感場景(合規、法務、金融決策):
    • TTT 可能放大 prompt 中的偏見或惡意樣本
    • 多人共用同一快權重的情境
    • 不建議跨 user 共用 TTT state,否則有「資訊污染」風險
    • 需要行為高度穩定可重現 的任務(例如 production codegen、評測 pipeline)

    建議做法:

    • 預設 關閉 TTT,只對少數實驗/內部使用場景逐步開啟
    • 透過 A/B 測 和離線 eval 確認收益和風險

    2. 如何限制更新範圍,避免「越調越壞」?

    可以採幾個防護措施:

    1. 學習率和步數上限
    2. ttt_lr 通常比 finetune 更小(例如 1e-4 或更低)
    3. 每個 block 最多 1~2 步更新即可
    4. Layer 範圍限制
    5. 只開啟 中後段幾層的 MLP.w2,減少對基本語言能力的影響
    6. 正則化 / 參考權重
    7. 在 loss 中加入對 base 權重的 L2 正則
    8. 或採 低秩 delta/adapter(權重不直接改動,易於 reset)
    9. TTL / Reset 策略
    10. TTT state 掛在 session 上,session 結束就丟棄
    11. 長 session 內可定期 reset:例如每處理 64k token reset 一次

    3. 與 RAG、cache、LoRA 等架構的整合與踩坑

    與 RAG:

    • RAG 解決「資料更新」,TTT 解決「怎麼更好地用同一份長 context
    • 組合策略:
    • RAG 提供切好的 chunks
    • In-Place TTT 在讀完多個 chunks 後,讓模型更懂這批文件中常見的 schema/名詞
    • 踩坑:
    • 若檢索結果品質不穩,TTT 可能被噪音「帶偏」,建議在高 confidence 檢索結果上才啟用 TTT

    與 KV cache:

    • 注意 權重更新與 KV cache 一致性
    • 一般做法是 TTT 階段只用來更新權重,不直接用於生成
    • 開始生成時,重新用更新後權重+原 context 做一次 forward 建立 KV(或只從 TTT 後的 block 開始)
    • 若使用 vLLM 類似框架,要確保:
    • 同一 session 的 KV cache 與 TTT 權重版本對齊,避免「舊權重產生的 KV」搭配「新權重」

    與 LoRA / adapters:

    • 一種實務上更穩的做法:
    • base 模型 frozen
    • 安裝一個 LoRA-adapter 只開在 MLP.w2
    • TTT 只更新 LoRA 權重
    • 好處:
    • 快權重是「外掛」,可以 per-session 建立、銷毀
    • 也可以在 A/B 測時方便地比較「有/無 TTT adapter」

    4. 小模型 + 超長 context:何時值得導入?

    In-Place TTT 對 參數較小、context 很長 的組合特別有價值:

    • 小模型在長 context 內本來就容易「記不住 pattern」,TTT 可以在當前 session 內補強
    • 和 LongSpec 等 長 context + speculative decoding 技術搭配:
    • 先用 LongSpec 等提升解碼效率
    • 再用 In-Place TTT 提升長文本任務表現

    可以用以下 heuristics 判斷是否值得導入:

    • 模型 ≤ 7B,context ≥ 32k,而且任務高度依賴長文本理解 → 強烈建議實驗 TTT
    • 模型很大(70B+)且 context 不長(≤ 8k) → TTT 收益相對有限
    • 對 latency/成本極敏感的線上服務 → 可以先在 async 或批次任務(離線總結、分析)試水溫

    總結:In-Place TTT 提供了一個工程上可落地的途徑,讓 LLM 在推理時針對當前長 context 做小幅度、可控的「就地學習」。

    若你正在:

    • 用小模型硬扛超長 context
    • 常常為某個固定 domain 調 prompt/RAG 卻仍覺得不穩

    那麼在 pipeline 中加入 MLP projection 級別的 In-Place TTT,是很值得 A/B 測的一步升級。


    🚀 你現在可以做的事

    • 在現有 LLM 服務中挑選一個長文本任務(如長文總結)實作一版只更新 MLP W2 的 In-Place TTT PoC
    • 針對「開/關 TTT」設計離線與線上 A/B 測試,觀察長 context 任務指標與 latency、成本變化
    • 若已有 RAG pipeline,在高信心檢索場景下試著加入 block-wise TTT,評估對 domain QA 準確率的提升