LLMs Interview Note 2.1.2 各种Normalization

本文为 LLMs Interview Note 的学习笔记。

(下面假设输入数据 xxN×C×H×WN\times C\times H\times W 的形状)

BatchNorm

  • 做BN的动机:不同batch的数据分布可能不同,初始时微小的差异可能随着深层网络而逐渐加大,造成参数尺度和方差漂移(internal covariate shift),容易造成梯度爆炸/梯度消失,从而无法训练,这限制人们只能使用较小的学习率来训练模型;为了缓解这一问题需要有一种normalization方法来平衡batch之间的数据分布,以此提升模型训练的稳定性。

x=xμσγ+βx' = \dfrac{x-\mu}{\sigma}\cdot \gamma+\beta

μ=1NHWiNjHkWxipjk\mu = \dfrac{1}{NHW}\cdot\sum_{i\to N}\sum_{j\to H}\sum_{k\to W} x_{ipjk}

  • 这里 μ\muσ\sigma 是BN层根据训练数据估计出来的;而 γ\gammaβ\beta 是可学习的仿射参数向量(长度为 C)。

  • 加入仿射参数的目的:提升模型的表达能力,允许数据在经过 BN 层后依然具有一个比较广的范围

  • 优势:

    • 允许使用较大的学习率训练模型

    • 减弱对初始初始化的强依赖性

    • 可以保持数据均值和方差不变,让训练更稳定

    • BN 相当于一种正则化,可以防止模型过拟合

  • 劣势:

    • 在 batch size 较小的情况下,无法对一个 batch 内均值和方差进行有效估计

LayerNorm

为了解决 BN 在小 batch size 下的缺陷,人们提出来 LayerNorm,它对每个数据样本单独进行 norm,和 batch size 无关

μ=1CHWiCjHkWxpijk\mu = \dfrac{1}{CHW} \cdot \sum_{i\to C}\sum_{j\to H}\sum_{k\to W} x_{pijk}

在PyTorch的实现里,用户可以传入一个 normalized_shape 参数,用来控制对哪几维数据进行归一化。

  • 一般来说,LN 在 RNN/Transformer 上效果比较明显,但是在CNN上,效果一般不如BN。

  • LN 中 γ\gammaβ\beta 参数是一个形状为 C×H×WC\times H\times W可学习张量,而不是一个标量

  • LN 中不需要计算 running mean / variance,而是直接在每个样本上计算统计量

  • 缺点:LN 忽略了不同 channel 之间的分布差异,因此不适用于图像生成等任务

InstanceNorm

InstanceNorm 最初用于图像生成任务。

为了解决 LN 无法处理不同 channel 之间差异的问题,人们引入了 InstanceNorm,对不同 channel 分别做归一化。

μ=1HWiHjWxpqij\mu = \dfrac{1}{HW} \sum_{i\to H} \sum_{j\to W} x_{pqij}

IN中的偏置和缩放参数是一个长度为 C 的向量。

GroupNorm

InstanceNorm 对每个 channel 单独计算均值方差,数据量比较小,可能无法得出对统计量好的估计;LayerNorm 对所有 channel 合并计算统计量,考虑不到 channel 之间的差异。

而 GroupNorm 是一种介于 InstanceNorm 与 LayerNorm 之间的方法。它把所有 channel 平均分为 GG 组,每组包含 CG\frac CG 个通道,在同一组样本之间做标准化。

μ=1HWCGiCGjHkWxpijk\mu = \frac{1}{HW\frac CG} \sum _{i\to\frac CG}\sum_{j \to H}\sum_{k\to W} x_{pijk}

RMSNorm

RMSNorm 是基于 LayerNorm 的改进,和 LN 相比,RMSNorm 去除了减去均值的部分。

xi=xi1ni=1nxi2gix_i' = \dfrac{x_i}{\sqrt{\frac1n \sum_{i=1}^n x_i^2}}\cdot g_i

其中 1ni=1nxi2\sqrt{\frac1n \sum_{i=1}^n x_i^2} 是 RMS(Root Square Mean)操作,gig_i 是一个可学习的缩放向量。

pRMSNorm

相比 RMSNorm 而言,pRMSNorm 不需要计算所有样本的 RMS,而是取前 p (0<p<1) 比例的数据进行计算。

xi=xi1npi=1npxi2gix_i' = \dfrac{x_i}{\sqrt{\frac{1}{np} \sum_{i=1}^{np} x_i^2}}\cdot g_i