Activation Recomputation:激活重计算
Activation Recomputation(激活重计算)是一种在深度学习训练中通过牺牲少量计算来换取显存节省的技术,在前向传播时丢弃中间激活值,反向传播时重新计算,从而支持更大模型或更大批量的训练。
一句话解释
Activation Recomputation(激活重计算)是一种显存优化技术,在前向传播时只保留部分关键激活值,丢弃大部分中间结果,然后在反向传播时重新计算被丢弃的激活值,从而大幅度降低显存占用。
为什么会被关注
随着大模型参数规模达到千亿甚至万亿级别,显存瓶颈成为训练时的主要限制。传统做法需要保存所有前向计算的中间激活值用于反向传播,这会使显存需求随模型深度线性增长。激活重计算通过用计算换内存,使得单卡可以训练更大的模型或使用更大的批量,成为各大AI框架(如PyTorch、Megatron-LM)的标准功能。
核心逻辑
在前向传播过程中,每个网络层都会产生激活值(中间特征)。为了反向传播计算梯度,通常需要这些激活值。激活重计算策略会标记某些层或计算区域,在前向完成后立即释放其激活值。反向传播时,从最近的检查点(Checkpoint)重新执行正向计算,恢复需要的激活值。这样做增加了约30%至50%的计算开销,但可能减少60%至80%的显存占用。
常见场景
大规模Transformer模型训练(如GPT、LLaMA、BERT)中,激活重计算常与张量并行、流水线并行结合使用。在PyTorch中可通过`torch.utils.checkpoint`接口轻松启用。训练超长序列(如8192以上)时,激活重计算几乎必不可少。此外,在显存有限的消费级显卡上微调大模型时,此技术也被广泛采用。
容易混淆的点
激活重计算(Activation Recomputation)和梯度检查点(Gradient Checkpointing)常被混用,实际上两者是同一概念的不同叫法。容易与“显存交换”(将数据换到CPU内存)混淆,但后者依靠PCIe传输,延迟更高。另外,它并非减少计算量,反而是增加计算量,只是让原本不能训练的大模型变得可训练。
本文内容用于 AI 热词解释和概念整理,仅供学习和理解参考。若涉及表述偏差或内容修正,欢迎联系站点进行更新。
相关热词Gradient Checkpointing是一种深度学习训练中的显存优化技术,通过选择性丢弃中间激活值并在反向传播时重新计算,从而显著降低GPU显存占用,适用于长序列、大模型等显存瓶颈场景。
显存优化是一系列旨在减少深度学习模型运行时对显卡内存占用的技术。它通过模型压缩、动态调度、混合精度等方法,让庞大的AI模型能在消费级显卡上运行,是降低AI应用成本、推动技术普及的核心环节。

