Batch Normalization

本文最后更新于 2025年3月19日 上午

批量归一化层

公式

\(\mu_B=\frac{1}{ |B| } \mathop \Sigma \limits_{ i\in B }x_i\)

\(\sigma_B^2=\frac{1}{ |B| } \mathop \Sigma\limits_{i \in B}(x_i- \mu_B )^2+\epsilon\)

再做额外调整:

\(x_{ i+1 }=\gamma\frac{x_i- \mu_B } { \sigma_B }+\beta\)

  • 可学习的参数为\(\gamma,\beta\)
  • 作用在:
    • 全连接层和卷积层输出上,激活函数前
    • 全连接层和卷积层输入上
  • 对全连接层,作用在特征维
  • 对卷积层,作用在通道维

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def batchNormalization(X, gamma, beta, moving_mean, moving_var, eps, momentum):
if not torch.is_grad_enabled():
X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
else:
assert len(X.shape) in (2, 4)
if len(X.shape) == 2:
mean = X.mean(dim=0)
var = ((X - mean)**2).mean(dim=0)
else:
mean = X.mean(dim=(0,2,3), keepdims=True)
var = ((X - mean)**2).mean(dim=(0,2,3),keepdim=True)
X_hat = (X - mean) / torch.sqrt(var + eps)
moving_mean = momentum * moving_mean + (1 - momentum) * mean
moving_var = momentum * moving_var + (1 - momentum) * var
Y=gamma * X_hat + beta
return Y, moving_mean, moving_var.data

作用

  • 起初用来减少内部变量转移
  • 后续指出可能是通过在每个小批量里加入噪音控制模型复杂度
  • 不用和丢弃法混合使用

总结

  • 可以加速收敛速度,一般不改变模型精度
  • 使用真实数据时,第一步是标准化输入特征(使其均值为0,方差为1),这种标准化可以很好地与优化器配合使用

BatchNorm和LayerNorm的区别

BatchNorm1d为例:

\(y = \frac{ x - \mathrm{ E }[x] }{ \sqrt{ \mathrm{ Var }[x] + \epsilon } } * \gamma + \beta\)

1
2
3
4
5
6
input = torch.randn(1,2,3) # 1:batch, 2:feature, 3:sequence length
bn1 = nn.BatchNorm1D(2) # 选取特征为2
print(bn1(input))

#tensor([[[ 0.9848, -1.3714, 0.3865],
# [ 1.2078, -1.2410, 0.0332]]], grad_fn=<NativeBatchNormBackward0>)

这可以等同于以下操作:

1
2
3
4
5
6
7
8
9
10
a1 = (input[:,0,:] - input[:,0,:].mean()) / torch.sqrt(input[:,0,:].var(unbiased=False) + 1e-5)
a2 = (input[:,1,:] - input[:,1,:].mean()) / torch.sqrt(input[:,1,:].var(unbiased=False) + 1e-5)

torch.cat((a1, a2), dim=0).reshape(1,2,3) - bn1(input) < 1e-5
'''
a1 = (input[:,0,:] - input[:,0,:].mean()) / torch.sqrt(input[:,0,:].var(unbiased=False) + 1e-5)
a2 = (input[:,1,:] - input[:,1,:].mean()) / torch.sqrt(input[:,1,:].var(unbiased=False) + 1e-5)
# 比较两种方法
torch.cat((a1, a2), dim=0).reshape(1,2,3) - bn1(input) < 1e-5
'''

LayerNorm


Batch Normalization
https://meteor041.git.io/2024/11/14/Batch Normalization/
作者
meteor041
发布于
2024年11月14日
许可协议