未验证 提交 6d87f600 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

fix thres in Step1 (#5475)

* fix ste1

* fix in re-impl cv

* Update 01_test_forward.py
上级 8b9fd465
...@@ -277,7 +277,7 @@ if __name__ == "__main__": ...@@ -277,7 +277,7 @@ if __name__ == "__main__":
* 模型在前向对齐验证时,需要调用`model.eval()`方法,保证组网中的随机量被关闭,比如BatchNorm、Dropout等。 * 模型在前向对齐验证时,需要调用`model.eval()`方法,保证组网中的随机量被关闭,比如BatchNorm、Dropout等。
* 给定相同的输入数据,为保证可复现性,如果有随机数生成,固定相关的随机种子。 * 给定相同的输入数据,为保证可复现性,如果有随机数生成,固定相关的随机种子。
* 我们可以基于reprod logger 的比较结果判断对齐效果,一般误差在1e-6附近的话,可以认为前向没有问题。 * 我们可以基于reprod logger 的比较结果判断对齐效果,一般误差在1e-5附近的话,可以认为前向没有问题。
* 如果最终输出结果diff较大,可以使用二分的方法进行排查,比如说ResNet50,包含1个stem、4个res-stage、global avg-pooling以及最后的fc层,那么完成模型组网和权重转换之后,如果模型输出没有对齐,可以尝试输出中间某一个res-stage的tensor进行对比,如果相同,则向后进行排查;如果不同,则继续向前进行排查,以此类推,直到找到导致没有对齐的操作。 * 如果最终输出结果diff较大,可以使用二分的方法进行排查,比如说ResNet50,包含1个stem、4个res-stage、global avg-pooling以及最后的fc层,那么完成模型组网和权重转换之后,如果模型输出没有对齐,可以尝试输出中间某一个res-stage的tensor进行对比,如果相同,则向后进行排查;如果不同,则继续向前进行排查,以此类推,直到找到导致没有对齐的操作。
**【实战】** **【实战】**
......
...@@ -9,6 +9,10 @@ from mobilenetv3_ref.torchvision.models import mobilenet_v3_small as mv3_small_t ...@@ -9,6 +9,10 @@ from mobilenetv3_ref.torchvision.models import mobilenet_v3_small as mv3_small_t
def test_forward(): def test_forward():
device = "gpu" # you can also set it as "cpu"
torch_device = torch.device("cuda:0" if device == "gpu" else "cpu")
paddle.set_device(device)
# load paddle model # load paddle model
paddle_model = mv3_small_paddle() paddle_model = mv3_small_paddle()
paddle_model.eval() paddle_model.eval()
...@@ -21,6 +25,8 @@ def test_forward(): ...@@ -21,6 +25,8 @@ def test_forward():
torch_state_dict = torch.load("./data/mobilenet_v3_small-047dcff4.pth") torch_state_dict = torch.load("./data/mobilenet_v3_small-047dcff4.pth")
torch_model.load_state_dict(torch_state_dict) torch_model.load_state_dict(torch_state_dict)
torch_model.to(torch_device)
# load data # load data
inputs = np.load("./data/fake_data.npy") inputs = np.load("./data/fake_data.npy")
...@@ -31,7 +37,9 @@ def test_forward(): ...@@ -31,7 +37,9 @@ def test_forward():
reprod_logger.save("./result/forward_paddle.npy") reprod_logger.save("./result/forward_paddle.npy")
# save the torch output # save the torch output
torch_out = torch_model(torch.tensor(inputs, dtype=torch.float32)) torch_out = torch_model(
torch.tensor(
inputs, dtype=torch.float32).to(torch_device))
reprod_logger.add("logits", torch_out.cpu().detach().numpy()) reprod_logger.add("logits", torch_out.cpu().detach().numpy())
reprod_logger.save("./result/forward_ref.npy") reprod_logger.save("./result/forward_ref.npy")
...@@ -46,4 +54,5 @@ if __name__ == "__main__": ...@@ -46,4 +54,5 @@ if __name__ == "__main__":
# compare result and produce log # compare result and produce log
diff_helper.compare_info(torch_info, paddle_info) diff_helper.compare_info(torch_info, paddle_info)
diff_helper.report(path="./result/log/forward_diff.log") diff_helper.report(
path="./result/log/forward_diff.log", diff_threshold=1e-5)
...@@ -157,14 +157,12 @@ python 01_test_forward.py ...@@ -157,14 +157,12 @@ python 01_test_forward.py
根据示例代码可以看到,我们将结果保存在`result/log/forward_diff.log`中,打开对应文件或者直接观察命令行输出,就会有下列结果: 根据示例代码可以看到,我们将结果保存在`result/log/forward_diff.log`中,打开对应文件或者直接观察命令行输出,就会有下列结果:
```bash ```bash
[2021/12/21 15:00:38] root INFO: logits: [2022/02/28 05:31:40] root INFO: logits:
[2021/12/21 15:00:38] root INFO: mean diff: check passed: False, value: 2.308018565599923e-06 [2022/02/28 05:31:40] root INFO: mean diff: check passed: True, value: 1.7629824924370041e-06
[2021/12/21 15:00:38] root INFO: diff check failed [2022/02/28 05:31:40] root INFO: diff check passed
``` ```
这里我们发现在`reprod_log`默认的平均差异小于1e-6的标准下,当前前向对齐是不符合条件的,但是这是由于前向 op 计算导致的微小的差异。 由于前向 op 计算导致的微小的差异。一般说来前向误差在 1e-5 左右的 diff 是可以接受的,到这里我们就验证了网络的前向是对齐的,完成了第一个打卡点。
一般说来前向误差在 1e-5 左右都是可以接受的,到这里我们就验证了网络的前向是对齐的,完成了第一个打卡点。
<a name="4.2"></a> <a name="4.2"></a>
### 4.2 数据加载对齐 ### 4.2 数据加载对齐
......
from torchvision.transforms import autoaugment, transforms from torchvision.transforms import transforms
class ClassificationPresetTrain: class ClassificationPresetTrain:
......
[2022/01/03 16:50:19] root INFO: logits: [2022/02/28 05:31:40] root INFO: logits:
[2022/01/03 16:50:19] root INFO: mean diff: check passed: False, value: 2.308018565599923e-06 [2022/02/28 05:31:40] root INFO: mean diff: check passed: True, value: 1.7629824924370041e-06
[2022/01/03 16:50:19] root INFO: diff check failed [2022/02/28 05:31:40] root INFO: diff check passed
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册