未验证 提交 8aa5be90 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Cherry-pick] Fix CUDA11.8 Unittest Accuracy (#49374)

Fix CUDA11.8 Unittest Accuracy
上级 b5fdd175
......@@ -20,19 +20,22 @@ from functools import partial
class TestResnetGPU(TestResnetBase):
def test_seresnext_with_learning_rate_decay(self):
# NOTE(zcd): This test is compare the result of use parallel_executor
# and executor, and the result of drop_out op and batch_norm op in
# this two executor have diff, so the two ops should be removed
# from the model.
check_func = partial(self.check_network_convergence,
optimizer=seresnext_net.optimizer,
use_parallel_executor=False)
self._compare_result_with_origin_model(check_func,
use_device=DeviceType.CUDA,
delta2=1e-5,
compare_separately=False)
check_func = partial(
self.check_network_convergence,
optimizer=seresnext_net.optimizer,
use_parallel_executor=False,
)
self._compare_result_with_origin_model(
check_func,
use_device=DeviceType.CUDA,
delta2=1e-3,
compare_separately=False,
)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册