Gradient Checkpointing(梯度检查点)
Gradient Checkpointing是一种深度学习训练中的显存优化技术,通过选择性丢弃中间激活值并在反向传播时重新计算,从而显著降低GPU显存占用,适用于长序列、大模型等显存瓶颈场景。
一句话解释
Gradient Checkpointing是一种深度学习训练策略,训练时只保存部分中间层的激活值(检查点),其余未保存的激活值在反向传播时根据已保存值和原始输入重新计算,从而用额外计算量换取显存占用的大幅降低。
为什么会被关注
随着Transformer、大语言模型等模型规模急速增长,GPU显存成为训练瓶颈。传统做法需存储所有中间激活值,显存极易耗尽。Gradient Checkpointing让工程师能在不增加硬件成本的前提下,训练更深的模型或处理更长的序列,因此成为大模型训练标配技术。
核心逻辑
在正向传播中,神经网络会产生活化值,用于反向传播计算梯度。Gradient Checkpointing将这些激活值分段保存,只保留关键“检查点”的完整状态。反向传播时,检查点之间的片段需要从最近检查点重新执行正向计算来恢复被丢弃的激活值。这本质是“用时间换空间”:每个检查点之间的计算量翻倍,但显存占用从O(L)降低到O(sqrt(L))或O(log L)。
常见场景
训练超长序列的Transformer模型(如GPT、BERT)时,大模型的层数深或序列长度超过4096,显存会快速占满。此外,在多卡并行训练中,每张卡的显存限制也通过Gradient Checkpointing缓解。推理时若需大batch,也可用类似策略。
容易混淆的点
Gradient Checkpointing不等于梯度累积(gradient accumulation),后者是多次小batch更新,缓解显存但增加通信;前者是单次前向中节约激活值存储。它也不等于模型并行或流水线并行,虽然常配合使用。另外,检查点策略会小幅延长训练时间(约20-30%),并非所有场景都适用,需权衡显存与速度。
本文内容用于 AI 热词解释和概念整理,仅供学习和理解参考。若涉及表述偏差或内容修正,欢迎联系站点进行更新。
相关热词大模型是指通过在海量数据上训练、拥有庞大参数规模的深度学习模型,其核心能力在于理解和生成人类语言及各类内容,是当前生成式AI(如ChatGPT)的技术基石。
显存优化是一系列旨在减少深度学习模型运行时对显卡内存占用的技术。它通过模型压缩、动态调度、混合精度等方法,让庞大的AI模型能在消费级显卡上运行,是降低AI应用成本、推动技术普及的核心环节。

