题目
大模型训练中常出现 loss spike(损失突增)。请说明原因、危害与排查处理流程。
参考答案
loss spike 表现:训练 loss 本来平稳下降,突然飙升(可能数十倍),伴随梯度爆炸、参数混乱,轻则模型退化,重则出现 NaN/Inf 训练中断。
常见原因:
- 数据异常:混入超长序列、重复 batch、低质脏数据、数值异常 token。
- 梯度爆炸:深层网络梯度连乘导致数值溢出,尤其 Post-Norm 架构。
- 学习率过大:某步更新跨过最优点,进入不稳定区域。
- 数值精度:FP16 下梯度下溢或上溢。
- 分布式不同步:all-reduce 异常导致部分卡用错梯度。
- 架构问题:注意力未加缩放、未做 QK-norm 等。
急救处理(先止血):
- 梯度裁剪(Gradient Clipping):限制梯度范数(如 max_norm=1.0),是最常用的防爆炸手段。
- 学习率回退:发现 spike 立即降低学习率,很多训练框架支持自动 LR decay on spike。
- 检查点回滚:回滚到 spike 前的健康 checkpoint,跳过引发 spike 的数据 batch 重训。
- 跳过异常 batch:检测到 loss 异常高时直接跳过该 step 不更新。
根治手段:
- 架构稳定性:用 Pre-Norm + RMSNorm + QK-norm + 初始化调整(μP 等)。
- 精度:BF16 替代 FP16,避免溢出;关键累积用 FP32。
- 数据质量:严格过滤超长/重复/异常样本,训练前做分布检查。
- 学习率调度:warmup + cosine decay,避免初期过大更新。
- Embedding 与输出层:这些层参数大、梯度大,常需单独学习率或额外正则。
- Loss scaling 监控:FP16 下动态 loss scale,监控 scale 是否持续下降(下降意味溢出频发)。
调试方法论:
- 看 loss 曲线判断是单点 spike 还是持续不收敛。
- 查 spike 出现的 step 与数据 batch,定位是否数据问题。
- 查梯度/激活统计(min/max/mean/std)逐层排查。
- 查 optimizer state 是否 NaN。
- 缩小复现范围(小 batch、单卡复现)。
面试加分点:
- 指出 spike 在百亿到千亿规模训练中几乎是常态,不是”是否遇到”而是”如何应对”——GPT-3、PaLM 训练报告都记录了多次 spike 与回滚。
- 强调”回滚 + 跳 batch”是工程上最实用的应急组合。
- LR warmup + 梯度裁剪 + BF16 是预防三件套,缺一不可。
出处:CSDN《大模型面试题52:LLM 如果在训练过程中 loss 值出现 spike 应该怎么办》、《大模型最新面试题系列:训练篇之训练稳定性》。
内容来源
整理自 CSDN《大模型面试题52:LLM 如果在训练过程中 loss 值出现 spike 应该怎么办》及《大模型最新面试题系列:训练篇之训练稳定性》
本站内容整理自公开面经与开源仓库,仅供学习交流,严禁杜撰。