BatchNormalization思考
之前在卷积网络中写过一点bn的东西,主要是论文中的一些想法,本文探讨一下Pytorch中的BatchNorm的实现。
参考:
[1]. https://www.jianshu.com/p/b38e14c1f14d
[2]. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
[3]. https://blog.csdn.net/winycg/article/details/88974107
[4]. https://github.com/ptrblck/pytorch_misc/blob/master/batch_norm_manual.py
[5]. https://discuss.pytorch.org/t/why-track-running-stats-is-not-set-to-false-during-eval/25412
BatchNorm具体数学形式:
torch.nn.BatchNorm1d(num_features: int, eps: float = 1e-5,momentum: float = 0.1,
affine: bool = True,track_running_stats: bool = True)
:
num_features:特征的数量(样本维度)
eps:$\epsilon$,默认是$1e-5$,防止方差为0的情况
momentum: 计算running_mean和running_var的滑动平均系数。采用下列的公式:
其中,$x_t$ 指当前批次样本的均值或方差,$\hat x$ 指之前计算的均值或方差。
affine:True表示使用 $\gamma$ 和 $\beta$ 参数,是超参数,通过学习得到,初始时$\gamma = 1,\beta = 0$
track_running_stats:如果为True,训练阶段采用实时的batch均值和方差, 同时采用滑动平均来计算全局的running_mean和running_var,测试阶段采用计算出的running_mean和running_var;如果是False,则训练阶段和测试阶段都采用实时的batch均值和方差。
说明:
- momentum其实用到了指数加权的思想(https://zhuanlan.zhihu.com/p/29895933),本质上是一种近似求平均的方法。
- 训练阶段,要求BatchNorm要求输入的样本数量必须大于1。测试时,对样本个数没有要求,使用
model.eval()
(model.eval()
pytorch会自动将bn层中的training设置为False,track_running_stats设置为True)则会固定Batchnorm层的参数,这时均值与方差会用running_mean
与running_var
代替并且不再更新。测试时不推荐将track_running_stats
设置为False
。 - Pytorch在BatchNorm中计算方差时,
torch.var(x,unbiased=False)
。区别可搜索”方差的贝塞尔校正“。 - 定义
bn = nn.BatchNorm1d(3)
,通过bn.running_mean
和bn.running_var
获取统计数据的均值和方差。
以下代码来自链接[4]。
1 | class MyBatchNorm2d(nn.BatchNorm2d): |