未验证 提交 9cb57f94 编写于 作者: C Chen Weihang 提交者: GitHub

Update set_dict method name & add aliases (#26700)

* update set_dict method name & add aliases

* fix var name error

* fix alias formats

* use set_state_dict in unittest

* add decorator solve compatible problem

* polish decorator

* replace layer set_state_dict by patched method

* remove import monkey path layer

* fix import function error

* add unittest for coverage
上级 3900f66c
...@@ -29,6 +29,9 @@ from .layer_object_helper import LayerObjectHelper ...@@ -29,6 +29,9 @@ from .layer_object_helper import LayerObjectHelper
from .base import program_desc_tracing_guard, param_guard from .base import program_desc_tracing_guard, param_guard
from paddle.fluid import framework from paddle.fluid import framework
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from paddle.fluid.executor import Executor, global_scope
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import _current_expected_place as _get_device
__all__ = ['Layer'] __all__ = ['Layer']
...@@ -797,7 +800,7 @@ class Layer(core.Layer): ...@@ -797,7 +800,7 @@ class Layer(core.Layer):
raise ValueError( raise ValueError(
"super(YourLayer, self).__init__() should be called first") "super(YourLayer, self).__init__() should be called first")
if len(self._loaddict_holder) > 0: if len(self._loaddict_holder) > 0:
assert value.name in self._loaddict_holder, "Parameter not found, Can't not find [ {} ] in stat_dict".format( assert value.name in self._loaddict_holder, "Parameter not found, Can't not find [ {} ] in state_dict".format(
value.name) value.name)
value.set_value(self._loaddict_holder[value.name]) value.set_value(self._loaddict_holder[value.name])
...@@ -943,12 +946,13 @@ class Layer(core.Layer): ...@@ -943,12 +946,13 @@ class Layer(core.Layer):
destination = destination_temp destination = destination_temp
return destination return destination
def set_dict(self, @framework.deprecate_stat_dict
stat_dict, def set_state_dict(self,
include_sublayers=True, state_dict,
use_structured_name=True): include_sublayers=True,
use_structured_name=True):
''' '''
Set parameters and persistable buffers from stat_dict. All the parameters and buffers will be reset by the tensor in the stat_dict Set parameters and persistable buffers from state_dict. All the parameters and buffers will be reset by the tensor in the state_dict
Parameters: Parameters:
state_dict(dict) : Dict contains all the parameters and persistable buffers. state_dict(dict) : Dict contains all the parameters and persistable buffers.
...@@ -961,72 +965,67 @@ class Layer(core.Layer): ...@@ -961,72 +965,67 @@ class Layer(core.Layer):
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle
with fluid.dygraph.guard():
emb = fluid.dygraph.Embedding([10, 10]) paddle.disable_static()
emb = paddle.nn.Embedding([10, 10])
state_dict = emb.state_dict() state_dict = emb.state_dict()
fluid.save_dygraph( state_dict, "paddle_dy") paddle.save(state_dict, "paddle_dy")
para_state_dict, _ = fluid.load_dygraph( "paddle_dy") para_state_dict, _ = paddle.load("paddle_dy")
emb.set_dict( para_state_dict )
''' emb.set_state_dict(para_state_dict)
self.load_dict(
stat_dict,
include_sublayers=include_sublayers,
use_structured_name=use_structured_name)
def load_dict(self,
stat_dict,
include_sublayers=True,
use_structured_name=True):
''' '''
Set parameters and persistable buffers from stat_dict. All the parameters and persistabl buffers will be reset by the tensor in the stat_dict
This api will be Deprecated. Please use set_dict def _check_match(key, param):
state = state_dict.get(key, None)
Parameters: if state is None:
state_dict(dict) : Dict contains all the parameters and persistable buffers. raise ValueError("{} is not found in the provided dict.".format(
include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True key))
use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter or buffer name as key. if list(state.shape) != list(param.shape):
Default: True raise ValueError(
Returns: "{} receives a shape {}, but the expected shape is {}.".
None format(key, list(state.shape), list(param.shape)))
return param, state
Examples:
.. code-block:: python matched_param_state = []
for key, param in self.state_dict().items():
import paddle.fluid as fluid key_name = key if use_structured_name else param.name
with fluid.dygraph.guard(): try:
emb = fluid.dygraph.Embedding([10, 10]) match_res = _check_match(key_name, param)
matched_param_state.append(match_res)
state_dict = emb.state_dict() except ValueError as err:
fluid.save_dygraph( state_dict, "paddle_dy") warnings.warn(("Skip loading for {}. ".format(key) + str(err)))
para_state_dict, _ = fluid.load_dygraph( "paddle_dy") if in_dygraph_mode():
for param, state in matched_param_state:
emb.load_dict( para_state_dict ) param.set_value(state)
else:
'''
inner_state_dict = self.state_dict()
for name, param_or_buffer in inner_state_dict.items(): def _set_var(var, ndarray):
key_name = name if use_structured_name else param_or_buffer.name t = global_scope().find_var(var.name).get_tensor()
if key_name in stat_dict: p = t._place()
param_or_buffer.set_value(stat_dict[key_name]) if p.is_cpu_place():
else: place = core.CPUPlace()
raise RuntimeError( elif p.is_cuda_pinned_place():
"Parameter or persistable buffer not found, Can't find [ {} ] in stat_dict" place = core.CUDAPinnedPlace()
"use_structured_name is set to [{}]".format( else:
key_name, use_structured_name)) p = core.Place()
unused_para_list = [] p.set_place(t._place())
for k, v in stat_dict.items(): place = core.CUDAPlace(p.gpu_device_id())
if k not in inner_state_dict: t.set(ndarray, place)
unused_para_list.append(k)
if len(unused_para_list) > 0: executor = Executor(_get_device())._default_executor
warnings.warn( # restore parameter states
"Variables [ {} ] are not used, because not included in layers state_dict". core._create_loaded_parameter(
format(" ".join(unused_para_list))) [param for param, state in matched_param_state],
global_scope(), executor)
for param, state in matched_param_state:
_set_var(param, state)
# [aliases] Compatible with old method names
set_dict = set_state_dict
load_dict = set_state_dict
...@@ -97,7 +97,7 @@ class LearningRateDecay(object): ...@@ -97,7 +97,7 @@ class LearningRateDecay(object):
""" """
self.keys = ['step_num'] self.keys = ['step_num']
def set_dict(self, state_dict): def set_state_dict(self, state_dict):
""" """
Loads the schedulers state. Loads the schedulers state.
""" """
...@@ -114,6 +114,9 @@ class LearningRateDecay(object): ...@@ -114,6 +114,9 @@ class LearningRateDecay(object):
"There are some unused values in state_dict. Maybe the optimizer have different 'LearningRateDecay' when invoking state_dict and set_dict" "There are some unused values in state_dict. Maybe the optimizer have different 'LearningRateDecay' when invoking state_dict and set_dict"
) )
# [aliases] Compatible with old method names
set_dict = set_state_dict
def step(self): def step(self):
raise NotImplementedError() raise NotImplementedError()
......
...@@ -587,12 +587,13 @@ class DataParallel(layers.Layer): ...@@ -587,12 +587,13 @@ class DataParallel(layers.Layer):
include_sublayers=include_sublayers, include_sublayers=include_sublayers,
structured_name_prefix=structured_name_prefix) structured_name_prefix=structured_name_prefix)
def set_dict(self, @framework.deprecate_stat_dict
stat_dict, def set_state_dict(self,
include_sublayers=True, state_dict,
use_structured_name=True): include_sublayers=True,
use_structured_name=True):
''' '''
Set parameters of self._layers from stat_dict. All the parameters of self._layers will be reset by the tensor in the stat_dict Set parameters of self._layers from state_dict. All the parameters of self._layers will be reset by the tensor in the state_dict
Parameters: Parameters:
state_dict(dict) : Dict contains all the parameters state_dict(dict) : Dict contains all the parameters
...@@ -605,62 +606,27 @@ class DataParallel(layers.Layer): ...@@ -605,62 +606,27 @@ class DataParallel(layers.Layer):
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle
with fluid.dygraph.guard():
strategy=fluid.dygraph.prepare_context()
emb = fluid.dygraph.Embedding([10, 10])
emb = fluid.dygraph.DataParallel(emb, strategy)
state_dict = emb.state_dict()
fluid.save_dygraph( state_dict, "paddle_dy")
para_state_dict, _ = fluid.load_dygraph( "paddle_dy")
emb.set_dict( para_state_dict )
''' paddle.disable_static()
self._layers.set_dict(
stat_dict,
include_sublayers=include_sublayers,
use_structured_name=use_structured_name)
def load_dict(self,
stat_dict,
include_sublayers=True,
use_structured_name=True):
'''
Set parameters of self._layers from stat_dict. All the parameters of self._layers will be reset by the tensor in the stat_dict
This api will be Deprecated. Please use set_dict
Parameters:
state_dict(dict) : Dict contains all the parameters
include_sublayers(bool, optional) : If true, also include the parameters from sublayers. Default: True
use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter name as key.
Default: True
Returns:
None
Examples: emb = paddle.nn.Embedding([10, 10])
.. code-block:: python emb = fluid.dygraph.DataParallel(emb, strategy)
import paddle.fluid as fluid state_dict = emb.state_dict()
with fluid.dygraph.guard(): paddle.save(state_dict, "paddle_dy")
strategy=fluid.dygraph.prepare_context()
emb = fluid.dygraph.Embedding([10, 10])
emb = fluid.dygraph.DataParallel(emb, strategy)
state_dict = emb.state_dict() para_state_dict, _ = paddle.load("paddle_dy")
fluid.save_dygraph( state_dict, "paddle_dy")
para_state_dict, _ = fluid.load_dygraph( "paddle_dy")
emb.load_dict( para_state_dict ) emb.set_state_dict(para_state_dict)
''' '''
self._layers.load_dict( self._layers.set_state_dict(
stat_dict, state_dict,
include_sublayers=include_sublayers, include_sublayers=include_sublayers,
use_structured_name=use_structured_name) use_structured_name=use_structured_name)
# [aliases] Compatible with old method names
set_dict = set_state_dict
load_dict = set_state_dict
...@@ -36,6 +36,7 @@ from . import core ...@@ -36,6 +36,7 @@ from . import core
from . import unique_name from . import unique_name
import paddle.version as fluid_version import paddle.version as fluid_version
import warnings import warnings
import functools
__all__ = [ __all__ = [
'Program', 'Program',
...@@ -238,6 +239,25 @@ def _fake_interface_only_(func): ...@@ -238,6 +239,25 @@ def _fake_interface_only_(func):
return __impl__ return __impl__
# NOTE(chenweihang): There is argument name typo (stat_dict, correct name is state_dict)
# in fluid api Layer.set_dict, Optimizer.load, in order to correct the argument without
# introducing compatibility issues, add this decorator
# NOTE(chenweihang): not using `wrap_decorator` here is because `wrap_decorator` will
# move kwargs to args, which doesn't work in this decorate case
def deprecate_stat_dict(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if 'stat_dict' in kwargs:
warnings.warn(
"The argument `stat_dict` has deprecated, please change it to `state_dict`.",
DeprecationWarning)
kwargs['state_dict'] = kwargs['stat_dict']
kwargs.pop('stat_dict')
return func(*args, **kwargs)
return wrapper
dygraph_not_support = wrap_decorator(_dygraph_not_support_) dygraph_not_support = wrap_decorator(_dygraph_not_support_)
dygraph_only = wrap_decorator(_dygraph_only_) dygraph_only = wrap_decorator(_dygraph_only_)
fake_interface_only = wrap_decorator(_fake_interface_only_) fake_interface_only = wrap_decorator(_fake_interface_only_)
......
...@@ -170,7 +170,7 @@ class Optimizer(object): ...@@ -170,7 +170,7 @@ class Optimizer(object):
return state_dict return state_dict
@framework.dygraph_only @framework.dygraph_only
def set_dict(self, state_dict): def set_state_dict(self, state_dict):
''' '''
Load optimizer state dict. For Adam optimizer, contains beta1, beta2, momentum etc. If LearningRateDecay have been used, global_step will be changed. Load optimizer state dict. For Adam optimizer, contains beta1, beta2, momentum etc. If LearningRateDecay have been used, global_step will be changed.
...@@ -182,20 +182,22 @@ class Optimizer(object): ...@@ -182,20 +182,22 @@ class Optimizer(object):
Examples: Examples:
.. code-block:: python .. code-block:: python
with fluid.dygraph.guard(): import paddle
emb = fluid.dygraph.Embedding([10, 10])
paddle.disable_static()
emb = paddle.nn.Embedding([10, 10])
state_dict = emb.state_dict() state_dict = emb.state_dict()
fluid.save_dygraph(state_dict, "paddle_dy") paddle.save(state_dict, "paddle_dy")
adam = fluid.optimizer.Adam(learning_rate=fluid.layers.noam_decay( 100, 10000), adam = paddle.optimizer.Adam(learning_rate=fluid.layers.noam_decay( 100, 10000),
parameter_list=emb.parameters()) parameter_list=emb.parameters())
state_dict = adam.state_dict() state_dict = adam.state_dict()
fluid.save_dygraph(state_dict, "paddle_dy")
para_state_dict, opti_state_dict = fluid.load_dygraph( "paddle_dy") para_state_dict, opti_state_dict = paddle.load("paddle_dy")
adam.set_dict(opti_state_dict) adam.set_state_dict(opti_state_dict)
''' '''
from paddle.optimizer.lr_scheduler import _LRScheduler from paddle.optimizer.lr_scheduler import _LRScheduler
...@@ -257,6 +259,9 @@ class Optimizer(object): ...@@ -257,6 +259,9 @@ class Optimizer(object):
tensor.set(load_para_np, framework._current_expected_place()) tensor.set(load_para_np, framework._current_expected_place())
# [aliases] Compatible with old method names
set_dict = set_state_dict
def get_opti_var_name_list(self): def get_opti_var_name_list(self):
return self._opti_name_list return self._opti_name_list
...@@ -4595,15 +4600,16 @@ class RecomputeOptimizer(Optimizer): ...@@ -4595,15 +4600,16 @@ class RecomputeOptimizer(Optimizer):
), "_checkpoints should be a list of Variable or a list of String" ), "_checkpoints should be a list of Variable or a list of String"
self._checkpoints = checkpoints self._checkpoints = checkpoints
def load(self, stat_dict): @framework.deprecate_stat_dict
def load(self, state_dict):
""" """
:api_attr: Static Graph :api_attr: Static Graph
load function is not supported by Recompute Optimizer for now. load function is not supported by Recompute Optimizer for now.
:return: None :return: None
Args: Args:
stat_dict: the dict load by load_persistable method state_dict: the dict load by load_persistable method
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -4627,8 +4633,8 @@ class RecomputeOptimizer(Optimizer): ...@@ -4627,8 +4633,8 @@ class RecomputeOptimizer(Optimizer):
sgd = fluid.optimizer.RecomputeOptimizer(sgd) sgd = fluid.optimizer.RecomputeOptimizer(sgd)
sgd._set_checkpoints([fc_1, pred]) sgd._set_checkpoints([fc_1, pred])
try: try:
stat_dict = {} state_dict = {}
sgd.load(stat_dict) sgd.load(state_dict)
except NotImplementedError as e: except NotImplementedError as e:
print(cpt.get_exception_message(e)) print(cpt.get_exception_message(e))
""" """
......
...@@ -374,8 +374,7 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -374,8 +374,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
adam._learning_rate.step_num = 0 adam._learning_rate.step_num = 0
para_state_dict, opti_state_dict = paddle.load("./test_dy") para_state_dict, opti_state_dict = paddle.load("./test_dy")
print(opti_state_dict['LR_Scheduler']) adam.set_state_dict(opti_state_dict)
adam.set_dict(opti_state_dict)
opti_dict = adam.state_dict() opti_dict = adam.state_dict()
for k, v in opti_dict.items(): for k, v in opti_dict.items():
...@@ -393,7 +392,7 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -393,7 +392,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
var.set(np.zeros_like(np_t), place) var.set(np.zeros_like(np_t), place)
ptb_model.set_dict(para_state_dict) ptb_model.set_state_dict(stat_dict=para_state_dict)
state_dict = ptb_model.state_dict() state_dict = ptb_model.state_dict()
...@@ -483,7 +482,7 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -483,7 +482,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
if isinstance(adam._learning_rate, LearningRateDecay): if isinstance(adam._learning_rate, LearningRateDecay):
adam._learning_rate.step_num = 0 adam._learning_rate.step_num = 0
adam.set_dict(self.opti_dict) adam.set_state_dict(self.opti_dict)
opti_dict = adam.state_dict() opti_dict = adam.state_dict()
for k, v in opti_dict.items(): for k, v in opti_dict.items():
if isinstance(v, core.VarBase): if isinstance(v, core.VarBase):
...@@ -500,7 +499,7 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -500,7 +499,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
var.set(np.zeros_like(np_t), place) var.set(np.zeros_like(np_t), place)
ptb_model.set_dict(self.state_dict) ptb_model.set_state_dict(self.state_dict)
state_dict = ptb_model.state_dict() state_dict = ptb_model.state_dict()
...@@ -593,7 +592,7 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -593,7 +592,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
if isinstance(adam._learning_rate, LearningRateDecay): if isinstance(adam._learning_rate, LearningRateDecay):
adam._learning_rate.step_num = 0 adam._learning_rate.step_num = 0
adam.set_dict(np_opti_dict) adam.set_state_dict(np_opti_dict)
opti_dict = adam.state_dict() opti_dict = adam.state_dict()
for k, v in opti_dict.items(): for k, v in opti_dict.items():
...@@ -613,7 +612,7 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -613,7 +612,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
var.set(np.zeros_like(np_t), place) var.set(np.zeros_like(np_t), place)
ptb_model.set_dict(np_state_dict) ptb_model.set_state_dict(np_state_dict)
state_dict = ptb_model.state_dict() state_dict = ptb_model.state_dict()
...@@ -656,8 +655,8 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -656,8 +655,8 @@ class TestDygraphPtbRnn(unittest.TestCase):
last_hidden = None last_hidden = None
last_cell = None last_cell = None
adam.set_dict(self.opti_dict) adam.set_state_dict(self.opti_dict)
ptb_model.set_dict(self.state_dict) ptb_model.set_state_dict(self.state_dict)
for i in range(1): for i in range(1):
x_data = np.arange(12).reshape(4, 3).astype('int64') x_data = np.arange(12).reshape(4, 3).astype('int64')
...@@ -745,8 +744,8 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -745,8 +744,8 @@ class TestDygraphPtbRnn(unittest.TestCase):
last_cell = None last_cell = None
state_dict, opti_dict = fluid.load_dygraph("./test_dy") state_dict, opti_dict = fluid.load_dygraph("./test_dy")
adam.set_dict(opti_dict) adam.set_state_dict(opti_dict)
ptb_model.set_dict(state_dict) ptb_model.set_state_dict(state_dict)
for i in range(1): for i in range(1):
x_data = np.arange(12).reshape(4, 3).astype('int64') x_data = np.arange(12).reshape(4, 3).astype('int64')
...@@ -849,8 +848,8 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -849,8 +848,8 @@ class TestDygraphPtbRnn(unittest.TestCase):
for k, v in self.state_dict.items(): for k, v in self.state_dict.items():
np_state_dict[k] = v.numpy() np_state_dict[k] = v.numpy()
adam.set_dict(np_opti_dict) adam.set_state_dict(np_opti_dict)
ptb_model.set_dict(np_state_dict) ptb_model.set_state_dict(np_state_dict)
for i in range(1): for i in range(1):
x_data = np.arange(12).reshape(4, 3).astype('int64') x_data = np.arange(12).reshape(4, 3).astype('int64')
y_data = np.arange(1, 13).reshape(4, 3).astype('int64') y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
......
...@@ -918,6 +918,29 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -918,6 +918,29 @@ class TestDygraphPtbRnn(unittest.TestCase):
para_state_dict, opti_state_dict = paddle.load( para_state_dict, opti_state_dict = paddle.load(
os.path.join('saved_dy', 'emb_dy.pdopt')) os.path.join('saved_dy', 'emb_dy.pdopt'))
def test_no_state_in_input_dict(self):
with fluid.dygraph.guard():
emb = fluid.dygraph.Embedding([10, 10])
state_dict = emb.state_dict()
paddle.save(state_dict, os.path.join('saved_dy', 'emb_dy'))
para_state_dict, _ = paddle.load(os.path.join('saved_dy', 'emb_dy'))
para_state_dict.pop('weight')
emb.set_state_dict(para_state_dict)
def test_state_shape_mismatch(self):
with fluid.dygraph.guard():
emb = fluid.dygraph.Embedding([10, 10])
state_dict = emb.state_dict()
paddle.save(state_dict, os.path.join('saved_dy', 'emb_dy'))
para_state_dict, _ = paddle.load(os.path.join('saved_dy', 'emb_dy'))
para_state_dict['weight'] = np.expand_dims(
para_state_dict['weight'], axis=-1)
emb.set_state_dict(para_state_dict)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -832,8 +832,8 @@ class TestRecomputeOptimizer(unittest.TestCase): ...@@ -832,8 +832,8 @@ class TestRecomputeOptimizer(unittest.TestCase):
recompute_optimizer = optimizer.RecomputeOptimizer(sgd_optimizer) recompute_optimizer = optimizer.RecomputeOptimizer(sgd_optimizer)
recompute_optimizer._set_checkpoints([b1_out]) recompute_optimizer._set_checkpoints([b1_out])
try: try:
stat_dict = {} state_dict = {}
recompute_optimizer.load(stat_dict) recompute_optimizer.load(state_dict)
except NotImplementedError as e: except NotImplementedError as e:
self.assertEqual( self.assertEqual(
"load function is not supported by Recompute Optimizer for now", "load function is not supported by Recompute Optimizer for now",
......
...@@ -19,10 +19,7 @@ from . import model_summary ...@@ -19,10 +19,7 @@ from . import model_summary
from . import model from . import model
from .model import * from .model import *
from .model_summary import summary from .model_summary import summary
from .dygraph_layer_patch import monkey_patch_layer
logger.setup_logger() logger.setup_logger()
__all__ = ['callbacks'] + model.__all__ + ['summary'] __all__ = ['callbacks'] + model.__all__ + ['summary']
monkey_patch_layer()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
import paddle.fluid as fluid
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import _current_expected_place as _get_device
def monkey_patch_layer():
def load_dict(self,
stat_dict,
include_sublayers=True,
use_structured_name=True):
'''
Set parameters from stat_dict. All the parameters will be reset by the
tensor in the stat_dict
This api will be Deprecated. Please use set_dict
Parameters:
state_dict(dict) : Dict contains all the parameters
include_sublayers(bool, optional) : If true, also include the
parameters from sublayers. Default: True
use_structured_name(bool, optional) : If true, use structured name
as key, otherwise, use parameter name as key. Default: True
Returns:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
with fluid.dygraph.guard():
emb = fluid.dygraph.Embedding([10, 10])
state_dict = emb.state_dict()
fluid.save_dygraph( state_dict, "paddle_dy")
para_state_dict, _ = fluid.load_dygraph( "paddle_dy")
emb.load_dict( para_state_dict )
'''
def _check_match(key, param):
state = stat_dict.get(key, None)
if state is None:
raise ValueError(
"{} is not found in the providing file.".format(key))
if list(state.shape) != list(param.shape):
raise ValueError(
"{} receives a shape {}, but the expected shape is {}.".
format(key, list(state.shape), list(param.shape)))
return param, state
matched_param_state = []
for key, param in self.state_dict().items():
key_name = key if use_structured_name else param.name
try:
match_res = _check_match(key_name, param)
matched_param_state.append(match_res)
except ValueError as err:
warnings.warn(("Skip loading for {}. ".format(key) + str(err)))
if in_dygraph_mode():
for param, state in matched_param_state:
param.set_value(state)
else:
def _set_var(var, ndarray):
t = fluid.global_scope().find_var(var.name).get_tensor()
p = t._place()
if p.is_cpu_place():
place = fluid.CPUPlace()
elif p.is_cuda_pinned_place():
place = fluid.CUDAPinnedPlace()
else:
p = fluid.core.Place()
p.set_place(t._place())
place = fluid.CUDAPlace(p.gpu_device_id())
t.set(ndarray, place)
executor = fluid.Executor(_get_device())._default_executor
# restore parameter states
fluid.core._create_loaded_parameter(
[param for param, state in matched_param_state],
fluid.global_scope(), executor)
for param, state in matched_param_state:
_set_var(param, state)
setattr(fluid.dygraph.Layer, 'load_dict', load_dict)
...@@ -109,7 +109,7 @@ class _LRScheduler(object): ...@@ -109,7 +109,7 @@ class _LRScheduler(object):
""" """
self.keys = ['last_epoch', 'last_lr'] self.keys = ['last_epoch', 'last_lr']
def set_dict(self, state_dict): def set_state_dict(self, state_dict):
""" """
Loads the schedulers state. Loads the schedulers state.
""" """
...@@ -126,8 +126,8 @@ class _LRScheduler(object): ...@@ -126,8 +126,8 @@ class _LRScheduler(object):
"There are some unused values in state_dict. Maybe the optimizer have different 'LearningRateDecay' when invoking state_dict and set_dict" "There are some unused values in state_dict. Maybe the optimizer have different 'LearningRateDecay' when invoking state_dict and set_dict"
) )
# alias for set_dict # alias for set_state_dict
set_state_dict = set_dict set_dict = set_state_dict
def get_lr(self): def get_lr(self):
# calculate by python float # calculate by python float
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册