From 1c39efb783c62f20a039459f4974c028074accd4 Mon Sep 17 00:00:00 2001 From: Leo Zhao <48052473+LeoZhao-Intel@users.noreply.github.com> Date: Sat, 4 Jan 2020 12:59:22 +0800 Subject: [PATCH] Enable test conv2d ngraph (#22074) --- .../paddle/fluid/tests/unittests/ngraph/CMakeLists.txt | 2 -- .../tests/unittests/ngraph/test_conv2d_ngraph_op.py | 9 ++++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/ngraph/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ngraph/CMakeLists.txt index e9866699a60..5ed2d0aa80c 100644 --- a/python/paddle/fluid/tests/unittests/ngraph/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ngraph/CMakeLists.txt @@ -1,8 +1,6 @@ file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") -list(REMOVE_ITEM TEST_OPS test_conv2d_ngraph_op) - foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS FLAGS_use_ngraph=true) endforeach(TEST_OP) diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_conv2d_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_conv2d_ngraph_op.py index 4894af949a2..7a9adcf4cd8 100644 --- a/python/paddle/fluid/tests/unittests/ngraph/test_conv2d_ngraph_op.py +++ b/python/paddle/fluid/tests/unittests/ngraph/test_conv2d_ngraph_op.py @@ -20,6 +20,13 @@ from test_conv2d_op import TestConv2dOp, TestWithPad, TestWithStride, TestWithGr import numpy as np +class TestNGRAPHWithStride(TestWithStride): + def init_test_case(self): + super(TestNGRAPHWithStride, self).init_test_case() + self.use_cuda = False + self.dtype = np.float32 + + class TestNGRAPHDepthwiseConv(TestDepthwiseConv): def init_test_case(self): super(TestNGRAPHDepthwiseConv, self).init_test_case() @@ -55,7 +62,7 @@ class TestNGRAPHDepthwiseConvWithDilation2(TestDepthwiseConvWithDilation2): self.dtype = np.float32 -del TestDepthwiseConv, TestDepthwiseConv2, TestDepthwiseConv3, TestDepthwiseConvWithDilation, TestDepthwiseConvWithDilation2 +del TestWithStride, TestDepthwiseConv, TestDepthwiseConv2, TestDepthwiseConv3, TestDepthwiseConvWithDilation, TestDepthwiseConvWithDilation2 if __name__ == '__main__': unittest.main() -- GitLab