diff --git a/tutorials/article-implementation/initializer.md b/tutorials/article-implementation/initializer.md new file mode 100644 index 0000000000000000000000000000000000000000..422df8dff3a70aa1b10ae7570489ee769d3b1827 --- /dev/null +++ b/tutorials/article-implementation/initializer.md @@ -0,0 +1,248 @@ + +# 模型å‚æ•°åˆå§‹åŒ–对é½æ–¹æ³• + +# 1. 背景 + +Paddleæ供了大é‡çš„åˆå§‹åŒ–方法,包括`Constant`, `KaimingUniform`, `KaimingNormal`, `TruncatedNormal`, `Uniform`, `XavierNormal`, `XavierUniform`ç‰ï¼Œåˆé€‚çš„åˆå§‹åŒ–方法能够帮助模型快速地收敛或者达到更高的精度。 + +论文å¤çŽ°çš„过程ä¸ï¼Œåœ¨è®ç»ƒå¯¹é½çŽ¯èŠ‚,需è¦ä¿è¯Paddleçš„å¤çŽ°ä»£ç å’Œå‚考代ç ä¿æŒä¸€è‡´ï¼Œä»Žè€Œå®žçŽ°å®Œå…¨å¯¹é½ã€‚然而由于ä¸åŒæ¡†æž¶çš„差异性,部分APIä¸å‚æ•°æ供的默认åˆå§‹åŒ–方法有区别,该文档以`nn.Conv2D`以åŠ`nn.Linear`这两个最常用的API为例,介ç»æ€Žæ ·å®žçŽ°å¯¹é½ã€‚ + +**更多å‚考链接:** + +* Paddleåˆå§‹åŒ–相关API链接:[åˆå§‹åŒ–API官网文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/Overview_cn.html#chushihuaxiangguan) +* Paddleæ供的åˆå§‹åŒ–æ–¹å¼ä¸ºç›´æŽ¥ä¿®æ”¹APIçš„`ParamAttr`,与`torch.nn.init`ç‰ç³»åˆ—API的使用方å¼ä¸åŒï¼ŒPaddleDetectionä¸å®žçŽ°äº†ä¸Ž`torch.nn.init`系列API完全对é½çš„åˆå§‹åŒ–API,包括`uniform_`, `normal_`, `constant_`, `ones_`, `zeros_`, `xavier_uniform_`, `xavier_normal_`, `kaiming_uniform_`, `kaiming_normal_`, `linear_init_`, `conv_init_`,å¯ä»¥å‚考[initializer.py](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/initializer.py),查看更多的实现细节。 + + +# 2. ä¸åŒæ¡†æž¶çš„åˆå§‹åŒ–差异 + +## 2.1 默认åˆå§‹åŒ–的对é½æ–¹æ³• + +在æ¤æƒ…况下,一般需è¦æŸ¥çœ‹æ–‡æ¡£ï¼Œäº†è§£å‚考代ç çš„åˆå§‹åŒ–方法,从而通过修改åˆå§‹åŒ–方法,实现åˆå§‹åŒ–的对é½ã€‚ + +下é¢ä»¥`nn.Conv2D` API为例进行说明。 + +* **Step1:** 基于Paddle与torch,定义2个å·ç§¯æ“作,绘制其weightå‚数的直方图,如下所示。 + +```python +import paddle +import torch +import numpy as np +import matplotlib.pyplot as plt +%matplotlib inline + +conv2d_pd = paddle.nn.Conv2D(4096, 512, 3) +conv2d_pt = torch.nn.Conv2d(4096, 512, 3) + +conv2d_pd_weight = conv2d_pd.weight.numpy().reshape((-1, )) +conv2d_pt_weight = conv2d_pt.weight.detach().numpy().reshape((-1, )) +plt.figure(figsize=(10, 6)) +temp = plt.hist([conv2d_pd_weight, conv2d_pt_weight], bins=100, rwidth=0.8, histtype="step") +plt.xlabel("value") +plt.ylabel("count") +plt.legend({"paddle.nn.Conv2D weight", "torch.nn.Conv2d weight"}) +``` + +<div align="center"> +<img src="https://paddle-model-ecology.bj.bcebos.com/images/initializer/conv2d_weight_default_diff.jpeg" width = "600" /> +</div> + +结åˆ[paddle文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/ParamAttr_cn.html#paramattr)å’Œ[torch文档](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html?highlight=conv2d#torch.nn.Conv2d)å¯çŸ¥ï¼Œpaddleçš„åˆå§‹åŒ–是`XavierNormal`,torchçš„åˆå§‹åŒ–是`uniform`,åˆå§‹åŒ–方法边界值是`(-sqrt(groups/(in_channels*prod(*kernal_size))), sqrt(groups/(in_channels*prod(*kernal_size))))`。 + + +* **Step2:** 由上述分æžï¼ŒåŸºäºŽ`paddle.nn.initializer.Uniform` API,自定义Paddleä¸Conv2Dçš„åˆå§‹åŒ–,代ç 如下所示: + +```python +import paddle +import torch +import numpy as np +import matplotlib.pyplot as plt +import paddle.nn.initializer as init +%matplotlib inline +# 该例åä¸ï¼Œå¯¹åº”上述公å¼çš„group=1,in_channels=4096,kernal_size=3,由于二维å·ç§¯çš„å·ç§¯æ ¸æ˜¯äºŒç»´çš„,所以æ¤å¤„的结果为4096*3*3 +conv2d_pd = paddle.nn.Conv2D(4096, 512, 3, + init.Uniform(-1/math.sqrt(4096*3*3), 1/math.sqrt(4096*3*3))) +conv2d_pt = torch.nn.Conv2d(4096, 512, 3) + +conv2d_pd_weight = conv2d_pd.weight.numpy().reshape((-1, )) +conv2d_pt_weight = conv2d_pt.weight.detach().numpy().reshape((-1, )) +plt.figure(figsize=(10, 6)) +temp = plt.hist([conv2d_pd_weight, conv2d_pt_weight], bins=100, rwidth=0.8, histtype="step") +plt.xlabel("value") +plt.ylabel("count") +plt.legend({"paddle.nn.Conv2D weight", "torch.nn.Conv2d weight"}) +``` + +<div align="center"> +<img src="https://paddle-model-ecology.bj.bcebos.com/images/initializer/conv2d_weight_fixed_diff.jpeg" width = "600" /> +</div> + +从图ä¸å¯çŸ¥ï¼ŒäºŒè€…çš„åˆå§‹åŒ–å‚数分布实现一致。 + + +## 2.2 自定义åˆå§‹åŒ–的对é½æ–¹æ³• + + +部分å‚考代ç ä¸ï¼Œåˆå§‹åŒ–的方法是通过使用`torch.nn.init`系列API实现,å¯ä»¥è®¤ä¸ºæ˜¯è‡ªå®šä¹‰åˆå§‹åŒ–。例如:[resnet](https://github.com/pytorch/vision/blob/ec1c2a12cf00c6df83c7fb88f75b8117cda2f970/torchvision/models/resnet.py#L208)ä¸ä½¿ç”¨çš„`kaiming_normal_`ä¼ å…¥äº†`mode`å’Œ`nonlinearity`两个å‚数: + +```python +if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") +``` + +在这类问题ä¸ï¼Œå¯ä»¥å…ˆå°è¯•ä½¿ç”¨`2.1`ç« èŠ‚ä¸çš„`Step1`,查看使用PaddleåŒååˆå§‹åŒ–æ–¹å¼çš„默认å‚数是å¦èƒ½å¤Ÿå¯¹é½ã€‚å¦‚æžœæ— æ³•å¯¹é½ï¼Œå¯ä»¥æŸ¥é˜…[initializer.py](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/initializer.py),使用该文件ä¸çš„åˆå§‹åŒ–函数,实现对é½ã€‚ + +ä¸åŒæ¡†æž¶çš„åˆå§‹åŒ–方法有所ä¸åŒï¼Œå¼€å‘者论文å¤çŽ°è¿‡ç¨‹ä¸éš¾ä»¥æŽ’æŸ¥ï¼Œå› æ¤ä¸‹é¢ç¬¬3ç« ä»‹ç»é€šè¿‡è‡ªå®šä¹‰åˆå§‹åŒ–çš„æ–¹å¼ï¼Œå®žçŽ°ä¸åŒæ¡†æž¶çš„å‚æ•°åˆå§‹åŒ–åˆ†å¸ƒä¸€è‡´ï¼Œæœ€ç»ˆå¸®åŠ©å¤§å®¶æ›´åŠ é¡ºåˆ©åœ°å®Œæˆè®ºæ–‡å¤çŽ°ã€‚ + +**注æ„:** BatchNorm2Dç‰å¤§å¤šæ•°çš„APIä¸ï¼Œå¯å¦ä¹ å‚æ•°çš„åˆå§‹åŒ–分布相åŒï¼Œåœ¨æ¤ä¸ºè¿›ä¸€æ¥å¯¹æ¯”,也给出其æƒé‡çš„å¯è§†åŒ–对比图åƒã€‚ + +# 3. åˆå§‹åŒ–å‚数分布对比 + +## 3.1 默认åˆå§‹åŒ–ä¸åŒçš„APIæƒé‡ç›´æ–¹å›¾å¯¹æ¯” + +| Paddle API | torch API | 默认åˆå§‹åŒ–方法的å‚数分布对比图 | 修改åˆå§‹åŒ–å‚数方法 | 修改之åŽçš„å‚数分布对比 | +|:---------:|:------------------:|:------------:|:------------:|:------------:| +| `paddle.nn.Conv2D` weightå‚æ•° | `torch.nn.Conv2d` weightå‚æ•° |  | è§é™„录`4.1` |  | +| `paddle.nn.Conv2D` biaså‚æ•° | `torch.nn.Conv2d` biaså‚æ•° |  | è§é™„录`4.1` |  | +| `paddle.nn.Linear` weightå‚æ•° | `torch.nn.Linear` weightå‚æ•° |  | è§é™„录`4.2` |  | +| `paddle.nn.Linear` biaså‚æ•° | `torch.nn.Linear` biaså‚æ•° |  | è§é™„录`4.2` |  | + + + +## 3.2 默认åˆå§‹åŒ–相åŒçš„APIæƒé‡ç›´æ–¹å›¾å¯¹æ¯” + +| Paddle API | torch API | 默认åˆå§‹åŒ–方法的å‚数分布对比图 | +|:---------:|:------------------:|:------------:| +| `paddle.nn.BatchNorm2D` weightå‚æ•° | `torch.nn.BatchNorm2d` weightå‚æ•° |  | +| `paddle.nn.BatchNorm2D` biaså‚æ•° | `torch.nn.BatchNorm2d` biaså‚æ•° |  | + +# 4. 附录 + +## 4.1 åˆå§‹åŒ–对é½ä»£ç + +### 4.1.1 paddle.nn.Conv2D + +* 默认åˆå§‹åŒ–以åŠå¯è§†åŒ–代ç + +```python +import paddle +import torch +import numpy as np +import matplotlib.pyplot as plt +%matplotlib inline + +conv2d_pd = paddle.nn.Conv2D(4096, 512, 3) +conv2d_pt = torch.nn.Conv2d(4096, 512, 3) + +conv2d_pd_weight = conv2d_pd.weight.numpy().reshape((-1, )) +conv2d_pd_bias = conv2d_pd.bias.numpy().reshape((-1, )) +conv2d_pt_weight = conv2d_pt.weight.detach().numpy().reshape((-1, )) +conv2d_pt_bias = conv2d_pd.bias.numpy().reshape((-1, )) +plt.figure(figsize=(10, 6)) +temp = plt.hist([conv2d_pd_weight, conv2d_pt_weight], bins=100, rwidth=0.8, histtype="step") +plt.xlabel("value") +plt.ylabel("count") +plt.legend({"paddle.nn.Conv2D weight", "torch.nn.Conv2d weight"}) + +plt.figure(figsize=(10, 6)) +temp = plt.hist([conv2d_pd_bias, conv2d_pt_bias], bins=50, rwidth=0.8, histtype="step") +plt.xlabel("value") +plt.ylabel("count") +plt.legend({"paddle.nn.Conv2D bias", "torch.nn.Conv2d bias"}) +``` + + + +* ä¿®æ£åŽåˆå§‹åŒ–以åŠå¯è§†åŒ–代ç + +```python +import paddle +import torch +import numpy as np +import matplotlib.pyplot as plt +import paddle.nn.initializer as init +%matplotlib inline + +conv2d_pd = paddle.nn.Conv2D(4096, 512, 3, + weight_attr=init.Uniform(-1/math.sqrt(4096*3*3), 1/math.sqrt(4096*3*3)), + bias_attr=init.Uniform(-1/math.sqrt(4096*3*3), 1/math.sqrt(4096*3*3))) +conv2d_pt = torch.nn.Conv2d(4096, 512, 3) + +conv2d_pd_weight = conv2d_pd.weight.numpy().reshape((-1, )) +conv2d_pd_bias = conv2d_pd.bias.numpy().reshape((-1, )) +conv2d_pt_weight = conv2d_pt.weight.detach().numpy().reshape((-1, )) +conv2d_pt_bias = conv2d_pd.bias.numpy().reshape((-1, )) +plt.figure(figsize=(10, 6)) +temp = plt.hist([conv2d_pd_weight, conv2d_pt_weight], bins=100, rwidth=0.8, histtype="step") +plt.xlabel("value") +plt.ylabel("count") +plt.legend({"paddle.nn.Conv2D weight", "torch.nn.Conv2d weight"}) + +plt.figure(figsize=(10, 6)) +temp = plt.hist([conv2d_pd_bias, conv2d_pt_bias], bins=50, rwidth=0.8, histtype="step") +plt.xlabel("value") +plt.ylabel("count") +plt.legend({"paddle.nn.Conv2D bias", "torch.nn.Conv2d bias"}) +``` + + +### 4.1.2 paddle.nn.Linear + +* 默认åˆå§‹åŒ–以åŠå¯è§†åŒ–代ç + +```python +import paddle +import torch +import numpy as np +import matplotlib.pyplot as plt +%matplotlib inline + +linear_pd = paddle.nn.Linear(4096, 512) +linear_pt = torch.nn.Linear(4096, 512) + +linear_pd_weight = linear_pd.weight.numpy().reshape((-1, )) +linear_pd_bias = linear_pd.bias.numpy().reshape((-1, )) +linear_pt_weight = linear_pt.weight.detach().numpy().reshape((-1, )) +linear_pt_bias = linear_pt.bias.numpy().reshape((-1, )) +plt.figure(figsize=(10, 6)) +temp = plt.hist([linear_pd_weight, linear_pt_weight], bins=100, rwidth=0.8, histtype="step") +plt.xlabel("value") +plt.ylabel("count") +plt.legend({"paddle.nn.Linear weight", "torch.nn.Linear weight"}) + +plt.figure(figsize=(10, 6)) +temp = plt.hist([linear_pd_bias, linear_pt_bias], bins=50, rwidth=0.8, histtype="step") +plt.xlabel("value") +plt.ylabel("count") +plt.legend({"paddle.nn.Linear bias", "torch.nn.Linear bias"}) +``` + +* ä¿®æ£åŽåˆå§‹åŒ–以åŠå¯è§†åŒ–代ç + +```python +import paddle +import torch +import numpy as np +import matplotlib.pyplot as plt +%matplotlib inline +# linearçš„åˆå§‹åŒ–方法åŒæ ·é€‚用于2.1节ä¸çš„å…¬å¼ï¼Œæ¤å¤„çš„kernal_sizeç‰ä»·äºŽ1。 +linear_pd = paddle.nn.Linear(4096, 512, + weight_attr=init.Uniform(-1/math.sqrt(4096), 1/math.sqrt(4096)), + bias_attr=init.Uniform(-1/math.sqrt(4096), 1/math.sqrt(4096))) +linear_pt = torch.nn.Linear(4096, 512) + +linear_pd_weight = linear_pd.weight.numpy().reshape((-1, )) +linear_pd_bias = linear_pd.bias.numpy().reshape((-1, )) +linear_pt_weight = linear_pt.weight.detach().numpy().reshape((-1, )) +linear_pt_bias = linear_pt.bias.numpy().reshape((-1, )) +plt.figure(figsize=(10, 6)) +temp = plt.hist([linear_pd_weight, linear_pt_weight], bins=100, rwidth=0.8, histtype="step") +plt.xlabel("value") +plt.ylabel("count") +plt.legend({"paddle.nn.Linear weight", "torch.nn.Linear weight"}) + +plt.figure(figsize=(10, 6)) +temp = plt.hist([linear_pd_bias, linear_pt_bias], bins=50, rwidth=0.8, histtype="step") +plt.xlabel("value") +plt.ylabel("count") +plt.legend({"paddle.nn.Linear bias", "torch.nn.Linear bias"}) +```