未验证 提交 b66c833f 编写于 作者: H HongyuJia 提交者: GitHub

[CustomOP Unittest] Optimize unit test, save setUp time (#52889)

上级 dd2a749a
...@@ -195,7 +195,15 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): ...@@ -195,7 +195,15 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
paddle.seed(SEED) paddle.seed(SEED)
paddle.framework.random._manual_program_seed(SEED) paddle.framework.random._manual_program_seed(SEED)
def test_static(self): def test_all(self):
self._test_static()
self._test_dynamic()
self._test_static_save_and_load_inference_model()
self._test_static_save_and_run_inference_predictor()
self._test_double_grad_dynamic()
self._test_with_dataloader()
def _test_static(self):
for device in self.devices: for device in self.devices:
for dtype in self.dtypes: for dtype in self.dtypes:
if device == 'cpu' and dtype == 'float16': if device == 'cpu' and dtype == 'float16':
...@@ -208,7 +216,7 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): ...@@ -208,7 +216,7 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
) )
check_output(out, pd_out, "out") check_output(out, pd_out, "out")
def test_dynamic(self): def _test_dynamic(self):
for device in self.devices: for device in self.devices:
for dtype in self.dtypes: for dtype in self.dtypes:
if device == 'cpu' and dtype == 'float16': if device == 'cpu' and dtype == 'float16':
...@@ -224,7 +232,7 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): ...@@ -224,7 +232,7 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
check_output(out, pd_out, "out") check_output(out, pd_out, "out")
check_output(x_grad, pd_x_grad, "x_grad") check_output(x_grad, pd_x_grad, "x_grad")
def test_static_save_and_load_inference_model(self): def _test_static_save_and_load_inference_model(self):
paddle.enable_static() paddle.enable_static()
np_data = np.random.random((1, 1, 28, 28)).astype("float32") np_data = np.random.random((1, 1, 28, 28)).astype("float32")
np_label = np.random.random((1, 1)).astype("int64") np_label = np.random.random((1, 1)).astype("int64")
...@@ -249,7 +257,7 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): ...@@ -249,7 +257,7 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
check_output(predict, predict_infer, "predict") check_output(predict, predict_infer, "predict")
paddle.disable_static() paddle.disable_static()
def test_static_save_and_run_inference_predictor(self): def _test_static_save_and_run_inference_predictor(self):
paddle.enable_static() paddle.enable_static()
np_data = np.random.random((1, 1, 28, 28)).astype("float32") np_data = np.random.random((1, 1, 28, 28)).astype("float32")
np_label = np.random.random((1, 1)).astype("int64") np_label = np.random.random((1, 1)).astype("int64")
...@@ -280,7 +288,7 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): ...@@ -280,7 +288,7 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
check_output_allclose(predict, predict_infer, "predict") check_output_allclose(predict, predict_infer, "predict")
paddle.disable_static() paddle.disable_static()
def test_double_grad_dynamic(self): def _test_double_grad_dynamic(self):
for device in self.devices: for device in self.devices:
for dtype in self.dtypes: for dtype in self.dtypes:
if device == 'cpu' and dtype == 'float16': if device == 'cpu' and dtype == 'float16':
...@@ -295,7 +303,7 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): ...@@ -295,7 +303,7 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
check_output(out, pd_out, "out") check_output(out, pd_out, "out")
check_output(dx_grad, pd_dx_grad, "dx_grad") check_output(dx_grad, pd_dx_grad, "dx_grad")
def test_with_dataloader(self): def _test_with_dataloader(self):
for device in self.devices: for device in self.devices:
paddle.set_device(device) paddle.set_device(device)
# data loader # data loader
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册