From bce9c8c4e97e30406e5bfd78feeeec3c31a80601 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Sun, 3 Apr 2022 11:19:52 +0800 Subject: [PATCH] [Eager] Support two callback related tests (#41275) --- .../tests/test_callback_reduce_lr_on_plateau.py | 15 +++++++++++++-- python/paddle/tests/test_callback_visualdl.py | 8 +++++++- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/python/paddle/tests/test_callback_reduce_lr_on_plateau.py b/python/paddle/tests/test_callback_reduce_lr_on_plateau.py index e950528ee4b..d7680537f37 100644 --- a/python/paddle/tests/test_callback_reduce_lr_on_plateau.py +++ b/python/paddle/tests/test_callback_reduce_lr_on_plateau.py @@ -29,6 +29,7 @@ from paddle.hapi.callbacks import config_callbacks from paddle.vision.datasets import MNIST from paddle.metric import Accuracy from paddle.nn.layer.loss import CrossEntropyLoss +from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph # Accelerate unittest @@ -38,7 +39,7 @@ class CustomMnist(MNIST): class TestReduceLROnPlateau(unittest.TestCase): - def test_reduce_lr_on_plateau(self): + def func_reduce_lr_on_plateau(self): transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])]) train_dataset = CustomMnist(mode='train', transform=transform) val_dataset = CustomMnist(mode='test', transform=transform) @@ -59,7 +60,12 @@ class TestReduceLROnPlateau(unittest.TestCase): epochs=10, callbacks=[callbacks]) - def test_warn_or_error(self): + def test_reduce_lr_on_plateau(self): + with _test_eager_guard(): + self.func_reduce_lr_on_plateau() + self.func_reduce_lr_on_plateau() + + def func_warn_or_error(self): with self.assertRaises(ValueError): paddle.callbacks.ReduceLROnPlateau(factor=2.0) # warning @@ -101,6 +107,11 @@ class TestReduceLROnPlateau(unittest.TestCase): epochs=3, callbacks=[callbacks]) + def test_warn_or_error(self): + with _test_eager_guard(): + self.func_warn_or_error() + self.func_warn_or_error() + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tests/test_callback_visualdl.py b/python/paddle/tests/test_callback_visualdl.py index db3b83f2b14..355e88edd2b 100644 --- a/python/paddle/tests/test_callback_visualdl.py +++ b/python/paddle/tests/test_callback_visualdl.py @@ -29,6 +29,7 @@ import paddle.vision.transforms as T from paddle.vision.datasets import MNIST from paddle.metric import Accuracy from paddle.nn.layer.loss import CrossEntropyLoss +from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph class MnistDataset(MNIST): @@ -43,7 +44,7 @@ class TestCallbacks(unittest.TestCase): def tearDown(self): shutil.rmtree(self.save_dir) - def test_visualdl_callback(self): + def func_visualdl_callback(self): # visualdl not support python2 if sys.version_info < (3, ): return @@ -70,6 +71,11 @@ class TestCallbacks(unittest.TestCase): batch_size=64, callbacks=callback) + def test_visualdl_callback(self): + with _test_eager_guard(): + self.func_visualdl_callback() + self.func_visualdl_callback() + if __name__ == '__main__': unittest.main() -- GitLab