diff --git a/python/paddle/hapi/callbacks.py b/python/paddle/hapi/callbacks.py index 8a89ee8517426ef9f2a8a85b9bc948ad80bcf6e4..2ffe7a986d5eb3342014c5d89f1365167ae894d7 100644 --- a/python/paddle/hapi/callbacks.py +++ b/python/paddle/hapi/callbacks.py @@ -15,12 +15,15 @@ import os import numbers -from paddle.fluid.dygraph.parallel import ParallelEnv +import paddle +from paddle.distributed import ParallelEnv from paddle.utils import try_import from .progressbar import ProgressBar -__all__ = ['Callback', 'ProgBarLogger', 'ModelCheckpoint', 'VisualDL'] +__all__ = [ + 'Callback', 'ProgBarLogger', 'ModelCheckpoint', 'VisualDL', 'LRScheduler' +] def config_callbacks(callbacks=None, @@ -42,6 +45,9 @@ def config_callbacks(callbacks=None, if not any(isinstance(k, ModelCheckpoint) for k in cbks): cbks = cbks + [ModelCheckpoint(save_freq, save_dir)] + if not any(isinstance(k, LRScheduler) for k in cbks): + cbks = cbks + [LRScheduler()] + cbk_list = CallbackList(cbks) cbk_list.set_model(model) metrics = metrics or [] if mode != 'test' else [] @@ -485,6 +491,96 @@ class ModelCheckpoint(Callback): self.model.save(path) +class LRScheduler(Callback): + """Lr scheduler callback function + Args: + by_step(bool, optional): whether to update learning rate scheduler + by step. Default: True. + by_epoch(bool, optional): whether to update learning rate scheduler + by epoch. Default: False. + + Examples: + .. code-block:: python + + import paddle + import paddle.vision.transforms as T + from paddle.static import InputSpec + + inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')] + labels = [InputSpec([None, 1], 'int64', 'label')] + + transform = T.Compose([ + T.Transpose(), + T.Normalize([127.5], [127.5]) + ]) + train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform) + + lenet = paddle.vision.LeNet() + model = paddle.Model(lenet, + inputs, labels) + + base_lr = 1e-3 + boundaries = [5, 8] + wamup_steps = 4 + + def make_optimizer(parameters=None): + momentum = 0.9 + weight_decay = 5e-4 + values = [base_lr * (0.1**i) for i in range(len(boundaries) + 1)] + learning_rate = paddle.optimizer.lr.PiecewiseDecay( + boundaries=boundaries, values=values) + learning_rate = paddle.optimizer.lr.LinearWarmup( + learning_rate=learning_rate, + warmup_steps=wamup_epochs, + start_lr=base_lr / 5., + end_lr=base_lr, + verbose=True) + optimizer = paddle.optimizer.Momentum( + learning_rate=learning_rate, + weight_decay=weight_decay, + momentum=momentum, + parameters=parameters) + return optimizer + + optim = make_optimizer(parameters=lenet.parameters()) + model.prepare(optimizer=optim, + loss=paddle.nn.CrossEntropyLoss(), + metrics=paddle.metric.Accuracy()) + + # if LRScheduler callback not set, an instance LRScheduler update by step + # will be created auto. + model.fit(train_dataset, batch_size=64) + + # create a learning rate scheduler update by epoch + callback = paddle.callbacks.LRScheduler(by_step=False, by_epoch=True) + model.fit(train_dataset, batch_size=64, callbacks=callback) + """ + + def __init__(self, by_step=True, by_epoch=False): + if by_step and by_epoch: + raise ValueError( + "by_step option is mutually exclusive with by_epoch") + + self.by_step = by_step + self.by_epoch = by_epoch + + def on_epoch_end(self, epoch, logs=None): + if self.by_epoch: + if self.model._optimizer and \ + hasattr(self.model._optimizer, '_learning_rate') and \ + isinstance(self.model._optimizer._learning_rate, + paddle.optimizer.lr.LRScheduler): + self.model._optimizer._learning_rate.step() + + def on_train_batch_end(self, step, logs=None): + if self.by_step: + if self.model._optimizer and \ + hasattr(self.model._optimizer, '_learning_rate') and \ + isinstance(self.model._optimizer._learning_rate, + paddle.optimizer.lr.LRScheduler): + self.model._optimizer._learning_rate.step() + + class VisualDL(Callback): """VisualDL callback function Args: diff --git a/python/paddle/hapi/model.py b/python/paddle/hapi/model.py index d5d2ec70e9906565a8a0db71e5775ff902109e25..1414cc8bb0dc0ae17463547264cb2db069c340fc 100644 --- a/python/paddle/hapi/model.py +++ b/python/paddle/hapi/model.py @@ -459,13 +459,6 @@ class StaticGraphAdapter(object): if len(name) > 0: rets.insert(i, feed[name]) - # step learning rate scheduler on each batch end - if self.model._optimizer and self.mode == 'train' and \ - hasattr(self.model._optimizer, '_learning_rate') and \ - isinstance(self.model._optimizer._learning_rate, - paddle.optimizer.lr.LRScheduler): - self.model._optimizer._learning_rate.step() - # LoDTensor cannot be fetch as numpy directly rets = [np.array(v) for v in rets] if self.mode == 'test': @@ -666,12 +659,6 @@ class DynamicGraphAdapter(object): self.model._optimizer.minimize(final_loss) self.model.network.clear_gradients() - # step learning rate scheduler on each batch end - if self.model._optimizer and \ - isinstance(self.model._optimizer._learning_rate, - paddle.optimizer.lr.LRScheduler): - self.model._optimizer._learning_rate.step() - metrics = [] for metric in self.model._metrics: metric_outs = metric.compute(*(to_list(outputs) + labels)) diff --git a/python/paddle/tests/test_model.py b/python/paddle/tests/test_model.py index ab7a3654e582c968002fa535f7d2856e0190a45b..c09259f06b899378da4532cd9735bc8ed3b7fd28 100644 --- a/python/paddle/tests/test_model.py +++ b/python/paddle/tests/test_model.py @@ -645,12 +645,13 @@ class TestModelFunction(unittest.TestCase): class TestModelWithLRScheduler(unittest.TestCase): - def test_fit(self): + def test_fit_by_step(self): + base_lr = 1e-3 + boundaries = [5, 8] + def make_optimizer(parameters=None): - base_lr = 1e-3 momentum = 0.9 weight_decay = 5e-4 - boundaries = [5, 8] values = [base_lr * (0.1**i) for i in range(len(boundaries) + 1)] learning_rate = paddle.optimizer.lr.PiecewiseDecay( boundaries=boundaries, values=values) @@ -680,6 +681,8 @@ class TestModelWithLRScheduler(unittest.TestCase): dataset = MyDataset() model.fit(dataset, dataset, batch_size=4, epochs=10, num_workers=0) + np.testing.assert_allclose(model._optimizer._learning_rate.last_lr, + base_lr * (0.1**len(boundaries))) # static test paddle.enable_static() @@ -693,6 +696,93 @@ class TestModelWithLRScheduler(unittest.TestCase): dataset = MyDataset() model.fit(dataset, dataset, batch_size=4, epochs=10, num_workers=0) + np.testing.assert_allclose(model._optimizer._learning_rate.last_lr, + base_lr * (0.1**len(boundaries))) + + def test_fit_by_epoch(self): + base_lr = 1e-3 + boundaries = [5, 8] + epochs = 10 + wamup_epochs = 4 + + def make_optimizer(parameters=None): + momentum = 0.9 + weight_decay = 5e-4 + values = [base_lr * (0.1**i) for i in range(len(boundaries) + 1)] + learning_rate = paddle.optimizer.lr.PiecewiseDecay( + boundaries=boundaries, values=values) + learning_rate = paddle.optimizer.lr.LinearWarmup( + learning_rate=learning_rate, + warmup_steps=wamup_epochs, + start_lr=base_lr / 5., + end_lr=base_lr, + verbose=True) + optimizer = paddle.optimizer.Momentum( + learning_rate=learning_rate, + weight_decay=weight_decay, + momentum=momentum, + parameters=parameters) + return optimizer + + # dynamic test + device = paddle.set_device('cpu') + fluid.enable_dygraph(device) + net = MyModel() + inputs = [InputSpec([None, 20], 'float32', 'x')] + labels = [InputSpec([None, 1], 'int64', 'label')] + optim = make_optimizer(net.parameters()) + model = Model(net, inputs, labels) + model.prepare(optimizer=optim, loss=CrossEntropyLoss(reduction="sum")) + + dataset = MyDataset() + + lr_scheduler_callback = paddle.callbacks.LRScheduler( + by_step=False, by_epoch=True) + + model.fit(dataset, + dataset, + batch_size=4, + epochs=epochs, + num_workers=0, + callbacks=lr_scheduler_callback) + + cnt = 0 + for b in boundaries: + if b + wamup_epochs <= epochs: + cnt += 1 + + np.testing.assert_allclose(model._optimizer._learning_rate.last_lr, + base_lr * (0.1**cnt)) + # static test + paddle.enable_static() + + net = MyModel() + inputs = [InputSpec([None, 20], 'float32', 'x')] + labels = [InputSpec([None, 1], 'int64', 'label')] + optim = make_optimizer(net.parameters()) + model = Model(net, inputs, labels) + model.prepare(optimizer=optim, loss=CrossEntropyLoss(reduction="sum")) + + dataset = MyDataset() + + lr_scheduler_callback = paddle.callbacks.LRScheduler( + by_step=False, by_epoch=True) + + model.fit(dataset, + dataset, + batch_size=4, + epochs=epochs, + num_workers=0, + callbacks=lr_scheduler_callback) + + cnt = 0 + for b in boundaries: + if b + wamup_epochs <= epochs: + cnt += 1 + + np.testing.assert_allclose(model._optimizer._learning_rate.last_lr, + base_lr * (0.1**cnt)) + class TestRaiseError(unittest.TestCase): def test_input_without_name(self):