未验证 提交 fafadbab 编写于 作者: W WeiXin 提交者: GitHub

Rename variables when use 'jit.load' (#28933)

* Rename variables when use 'jit.load'

* Check whether the original graph contains the variable with the same name

* add comment

* rename output/input of op and edit unittest

* modify the code according to CI

* edit code according to CI

* edit code according to CI

* edit code according to CI

* edit code according to CI

* edit code according to CI

* edit code according to CI
上级 a3faa520
......@@ -149,6 +149,79 @@ def _append_loaded_suffix_to_var(program_desc):
return suffix_varname_dict
@switch_to_static_graph
def _generate_unique_var_name_sync_with_main_program(prefix):
return unique_name.generate(prefix)
def _get_loaded_var_new_old(program_desc, all_new_old_dict_all):
new_old_dict = dict()
persistable_vars = _get_persistable_vars(program_desc)
for var_desc in persistable_vars:
name_new = var_desc.name()
new_old_dict[name_new] = all_new_old_dict_all[name_new]
return new_old_dict
def _rename_var_program_desc(program_desc):
"""
Change the name of the loaded variables.Use 'unique_name.generate' to avoid duplication
e.g. x ==> x_0, x_0 ==> x_1
"""
dict_rename_var_old_new = dict()
dict_rename_var_new_old = dict()
old_names = []
for b_idx in six.moves.range(program_desc.num_blocks()):
cur_block = program_desc.block(b_idx)
for var in cur_block.all_vars():
old_names.append(var.name())
persistable_vars = _get_persistable_vars(program_desc)
for b_idx in six.moves.range(program_desc.num_blocks()):
cur_block = program_desc.block(b_idx)
for var_idx, var in enumerate(cur_block.all_vars()):
if var not in persistable_vars:
continue
name_old = var.name()
while True:
temp_name = name_old.split('_')
if len(temp_name) > 1 and temp_name[-1].isnumeric():
temp_name = "_".join(temp_name[:-1])
else:
temp_name = "_".join(temp_name)
name_new = _generate_unique_var_name_sync_with_main_program(
temp_name)
if name_new not in old_names[:var_idx] + old_names[var_idx +
1:]:
break
if name_old != name_new:
cur_block._rename_var(
cpt.to_bytes(name_old), cpt.to_bytes(name_new))
dict_rename_var_old_new[name_old] = name_new
dict_rename_var_new_old[name_new] = name_old
for b_idx in six.moves.range(program_desc.num_blocks()):
cur_block = program_desc.block(b_idx)
for op_idx in six.moves.range(cur_block.op_size()):
op = cur_block.op(op_idx)
for input_arg_name in op.input_arg_names():
if input_arg_name in dict_rename_var_old_new:
if input_arg_name != dict_rename_var_old_new[
input_arg_name]:
op._rename_input(
input_arg_name,
dict_rename_var_old_new[input_arg_name])
for output_arg_name in op.output_arg_names():
if output_arg_name in dict_rename_var_old_new:
if output_arg_name != dict_rename_var_old_new[
output_arg_name]:
op._rename_output(
output_arg_name,
dict_rename_var_old_new[output_arg_name])
program_desc.flush()
return dict_rename_var_new_old, dict_rename_var_old_new
@switch_to_static_graph
def _build_program_by_desc(program_desc):
prog = framework.Program()
......@@ -227,6 +300,8 @@ class _ProgramHolder(object):
return self._inner_scope
def _preprocess(self, program_desc):
# rename variables of 'program_desc'
rename_new_old_dict, _ = _rename_var_program_desc(program_desc)
# 1. Prune original program
# remove feed, fetch and scale-1 op, remove op_callstack attr
ops_to_remove = []
......@@ -291,7 +366,9 @@ class _ProgramHolder(object):
# and later after loading, a new linear is added. At this time,
# there will be a problem of duplicate names, so here is unified
# to add the LOADED suffix to the parameters of the model loaded
self._suffix_varname_dict = _append_loaded_suffix_to_var(program_desc)
self._suffix_varname_dict = _get_loaded_var_new_old(program_desc,
rename_new_old_dict)
# - get persistable var
self._persistable_names = _get_persistable_var_names(program_desc)
......@@ -397,8 +474,12 @@ def _load_persistable_vars_by_program(model_path,
if params_filename is not None:
load_var_list = []
for name in sorted(load_var_dict.keys()):
load_var_list.append(load_var_dict[name])
dict_name_old_new = {
v: k
for k, v in program_holder._suffix_varname_dict.items()
}
for name in sorted(dict_name_old_new.keys()):
load_var_list.append(load_var_dict[dict_name_old_new[name]])
framework._dygraph_tracer().trace_op(
type='load_combine',
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from paddle.fluid import unique_name
from paddle.fluid.dygraph.base import switch_to_static_graph
@switch_to_static_graph
def _generate_unique_var_name_sync_with_main_program(prefix):
return unique_name.generate(prefix)
def rename_var_with_generator(names_old):
dict_rename_var_old_new = dict()
names_old = list(names_old)
for var_idx, name_old in enumerate(names_old):
while True:
temp_name = name_old.split('_')
if len(temp_name) > 1 and temp_name[-1].isnumeric():
temp_name = "_".join(temp_name[:-1])
else:
temp_name = "_".join(temp_name)
name_new = _generate_unique_var_name_sync_with_main_program(
temp_name)
if name_new not in names_old[:var_idx] + names_old[var_idx + 1:]:
break
dict_rename_var_old_new[name_old] = name_new
return dict_rename_var_old_new
......@@ -23,7 +23,9 @@ import six
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid import unique_name
from test_imperative_base import new_program_scope
from jit_load_rename_var import rename_var_with_generator
LOADED_VAR_SUFFIX = ".load_0"
......@@ -128,6 +130,9 @@ class TestImperativeStaticModelRunnerMnist(unittest.TestCase):
model_filename=self.model_filename,
params_filename=self.params_filename)
suffix_varname_dict = mnist._program_holder_dict[
'forward']._suffix_varname_dict
dict_old_new = {v: k for k, v in suffix_varname_dict.items()}
dy_param_init_value = {}
for param in mnist.parameters():
dy_param_init_value[param.name] = param.numpy()
......@@ -169,7 +174,7 @@ class TestImperativeStaticModelRunnerMnist(unittest.TestCase):
for param in mnist.parameters():
dy_param_value[param.name] = param.numpy()
return dy_x_data, dy_out, dy_param_init_value, dy_param_value
return dy_x_data, dy_out, dy_param_init_value, dy_param_value, dict_old_new
def load_and_train_static(self):
with new_program_scope():
......@@ -298,7 +303,8 @@ class TestImperativeStaticModelRunnerMnist(unittest.TestCase):
self.train_and_save_model()
# Phase 2. load model & train dygraph
dy_x_data, dy_out, dy_param_init_value, dy_param_value = \
dy_x_data, dy_out, dy_param_init_value, dy_param_value, dict_old_new_init= \
self.load_and_train_dygraph()
static_x_data, static_out, static_param_init_value, static_param_value = \
......@@ -308,14 +314,14 @@ class TestImperativeStaticModelRunnerMnist(unittest.TestCase):
self.assertTrue(np.array_equal(static_x_data, dy_x_data))
for key, value in six.iteritems(static_param_init_value):
key += LOADED_VAR_SUFFIX
key = dict_old_new_init[key]
self.assertTrue(np.array_equal(value, dy_param_init_value[key]))
# np.testing.assert_array_almost_equal(static_out, dy_out)
self.assertTrue(np.allclose(static_out, dy_out, atol=1e-04))
for key, value in six.iteritems(static_param_value):
key += LOADED_VAR_SUFFIX
key = dict_old_new_init[key]
self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-4))
def test_mnist_train_with_params_filename(self):
......@@ -325,8 +331,8 @@ class TestImperativeStaticModelRunnerMnist(unittest.TestCase):
# Phase 1. run and save static model
self.train_and_save_model()
# Phase 2. load model & train dygraph
dy_x_data, dy_out, dy_param_init_value, dy_param_value = \
# Phase 2. load model & train dygraph
dy_x_data, dy_out, dy_param_init_value, dy_param_value, dict_old_new_init= \
self.load_and_train_dygraph()
static_x_data, static_out, static_param_init_value, static_param_value = \
......@@ -334,16 +340,15 @@ class TestImperativeStaticModelRunnerMnist(unittest.TestCase):
# Phase 3. compare
self.assertTrue(np.array_equal(static_x_data, dy_x_data))
for key, value in six.iteritems(static_param_init_value):
key += LOADED_VAR_SUFFIX
key = dict_old_new_init[key]
self.assertTrue(np.array_equal(value, dy_param_init_value[key]))
# np.testing.assert_array_almost_equal(static_out, dy_out)
self.assertTrue(np.allclose(static_out, dy_out, atol=1e-04))
for key, value in six.iteritems(static_param_value):
key += LOADED_VAR_SUFFIX
key = dict_old_new_init[key]
self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-4))
def test_mnist_infer_no_params_filename(self):
......
......@@ -23,7 +23,9 @@ import six
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid import unique_name
from test_imperative_base import new_program_scope
from jit_load_rename_var import rename_var_with_generator
import paddle.fluid.transpiler.details.program_utils as pu
......@@ -211,15 +213,20 @@ class TestImperativeStaticModelRunnerWhile(unittest.TestCase):
self.train_and_save_model()
# # Phase 2. load model & train dygraph
dy_out, dy_param_init_value, dy_param_value = \
with unique_name.guard():
dy_out, dy_param_init_value, dy_param_value = \
self.load_and_train_dygraph()
static_out, static_param_init_value, static_param_value = \
self.load_and_train_static()
with unique_name.guard():
static_out, static_param_init_value, static_param_value = \
self.load_and_train_static()
# Phase 3. compare
with unique_name.guard():
dict_old_new_init = rename_var_with_generator(
static_param_init_value.keys())
for key, value in six.iteritems(static_param_init_value):
key += LOADED_VAR_SUFFIX
key = dict_old_new_init[key]
self.assertTrue(np.array_equal(value, dy_param_init_value[key]))
self.assertTrue(np.allclose(static_out, dy_out))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册