未验证 提交 1d0d6594 编写于 作者: Z zhangbo9674 提交者: GitHub

set_state_dict not use state_dict hook (#43407)

* set_state_dict not use state_dict hook

* add ut

* refine doc
上级 64e2f10c
...@@ -1327,7 +1327,8 @@ class Layer(object): ...@@ -1327,7 +1327,8 @@ class Layer(object):
destination=None, destination=None,
include_sublayers=True, include_sublayers=True,
structured_name_prefix="", structured_name_prefix="",
include_non_persistable_buffer=False): include_non_persistable_buffer=False,
use_hook=True):
""" """
Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict
...@@ -1335,6 +1336,7 @@ class Layer(object): ...@@ -1335,6 +1336,7 @@ class Layer(object):
destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
include_non_persistable_buffer(bool, optional): If true, include non persistable buffers of current layer and its sub-layers, it is used in pure fp16 and jit.save. Default: False include_non_persistable_buffer(bool, optional): If true, include non persistable buffers of current layer and its sub-layers, it is used in pure fp16 and jit.save. Default: False
use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True
""" """
if destination is None: if destination is None:
...@@ -1358,25 +1360,28 @@ class Layer(object): ...@@ -1358,25 +1360,28 @@ class Layer(object):
layer_item._state_dict_impl( layer_item._state_dict_impl(
destination_temp, include_sublayers, destination_temp, include_sublayers,
structured_name_prefix + layer_name + ".", structured_name_prefix + layer_name + ".",
include_non_persistable_buffer)) include_non_persistable_buffer, use_hook))
destination = destination_temp destination = destination_temp
for state_dict_hook in self._state_dict_hooks.values(): if use_hook:
hook_result = state_dict_hook(destination) for state_dict_hook in self._state_dict_hooks.values():
if hook_result is not None: hook_result = state_dict_hook(destination)
destination = hook_result if hook_result is not None:
destination = hook_result
return destination return destination
def to_static_state_dict(self, def to_static_state_dict(self,
destination=None, destination=None,
include_sublayers=True, include_sublayers=True,
structured_name_prefix=""): structured_name_prefix="",
use_hook=True):
''' '''
Get all parameters and buffers of current layer and its sub-layers. And set them into a dict Get all parameters and buffers of current layer and its sub-layers. And set them into a dict
Parameters: Parameters:
destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True
Retruns: Retruns:
dict: a dict contains all the parameters and persistable buffers. dict: a dict contains all the parameters and persistable buffers.
...@@ -1396,18 +1401,21 @@ class Layer(object): ...@@ -1396,18 +1401,21 @@ class Layer(object):
destination=destination, destination=destination,
include_sublayers=include_sublayers, include_sublayers=include_sublayers,
structured_name_prefix=structured_name_prefix, structured_name_prefix=structured_name_prefix,
include_non_persistable_buffer=True) include_non_persistable_buffer=True,
use_hook=use_hook)
def state_dict(self, def state_dict(self,
destination=None, destination=None,
include_sublayers=True, include_sublayers=True,
structured_name_prefix=""): structured_name_prefix="",
use_hook=True):
''' '''
Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict
Parameters: Parameters:
destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True
Retruns: Retruns:
dict: a dict contains all the parameters and persistable buffers. dict: a dict contains all the parameters and persistable buffers.
...@@ -1427,7 +1435,8 @@ class Layer(object): ...@@ -1427,7 +1435,8 @@ class Layer(object):
destination=destination, destination=destination,
include_sublayers=include_sublayers, include_sublayers=include_sublayers,
structured_name_prefix=structured_name_prefix, structured_name_prefix=structured_name_prefix,
include_non_persistable_buffer=False) include_non_persistable_buffer=False,
use_hook=use_hook)
@framework.deprecate_stat_dict @framework.deprecate_stat_dict
def set_state_dict(self, state_dict, use_structured_name=True): def set_state_dict(self, state_dict, use_structured_name=True):
...@@ -1478,7 +1487,7 @@ class Layer(object): ...@@ -1478,7 +1487,7 @@ class Layer(object):
return param, state return param, state
matched_param_state = [] matched_param_state = []
for key, param in self.state_dict().items(): for key, param in self.state_dict(use_hook=False).items():
key_name = key if use_structured_name else param.name key_name = key if use_structured_name else param.name
try: try:
match_res = _check_match(key_name, param) match_res = _check_match(key_name, param)
......
...@@ -704,6 +704,37 @@ class TestAmpDecorator(unittest.TestCase): ...@@ -704,6 +704,37 @@ class TestAmpDecorator(unittest.TestCase):
self.assertEqual((param.dtype == paddle.float32), True) self.assertEqual((param.dtype == paddle.float32), True)
class TestStateDictHookForAMP(unittest.TestCase):
def test_state_dict_hook(self):
def func_isinstance():
paddle.seed(100)
model = paddle.nn.Linear(2, 4)
model = paddle.amp.decorate(models=model,
level='O2',
save_dtype='float32')
param_value_ori = {}
for param in model.parameters():
param_value_ori[param.name] = param.numpy()
state_dict = model.state_dict()
for key, value in state_dict.items():
state_dict[key] = value.cast("float16")
model.set_state_dict(state_dict)
param_value_now = {}
for param in model.parameters():
param_value_now[param.name] = param.numpy()
for key in param_value_ori.keys():
print(np.equal(param_value_ori[key], param_value_now[key]))
with _test_eager_guard():
func_isinstance()
func_isinstance()
class TestPureFp16SaveLoad(unittest.TestCase): class TestPureFp16SaveLoad(unittest.TestCase):
def test_save_dtype_exception(self): def test_save_dtype_exception(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册