From b808979b8c0111e4abefdf9d92658163ca77e51a Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Thu, 15 Oct 2020 11:14:00 +0800 Subject: [PATCH] step lr_scheduler on epoch end in hapi/model.py fit (#27730) * step lr_scheduler on epoch end in hapi/model.py fit. test=develop --- python/paddle/hapi/model.py | 5 ++++ python/paddle/tests/test_model.py | 49 ++++++++++++++++++++++++++++++- 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/python/paddle/hapi/model.py b/python/paddle/hapi/model.py index 21e3054dde..5890d9760e 100644 --- a/python/paddle/hapi/model.py +++ b/python/paddle/hapi/model.py @@ -1461,6 +1461,11 @@ class Model(object): cbks.on_end('eval', eval_logs) + # step learning rate scheduler on each epcoh end + if isinstance(self._optimizer._learning_rate, + paddle.optimizer.lr.LRScheduler): + self._optimizer._learning_rate.step() + cbks.on_end('train', logs) self._test_dataloader = None diff --git a/python/paddle/tests/test_model.py b/python/paddle/tests/test_model.py index 46ea5ec995..4e732c59eb 100644 --- a/python/paddle/tests/test_model.py +++ b/python/paddle/tests/test_model.py @@ -33,7 +33,7 @@ from paddle.nn.layer.loss import CrossEntropyLoss from paddle.metric import Accuracy from paddle.vision.datasets import MNIST from paddle.vision.models import LeNet -from paddle.io import DistributedBatchSampler +from paddle.io import DistributedBatchSampler, Dataset from paddle.hapi.model import prepare_distributed_context from paddle.fluid.dygraph.jit import declarative from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator @@ -295,6 +295,15 @@ class MyModel(paddle.nn.Layer): return y +class MyDataset(Dataset): + def __getitem__(self, idx): + return np.random.random(size=(20,)).astype(np.float32), \ + np.random.randint(0, 10, size=(1,)).astype(np.int64) + + def __len__(self): + return 40 + + class TestModelFunction(unittest.TestCase): def set_seed(self, seed=1024): paddle.manual_seed(seed) @@ -599,6 +608,44 @@ class TestModelFunction(unittest.TestCase): shutil.rmtree(save_dir) +class TestModelWithLRScheduler(unittest.TestCase): + def test_fit(self): + def make_optimizer(parameters=None): + base_lr = 1e-3 + momentum = 0.9 + weight_decay = 5e-4 + boundaries = [5, 8] + values = [base_lr * (0.1**i) for i in range(len(boundaries) + 1)] + learning_rate = paddle.optimizer.lr.PiecewiseDecay( + boundaries=boundaries, values=values) + learning_rate = paddle.optimizer.lr.LinearWarmup( + learning_rate=learning_rate, + warmup_steps=4, + start_lr=base_lr / 5., + end_lr=base_lr, + verbose=True) + optimizer = paddle.optimizer.Momentum( + learning_rate=learning_rate, + weight_decay=weight_decay, + momentum=momentum, + parameters=parameters) + return optimizer + + device = paddle.set_device('cpu') + fluid.enable_dygraph(device) + net = MyModel() + inputs = [InputSpec([None, 20], 'float32', 'x')] + labels = [InputSpec([None, 1], 'int64', 'label')] + optim = make_optimizer(net.parameters()) + model = Model(net, inputs, labels) + model.prepare(optimizer=optim, loss=CrossEntropyLoss(reduction="sum")) + + dataset = MyDataset() + model.fit(dataset, dataset, batch_size=4, epochs=10, num_workers=0) + + paddle.enable_static() + + class TestRaiseError(unittest.TestCase): def test_input_without_name(self): net = MyModel() -- GitLab