From 0fb663559a39fa7354f31dfd731c68fbdd7deabd Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Wed, 22 Jun 2022 10:11:19 +0800 Subject: [PATCH] set_state_dict not use state_dict hook (#43407) (#43711) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在 amp-o2功能开发过程中,为了支持指定网络存储数据类型的功能,添加state_dict hook功能,但是在Layer的set_state_dict是通过state_dict获取网络参数并加载的,hook接口的存在导致 set_state_dict无法加载到原本网络参数。 本pr通过增加hook控制开关,在set_state_dict中禁用hook解决该问题。 详见pr43407 --- python/paddle/fluid/dygraph/layers.py | 31 ++++++++++++------- .../test_imperative_auto_mixed_precision.py | 31 +++++++++++++++++++ 2 files changed, 51 insertions(+), 11 deletions(-) diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 6392b0d1151..e9f92dfd1a3 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -1323,7 +1323,8 @@ class Layer(object): destination=None, include_sublayers=True, structured_name_prefix="", - include_non_persistable_buffer=False): + include_non_persistable_buffer=False, + use_hook=True): """ Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict @@ -1331,6 +1332,7 @@ class Layer(object): destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True include_non_persistable_buffer(bool, optional): If true, include non persistable buffers of current layer and its sub-layers, it is used in pure fp16 and jit.save. Default: False + use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True """ if destination is None: @@ -1354,25 +1356,28 @@ class Layer(object): layer_item._state_dict_impl( destination_temp, include_sublayers, structured_name_prefix + layer_name + ".", - include_non_persistable_buffer)) + include_non_persistable_buffer, use_hook)) destination = destination_temp - for state_dict_hook in self._state_dict_hooks.values(): - hook_result = state_dict_hook(destination) - if hook_result is not None: - destination = hook_result + if use_hook: + for state_dict_hook in self._state_dict_hooks.values(): + hook_result = state_dict_hook(destination) + if hook_result is not None: + destination = hook_result return destination def to_static_state_dict(self, destination=None, include_sublayers=True, - structured_name_prefix=""): + structured_name_prefix="", + use_hook=True): ''' Get all parameters and buffers of current layer and its sub-layers. And set them into a dict Parameters: destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True + use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True Retruns: dict: a dict contains all the parameters and persistable buffers. @@ -1392,18 +1397,21 @@ class Layer(object): destination=destination, include_sublayers=include_sublayers, structured_name_prefix=structured_name_prefix, - include_non_persistable_buffer=True) + include_non_persistable_buffer=True, + use_hook=use_hook) def state_dict(self, destination=None, include_sublayers=True, - structured_name_prefix=""): + structured_name_prefix="", + use_hook=True): ''' Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict Parameters: destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True + use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True Retruns: dict: a dict contains all the parameters and persistable buffers. @@ -1423,7 +1431,8 @@ class Layer(object): destination=destination, include_sublayers=include_sublayers, structured_name_prefix=structured_name_prefix, - include_non_persistable_buffer=False) + include_non_persistable_buffer=False, + use_hook=use_hook) @framework.deprecate_stat_dict def set_state_dict(self, state_dict, use_structured_name=True): @@ -1474,7 +1483,7 @@ class Layer(object): return param, state matched_param_state = [] - for key, param in self.state_dict().items(): + for key, param in self.state_dict(use_hook=False).items(): key_name = key if use_structured_name else param.name try: match_res = _check_match(key_name, param) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index 62458881305..f14606ca2d9 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -696,6 +696,37 @@ class TestAmpDecorator(unittest.TestCase): self.assertEqual((param.dtype == paddle.float32), True) +class TestStateDictHookForAMP(unittest.TestCase): + + def test_state_dict_hook(self): + + def func_isinstance(): + paddle.seed(100) + model = paddle.nn.Linear(2, 4) + model = paddle.amp.decorate(models=model, + level='O2', + save_dtype='float32') + param_value_ori = {} + for param in model.parameters(): + param_value_ori[param.name] = param.numpy() + + state_dict = model.state_dict() + for key, value in state_dict.items(): + state_dict[key] = value.cast("float16") + model.set_state_dict(state_dict) + + param_value_now = {} + for param in model.parameters(): + param_value_now[param.name] = param.numpy() + + for key in param_value_ori.keys(): + print(np.equal(param_value_ori[key], param_value_now[key])) + + with _test_eager_guard(): + func_isinstance() + func_isinstance() + + class TestPureFp16SaveLoad(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() -- GitLab