← 所有文章
On-Policy Distillation · Notes

On-Policy 蒸馏的两个坑:训推不一致,与后期 reward 塌成 0/1

记两个在 on-policy 蒸馏(TCoD)里真实踩到的坑:一个是同样权重却对不上的「训推不一致」,一个是训练后期 teacher 给学生的信号塌成 0/1 的现象。前者是算子层面的数值问题,后者是蒸馏信号的结构性问题。

2026 年 6 月 On-Policy蒸馏训推不一致TCoD强化学习数值精度
同一份权重,推理侧 decode 出的 μ 和训练侧 prefill 重算的 π_θ 并不相等。

同一份权重,推理侧 decode 出的 μ 和训练侧 prefill 重算的 π_θ 并不相等。

目录


    0两个坑

    同一份权重,不同算子路径算出来的概率不相等;而当学生学到后期,老师能给的「软监督」会塌成非 0 即 1。

    这两件事都发生在 on-policy distillation 里:学生在自己的分布上采样,老师在这些采样上给信号。第一个坑藏在基础设施的数值层面,第二个坑藏在蒸馏信号的结构里。下面分别拆开。


    1坑一:训推不一致 —— 同样的权重,算出来的概率却不一样

    先看 on-policy 蒸馏一轮里到底发生了几次前向计算:

    1. 学生采样:学生用推理引擎(vLLM / SGLang)逐 token 解码(decode)生成回答,顺手吐出采样概率 → 行为策略 $\mu$。
    2. 学生重算 logprob:训练时要算梯度,得在训练引擎(FSDP / Megatron)里对整条序列做一次 prefill(整段并行前向) → $\pi_\theta$。
    3. 老师打分:老师(冻结)对同一条序列做 prefill,给出 $\pi_{T}$。

    蒸馏损失大致是在学生采样的轨迹上算逐 token 的 KL(on-policy 蒸馏通常用 reverse KL,$\mathrm{KL}(\pi_{\theta}\,\|\,\pi_{T})$)。这里藏着三层不一致:前两层是「同一个 token,logprob 却不同」(数值问题);第三层更狠,是「连 token 本身都变了」(正确性问题)。

    ① decode vs prefill 不一致

    ② 训练引擎 vs 推理引擎不一致(经典 TIM)

    vLLM 和 FSDP 即使加载完全相同的 checkpoint,因为 kernel / 算子实现不同,对同一输入给出的下一 token 分布也不同。后果:采样策略 $\mu$ ≠ 训练策略 $\pi_\theta$,本该 on-policy 的训练悄悄变成 off-policy。而且这种偏差是「基础设施层面的噪声」,不是 PPO mini-step 那种 off-policy,光靠 clip 治不了。

    ③ token-in-token-out:别让分词漂移把训练带偏

    前两层是「同样的 token,logprob 不同」;这一层是「token 本身就不是同一串」,业内叫 retokenization drift(再分词漂移)。数据从推理侧传到训练侧有两种做法:

    问题在于 $\text{tokenize}(\text{detokenize}(\text{ids})) \neq \text{ids}$:哪怕字符串一模一样,重新编码出来的 token 边界也可能不同——

    retokenization drift:同一字符串被重新切成不同的 token
    同一个字符串重新编码,token 边界可能变化——训练用的 token 不再是采样时的那串。

    后果比数值不一致更硬:训练时算 logprob/梯度的 token 和实际采样的不是同一串 → 挂错了标的、本质变成 off-policy;而且 loss mask 也会错(只有模型自己生成的 behavior-policy token 该训,工具/模板等插入 token 该 mask 掉)。这不是小事——有工作(Strands-SGLang)报告非 TITO 的 agent RL 在 step 50 前就训崩了。

    解法:全程带着 token id 走,绝不 round-trip 文本。vLLM 的 OpenAI 端点用 "return_token_ids": true 会连 prompt_token_idstoken_ids 一起返回;SGLang 原生 /generate 直接给 token + logprob + mask。只有外部插入的文本(工具/用户)才做一次规范 tokenize 并 mask 掉——这样「生成的 token = 训练的 token = 算 logprob 的 token」。


    2为什么是「算子层面」

    根因不在算法,而在 GPU 上浮点算子的实现细节:

    一句话:数学上「应该相等」,但因为算子实现、并行策略、精度不同,数值上不相等


    37B / 14B 上重要吗?

    常见误区是「只有大模型才需要管」。其实它和参数量没有直接关系——大部分诊断 TIM 的工作就是在 7B 级别的小模型(如 DeepSeek-R1-Distill-Qwen-7B、Qwen 1.5B–7B)上复现并看到训练崩的。真正决定严重程度的是另一组变量:

    放大因素为什么放大和「大模型」的关系
    生成长度自回归误差逐 token 累积,序列越长分叉越大长 CoT / 长程 agent 最致命,与尺寸无关
    精度 bf16/fp16bf16 尾数少,舍入误差大与尺寸无关,最主因
    TP 并行度all-reduce 求和顺序变,浮点不结合大模型 TP 切得多 → 偏差更大
    MoE 路由推理/训练选的专家可能不同大模型多是 MoE → 多一层不一致
    数据复用 / off-policy 程度越 off-policy,$\mu$ 与 $\pi_\theta$ 差越远与尺寸无关

    直接放大它的是「长序列 + bf16 + 复用数据」,这些在 7B 上完全可以很严重;而 TP、MoE 随规模加重——这才是「大模型更明显」这一印象的真正来源。不是参数多本身让它变严重,而是大模型往往同时叠了 MoE + 大 TP + 长上下文。

    什么时候可基本忽略:输出短、严格 on-policy 不复用数据、dense + 小 TP、用 fp16,或做的是纯离线蒸馏 / SFT(不采样)。但只要任务是长链推理 / 多轮 agent,哪怕 7B 也得认真对待。


    4怎么应对训推不一致

    1. 数值对齐:重算 logprob 的引擎/精度/batch 配置尽量和采样一致,train_micro_batch_size == logprob_batch_size
    2. 确定性 / batch-invariant kernels:用对 batch 不敏感的确定性算子,做到和训练引擎零失配。
    3. 换 FP16:最省事的一招,从根上压住 bf16 的累积舍入误差。
    4. 直接用推理引擎返回的 logprob 当 $\mu$,不再用训练引擎重算去近似。
    5. 坚持 token-in-token-out(TITO):全程传 token id 而非文本,杜绝 retokenization drift(见 §1③)。
    6. 重要性采样矫正:既然变成 off-policy,就用 IS 比值纠偏并截断控方差(TIS):

      $$ w_t=\min\!\Big(\frac{\pi_\theta(o_t)}{\mu(o_t)},\,C\Big) $$

      序列级版本叫 MIS,长序列上通常比 token 级更稳。

    5坑二:后期 reward 塌成 0/1

    在 TCoD 的训练后期观察到:老师给学生的 reward(这里指对学生采样 token 的「认可度」,比如 teacher 概率 $\pi_{T}(o_t)\in[0,1]$)几乎都集中在 0 或 1 两端——要么完全认可、要么完全否定,中间 0.3~0.7 的「软」信号基本消失。这通常是预期内的,有几层成因:

    三者叠加 → 老师尖 + 学生也尖 + 中间软样本被消化 → reward 自然双峰到 {0,1}。这恰恰说明 temporal curriculum 把容易的、信息量大的部分吃完了。

    teacher 认可度分布从早期铺开到后期双峰塌到 0/1(示意)
    示意图:随训练推进,老师对学生 token 的认可度从「铺开」变成「双峰塌到 0 / 1」,中间的软信号消失。

    5.1 还有一个「假 0/1」:精度下溢

    如果 reward 直接取 teacher prob,而在 bf16/fp16 下算:匹配 token 的概率太接近 1 会被舍入成 1.0,不匹配的太小会下溢成 0.0。也就是本来还有梯度的软信号,被低精度量化成了硬 0/1。排查方法:把 reward / logprob 那段计算单独提到 fp32,看双峰是否缓解;缓解了就说明至少有一部分是数值假象。


    6是收敛,还是塌缩?

    0/1 化本身不能直接判断好坏,要配合其它指标一起看:

    同时观察健康收敛病态塌缩 / 课程耗尽
    reward≈1 的占比持续上升上不去,或 0/1 各半卡住
    下游 eval 准确率上升或高位平台早就不动了
    学生 token 熵 / 多样性适中、稳定骤降、开始重复复读
    reward=0 的 token越来越少,集中在真难点一直一大片,且学生很自信地错

    7怎么把梯度救回来

    1. 首选:用全词表 KL,而不是只对采样 token 取 teacher prob 当标量 reward。0/1 的根源就是「只看了采样 token 那一个数」;改成对整份词表分布算 KL 后,即使采样 token 老师给≈1,其余词表上的分布差异仍然提供梯度,根本不会塌到 0/1。这比调温度更该排第一。
    2. 蒸馏温度是个有「甜区」的小旋钮,不是越大越好。先分清两种温度:
      • 采样温度(生成时怎么抽 token)调大 → 生成内容会越来越离谱,不要用它来救 0/1
      • 蒸馏 / loss 温度只作用在算 KL 时的 softmax $p_i=\mathrm{softmax}(z_i/T)$ 上,不参与生成(OPD 里老师只打分、不采样),所以它本身不会让老师吐出离谱内容。
      • 但即便是 loss 温度,$T$ 太大也会坏事:目标分布趋近均匀 → 等于告诉学生「所有 token 差不多好」→ 学生分布被压平、质量崩;而且它会改变学生收敛到的最优(收敛到「被加热过的老师」),是有偏的。实操取 $T\approx 1.2\text{–}2$ 的小幅软化,并先对老师分布做 top-k / top-p 截断再软化,只在合理候选里软化,避免把概率质量摊到垃圾长尾上。
    3. 控学生采样温度 / 加 entropy bonus:防熵塌缩,保持探索,让采样还能覆盖到老师不确定的区域。
    4. 课程切换:对那批 reward=0 的硬残差,换更强/更近的老师,或干脆转成 RLVR(可验证奖励)去啃——蒸馏的天花板就是老师,硬骨头得靠环境奖励。
    5. 排查精度:先用 fp32 排除把软信号假性量化成 0/1。

    两个坑其实指向同一句话:on-policy 蒸馏里,「学生采样的分布」和「用来算梯度/打分的分布」必须尽量是同一个东西——数值上要一致(坑一),信号上要还有区分度(坑二)。