From c00f0dafa79f985a5ffd3d885e387ec57f5cf8e0 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Mon, 21 Nov 2022 07:40:18 +0800 Subject: [PATCH] add state_dict convert (#48161) --- python/paddle/fluid/dygraph/layers.py | 2 +- .../unittests/test_state_dict_convert.py | 77 +++++++++++++++++++ 2 files changed, 78 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/tests/unittests/test_state_dict_convert.py diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 752694b614a..5e15519bd96 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -1627,7 +1627,7 @@ class Layer: return param, state matched_param_state = [] - for key, param in self.state_dict(use_hook=False).items(): + for key, param in self._state_dict_impl(use_hook=False).items(): key_name = key if use_structured_name else param.name try: match_res = _check_match(key_name, param) diff --git a/python/paddle/fluid/tests/unittests/test_state_dict_convert.py b/python/paddle/fluid/tests/unittests/test_state_dict_convert.py new file mode 100644 index 00000000000..f62f983e903 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_state_dict_convert.py @@ -0,0 +1,77 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import numpy as np +import unittest + + +class MyModel(nn.Layer): + def __init__(self): + super().__init__() + self.linear = nn.Linear(100, 300) + + def forward(self, x): + return self.linear(x) + + @paddle.no_grad() + def state_dict( + self, + destination=None, + include_sublayers=True, + structured_name_prefix="", + use_hook=True, + ): + st = super().state_dict( + destination=destination, + include_sublayers=include_sublayers, + structured_name_prefix=structured_name_prefix, + use_hook=use_hook, + ) + st["linear.new_weight"] = paddle.transpose( + st.pop("linear.weight"), [1, 0] + ) + return st + + @paddle.no_grad() + def set_state_dict(self, state_dict, use_structured_name=True): + state_dict["linear.weight"] = paddle.transpose( + state_dict.pop("linear.new_weight"), [1, 0] + ) + return super().set_state_dict(state_dict) + + +def is_state_dict_equal(model1, model2): + st1 = model1.state_dict() + st2 = model2.state_dict() + assert set(st1.keys()) == set(st2.keys()) + for k, v1 in st1.items(): + v2 = st2[k] + if not np.array_equal(v1.numpy(), v2.numpy()): + return False + return True + + +class TestStateDictConvert(unittest.TestCase): + def test_main(self): + model1 = MyModel() + model2 = MyModel() + self.assertFalse(is_state_dict_equal(model1, model2)) + model2.set_state_dict(model1.state_dict()) + self.assertTrue(is_state_dict_equal(model1, model2)) + + +if __name__ == "__main__": + unittest.main() -- GitLab