未验证 提交 c928a35e 编写于 作者: W wanghuancoder 提交者: GitHub

set_state_dict return missing_keys and unexpected_keys (#48436)

* refine set_state_dict
上级 f5c520bb
...@@ -1600,7 +1600,8 @@ class Layer: ...@@ -1600,7 +1600,8 @@ class Layer:
use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter or buffer name as key. use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter or buffer name as key.
Default: True Default: True
Returns: Returns:
None missing_keys(list):A list of str containing the missing keys
unexpected_keys(list):A list of str containing the unexpected keys
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -1615,15 +1616,20 @@ class Layer: ...@@ -1615,15 +1616,20 @@ class Layer:
emb.set_state_dict(para_state_dict) emb.set_state_dict(para_state_dict)
''' '''
missing_keys = []
match_keys = set()
unexpected_keys = []
def _check_match(key, param): def _check_match(key, param):
state = state_dict.get(key, None) state = state_dict.get(key, None)
if state is None: if state is None:
missing_keys.append(key)
raise ValueError( raise ValueError(
"{} is not found in the provided dict.".format(key) "{} is not found in the provided dict.".format(key)
) )
if isinstance(state, dict) or isinstance(state, list): if isinstance(state, dict) or isinstance(state, list):
if len(state) != len(param): if len(state) != len(param):
missing_keys.append(key)
raise ValueError( raise ValueError(
"{} receieves the length of {}, " "{} receieves the length of {}, "
"but the expected shape is {}".format( "but the expected shape is {}".format(
...@@ -1631,6 +1637,7 @@ class Layer: ...@@ -1631,6 +1637,7 @@ class Layer:
) )
) )
else: else:
match_keys.add(key)
return param, state return param, state
else: else:
state_shape = ( state_shape = (
...@@ -1640,11 +1647,13 @@ class Layer: ...@@ -1640,11 +1647,13 @@ class Layer:
) )
if list(state_shape) != list(param.shape): if list(state_shape) != list(param.shape):
missing_keys.append(key)
raise ValueError( raise ValueError(
"{} receives a shape {}, but the expected shape is {}.".format( "{} receives a shape {}, but the expected shape is {}.".format(
key, list(state_shape), list(param.shape) key, list(state_shape), list(param.shape)
) )
) )
match_keys.add(key)
return param, state return param, state
matched_param_state = [] matched_param_state = []
...@@ -1655,7 +1664,9 @@ class Layer: ...@@ -1655,7 +1664,9 @@ class Layer:
matched_param_state.append(match_res) matched_param_state.append(match_res)
except ValueError as err: except ValueError as err:
warnings.warn(("Skip loading for {}. ".format(key) + str(err))) warnings.warn(("Skip loading for {}. ".format(key) + str(err)))
for key in state_dict.keys():
if key not in match_keys:
unexpected_keys.append(key)
if _non_static_mode(): if _non_static_mode():
for param, state in matched_param_state: for param, state in matched_param_state:
param.set_value(state) param.set_value(state)
...@@ -1693,6 +1704,8 @@ class Layer: ...@@ -1693,6 +1704,8 @@ class Layer:
"This error might happens in dy2static, while calling 'set_state_dict' dynamicly in 'forward', which is not supported. If you only need call 'set_state_dict' once, move it to '__init__'." "This error might happens in dy2static, while calling 'set_state_dict' dynamicly in 'forward', which is not supported. If you only need call 'set_state_dict' once, move it to '__init__'."
) )
return missing_keys, unexpected_keys
def to(self, device=None, dtype=None, blocking=None): def to(self, device=None, dtype=None, blocking=None):
''' '''
Cast the parameters and buffers of Layer by the give device, dtype and blocking. Cast the parameters and buffers of Layer by the give device, dtype and blocking.
......
...@@ -53,6 +53,15 @@ class MyModel(nn.Layer): ...@@ -53,6 +53,15 @@ class MyModel(nn.Layer):
return super().set_state_dict(state_dict) return super().set_state_dict(state_dict)
class MyModel2(nn.Layer):
def __init__(self):
super().__init__()
self.linear = nn.Linear(100, 300)
def forward(self, x):
return self.linear(x)
def is_state_dict_equal(model1, model2): def is_state_dict_equal(model1, model2):
st1 = model1.state_dict() st1 = model1.state_dict()
st2 = model2.state_dict() st2 = model2.state_dict()
...@@ -73,5 +82,18 @@ class TestStateDictConvert(unittest.TestCase): ...@@ -73,5 +82,18 @@ class TestStateDictConvert(unittest.TestCase):
self.assertTrue(is_state_dict_equal(model1, model2)) self.assertTrue(is_state_dict_equal(model1, model2))
class TestStateDictReturn(unittest.TestCase):
def test_missing_keys_and_unexpected_keys(self):
model1 = MyModel2()
tmp_dict = dict()
tmp_dict["unexpected_keys"] = paddle.to_tensor(1)
missing_keys, unexpected_keys = model1.set_state_dict(tmp_dict)
self.assertEqual(len(missing_keys), 2)
self.assertEqual(missing_keys[0], "linear.weight")
self.assertEqual(missing_keys[1], "linear.bias")
self.assertEqual(len(unexpected_keys), 1)
self.assertEqual(unexpected_keys[0], "unexpected_keys")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册