未验证 提交 c928a35e 编写于 作者: W wanghuancoder 提交者: GitHub

set_state_dict return missing_keys and unexpected_keys (#48436)

* refine set_state_dict
上级 f5c520bb
......@@ -1600,7 +1600,8 @@ class Layer:
use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter or buffer name as key.
Default: True
Returns:
None
missing_keys(list):A list of str containing the missing keys
unexpected_keys(list):A list of str containing the unexpected keys
Examples:
.. code-block:: python
......@@ -1615,15 +1616,20 @@ class Layer:
emb.set_state_dict(para_state_dict)
'''
missing_keys = []
match_keys = set()
unexpected_keys = []
def _check_match(key, param):
state = state_dict.get(key, None)
if state is None:
missing_keys.append(key)
raise ValueError(
"{} is not found in the provided dict.".format(key)
)
if isinstance(state, dict) or isinstance(state, list):
if len(state) != len(param):
missing_keys.append(key)
raise ValueError(
"{} receieves the length of {}, "
"but the expected shape is {}".format(
......@@ -1631,6 +1637,7 @@ class Layer:
)
)
else:
match_keys.add(key)
return param, state
else:
state_shape = (
......@@ -1640,11 +1647,13 @@ class Layer:
)
if list(state_shape) != list(param.shape):
missing_keys.append(key)
raise ValueError(
"{} receives a shape {}, but the expected shape is {}.".format(
key, list(state_shape), list(param.shape)
)
)
match_keys.add(key)
return param, state
matched_param_state = []
......@@ -1655,7 +1664,9 @@ class Layer:
matched_param_state.append(match_res)
except ValueError as err:
warnings.warn(("Skip loading for {}. ".format(key) + str(err)))
for key in state_dict.keys():
if key not in match_keys:
unexpected_keys.append(key)
if _non_static_mode():
for param, state in matched_param_state:
param.set_value(state)
......@@ -1693,6 +1704,8 @@ class Layer:
"This error might happens in dy2static, while calling 'set_state_dict' dynamicly in 'forward', which is not supported. If you only need call 'set_state_dict' once, move it to '__init__'."
)
return missing_keys, unexpected_keys
def to(self, device=None, dtype=None, blocking=None):
'''
Cast the parameters and buffers of Layer by the give device, dtype and blocking.
......
......@@ -53,6 +53,15 @@ class MyModel(nn.Layer):
return super().set_state_dict(state_dict)
class MyModel2(nn.Layer):
def __init__(self):
super().__init__()
self.linear = nn.Linear(100, 300)
def forward(self, x):
return self.linear(x)
def is_state_dict_equal(model1, model2):
st1 = model1.state_dict()
st2 = model2.state_dict()
......@@ -73,5 +82,18 @@ class TestStateDictConvert(unittest.TestCase):
self.assertTrue(is_state_dict_equal(model1, model2))
class TestStateDictReturn(unittest.TestCase):
def test_missing_keys_and_unexpected_keys(self):
model1 = MyModel2()
tmp_dict = dict()
tmp_dict["unexpected_keys"] = paddle.to_tensor(1)
missing_keys, unexpected_keys = model1.set_state_dict(tmp_dict)
self.assertEqual(len(missing_keys), 2)
self.assertEqual(missing_keys[0], "linear.weight")
self.assertEqual(missing_keys[1], "linear.bias")
self.assertEqual(len(unexpected_keys), 1)
self.assertEqual(unexpected_keys[0], "unexpected_keys")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册