From 407de03905b864dc7acc408f3db175caaa5a940b Mon Sep 17 00:00:00 2001 From: Zhou Wei <52485244+zhouwei25@users.noreply.github.com> Date: Mon, 24 Aug 2020 17:32:31 +0800 Subject: [PATCH] [2.0API] Reconstruct all API related to LR Scheduler, unify dygraph and static (#26550) * Reconstruct all API related to lr scheduler, unify dygraph and static * Reconstruct all API related to lr scheduler, unify dygraph and static * fix doc * fix doc * fix doc of lr_scheduler * fix unittest and english doc * fix english doc * fix confilt * fix doc --- python/paddle/fluid/executor.py | 23 +- python/paddle/fluid/framework.py | 2 + python/paddle/fluid/optimizer.py | 47 +- .../unittests/test_learning_rate_scheduler.py | 513 +++++- python/paddle/optimizer/__init__.py | 9 +- python/paddle/optimizer/lr_scheduler.py | 1442 +++++++++++++++++ python/paddle/static/__init__.py | 1 + 7 files changed, 1962 insertions(+), 75 deletions(-) create mode 100644 python/paddle/optimizer/lr_scheduler.py diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 5759b942763..52cfd9bf0a3 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -850,6 +850,7 @@ class Executor(object): def _run_parallel(self, program, scope, feed, fetch_list, fetch_var_name, return_numpy, return_merged): + from paddle.optimizer.lr_scheduler import _LRScheduler exe = program._executor # TODO(zhenghuihuang): quantization uses Graph in CompiledProgram # instead of program. We will add support for checking Vars in Graph @@ -893,6 +894,16 @@ class Executor(object): res.append(res_dict) exe.feed_tensors_into_local_scopes(res) + if hasattr(program._program, 'lr_sheduler'): + lr_sheduler = program._program.lr_sheduler + assert isinstance(lr_sheduler, _LRScheduler), "must be _LRScheduler" + lr_value = lr_sheduler() + lr_var = program._program.global_block().vars[lr_sheduler._var_name] + lr_tensor = _as_lodtensor(lr_value, core.CPUPlace(), lr_var.dtype) + exe.feed_and_split_tensor_into_local_scopes({ + lr_sheduler._var_name: lr_tensor + }) + fetch_var_names = list(map(_to_name_str, fetch_list)) tensors = exe.run(fetch_var_names, return_merged)._move_to_list() return as_numpy(tensors) if return_numpy else tensors @@ -1222,7 +1233,7 @@ class Executor(object): def _run_program(self, program, feed, fetch_list, feed_var_name, fetch_var_name, scope, return_numpy, use_program_cache): - + from paddle.optimizer.lr_scheduler import _LRScheduler if feed is None: feed = {} elif isinstance(feed, (list, tuple)): @@ -1278,6 +1289,16 @@ class Executor(object): fetch_var_name=fetch_var_name) self._feed_data(program, feed, feed_var_name, scope) + if hasattr(program, 'lr_sheduler'): + assert isinstance(program.lr_sheduler, + _LRScheduler), "must be _LRScheduler" + lr_sheduler = program.lr_sheduler + lr_value = lr_sheduler() + lr_var = program.global_block().vars[lr_sheduler._var_name] + data = np.array([lr_value]).astype(convert_dtype(lr_var.dtype)) + tensor = core.get_variable_tensor(scope, lr_sheduler._var_name) + tensor.set(data, self.place) + if not use_program_cache: self._default_executor.run(program.desc, scope, 0, True, True, fetch_var_name) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 3169cc9dae8..ef50294b8e7 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -4450,6 +4450,8 @@ class Program(object): p._current_role = self._current_role p.__op_role_var = self.__op_role_var p._appending_grad_times = self._appending_grad_times + if hasattr(self, 'lr_sheduler'): + p.lr_sheduler = self.lr_sheduler #NOTE(zhiqiu): we sync the cloned program, to update its program by # its desc. diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index a2a5a85cc0a..8f34576b836 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -68,14 +68,16 @@ class Optimizer(object): regularization=None, grad_clip=None, name=None): + # Because of the loop import, so place it in the function body + from paddle.optimizer.lr_scheduler import _LRScheduler self._parameter_list = list( parameter_list) if parameter_list is not None else None self._name = name if framework.in_dygraph_mode(): - if not isinstance(learning_rate, float) and \ - not isinstance(learning_rate, LearningRateDecay): + if not isinstance(learning_rate, + (float, LearningRateDecay, _LRScheduler)): raise TypeError( - "learning rate should be float or LearningRateDecay, got %s here" + "learning rate should be float or _LRScheduler, got %s here" % type(learning_rate)) if self._parameter_list is None: raise AttributeError( @@ -90,11 +92,11 @@ class Optimizer(object): % regularization.__str__()) break else: - if not isinstance(learning_rate, float) and \ - not isinstance(learning_rate, framework.Variable): + if not isinstance(learning_rate, + (float, framework.Variable, _LRScheduler)): raise TypeError( - "learning rate should be float or Variable, got %s here" % - type(learning_rate)) + "learning rate should be float or _LRScheduler, got %s here" + % type(learning_rate)) if grad_clip is not None: if not isinstance(grad_clip, GradientClipBase): @@ -144,11 +146,15 @@ class Optimizer(object): state_dict = adam.state_dict() ''' + from paddle.optimizer.lr_scheduler import _LRScheduler state_dict = {} for k, v in self._accumulators.items(): for para_name, var_tmp in v.items(): state_dict[var_tmp.name] = var_tmp # global step if use lr decay + if isinstance(self._learning_rate, _LRScheduler): + state_dict["LR_Scheduler"] = self._learning_rate.state_dict() + return state_dict if isinstance(self._learning_rate, LearningRateDecay): state_dict["LR_Scheduler"] = self._learning_rate.state_dict() @@ -192,6 +198,9 @@ class Optimizer(object): adam.set_dict(opti_state_dict) ''' + from paddle.optimizer.lr_scheduler import _LRScheduler + if isinstance(self._learning_rate, _LRScheduler): + self._learning_rate.set_dict(state_dict["LR_Scheduler"]) if isinstance(self._learning_rate, LearningRateDecay): self._learning_rate.set_dict(state_dict["LR_Scheduler"]) @@ -252,6 +261,30 @@ class Optimizer(object): return self._opti_name_list def _create_global_learning_rate(self): + from paddle.optimizer.lr_scheduler import _LRScheduler + if isinstance(self._learning_rate, _LRScheduler): + lr_var = self._global_learning_rate() + # only create global lr_var once + if not isinstance(lr_var, framework.Variable): + lr_name = unique_name.generate('learning_rate') + self._learning_rate._var_name = lr_name + lr_var = self.helper.create_global_variable( + name=lr_name, + shape=[1], + persistable=True, + stop_gradient=True, + dtype='float32' if self._dtype is None else self._dtype) + main_prog = framework.default_main_program() + main_prog.lr_sheduler = self._learning_rate + main_prog.lr_var = lr_var + self._learning_rate_map[framework.default_main_program( + )] = lr_var + + lr_value = float(self._learning_rate()) + self.helper.set_variable_initializer( + lr_var, initializer=Constant(value=lr_value)) + return + if imperative_base.enabled(): # create learning rate Variable if isinstance(self._learning_rate, float): diff --git a/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py b/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py index 71b452d4a2d..9a2e7b85e52 100644 --- a/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py +++ b/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py @@ -19,6 +19,7 @@ import math import numpy as np import unittest +import paddle import paddle.fluid as fluid import paddle.fluid.layers as layers import paddle.fluid.framework as framework @@ -553,79 +554,459 @@ def reduce_lr_on_plateau(decay_rate, threshold, cooldown, patience, m, n, loss, class TestReduceLROnPlateauDecay(unittest.TestCase): - def test_dygraph_mode(self): - with fluid.dygraph.guard(): - # the decay rate must be less than 1.0 - with self.assertRaises(ValueError): - fluid.dygraph.ReduceLROnPlateau( - learning_rate=1.0, decay_rate=2.0) - # the mode must be "min" or "max" - with self.assertRaises(ValueError): - fluid.dygraph.ReduceLROnPlateau(learning_rate=1.0, mode="test") - # the threshold_mode must be "rel" or "abs" - with self.assertRaises(ValueError): - fluid.dygraph.ReduceLROnPlateau( - learning_rate=1.0, threshold_mode="test") - - base_lr = 1.0 - patience = 3 - cooldown = 1 - decay_rate = 0.5 - threshold = 1e-4 - linear = fluid.dygraph.Linear(10, 10) + def test_ReduceLR(self): + # the decay rate must be less than 1.0 + with self.assertRaises(ValueError): + paddle.optimizer.ReduceLROnPlateau(learning_rate=1.0, factor=2.0) + # the mode must be "min" or "max" + with self.assertRaises(ValueError): + paddle.optimizer.ReduceLROnPlateau(learning_rate=1.0, mode="test") + # the threshold_mode must be "rel" or "abs" + with self.assertRaises(ValueError): + paddle.optimizer.ReduceLROnPlateau( + learning_rate=1.0, threshold_mode="test") + with self.assertRaises(TypeError): + paddle.optimizer.ReduceLROnPlateau(learning_rate="test") + with self.assertRaises(TypeError): + paddle.optimizer.ReduceLROnPlateau(learning_rate=0.5).step("test") + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + + for place in places: for m, n in zip(['min', 'max', 'min', 'max'], ['rel', 'rel', 'abs', 'abs']): kwargs = { - 'learning_rate': base_lr, - 'decay_rate': decay_rate, - 'threshold': threshold, - 'verbose': True, - 'patience': patience, - 'cooldown': cooldown, + 'learning_rate': 1.0, 'mode': m, + 'factor': 0.5, + 'patience': 3, + 'threshold': 1e-4, 'threshold_mode': n, - 'eps': 1e-6 + 'cooldown': 1, + 'min_lr': 0, + 'epsilon': 1e-8, + 'verbose': False, } - print("class=" + fluid.dygraph.ReduceLROnPlateau.__name__ + - " kwargs=" + str(kwargs)) - lr = fluid.dygraph.ReduceLROnPlateau(**kwargs) - sgd = fluid.optimizer.SGD(learning_rate=lr, - parameter_list=linear.parameters()) - - best = float("-10000") if m == "max" else float("10000") - expected_lr = 1.0 - cooldown_counter = 0 - num_bad_epochs = 0 - var_list = [best, expected_lr, cooldown_counter, num_bad_epochs] - step_num = 0 - epoch_num = 0 - for epoch in range(30): - total_loss = 0 - - for batch_id in range(2): - step_num += 1 - x = fluid.dygraph.to_variable( - np.array([step_num]).astype('float32')) - loss = layers.sin(x) - sgd.minimize(loss) - total_loss += loss - - epoch_num += 1 - # get expected lr from fluid - avg_loss = total_loss / 1 - lr.step(avg_loss) - actual_lr = lr().numpy()[0] - - # get expected lr form python - expected_lr = reduce_lr_on_plateau(decay_rate, threshold, - cooldown, patience, m, n, - avg_loss, var_list) - self.assertEqual( - expected_lr, - actual_lr, - msg='Failed reduce lr scheduler in epoch {0}, Python result is {1}, Fluid result is {2}'. - format(epoch_num, expected_lr, actual_lr)) + paddle.enable_static() + self._test_static(place, kwargs) + paddle.disable_static(place) + self._test_dygraph(place, kwargs) + paddle.enable_static() + + def _test_static(self, place, kwargs): + paddle.enable_static() + + best = float("-10000") if kwargs['mode'] == "max" else float("10000") + current_lr = 1.0 + cooldown_counter = 0 + num_bad_epochs = 0 + var_list = [best, current_lr, cooldown_counter, num_bad_epochs] + + main_prog = fluid.Program() + start_prog = fluid.Program() + with fluid.program_guard(main_prog, start_prog): + x = fluid.layers.create_global_var( + [1], 1, 'float32', persistable=True) + paddle.increment(x) + loss = paddle.sin(x) + scheduler = paddle.optimizer.ReduceLROnPlateau(**kwargs) + adam = fluid.optimizer.Adam(learning_rate=scheduler) + adam.minimize(loss) + lr_var = adam._global_learning_rate() + test_prog = main_prog.clone() + + exe = fluid.Executor(place) + exe.run(start_prog) + + for epoch in range(20): + for batch_id in range(1): + out, actual_lr = exe.run(main_prog, + fetch_list=[loss.name, lr_var.name]) + expected_lr = reduce_lr_on_plateau( + kwargs['factor'], kwargs['threshold'], kwargs['cooldown'], + kwargs['patience'], kwargs['mode'], + kwargs['threshold_mode'], out[0], var_list) + + scheduler.step(out[0]) + actual_lr = scheduler() + self.assertEqual(actual_lr, np.array(expected_lr)) + + for epoch in range(10): + for batch_id in range(1): + out, actual_lr = exe.run(test_prog, + fetch_list=[loss.name, lr_var.name]) + expected_lr = reduce_lr_on_plateau( + kwargs['factor'], kwargs['threshold'], kwargs['cooldown'], + kwargs['patience'], kwargs['mode'], + kwargs['threshold_mode'], out[0], var_list) + scheduler.step(out[0]) + actual_lr = scheduler() + self.assertEqual(actual_lr, np.array(expected_lr)) + + def _test_dygraph(self, place, kwargs): + paddle.disable_static(place) + + best = float("-10000") if kwargs['mode'] == "max" else float("10000") + current_lr = 1.0 + cooldown_counter = 0 + num_bad_epochs = 0 + var_list = [best, current_lr, cooldown_counter, num_bad_epochs] + + linear = paddle.nn.Linear(10, 10) + scheduler = paddle.optimizer.ReduceLROnPlateau(**kwargs) + sgd = paddle.optimizer.SGD(learning_rate=scheduler, + parameter_list=linear.parameters()) + + for epoch in range(20): + for batch_id in range(1): + x = paddle.to_tensor(epoch).astype('float32') + loss = paddle.sin(x) + loss.backward() + sgd.minimize(loss) + + scheduler.step(loss) + # get lr from paddle + current_lr = scheduler() + # get lr form python + expected_lr = reduce_lr_on_plateau( + kwargs['factor'], kwargs['threshold'], kwargs['cooldown'], + kwargs['patience'], kwargs['mode'], kwargs['threshold_mode'], + loss, var_list) + self.assertEqual(current_lr, expected_lr) + state_dict = sgd.state_dict() + scheduler1 = paddle.optimizer.ReduceLROnPlateau(**kwargs) + sgd1 = paddle.optimizer.SGD(learning_rate=scheduler1, + parameter_list=linear.parameters()) + sgd1.set_dict(state_dict) + self.assertEqual(scheduler.cooldown_counter, + scheduler1.cooldown_counter) + self.assertEqual(scheduler.best.numpy()[0], scheduler1.best) + self.assertEqual(scheduler.num_bad_epochs, scheduler1.num_bad_epochs) + self.assertEqual(scheduler.last_epoch, scheduler1.last_epoch) + self.assertEqual(scheduler.last_lr, scheduler1.last_lr) + + +def noam_lr(epoch_num, d_model, warmup_steps, learning_rate=1.0, verbose=False): + if epoch_num == 0: + a = 1 + else: + a = math.pow(epoch_num, -0.5) + b = math.pow(warmup_steps, -1.5) * epoch_num + return learning_rate * math.pow(d_model, -0.5) * min(a, b) + + +def lambda_lr(epoch_num, learning_rate, lr_lambda, verbose=False): + return learning_rate * lr_lambda(epoch_num) + + +def piecewise_lr(epoch_num, boundaries, values, verbose=False): + assert len(boundaries) + 1 == len(values) + for i in range(len(boundaries)): + if epoch_num < boundaries[i]: + return values[i] + return values[len(values) - 1] + + +def exponential_lr(epoch_num, learning_rate, gamma, verbose=False): + return learning_rate * gamma**epoch_num + + +def natural_exp_lr(epoch_num, learning_rate, gamma, verbose=False): + return learning_rate * math.exp(-1 * gamma * epoch_num) + + +def inverse_time_lr(epoch_num, learning_rate, gamma, verbose=False): + return learning_rate / (1 + gamma * epoch_num) + + +def polynomial_lr(epoch_num, + learning_rate, + decay_steps, + end_lr=0.0001, + power=1.0, + cycle=False, + verbose=False): + + if cycle: + div = math.ceil(epoch_num / float(decay_steps)) + if epoch_num == 0: + div = 1 + decay_steps = decay_steps * div + else: + epoch_num = min(epoch_num, decay_steps) + return (learning_rate - end_lr) * ( + (1 - float(epoch_num) / float(decay_steps))**power) + end_lr + + def get_lr(self): + if self.last_epoch == 0: + return self.base_lr + elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0: + return self.last_lr + (self.base_lr - self.eta_min) * (1 - math.cos( + math.pi / self.T_max)) / 2 + + return (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / ( + 1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) * ( + self.last_lr - self.eta_min) + self.eta_min + + +cosine_annealing_lr_current = None + + +def cosine_annealing_lr(epoch_num, + learning_rate, + T_max, + eta_min=0, + verbose=False): + global cosine_annealing_lr_current + if epoch_num == 0: + cosine_annealing_lr_current = learning_rate + elif (epoch_num - 1 - T_max) % (2 * T_max) == 0: + cosine_annealing_lr_current = cosine_annealing_lr_current + ( + learning_rate - eta_min) * (1 - math.cos(math.pi / float(T_max)) + ) / 2 + else: + cosine_annealing_lr_current = (1 + math.cos( + math.pi * epoch_num / float(T_max))) / (1 + math.cos(math.pi * ( + epoch_num - 1) / float(T_max))) * (cosine_annealing_lr_current - + eta_min) + eta_min + return cosine_annealing_lr_current + + +def linear_warmup_lr(epoch_num, + learning_rate, + warmup_steps, + start_lr, + end_lr, + verbose=False): + if epoch_num < warmup_steps: + return start_lr + (end_lr - start_lr) * (float(epoch_num) / + float(warmup_steps)) + else: + return learning_rate + + +def multi_step_lr(epoch_num, + learning_rate, + milestones, + gamma=0.1, + verbose=False): + for i in range(len(milestones)): + if epoch_num < milestones[i]: + return learning_rate * (gamma**i) + return learning_rate * (gamma**len(milestones)) + + +def step_lr(epoch_num, learning_rate, step_size, gamma=0.1, verbose=False): + return learning_rate * math.pow(gamma, epoch_num // step_size) + + +class TestLRScheduler(unittest.TestCase): + def _test_static(self, python_func, paddle_api, kwarg, place): + main_prog = fluid.Program() + start_prog = fluid.Program() + with fluid.program_guard(main_prog, start_prog): + x = fluid.data(name='x', shape=[3, 4, 5]) + y = fluid.data(name='y', shape=[3, 4, 5]) + z = fluid.layers.fc(x, 100) + loss = fluid.layers.mean(z) + scheduler = paddle_api(**kwarg) + adam = fluid.optimizer.Adam(learning_rate=scheduler) + adam.minimize(loss) + lr_var = adam._global_learning_rate() + test_prog = main_prog.clone() + + num = 0 + exe = fluid.Executor(place) + exe.run(start_prog) + for epoch in range(5): + for batch_id in range(2): + out = exe.run( + main_prog, + feed={ + 'x': np.random.randn(3, 4, 5).astype('float32'), + 'y': np.random.randn(3, 4, 5).astype('float32') + }, + fetch_list=lr_var.name) + self.assertEqual(out, np.array(python_func(num, **kwarg))) + scheduler.step() + num += 1 + + for epoch in range(5): + for batch_id in range(2): + out = exe.run( + test_prog, + feed={ + 'x': np.random.randn(3, 4, 5).astype('float32'), + 'y': np.random.randn(3, 4, 5).astype('float32') + }, + fetch_list=lr_var.name) + self.assertEqual(out, np.array(python_func(num, **kwarg))) + scheduler.step() + num += 1 + + if isinstance(place, fluid.CPUPlace): + compiled_train_prog = fluid.CompiledProgram( + main_prog).with_data_parallel( + loss_name=loss.name, places=fluid.cpu_places(4)) + for epoch in range(5): + python_result = python_func(num, **kwarg) + for batch_id in range(2): + _ = exe.run( + compiled_train_prog, + feed={ + 'x': np.random.randn(12, 4, 5).astype('float32'), + 'y': np.random.randn(12, 4, 5).astype('float32') + }, + fetch_list=lr_var.name) + scopes = compiled_train_prog._executor.local_scopes() + out = np.array(scopes[0].var(lr_var.name).get_tensor()) + self.assertEqual(out, np.array(python_result)) + out = np.array(scopes[1].var(lr_var.name).get_tensor()) + self.assertEqual(out, np.array(python_result)) + out = np.array(scopes[2].var(lr_var.name).get_tensor()) + self.assertEqual(out, np.array(python_result)) + out = np.array(scopes[3].var(lr_var.name).get_tensor()) + self.assertEqual(out, np.array(python_result)) + scheduler.step() + num += 1 + + compiled_test_prog = fluid.CompiledProgram( + test_prog).with_data_parallel( + loss_name=loss.name, + share_vars_from=compiled_train_prog, + places=fluid.cpu_places(4)) + for epoch in range(5): + python_result = python_func(num, **kwarg) + for batch_id in range(2): + _ = exe.run( + compiled_test_prog, + feed={ + 'x': np.random.randn(12, 4, 5).astype('float32'), + 'y': np.random.randn(12, 4, 5).astype('float32') + }, + fetch_list=lr_var.name) + scopes = compiled_test_prog._executor.local_scopes() + out = np.array(scopes[0].var(lr_var.name).get_tensor()) + self.assertEqual(out, np.array(python_result)) + out = np.array(scopes[1].var(lr_var.name).get_tensor()) + self.assertEqual(out, np.array(python_result)) + out = np.array(scopes[2].var(lr_var.name).get_tensor()) + self.assertEqual(out, np.array(python_result)) + out = np.array(scopes[3].var(lr_var.name).get_tensor()) + self.assertEqual(out, np.array(python_result)) + scheduler.step() + num += 1 + + def _test_dygraph(self, python_func, paddle_api, kwarg, place): + x = np.random.uniform(-1, 1, [10, 10]).astype("float32") + linear = paddle.nn.Linear(10, 10) + scheduler = paddle_api(**kwarg) + sgd = paddle.optimizer.SGD(learning_rate=scheduler, + parameter_list=linear.parameters()) + for epoch in range(20): + for batch_id in range(2): + x = paddle.to_tensor(x) + out = linear(x) + loss = paddle.reduce_mean(out) + out.backward() + sgd.minimize(loss) + linear.clear_gradients() + + self.assertAlmostEqual(sgd.current_step_lr(), + python_func(epoch, **kwarg)) + if paddle_api.__name__ != "CosineAnnealingLR": + scheduler.step() + else: + scheduler.step(epoch + 1) + + def test_scheduler(self): + with self.assertRaises(NotImplementedError): + paddle.optimizer.lr_scheduler._LRScheduler().step() + with self.assertRaises(TypeError): + paddle.optimizer.MultiStepLR( + learning_rate="test", milestones=[1, 2, 3]) + with self.assertRaises(TypeError): + paddle.optimizer.MultiStepLR(learning_rate=0.5, milestones='test') + with self.assertRaises(ValueError): + paddle.optimizer.MultiStepLR( + learning_rate=0.5, milestones=[3, 2, 1]) + with self.assertRaises(ValueError): + paddle.optimizer.MultiStepLR( + learning_rate=0.5, milestones=[1, 2, 3], gamma=2) + + func_api_kwargs = [(noam_lr, paddle.optimizer.NoamLR, { + "d_model": 0.01, + "warmup_steps": 100, + "verbose": False + }), (piecewise_lr, paddle.optimizer.PiecewiseLR, { + "boundaries": [3, 6, 9, 15, 20], + "values": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + "verbose": False + }), (natural_exp_lr, paddle.optimizer.NaturalExpLR, { + "learning_rate": 0.5, + "gamma": 0.1, + "verbose": False + }), (inverse_time_lr, paddle.optimizer.InverseTimeLR, { + "learning_rate": 0.5, + "gamma": 0.1, + "verbose": True + }), (polynomial_lr, paddle.optimizer.PolynomialLR, { + "learning_rate": 0.5, + "decay_steps": 20, + "end_lr": 0, + "power": 1.0, + "cycle": False, + "verbose": False + }), (polynomial_lr, paddle.optimizer.PolynomialLR, { + "learning_rate": 0.5, + "decay_steps": 20, + "end_lr": 0, + "power": 1.0, + "cycle": True, + "verbose": False + }), (linear_warmup_lr, paddle.optimizer.LinearLrWarmup, { + 'learning_rate': 0.5, + 'warmup_steps': 20, + 'start_lr': 0, + 'end_lr': 0.5, + "verbose": False + }), (exponential_lr, paddle.optimizer.ExponentialLR, { + "learning_rate": 0.5, + "gamma": 0.9, + "verbose": False + }), (multi_step_lr, paddle.optimizer.MultiStepLR, { + "learning_rate": 0.5, + "milestones": [3, 6, 9, 15, 20], + "gamma": 0.8, + "verbose": True + }), (step_lr, paddle.optimizer.StepLR, { + "learning_rate": 0.5, + "step_size": 2, + "gamma": 0.8, + "verbose": False + }), (lambda_lr, paddle.optimizer.LambdaLR, { + "learning_rate": 0.5, + "lr_lambda": lambda x: 0.95**x, + "verbose": False + }), (cosine_annealing_lr, paddle.optimizer.CosineAnnealingLR, { + "learning_rate": 0.5, + "T_max": 10, + "verbose": True + })] + + for python_func, paddle_api, kwarg in func_api_kwargs: + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + + for place in places: + paddle.enable_static() + self._test_static(python_func, paddle_api, kwarg, place) + paddle.disable_static(place) + self._test_dygraph(python_func, paddle_api, kwarg, place) + paddle.enable_static() if __name__ == '__main__': diff --git a/python/paddle/optimizer/__init__.py b/python/paddle/optimizer/__init__.py index 7159baeb305..49314c9832d 100644 --- a/python/paddle/optimizer/__init__.py +++ b/python/paddle/optimizer/__init__.py @@ -19,7 +19,10 @@ __all__ = [ 'ExponentialMovingAverage', 'Ftrl', 'FtrlOptimizer', 'LambOptimizer', 'LarsMomentum', 'LarsMomentumOptimizer', 'LookaheadOptimizer', 'ModelAverage', 'Momentum', 'MomentumOptimizer', 'PipelineOptimizer', - 'RecomputeOptimizer', 'RMSProp', 'SGD', 'SGDOptimizer', 'Optimizer' + 'RecomputeOptimizer', 'RMSProp', 'SGD', 'SGDOptimizer', 'Optimizer', + '_LRScheduler', 'NoamLR', 'PiecewiseLR', 'NaturalExpLR', 'InverseTimeLR', + 'PolynomialLR', 'LinearLrWarmup', 'ExponentialLR', 'MultiStepLR', 'StepLR', + 'LambdaLR', 'ReduceLROnPlateau', 'CosineAnnealingLR' ] @@ -36,3 +39,7 @@ from .adam import Adam from .adamw import AdamW from .adamax import Adamax from .rmsprop import RMSProp + +from . import lr_scheduler +from .lr_scheduler import _LRScheduler, NoamLR, PiecewiseLR, NaturalExpLR, InverseTimeLR, PolynomialLR, \ + LinearLrWarmup, ExponentialLR, MultiStepLR, StepLR, LambdaLR, ReduceLROnPlateau, CosineAnnealingLR diff --git a/python/paddle/optimizer/lr_scheduler.py b/python/paddle/optimizer/lr_scheduler.py new file mode 100644 index 00000000000..d01e62abaa6 --- /dev/null +++ b/python/paddle/optimizer/lr_scheduler.py @@ -0,0 +1,1442 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import numpy +import warnings +from paddle import Tensor + +__all__ = [ + 'NoamLR', 'PiecewiseLR', 'NaturalExpLR', 'InverseTimeLR', 'PolynomialLR', + 'LinearLrWarmup', 'ExponentialLR', 'MultiStepLR', 'StepLR', 'LambdaLR', + 'ReduceLROnPlateau', 'CosineAnnealingLR' +] + + +class _LRScheduler(object): + """LRScheduler Base class. + + Define the common interface of an LRScheduler. + User can 'form paddle.optimizer.lr_scheduler import _LRScheduler' + And inherit from it to have a custom implementation of get_lr(). + """ + + def __init__(self, learning_rate=0.1, last_epoch=-1, verbose=False): + if not isinstance(learning_rate, (float, int)): + raise TypeError( + "The type of learning rate must be float, but received {}". + format(type(learning_rate))) + self.base_lr = float(learning_rate) + self.last_lr = float(learning_rate) + self.last_epoch = last_epoch + self.verbose = verbose + self._var_name = None + + self.step() + + def __call__(self): + """ + Return last computed learning rate on current epoch. + """ + return self.last_lr + + def step(self, epoch=None): + """ + 'step' should be called after 'minimize' . It will update the learning rate in optimizer according to 'epoch'. + The new learning rate will take effect on next epoch. + + Args: + epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1. + + Returns: + None + + Examples: + Please refer to the example of current _LRScheduler. + """ + if epoch is None: + self.last_epoch += 1 + self.last_lr = self.get_lr() + else: + self.last_epoch = epoch + if hasattr(self, "_get_closed_form_lr"): + self.last_lr = self._get_closed_form_lr() + else: + self.last_lr = self.get_lr() + + if self.verbose: + print('Epoch {}: {} set learning rate to {}.'.format( + self.last_epoch, self.__class__.__name__, self.last_lr)) + + def state_dict(self): + """ + Returns the state of the scheduler as a :class:`dict`. + + It is a subset of self.__dict__ . + """ + self._state_keys() + state_dict = {} + for key in self.keys: + if key not in self.__dict__: + continue + value = self.__dict__[key] + if isinstance(value, Tensor): + assert value.shape == [ + 1 + ], "shape of Tensor in state_dict must be [1] {}".format( + value.shape) + value = value.numpy()[0] + state_dict[key] = value + + return state_dict + + # For those subclass who overload _LRScheduler, "last_epoch, last_lr" will be saved by default. + # (Note): you can change it for your subclass. + def _state_keys(self): + """ + set the keys in self.__dict__ that are needed to be saved. + """ + self.keys = ['last_epoch', 'last_lr'] + + def set_dict(self, state_dict): + """ + Loads the schedulers state. + """ + self._state_keys() + for key in self.keys: + if key in state_dict: + self.__dict__[key] = state_dict[key] + else: + raise RuntimeError( + "Please check whether state_dict is correct for optimizer. Can't find [ {} ] in state_dict". + format(key)) + if len(state_dict) > len(self.keys): + warnings.warn( + "There are some unused values in state_dict. Maybe the optimizer have different 'LearningRateDecay' when invoking state_dict and set_dict" + ) + + # alias for set_dict + set_state_dict = set_dict + + def get_lr(self): + # calculate by python float + raise NotImplementedError + + +class NoamLR(_LRScheduler): + """ + + Applies Noam Lear to the initial learning rate. + + The algorithm can be described as following. + + .. math:: + + new\_learning\_rate = learning\_rate * d_{model}^{-0.5} * min(epoch^{-0.5}, epoch * warmup\_steps^{-1.5}) + + Please reference `attention is all you need `_ + + + Args: + d$_{model}$(int): The dimensionality of input and output feature vector of model. It is a python int number. + warmup_steps(int): The number of warmup steps. A super parameter. It is a python int number + learning_rate (float): The initial learning rate. It is a python float number. Default: 1.0. + last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. + verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` . + + Returns: + ``NoamLR`` instance to schedule learning rate. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + # train on default dygraph mode + paddle.disable_static() + x = np.random.uniform(-1, 1, [10, 10]).astype("float32") + linear = paddle.nn.Linear(10, 10) + scheduler = paddle.optimizer.NoamLR(d_model=0.01, warmup_steps=100, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters()) + for epoch in range(20): + for batch_id in range(2): + x = paddle.to_tensor(x) + out = linear(x) + loss = paddle.reduce_mean(out) + out.backward() + sgd.minimize(loss) + linear.clear_gradients() + scheduler.step() + + # train on static mode + paddle.enable_static() + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, start_prog): + x = paddle.static.data(name='x', shape=[-1, 4, 5]) + y = paddle.static.data(name='y', shape=[-1, 4, 5]) + z = paddle.static.nn.fc(x, 100) + loss = paddle.mean(z) + scheduler = paddle.optimizer.NoamLR(d_model=0.01, warmup_steps=100, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler) + sgd.minimize(loss) + lr_var = sgd._global_learning_rate() + + exe = paddle.static.Executor() + exe.run(start_prog) + for epoch in range(20): + for batch_id in range(2): + out = exe.run( + main_prog, + feed={ + 'x': np.random.randn(3, 4, 5).astype('float32'), + 'y': np.random.randn(3, 4, 5).astype('float32') + }, + fetch_list=lr_var.name) + scheduler.step() + + """ + + def __init__(self, + d_model, + warmup_steps, + learning_rate=1.0, + last_epoch=-1, + verbose=False): + self.d_model = d_model + self.warmup_steps = warmup_steps + super(NoamLR, self).__init__(learning_rate, last_epoch, verbose) + + def get_lr(self): + if self.last_epoch == 0: + a = 1 + else: + a = self.last_epoch**-0.5 + b = self.warmup_steps**-1.5 * self.last_epoch + return self.base_lr * (self.d_model**-0.5) * min(a, b) + + +class PiecewiseLR(_LRScheduler): + """ + + Piecewise learning rate scheduler. + + The algorithm can be described as the code below: + + .. code-block:: text + + boundaries = [100, 200] + values = [1.0, 0.5, 0.1] + if epoch < 100: + learning_rate = 1.0 + elif 100 <= global_step < 200: + learning_rate = 0.5 + else: + learning_rate = 0.1 + + Args: + boundaries(list): A list of steps numbers. The type of element in the list is python int. + values(list): A list of learning rate values that will be picked during different epoch boundaries. + The type of element in the list is python float. + last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. + verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` . + + Returns: + ``PiecewiseLR`` instance to schedule learning rate. + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + + # train on default dygraph mode + paddle.disable_static() + x = np.random.uniform(-1, 1, [10, 10]).astype("float32") + linear = paddle.nn.Linear(10, 10) + scheduler = paddle.optimizer.PiecewiseLR(boundaries=[3, 6, 9], values=[0.1, 0.2, 0.3, 0.4], verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters()) + for epoch in range(20): + for batch_id in range(2): + x = paddle.to_tensor(x) + out = linear(x) + loss = paddle.reduce_mean(out) + out.backward() + sgd.minimize(loss) + linear.clear_gradients() + scheduler.step() + + # train on static mode + paddle.enable_static() + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, start_prog): + x = paddle.static.data(name='x', shape=[-1, 4, 5]) + y = paddle.static.data(name='y', shape=[-1, 4, 5]) + z = paddle.static.nn.fc(x, 100) + loss = paddle.mean(z) + scheduler = paddle.optimizer.PiecewiseLR(boundaries=[3, 6, 9], values=[0.1, 0.2, 0.3, 0.4], verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler) + sgd.minimize(loss) + lr_var = sgd._global_learning_rate() + + exe = paddle.static.Executor() + exe.run(start_prog) + for epoch in range(20): + for batch_id in range(2): + out = exe.run( + main_prog, + feed={ + 'x': np.random.randn(3, 4, 5).astype('float32'), + 'y': np.random.randn(3, 4, 5).astype('float32') + }, + fetch_list=lr_var.name) + scheduler.step() + """ + + def __init__(self, boundaries, values, last_epoch=-1, verbose=False): + self.boundaries = boundaries + self.values = values + super(PiecewiseLR, self).__init__( + last_epoch=last_epoch, verbose=verbose) + + def get_lr(self): + + for i in range(len(self.boundaries)): + if self.last_epoch < self.boundaries[i]: + return self.values[i] + return self.values[len(self.values) - 1] + + +class NaturalExpLR(_LRScheduler): + """ + + Applies natural exponential decay to the initial learning rate. + + The algorithm can be described as following: + + .. math:: + + new\_learning\_rate = learning\_rate * e^{- gama * epoch} + + Args: + learning_rate (float): The initial learning rate. It is a python float number. + gamma (float, optional): A Ratio to update the learning rate. Default: 0.1. + last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. + verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` . + + Returns: + ``NaturalExpLR`` instance to schedule learning rate. + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + + # train on default dygraph mode + paddle.disable_static() + x = np.random.uniform(-1, 1, [10, 10]).astype("float32") + linear = paddle.nn.Linear(10, 10) + scheduler = paddle.optimizer.NaturalExpLR(learning_rate=0.5, gamma=0.1, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters()) + for epoch in range(20): + for batch_id in range(2): + x = paddle.to_tensor(x) + out = linear(x) + loss = paddle.reduce_mean(out) + out.backward() + sgd.minimize(loss) + linear.clear_gradients() + scheduler.step() + + # train on static mode + paddle.enable_static() + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, start_prog): + x = paddle.static.data(name='x', shape=[-1, 4, 5]) + y = paddle.static.data(name='y', shape=[-1, 4, 5]) + z = paddle.static.nn.fc(x, 100) + loss = paddle.mean(z) + scheduler = paddle.optimizer.NaturalExpLR(learning_rate=0.5, gamma=0.1, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler) + sgd.minimize(loss) + lr_var = sgd._global_learning_rate() + + exe = paddle.static.Executor() + exe.run(start_prog) + for epoch in range(20): + for batch_id in range(2): + out = exe.run( + main_prog, + feed={ + 'x': np.random.randn(3, 4, 5).astype('float32'), + 'y': np.random.randn(3, 4, 5).astype('float32') + }, + fetch_list=lr_var.name) + scheduler.step() + """ + + def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False): + self.gamma = gamma + super(NaturalExpLR, self).__init__(learning_rate, last_epoch, verbose) + + def get_lr(self): + return self.base_lr * math.exp(-1 * self.gamma * self.last_epoch) + + +class InverseTimeLR(_LRScheduler): + """ + + Applies inverse time decay to the initial learning rate. + + The algorithm can be described as following: + + .. math:: + + new\_learning\_rate = \\frac{learning\_rate}{1 + gamma * epoch} + + Args: + learning_rate (float): The initial learning rate. It is a python float number. + gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . + It should be less than 1.0. Default: 0.1. + last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. + verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` . + + Returns: + ``InverseTimeLR`` instance to schedule learning rate. + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + + # train on default dygraph mode + paddle.disable_static() + x = np.random.uniform(-1, 1, [10, 10]).astype("float32") + linear = paddle.nn.Linear(10, 10) + scheduler = paddle.optimizer.InverseTimeLR(learning_rate=0.5, gamma=0.1, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters()) + for epoch in range(20): + for batch_id in range(2): + x = paddle.to_tensor(x) + out = linear(x) + loss = paddle.reduce_mean(out) + out.backward() + sgd.minimize(loss) + linear.clear_gradients() + scheduler.step() + + # train on static mode + paddle.enable_static() + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, start_prog): + x = paddle.static.data(name='x', shape=[-1, 4, 5]) + y = paddle.static.data(name='y', shape=[-1, 4, 5]) + z = paddle.static.nn.fc(x, 100) + loss = paddle.mean(z) + scheduler = paddle.optimizer.InverseTimeLR(learning_rate=0.5, gamma=0.1, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler) + sgd.minimize(loss) + lr_var = sgd._global_learning_rate() + + exe = paddle.static.Executor() + exe.run(start_prog) + for epoch in range(20): + for batch_id in range(2): + out = exe.run( + main_prog, + feed={ + 'x': np.random.randn(3, 4, 5).astype('float32'), + 'y': np.random.randn(3, 4, 5).astype('float32') + }, + fetch_list=lr_var.name) + scheduler.step() + + """ + + def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False): + self.gamma = gamma + super(InverseTimeLR, self).__init__(learning_rate, last_epoch, verbose) + + def get_lr(self): + return self.base_lr / (1 + self.gamma * self.last_epoch) + + +class PolynomialLR(_LRScheduler): + """ + + Applies polynomial decay to the initial learning rate. + + The algorithm can be described as following. + + If cycle is set to True, then: + + .. math:: + + decay\_steps & = decay\_steps * math.ceil(\\frac{epoch}{decay\_steps}) + + new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\\frac{epoch}{decay\_steps})^{power}+end\_lr + + If cycle is set to False, then: + + .. math:: + + epoch & = min(epoch, decay\_steps) + + new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\\frac{epoch}{decay\_steps})^{power}+end\_lr + + + Args: + learning_rate (float): The initial learning rate. It is a python float number. + decay_steps(int): The decay step size. It determines the decay cycle. + end_lr(float, optional): The minimum final learning rate. Default: 0.0001. + power(float, optional): Power of polynomial. Default: 1.0. + cycle(bool, optional): Whether the learning rate rises again. If True, then the learning rate will rise when it decrease + to ``end_lr`` . If False, the learning rate is monotone decreasing. Default: False. + last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. + verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` . + + Returns: + ``PolynomialLR`` instance to schedule learning rate. + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + + # train on default dygraph mode + paddle.disable_static() + x = np.random.uniform(-1, 1, [10, 10]).astype("float32") + linear = paddle.nn.Linear(10, 10) + scheduler = paddle.optimizer.PolynomialLR(learning_rate=0.5, decay_steps=20, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters()) + for epoch in range(20): + for batch_id in range(2): + x = paddle.to_tensor(x) + out = linear(x) + loss = paddle.reduce_mean(out) + out.backward() + sgd.minimize(loss) + linear.clear_gradients() + scheduler.step() + + # train on statich mode + paddle.enable_static() + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, start_prog): + x = paddle.static.data(name='x', shape=[-1, 4, 5]) + y = paddle.static.data(name='y', shape=[-1, 4, 5]) + z = paddle.static.nn.fc(x, 100) + loss = paddle.mean(z) + scheduler = paddle.optimizer.PolynomialLR(learning_rate=0.5, decay_steps=20, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler) + sgd.minimize(loss) + lr_var = sgd._global_learning_rate() + + exe = paddle.static.Executor() + exe.run(start_prog) + for epoch in range(20): + for batch_id in range(2): + out = exe.run( + main_prog, + feed={ + 'x': np.random.randn(3, 4, 5).astype('float32'), + 'y': np.random.randn(3, 4, 5).astype('float32') + }, + fetch_list=lr_var.name) + scheduler.step() + """ + + def __init__(self, + learning_rate, + decay_steps, + end_lr=0.0001, + power=1.0, + cycle=False, + last_epoch=-1, + verbose=False): + self.decay_steps = decay_steps + self.end_lr = end_lr + self.power = power + self.cycle = cycle + super(PolynomialLR, self).__init__(learning_rate, last_epoch, verbose) + + def get_lr(self): + tmp_epoch_num = self.last_epoch + tmp_decay_steps = self.decay_steps + if self.cycle: + div_res = math.ceil( + float(self.last_epoch) / float(self.decay_steps)) + + if self.last_epoch == 0: + div_res = 1 + tmp_decay_steps = self.decay_steps * div_res + else: + tmp_epoch_num = min(self.last_epoch, self.decay_steps) + + return (self.base_lr - self.end_lr) * ( + (1 - float(tmp_epoch_num) / float(tmp_decay_steps) + )**self.power) + self.end_lr + + +class LinearLrWarmup(_LRScheduler): + """ + + Linear learning rate warm up strategy. Update the learning rate preliminarily before the normal learning rate scheduler. + For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks `_ + + When epoch < warmup_steps, learning rate is updated as: + + .. code-block:: text + + lr = start_lr + (end_lr - start_lr) * (epoch / warmup_steps) + + where start_lr is the initial learning rate, and end_lr is the final learning rate; + + When epoch >= warmup_steps, learning rate is updated as: + + .. code-block:: text + + lr = learning_rate + + where lr is float or any subclass of ``_LRScheduler`` . + + Args: + learning_rate (float|_LRScheduler): The learning rate after warm-up. It is a python float number or any subclass of ``_LRScheduler`` . + warmup_steps (int): total steps of warm up. + start_lr (float): Initial learning rate of warm up. + end_lr (float): Final learning rate of warm up. + last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. + verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` . + + Returns: + ``LinearLrWarmup`` instance to schedule learning rate. + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + + # train on default dygraph mode + paddle.disable_static() + x = np.random.uniform(-1, 1, [10, 10]).astype("float32") + linear = paddle.nn.Linear(10, 10) + scheduler = paddle.optimizer.LinearLrWarmup( + learning_rate=0.5, warmup_steps=20, start_lr=0, end_lr=0.5, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters()) + for epoch in range(20): + for batch_id in range(2): + x = paddle.to_tensor(x) + out = linear(x) + loss = paddle.reduce_mean(out) + out.backward() + sgd.minimize(loss) + linear.clear_gradients() + scheduler.step() + + # train on statich mode + paddle.enable_static() + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, start_prog): + x = paddle.static.data(name='x', shape=[-1, 4, 5]) + y = paddle.static.data(name='y', shape=[-1, 4, 5]) + z = paddle.static.nn.fc(x, 100) + loss = paddle.mean(z) + scheduler = paddle.optimizer.LinearLrWarmup( + learning_rate=0.5, warmup_steps=20, start_lr=0, end_lr=0.5, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler) + sgd.minimize(loss) + lr_var = sgd._global_learning_rate() + + exe = paddle.static.Executor() + exe.run(start_prog) + for epoch in range(20): + for batch_id in range(2): + out = exe.run( + main_prog, + feed={ + 'x': np.random.randn(3, 4, 5).astype('float32'), + 'y': np.random.randn(3, 4, 5).astype('float32') + }, + fetch_list=lr_var.name) + scheduler.step() + """ + + def __init__(self, + learning_rate, + warmup_steps, + start_lr, + end_lr, + last_epoch=-1, + verbose=False): + type_check = isinstance(learning_rate, float) or isinstance( + learning_rate, int) or isinstance(learning_rate, _LRScheduler) + if not type_check: + raise TypeError( + "the type of learning_rate should be [int, float or _LRScheduler], the current type is {}". + format(learning_rate)) + self.learning_rate = learning_rate + self.warmup_steps = warmup_steps + self.start_lr = start_lr + self.end_lr = end_lr + assert end_lr > start_lr, "end_lr {} must be greater than start_lr {}".format( + end_lr, start_lr) + super(LinearLrWarmup, self).__init__(start_lr, last_epoch, verbose) + + def get_lr(self): + if self.last_epoch < self.warmup_steps: + return (self.end_lr - self.start_lr) * float( + self.last_epoch) / float(self.warmup_steps) + self.start_lr + else: + if isinstance(self.learning_rate, _LRScheduler): + self.learning_rate.step() + return self.learning_rate() + + return self.learning_rate + + +class ExponentialLR(_LRScheduler): + """ + + Update learning rate by 'gamma' each epoch. + + The algorithm can be described as following. + + .. math:: + + new\_learning\_rate = last\_learning\_rate * gamma + + Args: + learning_rate (float): The initial learning rate. It is a python float number. + gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . + It should be less than 1.0. Default: 0.1. + last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. + verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` . + + Returns: + ``ExponentialLR`` instance to schedule learning rate. + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + + # train on default dygraph mode + paddle.disable_static() + x = np.random.uniform(-1, 1, [10, 10]).astype("float32") + linear = paddle.nn.Linear(10, 10) + scheduler = paddle.optimizer.ExponentialLR(learning_rate=0.5, gamma=0.9, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters()) + for epoch in range(20): + for batch_id in range(2): + x = paddle.to_tensor(x) + out = linear(x) + loss = paddle.reduce_mean(out) + out.backward() + sgd.minimize(loss) + linear.clear_gradients() + scheduler.step() + + # train on statich mode + paddle.enable_static() + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, start_prog): + x = paddle.static.data(name='x', shape=[-1, 4, 5]) + y = paddle.static.data(name='y', shape=[-1, 4, 5]) + z = paddle.static.nn.fc(x, 100) + loss = paddle.mean(z) + scheduler = paddle.optimizer.ExponentialLR(learning_rate=0.5, gamma=0.9, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler) + sgd.minimize(loss) + lr_var = sgd._global_learning_rate() + + exe = paddle.static.Executor() + exe.run(start_prog) + for epoch in range(20): + for batch_id in range(2): + out = exe.run( + main_prog, + feed={ + 'x': np.random.randn(3, 4, 5).astype('float32'), + 'y': np.random.randn(3, 4, 5).astype('float32') + }, + fetch_list=lr_var.name) + scheduler.step() + """ + + def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False): + self.gamma = gamma + super(ExponentialLR, self).__init__(learning_rate, last_epoch, verbose) + + def get_lr(self): + return self.base_lr * (self.gamma**self.last_epoch) + + +class MultiStepLR(_LRScheduler): + """ + Update the learning rate by ``gama`` once ``epoch`` reaches one of the milestones. + + The algorithm can be described as the code below. + + .. code-block:: text + + learning_rate = 0.5 + milestones = [30, 50] + gamma = 0.1 + if epoch < 30: + learning_rate = 0.5 + elif epoch < 50: + learning_rate = 0.05 + else: + learning_rate = 0.005 + + Args: + learning_rate (float): The initial learning rate. It is a python float number. + milestones (tuple|list): List or tuple of each boundaries. Must be increasing. + gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . + It should be less than 1.0. Default: 0.1. + last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. + verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` . + + + Returns: + ``MultiStepLR`` instance to schedule learning rate. + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + + # train on default dygraph mode + paddle.disable_static() + x = np.random.uniform(-1, 1, [10, 10]).astype("float32") + linear = paddle.nn.Linear(10, 10) + scheduler = paddle.optimizer.MultiStepLR(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters()) + for epoch in range(20): + for batch_id in range(2): + x = paddle.to_tensor(x) + out = linear(x) + loss = paddle.reduce_mean(out) + out.backward() + sgd.minimize(loss) + linear.clear_gradients() + scheduler.step() + + # train on statich mode + paddle.enable_static() + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, start_prog): + x = paddle.static.data(name='x', shape=[-1, 4, 5]) + y = paddle.static.data(name='y', shape=[-1, 4, 5]) + z = paddle.static.nn.fc(x, 100) + loss = paddle.mean(z) + scheduler = paddle.optimizer.MultiStepLR(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler) + sgd.minimize(loss) + lr_var = sgd._global_learning_rate() + + exe = paddle.static.Executor() + exe.run(start_prog) + for epoch in range(20): + for batch_id in range(2): + out = exe.run( + main_prog, + feed={ + 'x': np.random.randn(3, 4, 5).astype('float32'), + 'y': np.random.randn(3, 4, 5).astype('float32') + }, + fetch_list=lr_var.name) + scheduler.step() + """ + + def __init__(self, + learning_rate, + milestones, + gamma=0.1, + last_epoch=-1, + verbose=False): + if not isinstance(milestones, (tuple, list)): + raise TypeError( + "The type of 'milestones' in 'MultiStepDecay' must be 'tuple, list', but received %s." + % type(milestones)) + + if not all([ + milestones[i] < milestones[i + 1] + for i in range(len(milestones) - 1) + ]): + raise ValueError('The elements of milestones must be incremented') + if gamma >= 1.0: + raise ValueError('gamma should be < 1.0.') + + self.milestones = milestones + self.gamma = gamma + super(MultiStepLR, self).__init__(learning_rate, last_epoch, verbose) + + def get_lr(self): + for i in range(len(self.milestones)): + if self.last_epoch < self.milestones[i]: + return self.base_lr * (self.gamma**i) + return self.base_lr * (self.gamma**len(self.milestones)) + + +class StepLR(_LRScheduler): + """ + Update the learning rate of ``optimizer`` by ``gamma`` every ``step_size`` number of epoch. + + The algorithm can be described as the code below. + + .. code-block:: text + + learning_rate = 0.5 + step_size = 30 + gamma = 0.1 + + learning_rate = 0.5 if epoch < 30 + learning_rate = 0.05 if 30 <= epoch < 60 + learning_rate = 0.005 if 60 <= epoch < 90 + ... + + Args: + learning_rate (float): The initial learning rate. It is a python float number. + step_size (int): the interval to update. + gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . + It should be less than 1.0. Default: 0.1. + last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. + verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` . + + Returns: + ``StepLR`` instance to schedule learning rate. + + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + + # train on default dygraph mode + paddle.disable_static() + x = np.random.uniform(-1, 1, [10, 10]).astype("float32") + linear = paddle.nn.Linear(10, 10) + scheduler = paddle.optimizer.StepLR(learning_rate=0.5, step_size=5, gamma=0.8, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters()) + for epoch in range(20): + for batch_id in range(2): + x = paddle.to_tensor(x) + out = linear(x) + loss = paddle.reduce_mean(out) + out.backward() + sgd.minimize(loss) + linear.clear_gradients() + scheduler.step() + + # train on statich mode + paddle.enable_static() + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, start_prog): + x = paddle.static.data(name='x', shape=[-1, 4, 5]) + y = paddle.static.data(name='y', shape=[-1, 4, 5]) + z = paddle.static.nn.fc(x, 100) + loss = paddle.mean(z) + scheduler = paddle.optimizer.StepLR(learning_rate=0.5, step_size=5, gamma=0.8, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler) + sgd.minimize(loss) + lr_var = sgd._global_learning_rate() + + exe = paddle.static.Executor() + exe.run(start_prog) + for epoch in range(20): + for batch_id in range(2): + out = exe.run( + main_prog, + feed={ + 'x': np.random.randn(3, 4, 5).astype('float32'), + 'y': np.random.randn(3, 4, 5).astype('float32') + }, + fetch_list=lr_var.name) + scheduler.step() + """ + + def __init__(self, + learning_rate, + step_size, + gamma=0.1, + last_epoch=-1, + verbose=False): + if not isinstance(step_size, int): + raise TypeError( + "The type of 'step_size' must be 'int', but received %s." % + type(step_size)) + if gamma >= 1.0: + raise ValueError('gamma should be < 1.0.') + + self.step_size = step_size + self.gamma = gamma + super(StepLR, self).__init__(learning_rate, last_epoch, verbose) + + def get_lr(self): + i = self.last_epoch // self.step_size + return self.base_lr * (self.gamma**i) + + +class LambdaLR(_LRScheduler): + """ + Sets the learning rate of ``optimizer`` by function ``lr_lambda`` . ``lr_lambda`` is funciton which receives ``epoch`` . + + The algorithm can be described as the code below. + + .. code-block:: text + + learning_rate = 0.5 # init learning_rate + lr_lambda = lambda epoch: 0.95 ** epoch + + learning_rate = 0.5 # epoch 0 + learning_rate = 0.475 # epoch 1 + learning_rate = 0.45125 # epoch 2 + + Args: + learning_rate (float): The initial learning rate. It is a python float number. + lr_lambda (function): A function which computes a factor by ``epoch`` , and then multiply the initial learning rate by this factor. + last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. + verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` . + + Returns: + ``LambdaLR`` instance to schedule learning rate. + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + + # train on default dygraph mode + paddle.disable_static() + x = np.random.uniform(-1, 1, [10, 10]).astype("float32") + linear = paddle.nn.Linear(10, 10) + scheduler = paddle.optimizer.LambdaLR(learning_rate=0.5, lr_lambda=lambda x:0.95**x, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters()) + for epoch in range(20): + for batch_id in range(2): + x = paddle.to_tensor(x) + out = linear(x) + loss = paddle.reduce_mean(out) + out.backward() + sgd.minimize(loss) + linear.clear_gradients() + scheduler.step() + + # train on statich mode + paddle.enable_static() + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, start_prog): + x = paddle.static.data(name='x', shape=[-1, 4, 5]) + y = paddle.static.data(name='y', shape=[-1, 4, 5]) + z = paddle.static.nn.fc(x, 100) + loss = paddle.mean(z) + scheduler = paddle.optimizer.LambdaLR(learning_rate=0.5, lr_lambda=lambda x:0.95**x, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler) + sgd.minimize(loss) + lr_var = sgd._global_learning_rate() + + exe = paddle.static.Executor() + exe.run(start_prog) + for epoch in range(20): + for batch_id in range(2): + out = exe.run( + main_prog, + feed={ + 'x': np.random.randn(3, 4, 5).astype('float32'), + 'y': np.random.randn(3, 4, 5).astype('float32') + }, + fetch_list=lr_var.name) + scheduler.step() + + """ + + def __init__(self, learning_rate, lr_lambda, last_epoch=-1, verbose=False): + if not callable(lr_lambda): + raise TypeError( + "The type of 'lr_lambda' in 'LambdaLR' must be 'function', but received %s." + % type(lr_lambda)) + + self.lr_lambda = lr_lambda + super(LambdaLR, self).__init__(learning_rate, last_epoch, verbose) + + def get_lr(self): + return self.base_lr * self.lr_lambda(self.last_epoch) + + +class ReduceLROnPlateau(_LRScheduler): + """ + Reduce learning rate when ``metrics`` has stopped descending. Models often benefit from reducing the learning rate + by 2 to 10 times once model performance has no longer improvement. + + The ``metrics`` is the one which has been pass into ``step`` , it must be 1-D Tensor with shape [1]. When ``metrics`` + stop descending for a ``patience`` number of epochs, the learning rate will be reduced to ``learning_rate * factor`` . + (Specially, ``mode`` can also be set to ``'max`` , in this case, when ``metrics`` stop ascending for a ``patience`` + number of epochs, the learning rate will be reduced.) + + In addition, After each reduction, it will wait a ``cooldown`` number of epochs before resuming above operation. + + Args: + learning_rate (float): The initial learning rate. It is a python float number. + mode (str, optional): ``'min'`` or ``'max'`` can be selected. Normally, it is ``'min'`` , which means that the + learning rate will reduce when ``loss`` stops descending. Specially, if it's set to ``'max'`` , the learning + rate will reduce when ``loss`` stops ascending. Default: ``'min'`` . + factor (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * factor`` . + It should be less than 1.0. Default: 0.1. + patience (int, optional): When ``loss`` doesn't improve for this number of epochs, learing rate will be reduced. + Default: 10. + threshold (float, optional): ``threshold`` and ``threshold_mode`` will determine the minimum change of ``loss`` . + This make tiny changes of ``loss`` will be ignored. Default: 1e-4. + threshold_mode (str, optional): ``'rel'`` or ``'abs'`` can be selected. In ``'rel'`` mode, the minimum change of ``loss`` + is ``last_loss * threshold`` , where ``last_loss`` is ``loss`` in last epoch. In ``'abs'`` mode, the minimum + change of ``loss`` is ``threshold`` . Default: ``'rel'`` . + cooldown (int, optional): The number of epochs to wait before resuming normal operation. Default: 0. + min_lr (float, optional): The lower bound of the learning rate after reduction. Default: 0. + epsilon (float, optional): Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is + ignored. Default: 1e-8. + verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False``. + + + Returns: + ``ReduceLROnPlateau`` instance to schedule learning rate. + + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + # train on default dygraph mode + paddle.disable_static() + x = np.random.uniform(-1, 1, [10, 10]).astype("float32") + linear = paddle.nn.Linear(10, 10) + scheduler = paddle.optimizer.ReduceLROnPlateau(learning_rate=1.0, factor=0.5, patience=5, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters()) + for epoch in range(20): + for batch_id in range(2): + x = paddle.to_tensor(x) + out = linear(x) + loss = paddle.reduce_mean(out) + out.backward() + sgd.minimize(loss) + linear.clear_gradients() + scheduler.step(loss) + + # train on statich mode + paddle.enable_static() + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, start_prog): + x = paddle.static.data(name='x', shape=[-1, 4, 5]) + y = paddle.static.data(name='y', shape=[-1, 4, 5]) + z = paddle.static.nn.fc(x, 100) + loss = paddle.mean(z) + scheduler = paddle.optimizer.ReduceLROnPlateau(learning_rate=1.0, factor=0.5, patience=5, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler) + sgd.minimize(loss) + lr_var = sgd._global_learning_rate() + + exe = paddle.static.Executor() + exe.run(start_prog) + for epoch in range(20): + for batch_id in range(2): + out = exe.run( + main_prog, + feed={ + 'x': np.random.randn(3, 4, 5).astype('float32'), + 'y': np.random.randn(3, 4, 5).astype('float32') + }, + fetch_list=lr_var.name) + scheduler.step(out[0]) + + """ + + def __init__(self, + learning_rate, + mode='min', + factor=0.1, + patience=10, + threshold=1e-4, + threshold_mode='rel', + cooldown=0, + min_lr=0, + epsilon=1e-8, + verbose=False): + mode = mode.lower() + if mode not in ['min', 'max']: + raise ValueError('mode: ' + mode + ' is unknown!') + self.mode = mode + + if factor >= 1.0: + raise ValueError( + 'new_lr = origin_lr * gamma and gamma should be < 1.0.') + self.factor = factor + + threshold_mode = threshold_mode.lower() + if threshold_mode not in ['rel', 'abs']: + raise ValueError('threshold mode: ' + threshold_mode + + ' is unknown!') + self.threshold_mode = threshold_mode + if not isinstance(learning_rate, (float, int)): + raise TypeError( + "The type of 'learning_rate' in 'ReduceLROnPlateau' must be 'float', but received %s." + % type(learning_rate)) + + self.verbose = verbose + self.patience = patience + self.threshold = threshold + self.threshold_mode = threshold_mode + self.cooldown = cooldown + self.min_lr = min_lr + self.epsilon = epsilon + + self.cooldown_counter = 0 + self.best = None + self.num_bad_epochs = 0 + + # Can not call Parent __init__, so implement here. + self.base_lr = float(learning_rate) + self.last_lr = float(learning_rate) + self.last_epoch = 0 + self.verbose = verbose + self._var_name = None + + # "cooldown_counter / best / num_bad_epochs / last_epoch / last_lr" will be stored. + def _state_keys(self): + self.keys = [ + 'cooldown_counter', 'best', 'num_bad_epochs', 'last_epoch', + 'last_lr' + ] + + def step(self, metrics, epoch=None): + """ + step should be called after 'minimize' . It will update the learning rate in optimizer according to ``metrics`` . + The new learning rate will take effect on next epoch. + + Args: + metrics (Tensor|numpy.ndarray|float): Which will be monitored to determine whether the learning rate will reduce. + If it stop descending for a ``patience`` number of epochs, the learning rate will reduce. If it's 'Tensor' or + 'numpy.ndarray', its shape must be [1]. + epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1. + + Returns: + None + + Examples: + Please refer to the example of current _LRScheduler. + """ + if epoch is None: + self.last_epoch = self.last_epoch + 1 + else: + self.last_epoch = epoch + + # loss must be 1-D Tensor with shape [1] + if isinstance(metrics, (Tensor, numpy.ndarray)): + assert len(metrics.shape) == 1 and metrics.shape[0] == 1, "the metrics.shape " \ + "should be (1L,), but the current metrics.shape is {}. Maybe that " \ + "you should call paddle.mean to process it first.".format(loss.shape) + elif not isinstance(metrics, + (int, float, numpy.float32, numpy.float64)): + raise TypeError( + "metrics must be 'int', 'float', 'np.float', 'numpy.ndarray' or 'paddle.Tensor', but receive {}". + format(type(metrics))) + + if self.cooldown_counter > 0: + self.cooldown_counter -= 1 + else: + if self.best is None or self._is_better(metrics, self.best): + self.best = metrics + self.num_bad_epochs = 0 + else: + self.num_bad_epochs += 1 + + if self.num_bad_epochs > self.patience: + self.cooldown_counter = self.cooldown + self.num_bad_epochs = 0 + new_lr = max(self.last_lr * self.factor, self.min_lr) + if self.last_lr - new_lr > self.epsilon: + self.last_lr = new_lr + if self.verbose: + print('Epoch {}: {} set learning rate to {}.'.format( + self.last_epoch, self.__class__.__name__, + self.last_lr)) + + def _is_better(self, current, best): + print("mode", self.mode, 'threshold_mode', self.threshold_mode) + if self.mode == 'min' and self.threshold_mode == 'rel': + return current < best - best * self.threshold + + elif self.mode == 'min' and self.threshold_mode == 'abs': + return current < best - self.threshold + + elif self.mode == 'max' and self.threshold_mode == 'rel': + return current > best + best * self.threshold + + else: + return current > best + self.threshold + + +class CosineAnnealingLR(_LRScheduler): + """ + + Set the learning rate using a cosine annealing schedule, where :math:`\eta_{max}` is set to + the initial learning_rate. :math:`T_{cur}` is the number of epochs since the last restart in + SGDR: + + \begin{aligned} + \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), + & T_{cur} \neq (2k+1)T_{max}; \\ + \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) + \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), + & T_{cur} = (2k+1)T_{max}. + \end{aligned} + + The algorithm can be described as following. + + .. math:: + \begin{aligned} + \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), + & T_{cur} \neq (2k+1)T_{max}; \\ + \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) + \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), + & T_{cur} = (2k+1)T_{max}. + \end{aligned} + + It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts `_. + Note that this only implements the cosine annealing part of SGDR, and not the restarts. + + Args: + learning_rate (float): The initial learning rate, that is :math:`\eta_{max}` . It can be set to python float or int number. + T_max (int): Maximum number of iterations. It is half of the decay cycle of learning rate. + eta_min (float|int, optional): Minimum learning rate, that is :math:`\eta_{min}` . Default: 0. + last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. + verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False`` . + + Returns: + ``CosineAnnealingLR`` instance to schedule learning rate. + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + + # train on default dygraph mode + paddle.disable_static() + x = np.random.uniform(-1, 1, [10, 10]).astype("float32") + linear = paddle.nn.Linear(10, 10) + scheduler = paddle.optimizer.CosineAnnealingLR(learning_rate=0.5, T_max=10, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameter_list=linear.parameters()) + for epoch in range(20): + for batch_id in range(2): + x = paddle.to_tensor(x) + out = linear(x) + loss = paddle.reduce_mean(out) + out.backward() + sgd.minimize(loss) + linear.clear_gradients() + scheduler.step() + + # train on statich mode + paddle.enable_static() + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, start_prog): + x = paddle.static.data(name='x', shape=[-1, 4, 5]) + y = paddle.static.data(name='y', shape=[-1, 4, 5]) + z = paddle.static.nn.fc(x, 100) + loss = paddle.mean(z) + scheduler = paddle.optimizer.CosineAnnealingLR(learning_rate=0.5, T_max=10, verbose=True) + sgd = paddle.optimizer.SGD(learning_rate=scheduler) + sgd.minimize(loss) + lr_var = sgd._global_learning_rate() + + exe = paddle.static.Executor() + exe.run(start_prog) + for epoch in range(20): + for batch_id in range(2): + out = exe.run( + main_prog, + feed={ + 'x': np.random.randn(3, 4, 5).astype('float32'), + 'y': np.random.randn(3, 4, 5).astype('float32') + }, + fetch_list=lr_var.name) + scheduler.step() + """ + + def __init__(self, + learning_rate, + T_max, + eta_min=0, + last_epoch=-1, + verbose=False): + if not isinstance(T_max, int): + raise TypeError( + "The type of 'T_max' in 'CosineAnnealingLR' must be 'int', but received %s." + % type(T_max)) + if not isinstance(eta_min, (float, int)): + raise TypeError( + "The type of 'eta_min' in 'CosineAnnealingLR' must be 'float, int', but received %s." + % type(eta_min)) + self.T_max = T_max + self.eta_min = float(eta_min) + super(CosineAnnealingLR, self).__init__(learning_rate, last_epoch, + verbose) + + def get_lr(self): + if self.last_epoch == 0: + return self.base_lr + elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0: + return self.last_lr + (self.base_lr - self.eta_min) * (1 - math.cos( + math.pi / self.T_max)) / 2 + + return (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / ( + 1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) * ( + self.last_lr - self.eta_min) + self.eta_min + + def _get_closed_form_lr(self): + return self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos( + math.pi * self.last_epoch / self.T_max)) / 2 diff --git a/python/paddle/static/__init__.py b/python/paddle/static/__init__.py index 93060c7865e..42a28a4f04e 100644 --- a/python/paddle/static/__init__.py +++ b/python/paddle/static/__init__.py @@ -21,6 +21,7 @@ __all__ = [ 'load', 'data', 'InputSpec' ] +from . import nn from .input import data #DEFINE_ALIAS from .input import InputSpec #DEFINE_ALIAS from ..fluid.executor import Executor #DEFINE_ALIAS -- GitLab