From c928a35e664ce1cec96af80daa7371783e347505 Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Tue, 29 Nov 2022 16:24:41 +0800 Subject: [PATCH] set_state_dict return missing_keys and unexpected_keys (#48436) * refine set_state_dict --- python/paddle/fluid/dygraph/layers.py | 17 ++++++++++++-- .../unittests/test_state_dict_convert.py | 22 +++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 1593cc78e6a..02b0e2bcfe1 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -1600,7 +1600,8 @@ class Layer: use_structured_name(bool, optional) : If true, use structured name as key, otherwise, use parameter or buffer name as key. Default: True Returns: - None + missing_keys(list):A list of str containing the missing keys + unexpected_keys(list):A list of str containing the unexpected keys Examples: .. code-block:: python @@ -1615,15 +1616,20 @@ class Layer: emb.set_state_dict(para_state_dict) ''' + missing_keys = [] + match_keys = set() + unexpected_keys = [] def _check_match(key, param): state = state_dict.get(key, None) if state is None: + missing_keys.append(key) raise ValueError( "{} is not found in the provided dict.".format(key) ) if isinstance(state, dict) or isinstance(state, list): if len(state) != len(param): + missing_keys.append(key) raise ValueError( "{} receieves the length of {}, " "but the expected shape is {}".format( @@ -1631,6 +1637,7 @@ class Layer: ) ) else: + match_keys.add(key) return param, state else: state_shape = ( @@ -1640,11 +1647,13 @@ class Layer: ) if list(state_shape) != list(param.shape): + missing_keys.append(key) raise ValueError( "{} receives a shape {}, but the expected shape is {}.".format( key, list(state_shape), list(param.shape) ) ) + match_keys.add(key) return param, state matched_param_state = [] @@ -1655,7 +1664,9 @@ class Layer: matched_param_state.append(match_res) except ValueError as err: warnings.warn(("Skip loading for {}. ".format(key) + str(err))) - + for key in state_dict.keys(): + if key not in match_keys: + unexpected_keys.append(key) if _non_static_mode(): for param, state in matched_param_state: param.set_value(state) @@ -1693,6 +1704,8 @@ class Layer: "This error might happens in dy2static, while calling 'set_state_dict' dynamicly in 'forward', which is not supported. If you only need call 'set_state_dict' once, move it to '__init__'." ) + return missing_keys, unexpected_keys + def to(self, device=None, dtype=None, blocking=None): ''' Cast the parameters and buffers of Layer by the give device, dtype and blocking. diff --git a/python/paddle/fluid/tests/unittests/test_state_dict_convert.py b/python/paddle/fluid/tests/unittests/test_state_dict_convert.py index f62f983e903..77a18161337 100644 --- a/python/paddle/fluid/tests/unittests/test_state_dict_convert.py +++ b/python/paddle/fluid/tests/unittests/test_state_dict_convert.py @@ -53,6 +53,15 @@ class MyModel(nn.Layer): return super().set_state_dict(state_dict) +class MyModel2(nn.Layer): + def __init__(self): + super().__init__() + self.linear = nn.Linear(100, 300) + + def forward(self, x): + return self.linear(x) + + def is_state_dict_equal(model1, model2): st1 = model1.state_dict() st2 = model2.state_dict() @@ -73,5 +82,18 @@ class TestStateDictConvert(unittest.TestCase): self.assertTrue(is_state_dict_equal(model1, model2)) +class TestStateDictReturn(unittest.TestCase): + def test_missing_keys_and_unexpected_keys(self): + model1 = MyModel2() + tmp_dict = dict() + tmp_dict["unexpected_keys"] = paddle.to_tensor(1) + missing_keys, unexpected_keys = model1.set_state_dict(tmp_dict) + self.assertEqual(len(missing_keys), 2) + self.assertEqual(missing_keys[0], "linear.weight") + self.assertEqual(missing_keys[1], "linear.bias") + self.assertEqual(len(unexpected_keys), 1) + self.assertEqual(unexpected_keys[0], "unexpected_keys") + + if __name__ == "__main__": unittest.main() -- GitLab