未验证 提交 cf8e86df 编写于 作者: H Huihuang Zheng 提交者: GitHub

[CINN] Enable test_resnet50_with_cinn (#44017)

上级 a42f48bd
......@@ -1687,3 +1687,22 @@ if($ENV{USE_STANDALONE_EXECUTOR})
set_tests_properties(test_imperative_mnist_sorted_gradient
PROPERTIES ENVIRONMENT FLAGS_USE_STANDALONE_EXECUTOR=0)
endif()
if(WITH_CINN AND WITH_TESTING)
set_tests_properties(
test_resnet50_with_cinn
PROPERTIES
LABELS
"RUN_TYPE=CINN"
ENVIRONMENT
FLAGS_allow_cinn_ops="conv2d;conv2d_grad;elementwise_add;elementwise_add_grad;relu;relu_grad;sum"
)
set_tests_properties(
test_parallel_executor_run_cinn
PROPERTIES
LABELS
"RUN_TYPE=CINN"
ENVIRONMENT
FLAGS_allow_cinn_ops="conv2d;conv2d_grad;elementwise_add;elementwise_add_grad;relu;relu_grad;sum"
)
endif()
......@@ -108,6 +108,10 @@ class TestResnet50Accuracy(unittest.TestCase):
loss_c = self.train(place, loop_num, feed, use_cinn=True)
loss_p = self.train(place, loop_num, feed, use_cinn=False)
print("Losses of CINN:")
print(loss_c)
print("Losses of Paddle")
print(loss_p)
self.assertTrue(np.allclose(loss_c, loss_p, atol=1e-5))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册