未验证 提交 bce9c8c4 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Support two callback related tests (#41275)

上级 af247f95
......@@ -29,6 +29,7 @@ from paddle.hapi.callbacks import config_callbacks
from paddle.vision.datasets import MNIST
from paddle.metric import Accuracy
from paddle.nn.layer.loss import CrossEntropyLoss
from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph
# Accelerate unittest
......@@ -38,7 +39,7 @@ class CustomMnist(MNIST):
class TestReduceLROnPlateau(unittest.TestCase):
def test_reduce_lr_on_plateau(self):
def func_reduce_lr_on_plateau(self):
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = CustomMnist(mode='train', transform=transform)
val_dataset = CustomMnist(mode='test', transform=transform)
......@@ -59,7 +60,12 @@ class TestReduceLROnPlateau(unittest.TestCase):
epochs=10,
callbacks=[callbacks])
def test_warn_or_error(self):
def test_reduce_lr_on_plateau(self):
with _test_eager_guard():
self.func_reduce_lr_on_plateau()
self.func_reduce_lr_on_plateau()
def func_warn_or_error(self):
with self.assertRaises(ValueError):
paddle.callbacks.ReduceLROnPlateau(factor=2.0)
# warning
......@@ -101,6 +107,11 @@ class TestReduceLROnPlateau(unittest.TestCase):
epochs=3,
callbacks=[callbacks])
def test_warn_or_error(self):
with _test_eager_guard():
self.func_warn_or_error()
self.func_warn_or_error()
if __name__ == '__main__':
unittest.main()
......@@ -29,6 +29,7 @@ import paddle.vision.transforms as T
from paddle.vision.datasets import MNIST
from paddle.metric import Accuracy
from paddle.nn.layer.loss import CrossEntropyLoss
from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph
class MnistDataset(MNIST):
......@@ -43,7 +44,7 @@ class TestCallbacks(unittest.TestCase):
def tearDown(self):
shutil.rmtree(self.save_dir)
def test_visualdl_callback(self):
def func_visualdl_callback(self):
# visualdl not support python2
if sys.version_info < (3, ):
return
......@@ -70,6 +71,11 @@ class TestCallbacks(unittest.TestCase):
batch_size=64,
callbacks=callback)
def test_visualdl_callback(self):
with _test_eager_guard():
self.func_visualdl_callback()
self.func_visualdl_callback()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册