From 4e99c2af0a6317cc8ec1b1ce142ba3d6284a1687 Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Sat, 28 Sep 2019 14:31:31 +0800 Subject: [PATCH] Add shape check in load model (#19936) * add parameter shape check when load parameter from file; test=develop * fix test error; test=develop * add wrong shape check; test=develop * remove useless code; test=develop * add testcase setup * add teardown, remove temp model_path; test=develop * add clean process; test=develop --- python/paddle/fluid/io.py | 25 ++++++++++ .../unittests/test_load_vars_shape_check.py | 48 +++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_load_vars_shape_check.py diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 0866320a1cd..3f8318688f3 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -21,6 +21,8 @@ import six import logging from functools import reduce +import numpy as np + import paddle import paddle.reader from paddle.reader import * @@ -649,11 +651,19 @@ def load_vars(executor, if not isinstance(main_program, Program): raise TypeError("program should be as Program type or None") + #save origin param shape + orig_para_shape = {} load_var_map = {} for each_var in vars: assert isinstance(each_var, Variable) if each_var.type == core.VarDesc.VarType.RAW: continue + + if isinstance(each_var, Parameter): + var_temp = paddle.fluid.global_scope().find_var(each_var.name) + assert var_temp != None, "can't not find var: " + each_var.name + orig_para_shape[each_var.name] = ( + np.array(var_temp.get_tensor())).shape new_var = _clone_var_in_block_(load_block, each_var) if filename is None: load_block.append_op( @@ -678,6 +688,21 @@ def load_vars(executor, attrs={'file_path': os.path.join(load_dirname, filename)}) executor.run(load_prog) + #check var shape + for each_var in vars: + if not isinstance(each_var, Parameter): + continue + var_temp = paddle.fluid.global_scope().find_var(each_var.name) + assert var_temp != None, "can't not find var: " + each_var.name + new_shape = (np.array(var_temp.get_tensor())).shape + assert each_var.name in orig_para_shape, earch_var.name + "MUST in var list" + orig_shape = orig_para_shape.get(each_var.name) + if new_shape != orig_shape: + raise RuntimeError( + "Shape not matching: the Program requires a parameter with a shape of ({}), " + "while the loaded parameter (namely [ {} ]) has a shape of ({}).". + format(orig_shape, each_var.name, new_shape)) + def load_params(executor, dirname, main_program=None, filename=None): """ diff --git a/python/paddle/fluid/tests/unittests/test_load_vars_shape_check.py b/python/paddle/fluid/tests/unittests/test_load_vars_shape_check.py new file mode 100644 index 00000000000..3e2e778d40e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_load_vars_shape_check.py @@ -0,0 +1,48 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +import shutil +import numpy as np +import paddle as paddle +import paddle.fluid as fluid +from paddle.fluid.executor import Executor + + +class TestLoadVarsShapeCheck(unittest.TestCase): + def setUp(self): + self.model_path = "./model_temp/" + + def test_shape_check_save_load(self): + program_1 = fluid.Program() + startup_program_1 = fluid.Program() + + with fluid.program_guard(program_1, startup_program_1): + input = fluid.layers.data(name="x", shape=[-1, 10], dtype='float32') + out = fluid.layers.fc(input, 20) + place = fluid.CPUPlace() + exe = Executor(place) + exe.run(startup_program_1) + + fluid.io.save_params(exe, self.model_path, main_program=program_1) + fluid.io.load_params(exe, self.model_path, main_program=program_1) + + def tearDown(self): + if os.path.exists(self.model_path): + shutil.rmtree(self.model_path) + + +if __name__ == "__main__": + unittest.main() -- GitLab