# 模型参数初始化对齐方法 # 1. 背景 Paddle提供了大量的初始化方法,包括`Constant`, `KaimingUniform`, `KaimingNormal`, `TruncatedNormal`, `Uniform`, `XavierNormal`, `XavierUniform`等,合适的初始化方法能够帮助模型快速地收敛或者达到更高的精度。 论文复现的过程中,在训练对齐环节,需要保证Paddle的复现代码和参考代码保持一致,从而实现完全对齐。然而由于不同框架的差异性,部分API中参数提供的默认初始化方法有区别,该文档以`nn.Conv2D`以及`nn.Linear`这两个最常用的API为例,介绍怎样实现对齐。 **更多参考链接:** * Paddle初始化相关API链接:[初始化API官网文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/Overview_cn.html#chushihuaxiangguan) * Paddle提供的初始化方式为直接修改API的`ParamAttr`,与`torch.nn.init`等系列API的使用方式不同,PaddleDetection中实现了与`torch.nn.init`系列API完全对齐的初始化API,包括`uniform_`, `normal_`, `constant_`, `ones_`, `zeros_`, `xavier_uniform_`, `xavier_normal_`, `kaiming_uniform_`, `kaiming_normal_`, `linear_init_`, `conv_init_`,可以参考[initializer.py](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/initializer.py),查看更多的实现细节。 # 2. 不同框架的初始化差异 ## 2.1 默认初始化的对齐方法 在此情况下,一般需要查看文档,了解参考代码的初始化方法,从而通过修改初始化方法,实现初始化的对齐。 下面以`nn.Conv2D` API为例进行说明。 * **Step1:** 基于Paddle与torch,定义2个卷积操作,绘制其weight参数的直方图,如下所示。 ```python import paddle import torch import numpy as np import matplotlib.pyplot as plt %matplotlib inline conv2d_pd = paddle.nn.Conv2D(4096, 512, 3) conv2d_pt = torch.nn.Conv2d(4096, 512, 3) conv2d_pd_weight = conv2d_pd.weight.numpy().reshape((-1, )) conv2d_pt_weight = conv2d_pt.weight.detach().numpy().reshape((-1, )) plt.figure(figsize=(10, 6)) temp = plt.hist([conv2d_pd_weight, conv2d_pt_weight], bins=100, rwidth=0.8, histtype="step") plt.xlabel("value") plt.ylabel("count") plt.legend({"paddle.nn.Conv2D weight", "torch.nn.Conv2d weight"}) ```
结合[paddle文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/ParamAttr_cn.html#paramattr)和[torch文档](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html?highlight=conv2d#torch.nn.Conv2d)可知,paddle的初始化是`XavierNormal`,torch的初始化是`uniform`,初始化方法边界值是`(-sqrt(groups/(in_channels*prod(*kernal_size))), sqrt(groups/(in_channels*prod(*kernal_size))))`。 * **Step2:** 由上述分析,基于`paddle.nn.initializer.Uniform` API,自定义Paddle中Conv2D的初始化,代码如下所示: ```python import paddle import torch import math import numpy as np import matplotlib.pyplot as plt import paddle.nn.initializer as init %matplotlib inline # 该例子中,对应上述公式的group=1,in_channels=4096,kernal_size=3,由于二维卷积的卷积核是二维的,所以此处的结果为4096*3*3 conv2d_pd = paddle.nn.Conv2D(4096, 512, 3, init.Uniform(-1/math.sqrt(4096*3*3), 1/math.sqrt(4096*3*3))) conv2d_pt = torch.nn.Conv2d(4096, 512, 3) conv2d_pd_weight = conv2d_pd.weight.numpy().reshape((-1, )) conv2d_pt_weight = conv2d_pt.weight.detach().numpy().reshape((-1, )) plt.figure(figsize=(10, 6)) temp = plt.hist([conv2d_pd_weight, conv2d_pt_weight], bins=100, rwidth=0.8, histtype="step") plt.xlabel("value") plt.ylabel("count") plt.legend({"paddle.nn.Conv2D weight", "torch.nn.Conv2d weight"}) ```
从图中可知,二者的初始化参数分布实现一致。 ## 2.2 自定义初始化的对齐方法 部分参考代码中,初始化的方法是通过使用`torch.nn.init`系列API实现,可以认为是自定义初始化。例如:[resnet](https://github.com/pytorch/vision/blob/ec1c2a12cf00c6df83c7fb88f75b8117cda2f970/torchvision/models/resnet.py#L208)中使用的`kaiming_normal_`传入了`mode`和`nonlinearity`两个参数: ```python if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") ``` 在这类问题中,可以先尝试使用`2.1`章节中的`Step1`,查看使用Paddle同名初始化方式的默认参数是否能够对齐。如果无法对齐,可以查阅[initializer.py](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/initializer.py),使用该文件中的初始化函数,实现对齐。 不同框架的初始化方法有所不同,开发者论文复现过程中难以排查,因此下面第3章介绍通过自定义初始化的方式,实现不同框架的参数初始化分布一致,最终帮助大家更加顺利地完成论文复现。 **注意:** BatchNorm2D等大多数的API中,可学习参数的初始化分布相同,在此为进一步对比,也给出其权重的可视化对比图像。 # 3. 初始化参数分布对比 ## 3.1 默认初始化不同的API权重直方图对比 | Paddle API | torch API | 默认初始化方法的参数分布对比图 | 修改初始化参数方法 | 修改之后的参数分布对比 | |:---------:|:------------------:|:------------:|:------------:|:------------:| | `paddle.nn.Conv2D` weight参数 | `torch.nn.Conv2d` weight参数 | ![](https://paddle-model-ecology.bj.bcebos.com/images/initializer/conv2d_weight_default_diff.jpeg) | 见附录`4.1.1` | ![](https://paddle-model-ecology.bj.bcebos.com/images/initializer/conv2d_weight_fixed_diff.jpeg) | | `paddle.nn.Conv2D` bias参数 | `torch.nn.Conv2d` bias参数 | ![](https://paddle-model-ecology.bj.bcebos.com/images/initializer/conv2d_bias_default_diff.jpeg) | 见附录`4.1.1` | ![](https://paddle-model-ecology.bj.bcebos.com/images/initializer/conv2d_bias_fixed_diff.jpeg) | | `paddle.nn.Linear` weight参数 | `torch.nn.Linear` weight参数 | ![](https://paddle-model-ecology.bj.bcebos.com/images/initializer/linear_weight_default_diff.jpeg) | 见附录`4.1.2` | ![](https://paddle-model-ecology.bj.bcebos.com/images/initializer/linear_weight_fixed_diff.jpeg) | | `paddle.nn.Linear` bias参数 | `torch.nn.Linear` bias参数 | ![](https://paddle-model-ecology.bj.bcebos.com/images/initializer/linear_bias_default_diff.jpeg) | 见附录`4.1.2` | ![](https://paddle-model-ecology.bj.bcebos.com/images/initializer/linear_bias_fixed_diff.jpeg) | | `paddle.nn.Embedding` weight参数 | `torch.nn.Embedding` weight参数 | ![](https://paddle-model-ecology.bj.bcebos.com/images/initializer/embedding_weight_default_diff.jpeg) | 见附录`4.1.3` | ![](https://paddle-model-ecology.bj.bcebos.com/images/initializer/embedding_weight_fixed_diff.jpeg) | ## 3.2 默认初始化相同的API权重直方图对比 | Paddle API | torch API | 默认初始化方法的参数分布对比图 | |:---------:|:------------------:|:------------:| | `paddle.nn.BatchNorm2D` weight参数 | `torch.nn.BatchNorm2d` weight参数 | ![](https://paddle-model-ecology.bj.bcebos.com/images/initializer/bn_weight_default_diff.jpeg) | | `paddle.nn.BatchNorm2D` bias参数 | `torch.nn.BatchNorm2d` bias参数 | ![](https://paddle-model-ecology.bj.bcebos.com/images/initializer/bn_bias_default_diff.jpeg) | # 4. 附录 ## 4.1 初始化对齐代码 ### 4.1.1 paddle.nn.Conv2D * 默认初始化以及可视化代码 ```python import paddle import torch import numpy as np import matplotlib.pyplot as plt %matplotlib inline conv2d_pd = paddle.nn.Conv2D(4096, 512, 3) conv2d_pt = torch.nn.Conv2d(4096, 512, 3) conv2d_pd_weight = conv2d_pd.weight.numpy().reshape((-1, )) conv2d_pd_bias = conv2d_pd.bias.numpy().reshape((-1, )) conv2d_pt_weight = conv2d_pt.weight.detach().numpy().reshape((-1, )) conv2d_pt_bias = conv2d_pd.bias.numpy().reshape((-1, )) plt.figure(figsize=(10, 6)) temp = plt.hist([conv2d_pd_weight, conv2d_pt_weight], bins=100, rwidth=0.8, histtype="step") plt.xlabel("value") plt.ylabel("count") plt.legend({"paddle.nn.Conv2D weight", "torch.nn.Conv2d weight"}) plt.figure(figsize=(10, 6)) temp = plt.hist([conv2d_pd_bias, conv2d_pt_bias], bins=50, rwidth=0.8, histtype="step") plt.xlabel("value") plt.ylabel("count") plt.legend({"paddle.nn.Conv2D bias", "torch.nn.Conv2d bias"}) ``` * 修正后初始化以及可视化代码 ```python import paddle import torch import math import numpy as np import matplotlib.pyplot as plt import paddle.nn.initializer as init %matplotlib inline conv2d_pd = paddle.nn.Conv2D(4096, 512, 3, weight_attr=init.Uniform(-1/math.sqrt(4096*3*3), 1/math.sqrt(4096*3*3)), bias_attr=init.Uniform(-1/math.sqrt(4096*3*3), 1/math.sqrt(4096*3*3))) conv2d_pt = torch.nn.Conv2d(4096, 512, 3) conv2d_pd_weight = conv2d_pd.weight.numpy().reshape((-1, )) conv2d_pd_bias = conv2d_pd.bias.numpy().reshape((-1, )) conv2d_pt_weight = conv2d_pt.weight.detach().numpy().reshape((-1, )) conv2d_pt_bias = conv2d_pd.bias.numpy().reshape((-1, )) plt.figure(figsize=(10, 6)) temp = plt.hist([conv2d_pd_weight, conv2d_pt_weight], bins=100, rwidth=0.8, histtype="step") plt.xlabel("value") plt.ylabel("count") plt.legend({"paddle.nn.Conv2D weight", "torch.nn.Conv2d weight"}) plt.figure(figsize=(10, 6)) temp = plt.hist([conv2d_pd_bias, conv2d_pt_bias], bins=50, rwidth=0.8, histtype="step") plt.xlabel("value") plt.ylabel("count") plt.legend({"paddle.nn.Conv2D bias", "torch.nn.Conv2d bias"}) ``` ### 4.1.2 paddle.nn.Linear * 默认初始化以及可视化代码 ```python import paddle import torch import numpy as np import matplotlib.pyplot as plt %matplotlib inline linear_pd = paddle.nn.Linear(4096, 512) linear_pt = torch.nn.Linear(4096, 512) linear_pd_weight = linear_pd.weight.numpy().reshape((-1, )) linear_pd_bias = linear_pd.bias.numpy().reshape((-1, )) linear_pt_weight = linear_pt.weight.detach().numpy().reshape((-1, )) linear_pt_bias = linear_pt.bias.numpy().reshape((-1, )) plt.figure(figsize=(10, 6)) temp = plt.hist([linear_pd_weight, linear_pt_weight], bins=100, rwidth=0.8, histtype="step") plt.xlabel("value") plt.ylabel("count") plt.legend({"paddle.nn.Linear weight", "torch.nn.Linear weight"}) plt.figure(figsize=(10, 6)) temp = plt.hist([linear_pd_bias, linear_pt_bias], bins=50, rwidth=0.8, histtype="step") plt.xlabel("value") plt.ylabel("count") plt.legend({"paddle.nn.Linear bias", "torch.nn.Linear bias"}) ``` * 修正后初始化以及可视化代码 ```python import paddle import torch import math import numpy as np import matplotlib.pyplot as plt import paddle.nn.initializer as init %matplotlib inline # linear的初始化方法同样适用于2.1节中的公式,此处的kernal_size等价于1。 linear_pd = paddle.nn.Linear(4096, 512, weight_attr=init.Uniform(-1/math.sqrt(4096), 1/math.sqrt(4096)), bias_attr=init.Uniform(-1/math.sqrt(4096), 1/math.sqrt(4096))) linear_pt = torch.nn.Linear(4096, 512) linear_pd_weight = linear_pd.weight.numpy().reshape((-1, )) linear_pd_bias = linear_pd.bias.numpy().reshape((-1, )) linear_pt_weight = linear_pt.weight.detach().numpy().reshape((-1, )) linear_pt_bias = linear_pt.bias.numpy().reshape((-1, )) plt.figure(figsize=(10, 6)) temp = plt.hist([linear_pd_weight, linear_pt_weight], bins=100, rwidth=0.8, histtype="step") plt.xlabel("value") plt.ylabel("count") plt.legend({"paddle.nn.Linear weight", "torch.nn.Linear weight"}) plt.figure(figsize=(10, 6)) temp = plt.hist([linear_pd_bias, linear_pt_bias], bins=50, rwidth=0.8, histtype="step") plt.xlabel("value") plt.ylabel("count") plt.legend({"paddle.nn.Linear bias", "torch.nn.Linear bias"}) ``` ### 4.1.3 paddle.nn.Embedding * 默认初始化以及可视化代码 ```python import paddle import torch import numpy as np import matplotlib.pyplot as plt %matplotlib inline eb_pd = paddle.nn.Embedding(1024, 512) eb_pt = torch.nn.Embedding(1024, 512) eb_pd_weight = eb_pd.weight.numpy().reshape((-1, )) eb_pt_weight = eb_pt.weight.detach().numpy().reshape((-1, )) plt.figure(figsize=(10, 6)) temp = plt.hist([eb_pd_weight, eb_pt_weight], bins=100, rwidth=0.8, histtype="step") plt.xlabel("value") plt.ylabel("count") plt.legend({"paddle.nn.Embedding weight", "torch.nn.Embedding weight"}) ``` * 修正后初始化以及可视化代码 ```python import paddle import torch import numpy as np import matplotlib.pyplot as plt import paddle.nn.initializer as init %matplotlib inline # linear的初始化方法同样适用于2.1节中的公式,此处的kernal_size等价于1。 eb_pd = paddle.nn.Embedding(1024, 512, weight_attr=init.Normal(0.0, 1.0)) eb_pt = torch.nn.Embedding(1024, 512) eb_pd_weight = eb_pd.weight.numpy().reshape((-1, )) eb_pt_weight = eb_pt.weight.detach().numpy().reshape((-1, )) plt.figure(figsize=(10, 6)) temp = plt.hist([eb_pd_weight, eb_pt_weight], bins=100, rwidth=0.8, histtype="step") plt.xlabel("value") plt.ylabel("count") plt.legend({"paddle.nn.Embedding weight", "torch.nn.Embedding weight"}) ```