未验证 提交 b66c833f 编写于 作者: H HongyuJia 提交者: GitHub

[CustomOP Unittest] Optimize unit test, save setUp time (#52889)

上级 dd2a749a
......@@ -195,7 +195,15 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
paddle.seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
def test_static(self):
def test_all(self):
self._test_static()
self._test_dynamic()
self._test_static_save_and_load_inference_model()
self._test_static_save_and_run_inference_predictor()
self._test_double_grad_dynamic()
self._test_with_dataloader()
def _test_static(self):
for device in self.devices:
for dtype in self.dtypes:
if device == 'cpu' and dtype == 'float16':
......@@ -208,7 +216,7 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
)
check_output(out, pd_out, "out")
def test_dynamic(self):
def _test_dynamic(self):
for device in self.devices:
for dtype in self.dtypes:
if device == 'cpu' and dtype == 'float16':
......@@ -224,7 +232,7 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
check_output(out, pd_out, "out")
check_output(x_grad, pd_x_grad, "x_grad")
def test_static_save_and_load_inference_model(self):
def _test_static_save_and_load_inference_model(self):
paddle.enable_static()
np_data = np.random.random((1, 1, 28, 28)).astype("float32")
np_label = np.random.random((1, 1)).astype("int64")
......@@ -249,7 +257,7 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
check_output(predict, predict_infer, "predict")
paddle.disable_static()
def test_static_save_and_run_inference_predictor(self):
def _test_static_save_and_run_inference_predictor(self):
paddle.enable_static()
np_data = np.random.random((1, 1, 28, 28)).astype("float32")
np_label = np.random.random((1, 1)).astype("int64")
......@@ -280,7 +288,7 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
check_output_allclose(predict, predict_infer, "predict")
paddle.disable_static()
def test_double_grad_dynamic(self):
def _test_double_grad_dynamic(self):
for device in self.devices:
for dtype in self.dtypes:
if device == 'cpu' and dtype == 'float16':
......@@ -295,7 +303,7 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
check_output(out, pd_out, "out")
check_output(dx_grad, pd_dx_grad, "dx_grad")
def test_with_dataloader(self):
def _test_with_dataloader(self):
for device in self.devices:
paddle.set_device(device)
# data loader
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册