未验证 提交 8aa5be90 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Cherry-pick] Fix CUDA11.8 Unittest Accuracy (#49374)

Fix CUDA11.8 Unittest Accuracy
上级 b5fdd175
...@@ -20,19 +20,22 @@ from functools import partial ...@@ -20,19 +20,22 @@ from functools import partial
class TestResnetGPU(TestResnetBase): class TestResnetGPU(TestResnetBase):
def test_seresnext_with_learning_rate_decay(self): def test_seresnext_with_learning_rate_decay(self):
# NOTE(zcd): This test is compare the result of use parallel_executor # NOTE(zcd): This test is compare the result of use parallel_executor
# and executor, and the result of drop_out op and batch_norm op in # and executor, and the result of drop_out op and batch_norm op in
# this two executor have diff, so the two ops should be removed # this two executor have diff, so the two ops should be removed
# from the model. # from the model.
check_func = partial(self.check_network_convergence, check_func = partial(
self.check_network_convergence,
optimizer=seresnext_net.optimizer, optimizer=seresnext_net.optimizer,
use_parallel_executor=False) use_parallel_executor=False,
self._compare_result_with_origin_model(check_func, )
self._compare_result_with_origin_model(
check_func,
use_device=DeviceType.CUDA, use_device=DeviceType.CUDA,
delta2=1e-5, delta2=1e-3,
compare_separately=False) compare_separately=False,
)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册