未验证 提交 bce9c8c4 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Support two callback related tests (#41275)

上级 af247f95
...@@ -29,6 +29,7 @@ from paddle.hapi.callbacks import config_callbacks ...@@ -29,6 +29,7 @@ from paddle.hapi.callbacks import config_callbacks
from paddle.vision.datasets import MNIST from paddle.vision.datasets import MNIST
from paddle.metric import Accuracy from paddle.metric import Accuracy
from paddle.nn.layer.loss import CrossEntropyLoss from paddle.nn.layer.loss import CrossEntropyLoss
from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph
# Accelerate unittest # Accelerate unittest
...@@ -38,7 +39,7 @@ class CustomMnist(MNIST): ...@@ -38,7 +39,7 @@ class CustomMnist(MNIST):
class TestReduceLROnPlateau(unittest.TestCase): class TestReduceLROnPlateau(unittest.TestCase):
def test_reduce_lr_on_plateau(self): def func_reduce_lr_on_plateau(self):
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])]) transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = CustomMnist(mode='train', transform=transform) train_dataset = CustomMnist(mode='train', transform=transform)
val_dataset = CustomMnist(mode='test', transform=transform) val_dataset = CustomMnist(mode='test', transform=transform)
...@@ -59,7 +60,12 @@ class TestReduceLROnPlateau(unittest.TestCase): ...@@ -59,7 +60,12 @@ class TestReduceLROnPlateau(unittest.TestCase):
epochs=10, epochs=10,
callbacks=[callbacks]) callbacks=[callbacks])
def test_warn_or_error(self): def test_reduce_lr_on_plateau(self):
with _test_eager_guard():
self.func_reduce_lr_on_plateau()
self.func_reduce_lr_on_plateau()
def func_warn_or_error(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
paddle.callbacks.ReduceLROnPlateau(factor=2.0) paddle.callbacks.ReduceLROnPlateau(factor=2.0)
# warning # warning
...@@ -101,6 +107,11 @@ class TestReduceLROnPlateau(unittest.TestCase): ...@@ -101,6 +107,11 @@ class TestReduceLROnPlateau(unittest.TestCase):
epochs=3, epochs=3,
callbacks=[callbacks]) callbacks=[callbacks])
def test_warn_or_error(self):
with _test_eager_guard():
self.func_warn_or_error()
self.func_warn_or_error()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -29,6 +29,7 @@ import paddle.vision.transforms as T ...@@ -29,6 +29,7 @@ import paddle.vision.transforms as T
from paddle.vision.datasets import MNIST from paddle.vision.datasets import MNIST
from paddle.metric import Accuracy from paddle.metric import Accuracy
from paddle.nn.layer.loss import CrossEntropyLoss from paddle.nn.layer.loss import CrossEntropyLoss
from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph
class MnistDataset(MNIST): class MnistDataset(MNIST):
...@@ -43,7 +44,7 @@ class TestCallbacks(unittest.TestCase): ...@@ -43,7 +44,7 @@ class TestCallbacks(unittest.TestCase):
def tearDown(self): def tearDown(self):
shutil.rmtree(self.save_dir) shutil.rmtree(self.save_dir)
def test_visualdl_callback(self): def func_visualdl_callback(self):
# visualdl not support python2 # visualdl not support python2
if sys.version_info < (3, ): if sys.version_info < (3, ):
return return
...@@ -70,6 +71,11 @@ class TestCallbacks(unittest.TestCase): ...@@ -70,6 +71,11 @@ class TestCallbacks(unittest.TestCase):
batch_size=64, batch_size=64,
callbacks=callback) callbacks=callback)
def test_visualdl_callback(self):
with _test_eager_guard():
self.func_visualdl_callback()
self.func_visualdl_callback()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册