diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 1ef719b9da187be659d9c898ec996b5ad0c0d8a6..7075024369f328b59ecac014b0960fc26f447ff2 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -29,6 +29,9 @@ from .layer_object_helper import LayerObjectHelper from .base import program_desc_tracing_guard, param_guard from paddle.fluid import framework from ..param_attr import ParamAttr +from paddle.fluid.executor import Executor, global_scope +from paddle.fluid.framework import in_dygraph_mode +from paddle.fluid.framework import _current_expected_place as _get_device __all__ = ['Layer'] @@ -797,7 +800,7 @@ class Layer(core.Layer): raise ValueError( "super(YourLayer, self).__init__() should be called first") if len(self._loaddict_holder) > 0: - assert value.name in self._loaddict_holder, "Parameter not found, Can't not find [ {} ] in stat_dict".format( + assert value.name in self._loaddict_holder, "Parameter not found, Can't not find [ {} ] in state_dict".format( value.name) value.set_value(self._loaddict_holder[value.name]) @@ -943,12 +946,13 @@ class Layer(core.Layer): destination = destination_temp return destination - def set_dict(self, - stat_dict, - include_sublayers=True, - use_structured_name=True): + @framework.deprecate_stat_dict + def set_state_dict(self, + state_dict, + include_sublayers=True, + use_structured_name=True): ''' - Set parameters and persistable buffers from stat_dict. All the parameters and buffers will be reset by the tensor in the stat_dict + Set parameters and persistable buffers from state_dict. All the parameters and buffers will be reset by the tensor in the state_dict Parameters: state_dict(dict) : Dict contains all the parameters and persistable buffers. @@ -961,72 +965,67 @@ class Layer(core.Layer): Examples: .. code-block:: python - import paddle.fluid as fluid - with fluid.dygraph.guard(): - emb = fluid.dygraph.Embedding([10, 10]) + import paddle + + paddle.disable_static() + + emb = paddle.nn.Embedding([10, 10]) - state_dict = emb.state_dict() - fluid.save_dygraph( state_dict, "paddle_dy") - - para_state_dict, _ = fluid.load_dygraph( "paddle_dy") - - emb.set_dict( para_state_dict ) + state_dict = emb.state_dict() + paddle.save(state_dict, "paddle_dy") + + para_state_dict, _ = paddle.load("paddle_dy") - ''' - self.load_dict( - stat_dict, - include_sublayers=include_sublayers, - use_structured_name=use_structured_name) + emb.set_state_dict(para_state_dict) - def load_dict(self, - stat_dict, - include_sublayers=True, - use_structured_name=True): ''' - Set parameters and persistable buffers from stat_dict. All the parameters and persistabl buffers will be reset by the tensor in the stat_dict - This api will be Deprecated. Please use set_dict - - Parameters: - state_dict(dict) : Dict contains all the parameters and persistable buffers. - include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True - use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter or buffer name as key. - Default: True - Returns: - None - - Examples: - .. code-block:: python - - import paddle.fluid as fluid - with fluid.dygraph.guard(): - emb = fluid.dygraph.Embedding([10, 10]) - - state_dict = emb.state_dict() - fluid.save_dygraph( state_dict, "paddle_dy") - - para_state_dict, _ = fluid.load_dygraph( "paddle_dy") - - emb.load_dict( para_state_dict ) - - ''' - - inner_state_dict = self.state_dict() + def _check_match(key, param): + state = state_dict.get(key, None) + if state is None: + raise ValueError("{} is not found in the provided dict.".format( + key)) + if list(state.shape) != list(param.shape): + raise ValueError( + "{} receives a shape {}, but the expected shape is {}.". + format(key, list(state.shape), list(param.shape))) + return param, state + + matched_param_state = [] + for key, param in self.state_dict().items(): + key_name = key if use_structured_name else param.name + try: + match_res = _check_match(key_name, param) + matched_param_state.append(match_res) + except ValueError as err: + warnings.warn(("Skip loading for {}. ".format(key) + str(err))) + + if in_dygraph_mode(): + for param, state in matched_param_state: + param.set_value(state) + else: - for name, param_or_buffer in inner_state_dict.items(): - key_name = name if use_structured_name else param_or_buffer.name - if key_name in stat_dict: - param_or_buffer.set_value(stat_dict[key_name]) - else: - raise RuntimeError( - "Parameter or persistable buffer not found, Can't find [ {} ] in stat_dict" - "use_structured_name is set to [{}]".format( - key_name, use_structured_name)) - unused_para_list = [] - for k, v in stat_dict.items(): - if k not in inner_state_dict: - unused_para_list.append(k) - if len(unused_para_list) > 0: - warnings.warn( - "Variables [ {} ] are not used, because not included in layers state_dict". - format(" ".join(unused_para_list))) + def _set_var(var, ndarray): + t = global_scope().find_var(var.name).get_tensor() + p = t._place() + if p.is_cpu_place(): + place = core.CPUPlace() + elif p.is_cuda_pinned_place(): + place = core.CUDAPinnedPlace() + else: + p = core.Place() + p.set_place(t._place()) + place = core.CUDAPlace(p.gpu_device_id()) + t.set(ndarray, place) + + executor = Executor(_get_device())._default_executor + # restore parameter states + core._create_loaded_parameter( + [param for param, state in matched_param_state], + global_scope(), executor) + for param, state in matched_param_state: + _set_var(param, state) + + # [aliases] Compatible with old method names + set_dict = set_state_dict + load_dict = set_state_dict diff --git a/python/paddle/fluid/dygraph/learning_rate_scheduler.py b/python/paddle/fluid/dygraph/learning_rate_scheduler.py index cce383be7e22cd066199f814db80a75367862b82..cd6af6fd5b575e8188088bde9e8944ab94c7e0f8 100644 --- a/python/paddle/fluid/dygraph/learning_rate_scheduler.py +++ b/python/paddle/fluid/dygraph/learning_rate_scheduler.py @@ -97,7 +97,7 @@ class LearningRateDecay(object): """ self.keys = ['step_num'] - def set_dict(self, state_dict): + def set_state_dict(self, state_dict): """ Loads the schedulers state. """ @@ -114,6 +114,9 @@ class LearningRateDecay(object): "There are some unused values in state_dict. Maybe the optimizer have different 'LearningRateDecay' when invoking state_dict and set_dict" ) + # [aliases] Compatible with old method names + set_dict = set_state_dict + def step(self): raise NotImplementedError() diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index 5ecc713ddcace7a6bed05ffa4282d9f5c1041a44..472022bced7e3e2dd11d301501ebaec75e5e412a 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -587,12 +587,13 @@ class DataParallel(layers.Layer): include_sublayers=include_sublayers, structured_name_prefix=structured_name_prefix) - def set_dict(self, - stat_dict, - include_sublayers=True, - use_structured_name=True): + @framework.deprecate_stat_dict + def set_state_dict(self, + state_dict, + include_sublayers=True, + use_structured_name=True): ''' - Set parameters of self._layers from stat_dict. All the parameters of self._layers will be reset by the tensor in the stat_dict + Set parameters of self._layers from state_dict. All the parameters of self._layers will be reset by the tensor in the state_dict Parameters: state_dict(dict) : Dict contains all the parameters @@ -605,62 +606,27 @@ class DataParallel(layers.Layer): Examples: .. code-block:: python - import paddle.fluid as fluid - with fluid.dygraph.guard(): - strategy=fluid.dygraph.prepare_context() - emb = fluid.dygraph.Embedding([10, 10]) - emb = fluid.dygraph.DataParallel(emb, strategy) - - state_dict = emb.state_dict() - fluid.save_dygraph( state_dict, "paddle_dy") - - para_state_dict, _ = fluid.load_dygraph( "paddle_dy") - - emb.set_dict( para_state_dict ) + import paddle - ''' - - self._layers.set_dict( - stat_dict, - include_sublayers=include_sublayers, - use_structured_name=use_structured_name) - - def load_dict(self, - stat_dict, - include_sublayers=True, - use_structured_name=True): - ''' - Set parameters of self._layers from stat_dict. All the parameters of self._layers will be reset by the tensor in the stat_dict - - This api will be Deprecated. Please use set_dict - - Parameters: - state_dict(dict) : Dict contains all the parameters - include_sublayers(bool, optional) : If true, also include the parameters from sublayers. Default: True - use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter name as key. - Default: True - Returns: - None + paddle.disable_static() - Examples: - .. code-block:: python + emb = paddle.nn.Embedding([10, 10]) + emb = fluid.dygraph.DataParallel(emb, strategy) - import paddle.fluid as fluid - with fluid.dygraph.guard(): - strategy=fluid.dygraph.prepare_context() - emb = fluid.dygraph.Embedding([10, 10]) - emb = fluid.dygraph.DataParallel(emb, strategy) + state_dict = emb.state_dict() + paddle.save(state_dict, "paddle_dy") - state_dict = emb.state_dict() - fluid.save_dygraph( state_dict, "paddle_dy") - - para_state_dict, _ = fluid.load_dygraph( "paddle_dy") + para_state_dict, _ = paddle.load("paddle_dy") - emb.load_dict( para_state_dict ) + emb.set_state_dict(para_state_dict) ''' - self._layers.load_dict( - stat_dict, + self._layers.set_state_dict( + state_dict, include_sublayers=include_sublayers, use_structured_name=use_structured_name) + + # [aliases] Compatible with old method names + set_dict = set_state_dict + load_dict = set_state_dict diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index fc4e91aad4fff1db325e17828d26ccd94c164c3d..5281df9ead10acea5ae8656dcc4a0eed14fb3e83 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -36,6 +36,7 @@ from . import core from . import unique_name import paddle.version as fluid_version import warnings +import functools __all__ = [ 'Program', @@ -238,6 +239,25 @@ def _fake_interface_only_(func): return __impl__ +# NOTE(chenweihang): There is argument name typo (stat_dict, correct name is state_dict) +# in fluid api Layer.set_dict, Optimizer.load, in order to correct the argument without +# introducing compatibility issues, add this decorator +# NOTE(chenweihang): not using `wrap_decorator` here is because `wrap_decorator` will +# move kwargs to args, which doesn't work in this decorate case +def deprecate_stat_dict(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if 'stat_dict' in kwargs: + warnings.warn( + "The argument `stat_dict` has deprecated, please change it to `state_dict`.", + DeprecationWarning) + kwargs['state_dict'] = kwargs['stat_dict'] + kwargs.pop('stat_dict') + return func(*args, **kwargs) + + return wrapper + + dygraph_not_support = wrap_decorator(_dygraph_not_support_) dygraph_only = wrap_decorator(_dygraph_only_) fake_interface_only = wrap_decorator(_fake_interface_only_) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 9e2d77df777d761b6904d8916c7a35fb8e6bfaba..8b37cfef3890eace0ff5141eeb91d85e78f1c964 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -170,7 +170,7 @@ class Optimizer(object): return state_dict @framework.dygraph_only - def set_dict(self, state_dict): + def set_state_dict(self, state_dict): ''' Load optimizer state dict. For Adam optimizer, contains beta1, beta2, momentum etc. If LearningRateDecay have been used, global_step will be changed. @@ -182,20 +182,22 @@ class Optimizer(object): Examples: .. code-block:: python - with fluid.dygraph.guard(): - emb = fluid.dygraph.Embedding([10, 10]) + import paddle + + paddle.disable_static() + + emb = paddle.nn.Embedding([10, 10]) - state_dict = emb.state_dict() - fluid.save_dygraph(state_dict, "paddle_dy") + state_dict = emb.state_dict() + paddle.save(state_dict, "paddle_dy") - adam = fluid.optimizer.Adam(learning_rate=fluid.layers.noam_decay( 100, 10000), + adam = paddle.optimizer.Adam(learning_rate=fluid.layers.noam_decay( 100, 10000), parameter_list=emb.parameters()) - state_dict = adam.state_dict() - fluid.save_dygraph(state_dict, "paddle_dy") + state_dict = adam.state_dict() - para_state_dict, opti_state_dict = fluid.load_dygraph( "paddle_dy") + para_state_dict, opti_state_dict = paddle.load("paddle_dy") - adam.set_dict(opti_state_dict) + adam.set_state_dict(opti_state_dict) ''' from paddle.optimizer.lr_scheduler import _LRScheduler @@ -257,6 +259,9 @@ class Optimizer(object): tensor.set(load_para_np, framework._current_expected_place()) + # [aliases] Compatible with old method names + set_dict = set_state_dict + def get_opti_var_name_list(self): return self._opti_name_list @@ -4595,15 +4600,16 @@ class RecomputeOptimizer(Optimizer): ), "_checkpoints should be a list of Variable or a list of String" self._checkpoints = checkpoints - def load(self, stat_dict): + @framework.deprecate_stat_dict + def load(self, state_dict): """ - :api_attr: Static Graph + :api_attr: Static Graph load function is not supported by Recompute Optimizer for now. :return: None Args: - stat_dict: the dict load by load_persistable method + state_dict: the dict load by load_persistable method Examples: .. code-block:: python @@ -4627,8 +4633,8 @@ class RecomputeOptimizer(Optimizer): sgd = fluid.optimizer.RecomputeOptimizer(sgd) sgd._set_checkpoints([fc_1, pred]) try: - stat_dict = {} - sgd.load(stat_dict) + state_dict = {} + sgd.load(state_dict) except NotImplementedError as e: print(cpt.get_exception_message(e)) """ diff --git a/python/paddle/fluid/tests/unittests/test_imperative_save_load.py b/python/paddle/fluid/tests/unittests/test_imperative_save_load.py index 2c92f760fdecaa2169f8cb76c79e08ef0c486a05..22e19efcb58d19c41835565de2c8c01fe253702a 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_save_load.py @@ -374,8 +374,7 @@ class TestDygraphPtbRnn(unittest.TestCase): adam._learning_rate.step_num = 0 para_state_dict, opti_state_dict = paddle.load("./test_dy") - print(opti_state_dict['LR_Scheduler']) - adam.set_dict(opti_state_dict) + adam.set_state_dict(opti_state_dict) opti_dict = adam.state_dict() for k, v in opti_dict.items(): @@ -393,7 +392,7 @@ class TestDygraphPtbRnn(unittest.TestCase): var.set(np.zeros_like(np_t), place) - ptb_model.set_dict(para_state_dict) + ptb_model.set_state_dict(stat_dict=para_state_dict) state_dict = ptb_model.state_dict() @@ -483,7 +482,7 @@ class TestDygraphPtbRnn(unittest.TestCase): if isinstance(adam._learning_rate, LearningRateDecay): adam._learning_rate.step_num = 0 - adam.set_dict(self.opti_dict) + adam.set_state_dict(self.opti_dict) opti_dict = adam.state_dict() for k, v in opti_dict.items(): if isinstance(v, core.VarBase): @@ -500,7 +499,7 @@ class TestDygraphPtbRnn(unittest.TestCase): var.set(np.zeros_like(np_t), place) - ptb_model.set_dict(self.state_dict) + ptb_model.set_state_dict(self.state_dict) state_dict = ptb_model.state_dict() @@ -593,7 +592,7 @@ class TestDygraphPtbRnn(unittest.TestCase): if isinstance(adam._learning_rate, LearningRateDecay): adam._learning_rate.step_num = 0 - adam.set_dict(np_opti_dict) + adam.set_state_dict(np_opti_dict) opti_dict = adam.state_dict() for k, v in opti_dict.items(): @@ -613,7 +612,7 @@ class TestDygraphPtbRnn(unittest.TestCase): var.set(np.zeros_like(np_t), place) - ptb_model.set_dict(np_state_dict) + ptb_model.set_state_dict(np_state_dict) state_dict = ptb_model.state_dict() @@ -656,8 +655,8 @@ class TestDygraphPtbRnn(unittest.TestCase): last_hidden = None last_cell = None - adam.set_dict(self.opti_dict) - ptb_model.set_dict(self.state_dict) + adam.set_state_dict(self.opti_dict) + ptb_model.set_state_dict(self.state_dict) for i in range(1): x_data = np.arange(12).reshape(4, 3).astype('int64') @@ -745,8 +744,8 @@ class TestDygraphPtbRnn(unittest.TestCase): last_cell = None state_dict, opti_dict = fluid.load_dygraph("./test_dy") - adam.set_dict(opti_dict) - ptb_model.set_dict(state_dict) + adam.set_state_dict(opti_dict) + ptb_model.set_state_dict(state_dict) for i in range(1): x_data = np.arange(12).reshape(4, 3).astype('int64') @@ -849,8 +848,8 @@ class TestDygraphPtbRnn(unittest.TestCase): for k, v in self.state_dict.items(): np_state_dict[k] = v.numpy() - adam.set_dict(np_opti_dict) - ptb_model.set_dict(np_state_dict) + adam.set_state_dict(np_opti_dict) + ptb_model.set_state_dict(np_state_dict) for i in range(1): x_data = np.arange(12).reshape(4, 3).astype('int64') y_data = np.arange(1, 13).reshape(4, 3).astype('int64') diff --git a/python/paddle/fluid/tests/unittests/test_imperative_save_load_v2.py b/python/paddle/fluid/tests/unittests/test_imperative_save_load_v2.py index 3ccd1dbda3a443d50e43ba498cb3d5b529318c32..3eb413a62664057c56567d5834b216110fac04fb 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_save_load_v2.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_save_load_v2.py @@ -918,6 +918,29 @@ class TestDygraphPtbRnn(unittest.TestCase): para_state_dict, opti_state_dict = paddle.load( os.path.join('saved_dy', 'emb_dy.pdopt')) + def test_no_state_in_input_dict(self): + with fluid.dygraph.guard(): + emb = fluid.dygraph.Embedding([10, 10]) + state_dict = emb.state_dict() + paddle.save(state_dict, os.path.join('saved_dy', 'emb_dy')) + + para_state_dict, _ = paddle.load(os.path.join('saved_dy', 'emb_dy')) + para_state_dict.pop('weight') + + emb.set_state_dict(para_state_dict) + + def test_state_shape_mismatch(self): + with fluid.dygraph.guard(): + emb = fluid.dygraph.Embedding([10, 10]) + state_dict = emb.state_dict() + paddle.save(state_dict, os.path.join('saved_dy', 'emb_dy')) + + para_state_dict, _ = paddle.load(os.path.join('saved_dy', 'emb_dy')) + para_state_dict['weight'] = np.expand_dims( + para_state_dict['weight'], axis=-1) + + emb.set_state_dict(para_state_dict) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_optimizer.py b/python/paddle/fluid/tests/unittests/test_optimizer.py index 2e6e516aa2edde79e6524b4b35507ea95876ec53..91d705223316360b8c05954259724a5f7d246440 100644 --- a/python/paddle/fluid/tests/unittests/test_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_optimizer.py @@ -832,8 +832,8 @@ class TestRecomputeOptimizer(unittest.TestCase): recompute_optimizer = optimizer.RecomputeOptimizer(sgd_optimizer) recompute_optimizer._set_checkpoints([b1_out]) try: - stat_dict = {} - recompute_optimizer.load(stat_dict) + state_dict = {} + recompute_optimizer.load(state_dict) except NotImplementedError as e: self.assertEqual( "load function is not supported by Recompute Optimizer for now", diff --git a/python/paddle/hapi/__init__.py b/python/paddle/hapi/__init__.py index fb16b829d5b8e563be9b4e1e5db5d19dded23521..67965de5d97621e188acfa1e0384325b9ec5b7aa 100644 --- a/python/paddle/hapi/__init__.py +++ b/python/paddle/hapi/__init__.py @@ -19,10 +19,7 @@ from . import model_summary from . import model from .model import * from .model_summary import summary -from .dygraph_layer_patch import monkey_patch_layer logger.setup_logger() __all__ = ['callbacks'] + model.__all__ + ['summary'] - -monkey_patch_layer() diff --git a/python/paddle/hapi/dygraph_layer_patch.py b/python/paddle/hapi/dygraph_layer_patch.py deleted file mode 100644 index e3a2948b69305fcb08c14c850f5738ac46aea2be..0000000000000000000000000000000000000000 --- a/python/paddle/hapi/dygraph_layer_patch.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings - -import paddle.fluid as fluid -from paddle.fluid.framework import in_dygraph_mode -from paddle.fluid.framework import _current_expected_place as _get_device - - -def monkey_patch_layer(): - def load_dict(self, - stat_dict, - include_sublayers=True, - use_structured_name=True): - ''' - Set parameters from stat_dict. All the parameters will be reset by the - tensor in the stat_dict - - This api will be Deprecated. Please use set_dict - - Parameters: - state_dict(dict) : Dict contains all the parameters - include_sublayers(bool, optional) : If true, also include the - parameters from sublayers. Default: True - use_structured_name(bool, optional) : If true, use structured name - as key, otherwise, use parameter name as key. Default: True - Returns: - None - - Examples: - .. code-block:: python - - import paddle.fluid as fluid - with fluid.dygraph.guard(): - emb = fluid.dygraph.Embedding([10, 10]) - - state_dict = emb.state_dict() - fluid.save_dygraph( state_dict, "paddle_dy") - - para_state_dict, _ = fluid.load_dygraph( "paddle_dy") - emb.load_dict( para_state_dict ) - - ''' - - def _check_match(key, param): - state = stat_dict.get(key, None) - if state is None: - raise ValueError( - "{} is not found in the providing file.".format(key)) - if list(state.shape) != list(param.shape): - raise ValueError( - "{} receives a shape {}, but the expected shape is {}.". - format(key, list(state.shape), list(param.shape))) - return param, state - - matched_param_state = [] - for key, param in self.state_dict().items(): - key_name = key if use_structured_name else param.name - try: - match_res = _check_match(key_name, param) - matched_param_state.append(match_res) - except ValueError as err: - warnings.warn(("Skip loading for {}. ".format(key) + str(err))) - - if in_dygraph_mode(): - for param, state in matched_param_state: - param.set_value(state) - else: - - def _set_var(var, ndarray): - t = fluid.global_scope().find_var(var.name).get_tensor() - p = t._place() - if p.is_cpu_place(): - place = fluid.CPUPlace() - elif p.is_cuda_pinned_place(): - place = fluid.CUDAPinnedPlace() - else: - p = fluid.core.Place() - p.set_place(t._place()) - place = fluid.CUDAPlace(p.gpu_device_id()) - t.set(ndarray, place) - - executor = fluid.Executor(_get_device())._default_executor - # restore parameter states - fluid.core._create_loaded_parameter( - [param for param, state in matched_param_state], - fluid.global_scope(), executor) - for param, state in matched_param_state: - _set_var(param, state) - - setattr(fluid.dygraph.Layer, 'load_dict', load_dict) diff --git a/python/paddle/optimizer/lr_scheduler.py b/python/paddle/optimizer/lr_scheduler.py index 4ecaffb8fa509bdc54067bb25f8d1b5191b7ac1b..61391704061bda7dfbad7252cbc04c0b7d6492a4 100644 --- a/python/paddle/optimizer/lr_scheduler.py +++ b/python/paddle/optimizer/lr_scheduler.py @@ -109,7 +109,7 @@ class _LRScheduler(object): """ self.keys = ['last_epoch', 'last_lr'] - def set_dict(self, state_dict): + def set_state_dict(self, state_dict): """ Loads the schedulers state. """ @@ -126,8 +126,8 @@ class _LRScheduler(object): "There are some unused values in state_dict. Maybe the optimizer have different 'LearningRateDecay' when invoking state_dict and set_dict" ) - # alias for set_dict - set_state_dict = set_dict + # alias for set_state_dict + set_dict = set_state_dict def get_lr(self): # calculate by python float