From fafadbab702d06065691adcfe23cac15021515dc Mon Sep 17 00:00:00 2001 From: WeiXin Date: Wed, 25 Nov 2020 19:55:04 +0800 Subject: [PATCH] Rename variables when use 'jit.load' (#28933) * Rename variables when use 'jit.load' * Check whether the original graph contains the variable with the same name * add comment * rename output/input of op and edit unittest * modify the code according to CI * edit code according to CI * edit code according to CI * edit code according to CI * edit code according to CI * edit code according to CI * edit code according to CI --- python/paddle/fluid/dygraph/io.py | 87 ++++++++++++++++++- .../tests/unittests/jit_load_rename_var.py | 41 +++++++++ .../test_imperative_static_runner_mnist.py | 23 +++-- .../test_imperative_static_runner_while.py | 15 +++- 4 files changed, 150 insertions(+), 16 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/jit_load_rename_var.py diff --git a/python/paddle/fluid/dygraph/io.py b/python/paddle/fluid/dygraph/io.py index 8797bbcf928..05d2b0bf1e3 100644 --- a/python/paddle/fluid/dygraph/io.py +++ b/python/paddle/fluid/dygraph/io.py @@ -149,6 +149,79 @@ def _append_loaded_suffix_to_var(program_desc): return suffix_varname_dict +@switch_to_static_graph +def _generate_unique_var_name_sync_with_main_program(prefix): + return unique_name.generate(prefix) + + +def _get_loaded_var_new_old(program_desc, all_new_old_dict_all): + new_old_dict = dict() + persistable_vars = _get_persistable_vars(program_desc) + for var_desc in persistable_vars: + name_new = var_desc.name() + new_old_dict[name_new] = all_new_old_dict_all[name_new] + return new_old_dict + + +def _rename_var_program_desc(program_desc): + """ + Change the name of the loaded variables.Use 'unique_name.generate' to avoid duplication + e.g. x ==> x_0, x_0 ==> x_1 + """ + dict_rename_var_old_new = dict() + dict_rename_var_new_old = dict() + old_names = [] + for b_idx in six.moves.range(program_desc.num_blocks()): + cur_block = program_desc.block(b_idx) + for var in cur_block.all_vars(): + old_names.append(var.name()) + persistable_vars = _get_persistable_vars(program_desc) + for b_idx in six.moves.range(program_desc.num_blocks()): + cur_block = program_desc.block(b_idx) + for var_idx, var in enumerate(cur_block.all_vars()): + if var not in persistable_vars: + continue + name_old = var.name() + while True: + temp_name = name_old.split('_') + if len(temp_name) > 1 and temp_name[-1].isnumeric(): + temp_name = "_".join(temp_name[:-1]) + else: + temp_name = "_".join(temp_name) + + name_new = _generate_unique_var_name_sync_with_main_program( + temp_name) + if name_new not in old_names[:var_idx] + old_names[var_idx + + 1:]: + break + if name_old != name_new: + cur_block._rename_var( + cpt.to_bytes(name_old), cpt.to_bytes(name_new)) + dict_rename_var_old_new[name_old] = name_new + dict_rename_var_new_old[name_new] = name_old + + for b_idx in six.moves.range(program_desc.num_blocks()): + cur_block = program_desc.block(b_idx) + for op_idx in six.moves.range(cur_block.op_size()): + op = cur_block.op(op_idx) + for input_arg_name in op.input_arg_names(): + if input_arg_name in dict_rename_var_old_new: + if input_arg_name != dict_rename_var_old_new[ + input_arg_name]: + op._rename_input( + input_arg_name, + dict_rename_var_old_new[input_arg_name]) + for output_arg_name in op.output_arg_names(): + if output_arg_name in dict_rename_var_old_new: + if output_arg_name != dict_rename_var_old_new[ + output_arg_name]: + op._rename_output( + output_arg_name, + dict_rename_var_old_new[output_arg_name]) + program_desc.flush() + return dict_rename_var_new_old, dict_rename_var_old_new + + @switch_to_static_graph def _build_program_by_desc(program_desc): prog = framework.Program() @@ -227,6 +300,8 @@ class _ProgramHolder(object): return self._inner_scope def _preprocess(self, program_desc): + # rename variables of 'program_desc' + rename_new_old_dict, _ = _rename_var_program_desc(program_desc) # 1. Prune original program # remove feed, fetch and scale-1 op, remove op_callstack attr ops_to_remove = [] @@ -291,7 +366,9 @@ class _ProgramHolder(object): # and later after loading, a new linear is added. At this time, # there will be a problem of duplicate names, so here is unified # to add the LOADED suffix to the parameters of the model loaded - self._suffix_varname_dict = _append_loaded_suffix_to_var(program_desc) + self._suffix_varname_dict = _get_loaded_var_new_old(program_desc, + rename_new_old_dict) + # - get persistable var self._persistable_names = _get_persistable_var_names(program_desc) @@ -397,8 +474,12 @@ def _load_persistable_vars_by_program(model_path, if params_filename is not None: load_var_list = [] - for name in sorted(load_var_dict.keys()): - load_var_list.append(load_var_dict[name]) + dict_name_old_new = { + v: k + for k, v in program_holder._suffix_varname_dict.items() + } + for name in sorted(dict_name_old_new.keys()): + load_var_list.append(load_var_dict[dict_name_old_new[name]]) framework._dygraph_tracer().trace_op( type='load_combine', diff --git a/python/paddle/fluid/tests/unittests/jit_load_rename_var.py b/python/paddle/fluid/tests/unittests/jit_load_rename_var.py new file mode 100644 index 00000000000..9e3424bf990 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/jit_load_rename_var.py @@ -0,0 +1,41 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from paddle.fluid import unique_name +from paddle.fluid.dygraph.base import switch_to_static_graph + + +@switch_to_static_graph +def _generate_unique_var_name_sync_with_main_program(prefix): + return unique_name.generate(prefix) + + +def rename_var_with_generator(names_old): + dict_rename_var_old_new = dict() + names_old = list(names_old) + for var_idx, name_old in enumerate(names_old): + while True: + temp_name = name_old.split('_') + if len(temp_name) > 1 and temp_name[-1].isnumeric(): + temp_name = "_".join(temp_name[:-1]) + else: + temp_name = "_".join(temp_name) + name_new = _generate_unique_var_name_sync_with_main_program( + temp_name) + if name_new not in names_old[:var_idx] + names_old[var_idx + 1:]: + break + dict_rename_var_old_new[name_old] = name_new + return dict_rename_var_old_new diff --git a/python/paddle/fluid/tests/unittests/test_imperative_static_runner_mnist.py b/python/paddle/fluid/tests/unittests/test_imperative_static_runner_mnist.py index f10d2df7f06..bab2674e878 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_static_runner_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_static_runner_mnist.py @@ -23,7 +23,9 @@ import six import paddle import paddle.fluid as fluid from paddle.fluid import core +from paddle.fluid import unique_name from test_imperative_base import new_program_scope +from jit_load_rename_var import rename_var_with_generator LOADED_VAR_SUFFIX = ".load_0" @@ -128,6 +130,9 @@ class TestImperativeStaticModelRunnerMnist(unittest.TestCase): model_filename=self.model_filename, params_filename=self.params_filename) + suffix_varname_dict = mnist._program_holder_dict[ + 'forward']._suffix_varname_dict + dict_old_new = {v: k for k, v in suffix_varname_dict.items()} dy_param_init_value = {} for param in mnist.parameters(): dy_param_init_value[param.name] = param.numpy() @@ -169,7 +174,7 @@ class TestImperativeStaticModelRunnerMnist(unittest.TestCase): for param in mnist.parameters(): dy_param_value[param.name] = param.numpy() - return dy_x_data, dy_out, dy_param_init_value, dy_param_value + return dy_x_data, dy_out, dy_param_init_value, dy_param_value, dict_old_new def load_and_train_static(self): with new_program_scope(): @@ -298,7 +303,8 @@ class TestImperativeStaticModelRunnerMnist(unittest.TestCase): self.train_and_save_model() # Phase 2. load model & train dygraph - dy_x_data, dy_out, dy_param_init_value, dy_param_value = \ + + dy_x_data, dy_out, dy_param_init_value, dy_param_value, dict_old_new_init= \ self.load_and_train_dygraph() static_x_data, static_out, static_param_init_value, static_param_value = \ @@ -308,14 +314,14 @@ class TestImperativeStaticModelRunnerMnist(unittest.TestCase): self.assertTrue(np.array_equal(static_x_data, dy_x_data)) for key, value in six.iteritems(static_param_init_value): - key += LOADED_VAR_SUFFIX + key = dict_old_new_init[key] self.assertTrue(np.array_equal(value, dy_param_init_value[key])) # np.testing.assert_array_almost_equal(static_out, dy_out) self.assertTrue(np.allclose(static_out, dy_out, atol=1e-04)) for key, value in six.iteritems(static_param_value): - key += LOADED_VAR_SUFFIX + key = dict_old_new_init[key] self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-4)) def test_mnist_train_with_params_filename(self): @@ -325,8 +331,8 @@ class TestImperativeStaticModelRunnerMnist(unittest.TestCase): # Phase 1. run and save static model self.train_and_save_model() - # Phase 2. load model & train dygraph - dy_x_data, dy_out, dy_param_init_value, dy_param_value = \ + # Phase 2. load model & train dygraph + dy_x_data, dy_out, dy_param_init_value, dy_param_value, dict_old_new_init= \ self.load_and_train_dygraph() static_x_data, static_out, static_param_init_value, static_param_value = \ @@ -334,16 +340,15 @@ class TestImperativeStaticModelRunnerMnist(unittest.TestCase): # Phase 3. compare self.assertTrue(np.array_equal(static_x_data, dy_x_data)) - for key, value in six.iteritems(static_param_init_value): - key += LOADED_VAR_SUFFIX + key = dict_old_new_init[key] self.assertTrue(np.array_equal(value, dy_param_init_value[key])) # np.testing.assert_array_almost_equal(static_out, dy_out) self.assertTrue(np.allclose(static_out, dy_out, atol=1e-04)) for key, value in six.iteritems(static_param_value): - key += LOADED_VAR_SUFFIX + key = dict_old_new_init[key] self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-4)) def test_mnist_infer_no_params_filename(self): diff --git a/python/paddle/fluid/tests/unittests/test_imperative_static_runner_while.py b/python/paddle/fluid/tests/unittests/test_imperative_static_runner_while.py index db47170c7bf..841df6d0896 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_static_runner_while.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_static_runner_while.py @@ -23,7 +23,9 @@ import six import paddle import paddle.fluid as fluid from paddle.fluid import core +from paddle.fluid import unique_name from test_imperative_base import new_program_scope +from jit_load_rename_var import rename_var_with_generator import paddle.fluid.transpiler.details.program_utils as pu @@ -211,15 +213,20 @@ class TestImperativeStaticModelRunnerWhile(unittest.TestCase): self.train_and_save_model() # # Phase 2. load model & train dygraph - dy_out, dy_param_init_value, dy_param_value = \ + with unique_name.guard(): + dy_out, dy_param_init_value, dy_param_value = \ self.load_and_train_dygraph() - static_out, static_param_init_value, static_param_value = \ - self.load_and_train_static() + with unique_name.guard(): + static_out, static_param_init_value, static_param_value = \ + self.load_and_train_static() # Phase 3. compare + with unique_name.guard(): + dict_old_new_init = rename_var_with_generator( + static_param_init_value.keys()) for key, value in six.iteritems(static_param_init_value): - key += LOADED_VAR_SUFFIX + key = dict_old_new_init[key] self.assertTrue(np.array_equal(value, dy_param_init_value[key])) self.assertTrue(np.allclose(static_out, dy_out)) -- GitLab