diff --git a/python/paddle/fluid/dygraph/checkpoint.py b/python/paddle/fluid/dygraph/checkpoint.py index 3a6e46bc64ec9af4dfcc4164a5078b2af85dca7a..fc5485078a69fbe43f4efa8321559eb0314a67c4 100644 --- a/python/paddle/fluid/dygraph/checkpoint.py +++ b/python/paddle/fluid/dygraph/checkpoint.py @@ -78,9 +78,9 @@ def save_dygraph(state_dict, model_path): for k, v in state_dict.items(): if isinstance(v, (Variable, core.VarBase)): model_dict[k] = v.numpy() + name_table[k] = v.name else: model_dict[k] = v - name_table[k] = v.name model_dict["StructuredToParameterName@@"] = name_table file_name = model_path + suffix diff --git a/python/paddle/fluid/dygraph/learning_rate_scheduler.py b/python/paddle/fluid/dygraph/learning_rate_scheduler.py index 0e6654f2f1ee31a1d2998fe52c2cce2ad70594da..a6226a9ead27f6c6729fdac877ea9eb3a4f61267 100644 --- a/python/paddle/fluid/dygraph/learning_rate_scheduler.py +++ b/python/paddle/fluid/dygraph/learning_rate_scheduler.py @@ -15,6 +15,7 @@ from __future__ import print_function import math +import warnings from .. import unique_name from ..framework import Variable @@ -66,6 +67,51 @@ class LearningRateDecay(object): persistable=False) return lr + def state_dict(self): + """ + Returns the state of the scheduler as a :class:`dict`. + + It is a subset of self.__dict__ . + """ + self._state_keys() + state_dict = {} + for key in self.keys: + if key not in self.__dict__: + continue + value = self.__dict__[key] + if isinstance(value, Variable): + assert value.shape == [ + 1 + ], "shape of Variable in state_dict must be [1] {}".format( + value.shape) + value = value.numpy()[0] + state_dict[key] = value + + return state_dict + + def _state_keys(self): + """ + set the keys in self.__dict__ that are needed to be saved. + """ + self.keys = ['step_num'] + + def set_dict(self, state_dict): + """ + Loads the schedulers state. + """ + self._state_keys() + for key in self.keys: + if key in state_dict: + self.__dict__[key] = state_dict[key] + else: + raise RuntimeError( + "Please check whether state_dict is correct for optimizer. Can't find [ {} ] in state_dict". + format(key)) + if len(state_dict) > len(self.keys): + warnings.warn( + "There are some unused values in state_dict. Maybe the optimizer have different 'LearningRateDecay' when invoking state_dict and set_dict" + ) + def step(self): raise NotImplementedError() @@ -402,7 +448,7 @@ class PolynomialDecay(LearningRateDecay): learning_rate(Variable|float): The initial learning rate. If the type is Variable, it's a tensor with shape [1], the data type can be float32 or float64. It also can be set to python int number. - decay_steps(int32): The decay step size. It determines the decay cycle. + decay_steps(int): The decay step size. It determines the decay cycle. end_learning_rate(float, optional): The minimum final learning rate. The default value is 0.0001. power(float, optional): Power of polynomial. The default value is 1.0. cycle(bool, optional): If set true, decay the learning rate every decay_steps. The default value is False. @@ -784,7 +830,7 @@ class ReduceLROnPlateau(LearningRateDecay): raise ValueError( 'new_lr = origin_lr * decay_rate and decay_rate should be < 1.0.' ) - self.decay_rate = decay_rate + self.decay_rate = self.create_lr_var(decay_rate) threshold_mode = threshold_mode.lower() if threshold_mode not in ['rel', 'abs']: @@ -793,8 +839,10 @@ class ReduceLROnPlateau(LearningRateDecay): self.threshold_mode = threshold_mode check_type(learning_rate, 'learning_rate', (float, int, Variable), 'ReduceLROnPlateau') - if isinstance(learning_rate, (float, int)): - learning_rate = self.create_lr_var(learning_rate) + if not isinstance(learning_rate, (float, int, Variable)): + raise TypeError( + "The type of 'learning_rate' in 'ReduceLROnPlateau' must be 'float, int, Variable', but received %s." + % type(learning_rate)) self.learning_rate = learning_rate self.verbose = verbose @@ -808,9 +856,17 @@ class ReduceLROnPlateau(LearningRateDecay): self.cooldown_counter = 0 self.best_loss = None self.num_bad_epochs = 0 - self.epoch = 0 + self.epoch_num = 0 + + def _state_keys(self): + self.keys = [ + 'cooldown_counter', 'best_loss', 'num_bad_epochs', 'epoch_num', + 'learning_rate' + ] def __call__(self): + if not isinstance(self.learning_rate, Variable): + self.learning_rate = self.create_lr_var(self.learning_rate) return self.learning_rate def step(self, loss): @@ -836,7 +892,7 @@ class ReduceLROnPlateau(LearningRateDecay): "should be (1L,), but the current loss.shape is {}. Maybe that " \ "you should call fluid.layers.mean to process it first.".format(loss.shape) - self.epoch += 1 + self.epoch_num += 1 if self.cooldown_counter > 0: self.cooldown_counter -= 1 else: @@ -854,10 +910,11 @@ class ReduceLROnPlateau(LearningRateDecay): self.decay_rate, self.min_lr) if self.learning_rate - new_lr > self.eps: if self.verbose: + old_lr = self.learning_rate.numpy()[0] if isinstance( + self.learning_rate, + Variable) else self.learning_rate print('Epoch {}: reducing learning rate from {} to {}.'. - format(self.epoch, - self.learning_rate.numpy()[0], - new_lr.numpy()[0])) + format(self.epoch_num, old_lr, new_lr.numpy()[0])) self.learning_rate = new_lr def _is_better(self, current, best): @@ -890,22 +947,28 @@ class _LearningRateEpochDecay(LearningRateDecay): raise TypeError( "The type of 'learning_rate' must be 'float, int', but received %s." % type(learning_rate)) - if learning_rate >= 1.0: - raise ValueError("The initial learning rate") + if learning_rate < 0: + raise ValueError("Invalid learning rate: {}".format(learning_rate)) self.base_lr = float(learning_rate) self.epoch_num = -1 + self.dtype = dtype if dtype is None: self.dtype = "float32" self.learning_rate = self.create_lr_var(self.base_lr) self.epoch() + def _state_keys(self): + self.keys = ['epoch_num', 'learning_rate'] + def __call__(self): """ Return last computed learning rate on current epoch. """ + if not isinstance(self.learning_rate, Variable): + self.learning_rate = self.create_lr_var(self.learning_rate) return self.learning_rate def epoch(self, epoch=None): @@ -918,8 +981,6 @@ class _LearningRateEpochDecay(LearningRateDecay): self.epoch_num = epoch self.learning_rate = self.get_lr() - if isinstance(self.learning_rate, float): - self.learning_rate = self.create_lr_var(self.learning_rate) def get_lr(self): raise NotImplementedError @@ -946,7 +1007,7 @@ class StepDecay(_LearningRateEpochDecay): Parameters: learning_rate (float|int): The initial learning rate. It can be set to python float or int number. - step_size (int): Period of learning rate decay.. + step_size (int): Period of learning rate decay. decay_rate (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * decay_rate`` . It should be less than 1.0. Default: 0.1. @@ -1024,7 +1085,7 @@ class MultiStepDecay(_LearningRateEpochDecay): learning_rate = 0.005 Parameters: - learning_rate (float|int): The initial learning rate. It can be set to python float or int number. If it + learning_rate (float|int): The initial learning rate. It can be set to python float or int number. milestones (tuple|list): List or tuple of each boundaries. Must be increasing. decay_rate (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * decay_rate`` . It should be less than 1.0. Default: 0.1. diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 7fb233736659b683a2e9ecd886cbf3cb95f52d7b..d72307586c85e3e2fedbf7ec7480f789cc74628b 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -33,7 +33,7 @@ from .layers import ops from .regularizer import append_regularization_ops from .dygraph import base as imperative_base from .dygraph import no_grad -from .dygraph.learning_rate_scheduler import LearningRateDecay +from .dygraph.learning_rate_scheduler import LearningRateDecay, _LearningRateEpochDecay from paddle.fluid import core from paddle.fluid.layers import tensor from functools import reduce @@ -148,17 +148,17 @@ class Optimizer(object): state_dict[var_tmp.name] = var_tmp # global step if use lr decay if isinstance(self._learning_rate, LearningRateDecay): - var_tmp = None - if framework.in_dygraph_mode(): + state_dict["LR_Scheduler"] = self._learning_rate.state_dict() + + if not isinstance(self._learning_rate, _LearningRateEpochDecay): + var_tmp = None var_temp = framework._varbase_creator( None, name='global_step', dtype='int32') - else: - var_temp = Variable(None, name='global_step', dtype='int32') - tensor.fill_constant( - [1], "int32", self._learning_rate.step_num, out=var_temp) + tensor.fill_constant( + [1], "int32", self._learning_rate.step_num, out=var_temp) - state_dict['global_step'] = var_temp + state_dict['global_step'] = var_temp return state_dict @framework.dygraph_only @@ -192,30 +192,28 @@ class Optimizer(object): ''' if isinstance(self._learning_rate, LearningRateDecay): - assert 'global_step' in state_dict, \ - 'Global step not in state dict, Dygraph use LearningRateDecay, global_step must in state_dict' - global_step = state_dict['global_step'] - - if isinstance(global_step, core.VarBase): - step_np = global_step - step_np = np.array(step_np.value().get_tensor()) - assert step_np.shape == (1,), \ - "global step shape is (1,), the shape is {}".format( step_np.shape ) - - self._learning_rate.step_num = int(step_np[0]) - elif isinstance(global_step, Variable): - step_np = global_step.numpy() - assert step_np.shape == (1,), \ - "global step shape is (1,), the shape is {}".format( step_np.shape ) - self._learning_rate.step_num = step_np[0] - elif isinstance(global_step, np.ndarray): - assert global_step.shape == (1,), \ - "global step shape is (1,), the shape is {}".format( global_step.shape ) - self._learning_rate.step_num = global_step[0] - else: - raise RuntimeError( - "Type not supprt, value in state dict must be [VarBase, Variable, numpy], the type is ", - type(global_step)) + self._learning_rate.set_dict(state_dict["LR_Scheduler"]) + + if not isinstance(self._learning_rate, _LearningRateEpochDecay): + assert 'global_step' in state_dict, \ + 'Global step not in state dict, Dygraph use LearningRateDecay, global_step must in state_dict' + global_step = state_dict['global_step'] + + if isinstance(global_step, Variable): + step_np = global_step + step_np = np.array(step_np.value().get_tensor()) + assert step_np.shape == (1,), \ + "global step shape is (1,), the shape is {}".format( step_np.shape ) + + self._learning_rate.step_num = int(step_np[0]) + elif isinstance(global_step, np.ndarray): + assert global_step.shape == (1,), \ + "global step shape is (1,), the shape is {}".format( global_step.shape ) + self._learning_rate.step_num = global_step[0] + else: + raise RuntimeError( + "Type not supprt, value in state dict must be [VarBase, Variable, numpy], the type is ", + type(global_step)) self._accumulators_holder = state_dict for k, v in self._accumulators.items(): @@ -346,11 +344,14 @@ class Optimizer(object): """ current_lr = self._global_learning_rate() - if current_lr: + if isinstance(current_lr, framework.Variable): return self._global_learning_rate().numpy()[0] if isinstance(self._learning_rate, float): return self._learning_rate + elif isinstance(self._learning_rate, _LearningRateEpochDecay): + step_lr = self._learning_rate() + return step_lr.numpy()[0] else: step_lr = self._learning_rate.step() if isinstance(step_lr, (float, int)): diff --git a/python/paddle/fluid/tests/unittests/test_imperative_save_load.py b/python/paddle/fluid/tests/unittests/test_imperative_save_load.py index 694ad077c024f8f7144c606d1f3819f377fd034b..545f7125a9d5d43f64a640082ae24da424b835e5 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_save_load.py @@ -276,8 +276,11 @@ class TestDygraphPtbRnn(unittest.TestCase): self.opti_dict = adam.state_dict() self.base_opti = {} for k, v in self.opti_dict.items(): - self.base_opti[v.name] = v.numpy() - self.assertTrue(np.sum(np.abs(v.numpy())) != 0) + if isinstance(v, core.VarBase): + self.base_opti[v.name] = v.numpy() + self.assertTrue(np.sum(np.abs(v.numpy())) != 0) + else: + self.base_opti[k] = v fluid.save_dygraph(self.opti_dict, "./test_dy") @@ -359,11 +362,12 @@ class TestDygraphPtbRnn(unittest.TestCase): opti_dict = adam.state_dict() # set to zero for k, v in opti_dict.items(): - np_t = v.numpy() - var = v.value().get_tensor() - var.set(np.zeros_like(np_t), place) + if isinstance(v, core.VarBase): + np_t = v.numpy() + var = v.value().get_tensor() + var.set(np.zeros_like(np_t), place) - self.assertTrue(np.sum(np.abs(v.numpy())) == 0) + self.assertTrue(np.sum(np.abs(v.numpy())) == 0) if isinstance(adam._learning_rate, LearningRateDecay): adam._learning_rate.step_num = 0 @@ -373,8 +377,11 @@ class TestDygraphPtbRnn(unittest.TestCase): opti_dict = adam.state_dict() for k, v in opti_dict.items(): - self.assertTrue( - np.array_equal(v.numpy(), self.base_opti[v.name])) + if isinstance(v, core.VarBase): + self.assertTrue( + np.array_equal(v.numpy(), self.base_opti[v.name])) + else: + self.assertEqual(v, self.base_opti[k]) # check parameter state_dict = ptb_model.state_dict() @@ -464,21 +471,24 @@ class TestDygraphPtbRnn(unittest.TestCase): opti_dict = adam.state_dict() # set to zero for k, v in opti_dict.items(): - np_t = v.numpy() - var = v.value().get_tensor() - var.set(np.zeros_like(np_t), place) + if isinstance(v, core.VarBase): + np_t = v.numpy() + var = v.value().get_tensor() + var.set(np.zeros_like(np_t), place) - self.assertTrue(np.sum(np.abs(v.numpy())) == 0) + self.assertTrue(np.sum(np.abs(v.numpy())) == 0) if isinstance(adam._learning_rate, LearningRateDecay): adam._learning_rate.step_num = 0 adam.set_dict(self.opti_dict) - opti_dict = adam.state_dict() for k, v in opti_dict.items(): - self.assertTrue( - np.array_equal(v.numpy(), self.base_opti[v.name])) + if isinstance(v, core.VarBase): + self.assertTrue( + np.array_equal(v.numpy(), self.base_opti[v.name])) + else: + self.assertEqual(v, self.base_opti[k]) # check parameter state_dict = ptb_model.state_dict() @@ -569,12 +579,14 @@ class TestDygraphPtbRnn(unittest.TestCase): np_opti_dict = {} # set to zero for k, v in opti_dict.items(): - np_t = v.numpy() - np_opti_dict[v.name] = np_t - var = v.value().get_tensor() - var.set(np.zeros_like(np_t), place) - - self.assertTrue(np.sum(np.abs(v.numpy())) == 0) + if isinstance(v, core.VarBase): + np_t = v.numpy() + np_opti_dict[v.name] = np_t + var = v.value().get_tensor() + var.set(np.zeros_like(np_t), place) + self.assertTrue(np.sum(np.abs(v.numpy())) == 0) + else: + np_opti_dict[k] = v if isinstance(adam._learning_rate, LearningRateDecay): adam._learning_rate.step_num = 0 @@ -583,8 +595,11 @@ class TestDygraphPtbRnn(unittest.TestCase): opti_dict = adam.state_dict() for k, v in opti_dict.items(): - self.assertTrue( - np.array_equal(v.numpy(), self.base_opti[v.name])) + if isinstance(v, core.VarBase): + self.assertTrue( + np.array_equal(v.numpy(), self.base_opti[v.name])) + else: + self.assertEqual(v, self.base_opti[k]) # check parameter state_dict = ptb_model.state_dict() @@ -825,7 +840,10 @@ class TestDygraphPtbRnn(unittest.TestCase): np_state_dict = {} for k, v in self.opti_dict.items(): - np_opti_dict[v.name] = v.numpy() + if isinstance(v, core.VarBase): + np_opti_dict[v.name] = v.numpy() + else: + np_opti_dict[k] = v for k, v in self.state_dict.items(): np_state_dict[k] = v.numpy() diff --git a/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py b/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py index 296cbe25a0612b03af877b8b43c5a0baca8799df..71b452d4a2dd192c756599eb24949084bfa0860e 100644 --- a/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py +++ b/python/paddle/fluid/tests/unittests/test_learning_rate_scheduler.py @@ -121,6 +121,104 @@ def lambda_decay(global_step, learning_rate, lr_lambda): class TestLearningRateDecayDygraph(unittest.TestCase): + def test_LR_state_dict(self): + with fluid.dygraph.guard(): + x = np.random.uniform(-1, 1, [3, 10]).astype("float32") + linear = fluid.dygraph.Linear(10, 10) + input = fluid.dygraph.to_variable(x) + + Exponential_scheduler = fluid.dygraph.ExponentialDecay( + learning_rate=0.1, + decay_steps=10000, + decay_rate=0.5, + staircase=True) + Step_scheduler = fluid.dygraph.StepDecay(0.5, step_size=3) + Reducelr_scheduler = fluid.dygraph.ReduceLROnPlateau( + learning_rate=1.0, decay_rate=0.5, patience=5, cooldown=3) + + adam1 = fluid.optimizer.Adam( + learning_rate=Exponential_scheduler, + parameter_list=linear.parameters()) + adam2 = fluid.optimizer.Adam( + learning_rate=Step_scheduler, + parameter_list=linear.parameters()) + adam3 = fluid.optimizer.Adam( + learning_rate=Reducelr_scheduler, + parameter_list=linear.parameters()) + print(adam3.state_dict()) + + for epoch in range(10): + out = linear(input) + loss = fluid.layers.reduce_mean(out) + loss.backward() + adam1.minimize(loss) + adam2.minimize(loss) + adam3.minimize(loss) + linear.clear_gradients() + + Step_scheduler.epoch() + Reducelr_scheduler.step(loss) + + fluid.dygraph.save_dygraph(linear.state_dict(), "save_path") + + Exponential_scheduler_test = fluid.dygraph.ExponentialDecay( + learning_rate=0.1, + decay_steps=10000, + decay_rate=0.5, + staircase=True) + Step_scheduler_test = fluid.dygraph.StepDecay(0.5, step_size=3) + Reducelr_scheduler_test = fluid.dygraph.ReduceLROnPlateau( + learning_rate=1.0, decay_rate=0.5, patience=5, cooldown=3) + + fluid.dygraph.save_dygraph(adam1.state_dict(), "save_path") + _, opt_state = fluid.dygraph.load_dygraph("save_path") + adam_test = fluid.optimizer.Adam( + learning_rate=Exponential_scheduler_test, + parameter_list=linear.parameters()) + adam_test.set_dict(opt_state) + self.assertEqual(adam_test._learning_rate.step_num, + adam1._learning_rate.step_num, + "epoch_num is different before and after set_dict") + + fluid.dygraph.save_dygraph(adam2.state_dict(), "save_path") + _, opt_state = fluid.dygraph.load_dygraph("save_path") + adam_test = fluid.optimizer.Adam( + learning_rate=Step_scheduler_test, + parameter_list=linear.parameters()) + adam_test.set_dict(opt_state) + self.assertEqual(adam_test._learning_rate.epoch_num, + adam2._learning_rate.epoch_num, + "epoch_num is different before and after set_dict") + self.assertEqual( + adam_test._learning_rate(), + adam2._learning_rate(), + "current learning rate is different before and after set_dict") + + fluid.dygraph.save_dygraph(adam3.state_dict(), "save_path") + _, opt_state = fluid.dygraph.load_dygraph("save_path") + adam_test = fluid.optimizer.Adam( + learning_rate=Reducelr_scheduler_test, + parameter_list=linear.parameters()) + adam_test.set_dict(opt_state) + self.assertEqual(adam_test._learning_rate.best_loss, + adam3._learning_rate.best_loss.numpy()[0], + "best_loss is different before and after set_dict") + self.assertEqual( + adam_test._learning_rate.cooldown_counter, + adam3._learning_rate.cooldown_counter, + "cooldown_counter is different before and after set_dict") + self.assertEqual( + adam_test._learning_rate.num_bad_epochs, + adam3._learning_rate.num_bad_epochs, + "num_bad_epochs is different before and after set_dict") + self.assertEqual(adam_test._learning_rate.epoch_num, + adam3._learning_rate.epoch_num, + "epoch is different before and after set_dict") + self.assertEqual( + adam_test._learning_rate(), + adam3._learning_rate(), + "current learning rate is different before and after set_dict") + def test_NoamDecay(self): with fluid.dygraph.guard(): d_model = 0.01 @@ -169,17 +267,22 @@ class TestLearningRateDecayDygraph(unittest.TestCase): learning_rate = 0.5 milestones = [2, 4, 8] decay_rate = 0.2 + linear = fluid.dygraph.Linear(10, 10) + scheduler = fluid.dygraph.MultiStepDecay(learning_rate, milestones, decay_rate) + + adam = fluid.optimizer.AdamOptimizer( + learning_rate=scheduler, parameter_list=linear.parameters()) for epoch in range(10): right_result = multi_step_decay(epoch, learning_rate, milestones, decay_rate) - fluid_result = scheduler().numpy()[0] + fluid_result = adam.current_step_lr() scheduler.epoch() self.assertAlmostEqual( right_result, fluid_result, - msg='Failed lr scheduler in step {0}, Python result is {1}, Fluid result is {2}'. + msg='Failed lr scheduler in epoch {0}, Python result is {1}, Fluid result is {2}'. format(epoch, right_result, fluid_result)) with self.assertRaises(ValueError): @@ -190,6 +293,12 @@ class TestLearningRateDecayDygraph(unittest.TestCase): lr = fluid.dygraph.MultiStepDecay(learning_rate, [20, 30, 50], 1) + with self.assertRaises(TypeError): + lr = fluid.dygraph.MultiStepDecay("test", [20, 30, 50]) + + with self.assertRaises(ValueError): + lr = fluid.dygraph.MultiStepDecay(-1, [20, 30, 50]) + def test_StepDecay(self): with fluid.dygraph.guard(): learning_rate = 0.5 @@ -205,21 +314,14 @@ class TestLearningRateDecayDygraph(unittest.TestCase): self.assertAlmostEqual( right_result, fluid_result, - msg='Failed lr scheduler in step {0}, Python result is {1}, Fluid result is {2}'. + msg='Failed lr scheduler in epoch {0}, Python result is {1}, Fluid result is {2}'. format(epoch, right_result, fluid_result)) with self.assertRaises(TypeError): - lr = fluid.dygraph.MultiStepDecay(learning_rate, "test", 0.1) - - with self.assertRaises(ValueError): - lr = fluid.dygraph.MultiStepDecay(learning_rate, [20, 30, 50], - 1) - - with self.assertRaises(TypeError): - lr = fluid.dygraph.MultiStepDecay("test", [20, 30, 50]) + lr = fluid.dygraph.StepDecay(learning_rate, "test", 0.1) with self.assertRaises(ValueError): - lr = fluid.dygraph.MultiStepDecay(2.0, [20, 30, 50]) + lr = fluid.dygraph.StepDecay(learning_rate, 20, 2) def test_LambdaDecay(self): with fluid.dygraph.guard():