未验证 提交 6d87f600 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

fix thres in Step1 (#5475)

* fix ste1

* fix in re-impl cv

* Update 01_test_forward.py
上级 8b9fd465
......@@ -277,7 +277,7 @@ if __name__ == "__main__":
* 模型在前向对齐验证时,需要调用`model.eval()`方法,保证组网中的随机量被关闭,比如BatchNorm、Dropout等。
* 给定相同的输入数据,为保证可复现性,如果有随机数生成,固定相关的随机种子。
* 我们可以基于reprod logger 的比较结果判断对齐效果,一般误差在1e-6附近的话,可以认为前向没有问题。
* 我们可以基于reprod logger 的比较结果判断对齐效果,一般误差在1e-5附近的话,可以认为前向没有问题。
* 如果最终输出结果diff较大,可以使用二分的方法进行排查,比如说ResNet50,包含1个stem、4个res-stage、global avg-pooling以及最后的fc层,那么完成模型组网和权重转换之后,如果模型输出没有对齐,可以尝试输出中间某一个res-stage的tensor进行对比,如果相同,则向后进行排查;如果不同,则继续向前进行排查,以此类推,直到找到导致没有对齐的操作。
**【实战】**
......
......@@ -9,6 +9,10 @@ from mobilenetv3_ref.torchvision.models import mobilenet_v3_small as mv3_small_t
def test_forward():
device = "gpu" # you can also set it as "cpu"
torch_device = torch.device("cuda:0" if device == "gpu" else "cpu")
paddle.set_device(device)
# load paddle model
paddle_model = mv3_small_paddle()
paddle_model.eval()
......@@ -21,6 +25,8 @@ def test_forward():
torch_state_dict = torch.load("./data/mobilenet_v3_small-047dcff4.pth")
torch_model.load_state_dict(torch_state_dict)
torch_model.to(torch_device)
# load data
inputs = np.load("./data/fake_data.npy")
......@@ -31,7 +37,9 @@ def test_forward():
reprod_logger.save("./result/forward_paddle.npy")
# save the torch output
torch_out = torch_model(torch.tensor(inputs, dtype=torch.float32))
torch_out = torch_model(
torch.tensor(
inputs, dtype=torch.float32).to(torch_device))
reprod_logger.add("logits", torch_out.cpu().detach().numpy())
reprod_logger.save("./result/forward_ref.npy")
......@@ -46,4 +54,5 @@ if __name__ == "__main__":
# compare result and produce log
diff_helper.compare_info(torch_info, paddle_info)
diff_helper.report(path="./result/log/forward_diff.log")
diff_helper.report(
path="./result/log/forward_diff.log", diff_threshold=1e-5)
......@@ -157,14 +157,12 @@ python 01_test_forward.py
根据示例代码可以看到,我们将结果保存在`result/log/forward_diff.log`中,打开对应文件或者直接观察命令行输出,就会有下列结果:
```bash
[2021/12/21 15:00:38] root INFO: logits:
[2021/12/21 15:00:38] root INFO: mean diff: check passed: False, value: 2.308018565599923e-06
[2021/12/21 15:00:38] root INFO: diff check failed
[2022/02/28 05:31:40] root INFO: logits:
[2022/02/28 05:31:40] root INFO: mean diff: check passed: True, value: 1.7629824924370041e-06
[2022/02/28 05:31:40] root INFO: diff check passed
```
这里我们发现在`reprod_log`默认的平均差异小于1e-6的标准下,当前前向对齐是不符合条件的,但是这是由于前向 op 计算导致的微小的差异。
一般说来前向误差在 1e-5 左右都是可以接受的,到这里我们就验证了网络的前向是对齐的,完成了第一个打卡点。
由于前向 op 计算导致的微小的差异。一般说来前向误差在 1e-5 左右的 diff 是可以接受的,到这里我们就验证了网络的前向是对齐的,完成了第一个打卡点。
<a name="4.2"></a>
### 4.2 数据加载对齐
......
from torchvision.transforms import autoaugment, transforms
from torchvision.transforms import transforms
class ClassificationPresetTrain:
......
[2022/01/03 16:50:19] root INFO: logits:
[2022/01/03 16:50:19] root INFO: mean diff: check passed: False, value: 2.308018565599923e-06
[2022/01/03 16:50:19] root INFO: diff check failed
[2022/02/28 05:31:40] root INFO: logits:
[2022/02/28 05:31:40] root INFO: mean diff: check passed: True, value: 1.7629824924370041e-06
[2022/02/28 05:31:40] root INFO: diff check passed
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册