未验证 提交 7618cbdc 编写于 作者: W WangZhen 提交者: GitHub

[Dy2St]Fix abnormal growth of memory in train mode and no_grad for Dy2St (#47398) (#47414)

* [Dy2St]Fix abnormal growth of memory in train mode and no_grad for Dy2St 
上级 c42929c5
...@@ -394,7 +394,7 @@ inline void RunProgramAPI( ...@@ -394,7 +394,7 @@ inline void RunProgramAPI(
} }
VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(out_scope_vec->front()); VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(out_scope_vec->front());
if (is_test) { if (is_test || !egr::Controller::Instance().HasGrad()) {
VLOG(4) << "is test, set this scope can reused"; VLOG(4) << "is test, set this scope can reused";
global_inner_scope->SetCanReuesd(true); global_inner_scope->SetCanReuesd(true);
details::GcScope(global_inner_scope); details::GcScope(global_inner_scope);
...@@ -470,7 +470,7 @@ inline void RunProgramAPI( ...@@ -470,7 +470,7 @@ inline void RunProgramAPI(
// Debug info: scope info when run end // Debug info: scope info when run end
VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(out_scope_vec->front()); VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(out_scope_vec->front());
// Step 5. Drop all children scopes while testing. // Step 5. Drop all children scopes while testing.
if (is_test) { if (is_test || !egr::Controller::Instance().HasGrad()) {
out_scope_vec->front()->DropKids(); out_scope_vec->front()->DropKids();
} }
VLOG(2) << "The number of sub scopes after forward: " VLOG(2) << "The number of sub scopes after forward: "
......
...@@ -30,8 +30,14 @@ from paddle.fluid.layers import nn ...@@ -30,8 +30,14 @@ from paddle.fluid.layers import nn
from paddle.fluid.layers.utils import _hash_with_id from paddle.fluid.layers.utils import _hash_with_id
from paddle.fluid.dygraph.base import switch_to_static_graph from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.framework import _non_static_mode from paddle.fluid.framework import _non_static_mode
from paddle.fluid.executor import _is_enable_standalone_executor, _is_dy2st_enable_standalone_executor from paddle.fluid.executor import (
from paddle.fluid.dygraph.dygraph_to_static.partial_program import add_build_strategy_for, LazyInitialized _is_enable_standalone_executor,
_is_dy2st_enable_standalone_executor,
)
from paddle.fluid.dygraph.dygraph_to_static.partial_program import (
add_build_strategy_for,
LazyInitialized,
)
from paddle import _C_ops, _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
__all__ = ['TranslatedLayer'] __all__ = ['TranslatedLayer']
...@@ -53,17 +59,20 @@ def _load_program_desc(model_file_path): ...@@ -53,17 +59,20 @@ def _load_program_desc(model_file_path):
program_desc = core.ProgramDesc(program_desc_str) program_desc = core.ProgramDesc(program_desc_str)
if not core._is_program_version_supported(program_desc._version()): if not core._is_program_version_supported(program_desc._version()):
raise ValueError("Unsupported program version: %d\n" % raise ValueError(
program_desc._version()) "Unsupported program version: %d\n" % program_desc._version()
)
return program_desc return program_desc
def _is_persistable(var_desc): def _is_persistable(var_desc):
if var_desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ if (
var_desc.type() == core.VarDesc.VarType.FETCH_LIST or \ var_desc.type() == core.VarDesc.VarType.FEED_MINIBATCH
var_desc.type() == core.VarDesc.VarType.READER or \ or var_desc.type() == core.VarDesc.VarType.FETCH_LIST
var_desc.type() == core.VarDesc.VarType.RAW: or var_desc.type() == core.VarDesc.VarType.READER
or var_desc.type() == core.VarDesc.VarType.RAW
):
return False return False
return var_desc.persistable() return var_desc.persistable()
...@@ -208,9 +217,11 @@ def _rename_var_program_desc(program_desc, include=None, exclude=None): ...@@ -208,9 +217,11 @@ def _rename_var_program_desc(program_desc, include=None, exclude=None):
name_old = var.name() name_old = var.name()
is_double_grad_var = "@GRAD" in name_old is_double_grad_var = "@GRAD" in name_old
has_double_grad = has_double_grad or is_double_grad_var has_double_grad = has_double_grad or is_double_grad_var
should_rename = (include is None or name_old in include) and ( should_rename = (
exclude is None (include is None or name_old in include)
or name_old not in exclude) and not is_double_grad_var and (exclude is None or name_old not in exclude)
and not is_double_grad_var
)
if should_rename: if should_rename:
temp_name = name_old.split('_') temp_name = name_old.split('_')
if len(temp_name) > 1 and temp_name[-1].isnumeric(): if len(temp_name) > 1 and temp_name[-1].isnumeric():
...@@ -219,15 +230,19 @@ def _rename_var_program_desc(program_desc, include=None, exclude=None): ...@@ -219,15 +230,19 @@ def _rename_var_program_desc(program_desc, include=None, exclude=None):
temp_name = name_old temp_name = name_old
while True: while True:
name_new = _generate_unique_var_name_sync_with_main_program( name_new = _generate_unique_var_name_sync_with_main_program(
temp_name) temp_name
if name_new not in old_names[:var_idx] + old_names[var_idx + )
1:]: if (
name_new
not in old_names[:var_idx] + old_names[var_idx + 1 :]
):
break break
else: else:
name_new = name_old name_new = name_old
if name_old != name_new: if name_old != name_new:
cur_block._rename_var(cpt.to_bytes(name_old), cur_block._rename_var(
cpt.to_bytes(name_new)) cpt.to_bytes(name_old), cpt.to_bytes(name_new)
)
if not is_double_grad_var: if not is_double_grad_var:
dict_rename_var_old_new[name_old] = name_new dict_rename_var_old_new[name_old] = name_new
dict_rename_var_new_old[name_new] = name_old dict_rename_var_new_old[name_new] = name_old
...@@ -242,13 +257,16 @@ def _rename_var_program_desc(program_desc, include=None, exclude=None): ...@@ -242,13 +257,16 @@ def _rename_var_program_desc(program_desc, include=None, exclude=None):
var_name = var.name() var_name = var.name()
if "@GRAD" in var_name and name_old in var_name: if "@GRAD" in var_name and name_old in var_name:
new_var_name = var_name.replace( new_var_name = var_name.replace(
name_old, dict_rename_var_old_new[name_old]) name_old, dict_rename_var_old_new[name_old]
)
double_grad_rename_dict[var_name] = new_var_name double_grad_rename_dict[var_name] = new_var_name
for var_name in double_grad_rename_dict: for var_name in double_grad_rename_dict:
dict_rename_var_old_new[var_name] = double_grad_rename_dict[ dict_rename_var_old_new[var_name] = double_grad_rename_dict[
var_name] var_name
]
dict_rename_var_new_old[ dict_rename_var_new_old[
double_grad_rename_dict[var_name]] = var_name double_grad_rename_dict[var_name]
] = var_name
# Rename on program desc # Rename on program desc
for b_idx in six.moves.range(program_desc.num_blocks()): for b_idx in six.moves.range(program_desc.num_blocks()):
...@@ -257,27 +275,38 @@ def _rename_var_program_desc(program_desc, include=None, exclude=None): ...@@ -257,27 +275,38 @@ def _rename_var_program_desc(program_desc, include=None, exclude=None):
op = cur_block.op(op_idx) op = cur_block.op(op_idx)
for input_arg_name in op.input_arg_names(): for input_arg_name in op.input_arg_names():
if input_arg_name in dict_rename_var_old_new: if input_arg_name in dict_rename_var_old_new:
if input_arg_name != dict_rename_var_old_new[input_arg_name]: if (
input_arg_name
!= dict_rename_var_old_new[input_arg_name]
):
op._rename_input( op._rename_input(
input_arg_name, input_arg_name,
dict_rename_var_old_new[input_arg_name]) dict_rename_var_old_new[input_arg_name],
)
if cur_block.has_var(cpt.to_bytes(input_arg_name)): if cur_block.has_var(cpt.to_bytes(input_arg_name)):
cur_block._rename_var( cur_block._rename_var(
cpt.to_bytes(input_arg_name), cpt.to_bytes(input_arg_name),
cpt.to_bytes( cpt.to_bytes(
dict_rename_var_old_new[input_arg_name])) dict_rename_var_old_new[input_arg_name]
),
)
for output_arg_name in op.output_arg_names(): for output_arg_name in op.output_arg_names():
if output_arg_name in dict_rename_var_old_new: if output_arg_name in dict_rename_var_old_new:
if output_arg_name != dict_rename_var_old_new[ if (
output_arg_name]: output_arg_name
!= dict_rename_var_old_new[output_arg_name]
):
op._rename_output( op._rename_output(
output_arg_name, output_arg_name,
dict_rename_var_old_new[output_arg_name]) dict_rename_var_old_new[output_arg_name],
)
if cur_block.has_var(cpt.to_bytes(output_arg_name)): if cur_block.has_var(cpt.to_bytes(output_arg_name)):
cur_block._rename_var( cur_block._rename_var(
cpt.to_bytes(output_arg_name), cpt.to_bytes(output_arg_name),
cpt.to_bytes( cpt.to_bytes(
dict_rename_var_old_new[output_arg_name])) dict_rename_var_old_new[output_arg_name]
),
)
program_desc.flush() program_desc.flush()
return dict_rename_var_new_old, dict_rename_var_old_new return dict_rename_var_new_old, dict_rename_var_old_new
...@@ -308,8 +337,8 @@ class _ProgramHolder(object): ...@@ -308,8 +337,8 @@ class _ProgramHolder(object):
""" """
Holds the execution information of a Program. Holds the execution information of a Program.
_ProgramHolder is the execution unit of TranslatedLayer, _ProgramHolder is the execution unit of TranslatedLayer,
if TranslatedLayer contains multiple _ProgramHolder, if TranslatedLayer contains multiple _ProgramHolder,
it can execute multiple methods it can execute multiple methods
_ProgramHolder is an internal concept. _ProgramHolder is an internal concept.
...@@ -333,7 +362,8 @@ class _ProgramHolder(object): ...@@ -333,7 +362,8 @@ class _ProgramHolder(object):
self._infer_program_desc = self._preprocess(program_desc) self._infer_program_desc = self._preprocess(program_desc)
# forward + backward program # forward + backward program
self._train_program_desc = self._append_backward_desc( self._train_program_desc = self._append_backward_desc(
self._infer_program_desc) self._infer_program_desc
)
# forward: # forward:
@switch_to_static_graph @switch_to_static_graph
...@@ -354,11 +384,13 @@ class _ProgramHolder(object): ...@@ -354,11 +384,13 @@ class _ProgramHolder(object):
def _create_backward_train_program(self): def _create_backward_train_program(self):
whole_program = _build_program_by_desc(self._train_program_desc) whole_program = _build_program_by_desc(self._train_program_desc)
start_op_index = self._infer_program_desc.block(0).op_size() + 2 * len( start_op_index = self._infer_program_desc.block(0).op_size() + 2 * len(
self._output_descs) self._output_descs
)
end_op_index = whole_program.desc.block(0).op_size() end_op_index = whole_program.desc.block(0).op_size()
if (start_op_index < end_op_index): if start_op_index < end_op_index:
return add_build_strategy_for(whole_program, start_op_index, return add_build_strategy_for(
end_op_index) whole_program, start_op_index, end_op_index
)
else: else:
return paddle.static.Program() return paddle.static.Program()
...@@ -406,7 +438,8 @@ class _ProgramHolder(object): ...@@ -406,7 +438,8 @@ class _ProgramHolder(object):
# rename persistable variables of 'program_desc' # rename persistable variables of 'program_desc'
list_persistable_var = _get_persistable_var_names(program_desc) list_persistable_var = _get_persistable_var_names(program_desc)
rename_new_old_dict, _ = _rename_var_program_desc( rename_new_old_dict, _ = _rename_var_program_desc(
program_desc, list_persistable_var) program_desc, list_persistable_var
)
# 1. Prune original program # 1. Prune original program
# remove feed, fetch and scale-1 op, remove op_callstack attr # remove feed, fetch and scale-1 op, remove op_callstack attr
ops_to_remove = [] ops_to_remove = []
...@@ -418,14 +451,17 @@ class _ProgramHolder(object): ...@@ -418,14 +451,17 @@ class _ProgramHolder(object):
feed_var_name = cpt.to_bytes(op.input('X')[0]) feed_var_name = cpt.to_bytes(op.input('X')[0])
root_block._remove_var(feed_var_name) root_block._remove_var(feed_var_name)
self._input_descs.append( self._input_descs.append(
root_block.find_var(cpt.to_bytes(op.output('Out')[0]))) root_block.find_var(cpt.to_bytes(op.output('Out')[0]))
)
elif op.type() == 'scale' and op.output('Out')[0].startswith( elif op.type() == 'scale' and op.output('Out')[0].startswith(
'save_infer_model/scale_'): 'save_infer_model/scale_'
):
ops_to_remove.append(i) ops_to_remove.append(i)
out_var_name = cpt.to_bytes(op.output('Out')[0]) out_var_name = cpt.to_bytes(op.output('Out')[0])
root_block._remove_var(out_var_name) root_block._remove_var(out_var_name)
self._output_descs.append( self._output_descs.append(
root_block.find_var(cpt.to_bytes(op.input('X')[0]))) root_block.find_var(cpt.to_bytes(op.input('X')[0]))
)
elif op.type() == 'fetch': elif op.type() == 'fetch':
ops_to_remove.append(i) ops_to_remove.append(i)
fetch_var_name = cpt.to_bytes(op.output('Out')[0]) fetch_var_name = cpt.to_bytes(op.output('Out')[0])
...@@ -433,7 +469,8 @@ class _ProgramHolder(object): ...@@ -433,7 +469,8 @@ class _ProgramHolder(object):
# NOTE: some old pre-train models have no extra scale_op # NOTE: some old pre-train models have no extra scale_op
if not op.input('X')[0].startswith('save_infer_model/scale_'): if not op.input('X')[0].startswith('save_infer_model/scale_'):
self._output_descs.append( self._output_descs.append(
root_block.find_var(cpt.to_bytes(op.input('X')[0]))) root_block.find_var(cpt.to_bytes(op.input('X')[0]))
)
else: else:
if op.has_attr("op_callstack"): if op.has_attr("op_callstack"):
op.remove_attr("op_callstack") op.remove_attr("op_callstack")
...@@ -478,7 +515,8 @@ class _ProgramHolder(object): ...@@ -478,7 +515,8 @@ class _ProgramHolder(object):
# there will be a problem of duplicate names, so here is unified # there will be a problem of duplicate names, so here is unified
# to add the LOADED suffix to the parameters of the model loaded # to add the LOADED suffix to the parameters of the model loaded
self._suffix_varname_dict = _get_loaded_var_new_old( self._suffix_varname_dict = _get_loaded_var_new_old(
program_desc, rename_new_old_dict) program_desc, rename_new_old_dict
)
# - get persistable var # - get persistable var
self._persistable_names = _get_persistable_var_names(program_desc) self._persistable_names = _get_persistable_var_names(program_desc)
...@@ -492,9 +530,9 @@ class _ProgramHolder(object): ...@@ -492,9 +530,9 @@ class _ProgramHolder(object):
with framework.program_guard(program): with framework.program_guard(program):
for i, out in enumerate(self._output_descs): for i, out in enumerate(self._output_descs):
var = program.global_block().var(out.name()) var = program.global_block().var(out.name())
var = nn.scale(var, var = nn.scale(
1., var, 1.0, name="translated_layer/scale_{}".format(i)
name="translated_layer/scale_{}".format(i)) )
scale_output_vars.append(var) scale_output_vars.append(var)
# 2. update output names & descs # 2. update output names & descs
for i, var in enumerate(scale_output_vars): for i, var in enumerate(scale_output_vars):
...@@ -519,15 +557,19 @@ class _ProgramHolder(object): ...@@ -519,15 +557,19 @@ class _ProgramHolder(object):
block = program.block(block_idx) block = program.block(block_idx)
for op in block.ops: for op in block.ops:
if op.type == "batch_norm": if op.type == "batch_norm":
if "ReserveSpace" not in op.output_names or len( if (
op.output("ReserveSpace")) == 0: "ReserveSpace" not in op.output_names
or len(op.output("ReserveSpace")) == 0
):
reserve_space = block.create_var( reserve_space = block.create_var(
name=unique_name.generate_with_ignorable_key( name=unique_name.generate_with_ignorable_key(
".".join(["reserve_space", 'tmp'])), ".".join(["reserve_space", 'tmp'])
),
dtype=block.var(op.input("X")[0]).dtype, dtype=block.var(op.input("X")[0]).dtype,
type=core.VarDesc.VarType.LOD_TENSOR, type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False, persistable=False,
stop_gradient=True) stop_gradient=True,
)
op.desc.set_output("ReserveSpace", [reserve_space.name]) op.desc.set_output("ReserveSpace", [reserve_space.name])
return program return program
...@@ -573,9 +615,9 @@ class _ProgramHolder(object): ...@@ -573,9 +615,9 @@ class _ProgramHolder(object):
# NOTE: [compatible] deal with model saved by save_inference_model, # NOTE: [compatible] deal with model saved by save_inference_model,
# which need get var info from program desc # which need get var info from program desc
def _load_persistable_vars_by_program(model_path, def _load_persistable_vars_by_program(
program_holder, model_path, program_holder, params_filename=None
params_filename=None): ):
# make sure the path has been checked # make sure the path has been checked
persistable_vars = _get_persistable_vars(program_holder.infer_program) persistable_vars = _get_persistable_vars(program_holder.infer_program)
load_var_dict = {} load_var_dict = {}
...@@ -584,37 +626,43 @@ def _load_persistable_vars_by_program(model_path, ...@@ -584,37 +626,43 @@ def _load_persistable_vars_by_program(model_path,
if _is_parameter(each_var, program_holder.infer_program): if _is_parameter(each_var, program_holder.infer_program):
# create output varbase # create output varbase
if framework._in_eager_without_dygraph_check(): if framework._in_eager_without_dygraph_check():
new_var = framework.EagerParamBase(shape=each_var.shape(), new_var = framework.EagerParamBase(
dtype=each_var.dtype(), shape=each_var.shape(),
name=each_var.name(), dtype=each_var.dtype(),
type=each_var.type(), name=each_var.name(),
persistable=True) type=each_var.type(),
persistable=True,
)
else: else:
new_var = framework.ParamBase(shape=each_var.shape(), new_var = framework.ParamBase(
dtype=each_var.dtype(), shape=each_var.shape(),
name=each_var.name(), dtype=each_var.dtype(),
type=each_var.type(), name=each_var.name(),
persistable=True) type=each_var.type(),
persistable=True,
)
else: else:
new_var = framework._varbase_creator(type=each_var.type(), new_var = framework._varbase_creator(
name=each_var.name(), type=each_var.type(),
shape=each_var.shape(), name=each_var.name(),
dtype=each_var.dtype(), shape=each_var.shape(),
persistable=True) dtype=each_var.dtype(),
persistable=True,
)
if params_filename is None: if params_filename is None:
framework._dygraph_tracer().trace_op( framework._dygraph_tracer().trace_op(
type='load', type='load',
inputs={}, inputs={},
outputs={'Out': new_var}, outputs={'Out': new_var},
attrs={'file_path': os.path.join(model_path, orig_each_name)}) attrs={'file_path': os.path.join(model_path, orig_each_name)},
)
new_var.stop_gradient = False new_var.stop_gradient = False
load_var_dict[each_var.name()] = new_var load_var_dict[each_var.name()] = new_var
if params_filename is not None: if params_filename is not None:
load_var_list = [] load_var_list = []
dict_name_old_new = { dict_name_old_new = {
v: k v: k for k, v in program_holder._suffix_varname_dict.items()
for k, v in program_holder._suffix_varname_dict.items()
} }
for name in sorted(dict_name_old_new.keys()): for name in sorted(dict_name_old_new.keys()):
load_var_list.append(load_var_dict[dict_name_old_new[name]]) load_var_list.append(load_var_dict[dict_name_old_new[name]])
...@@ -623,7 +671,8 @@ def _load_persistable_vars_by_program(model_path, ...@@ -623,7 +671,8 @@ def _load_persistable_vars_by_program(model_path,
type='load_combine', type='load_combine',
inputs={}, inputs={},
outputs={'Out': load_var_list}, outputs={'Out': load_var_list},
attrs={'file_path': os.path.join(model_path, params_filename)}) attrs={'file_path': os.path.join(model_path, params_filename)},
)
for each_var in persistable_vars: for each_var in persistable_vars:
if not _is_parameter(each_var, program_holder.infer_program): if not _is_parameter(each_var, program_holder.infer_program):
...@@ -645,8 +694,9 @@ def _load_persistable_vars_by_program(model_path, ...@@ -645,8 +694,9 @@ def _load_persistable_vars_by_program(model_path,
return load_var_dict return load_var_dict
def _load_persistable_vars(model_path, var_info_path, program_holder, def _load_persistable_vars(
params_filename): model_path, var_info_path, program_holder, params_filename
):
# 1. load extra var info # 1. load extra var info
with open(var_info_path, 'rb') as f: with open(var_info_path, 'rb') as f:
extra_var_info = pickle.load(f) extra_var_info = pickle.load(f)
...@@ -655,8 +705,7 @@ def _load_persistable_vars(model_path, var_info_path, program_holder, ...@@ -655,8 +705,7 @@ def _load_persistable_vars(model_path, var_info_path, program_holder,
load_var_dict = dict() load_var_dict = dict()
load_var_list = [] load_var_list = []
inv_suffix_varname_dict = { inv_suffix_varname_dict = {
value: key value: key for key, value in program_holder._suffix_varname_dict.items()
for key, value in program_holder._suffix_varname_dict.items()
} }
# NOTE(chenweihang): we need load persistable vars based the program, # NOTE(chenweihang): we need load persistable vars based the program,
...@@ -667,7 +716,8 @@ def _load_persistable_vars(model_path, var_info_path, program_holder, ...@@ -667,7 +716,8 @@ def _load_persistable_vars(model_path, var_info_path, program_holder,
raise RuntimeError( raise RuntimeError(
"The model to be loaded is not complete." "The model to be loaded is not complete."
"The variable `%s` of program cannot be found in loaded model.", "The variable `%s` of program cannot be found in loaded model.",
name) name,
)
# get suffix var name, see [why need to append suffix to persistable vars] # get suffix var name, see [why need to append suffix to persistable vars]
new_name = inv_suffix_varname_dict[name] new_name = inv_suffix_varname_dict[name]
# create output varbase # create output varbase
...@@ -680,7 +730,8 @@ def _load_persistable_vars(model_path, var_info_path, program_holder, ...@@ -680,7 +730,8 @@ def _load_persistable_vars(model_path, var_info_path, program_holder,
], # only to pass check, this shape is not meaningful ], # only to pass check, this shape is not meaningful
dtype=core.VarDesc.VarType.FP32, dtype=core.VarDesc.VarType.FP32,
name=new_name, name=new_name,
persistable=True) persistable=True,
)
else: else:
new_var = framework.ParamBase( new_var = framework.ParamBase(
shape=[ shape=[
...@@ -688,10 +739,12 @@ def _load_persistable_vars(model_path, var_info_path, program_holder, ...@@ -688,10 +739,12 @@ def _load_persistable_vars(model_path, var_info_path, program_holder,
], # only to pass check, this shape is not meaningful ], # only to pass check, this shape is not meaningful
dtype=core.VarDesc.VarType.FP32, dtype=core.VarDesc.VarType.FP32,
name=new_name, name=new_name,
persistable=True) persistable=True,
)
else: else:
new_var = framework._varbase_creator(name=new_name, new_var = framework._varbase_creator(
persistable=True) name=new_name, persistable=True
)
new_var.stop_gradient = extra_var_info[name]['stop_gradient'] new_var.stop_gradient = extra_var_info[name]['stop_gradient']
load_var_dict[new_name] = new_var load_var_dict[new_name] = new_var
...@@ -704,10 +757,12 @@ def _load_persistable_vars(model_path, var_info_path, program_holder, ...@@ -704,10 +757,12 @@ def _load_persistable_vars(model_path, var_info_path, program_holder,
if len(extra_var_info) != 0: if len(extra_var_info) != 0:
raise ValueError("The model to be loaded is incomplete.") raise ValueError("The model to be loaded is incomplete.")
else: else:
framework._dygraph_tracer().trace_op(type='load_combine', framework._dygraph_tracer().trace_op(
inputs={}, type='load_combine',
outputs={'Out': load_var_list}, inputs={},
attrs={'file_path': var_file_path}) outputs={'Out': load_var_list},
attrs={'file_path': var_file_path},
)
return load_var_dict return load_var_dict
...@@ -729,17 +784,18 @@ def _construct_program_holders(model_path, model_filename=None): ...@@ -729,17 +784,18 @@ def _construct_program_holders(model_path, model_filename=None):
# [compatible] if assign model_filename, only can load one program as Layer.forward # [compatible] if assign model_filename, only can load one program as Layer.forward
model_filename = os.path.basename(model_filename) model_filename = os.path.basename(model_filename)
model_file_path = os.path.join(model_path, model_filename) model_file_path = os.path.join(model_path, model_filename)
model_name = model_filename[:-len(INFER_MODEL_SUFFIX)] model_name = model_filename[: -len(INFER_MODEL_SUFFIX)]
#Load every file that meets the requirements in the directory model_path. # Load every file that meets the requirements in the directory model_path.
for filename in os.listdir(model_path): for filename in os.listdir(model_path):
if model_filename == filename: if model_filename == filename:
func_name = 'forward' func_name = 'forward'
model_file_path = os.path.join(model_path, model_filename) model_file_path = os.path.join(model_path, model_filename)
elif filename.endswith(INFER_MODEL_SUFFIX) and filename.startswith( elif filename.endswith(INFER_MODEL_SUFFIX) and filename.startswith(
model_name): model_name
parsing_names = filename[len(model_name ):
):-len(INFER_MODEL_SUFFIX) + parsing_names = filename[
1].split('.') len(model_name) : -len(INFER_MODEL_SUFFIX) + 1
].split('.')
if len(parsing_names) == 3 and len(parsing_names[1]) > 0: if len(parsing_names) == 3 and len(parsing_names[1]) > 0:
func_name = parsing_names[1] func_name = parsing_names[1]
model_file_path = os.path.join(model_path, filename) model_file_path = os.path.join(model_path, filename)
...@@ -748,7 +804,8 @@ def _construct_program_holders(model_path, model_filename=None): ...@@ -748,7 +804,8 @@ def _construct_program_holders(model_path, model_filename=None):
else: else:
continue continue
program_holder_dict[func_name] = _ProgramHolder( program_holder_dict[func_name] = _ProgramHolder(
_load_program_desc(model_file_path)) _load_program_desc(model_file_path)
)
else: else:
for _, _, file_names in os.walk(model_path): for _, _, file_names in os.walk(model_path):
for name in file_names: for name in file_names:
...@@ -760,30 +817,32 @@ def _construct_program_holders(model_path, model_filename=None): ...@@ -760,30 +817,32 @@ def _construct_program_holders(model_path, model_filename=None):
else: else:
method_name.replace('model', '') method_name.replace('model', '')
program_holder_dict[method_name] = _ProgramHolder( program_holder_dict[method_name] = _ProgramHolder(
_load_program_desc(model_file_path)) _load_program_desc(model_file_path)
)
return program_holder_dict return program_holder_dict
def _construct_params_and_buffers(model_path, def _construct_params_and_buffers(
programs, model_path, programs, params_filename=None, append_suffix=True
params_filename=None, ):
append_suffix=True):
var_info_filename = str(params_filename) + ".info" var_info_filename = str(params_filename) + ".info"
var_info_path = os.path.join(model_path, var_info_filename) var_info_path = os.path.join(model_path, var_info_filename)
params_path = os.path.join(model_path, str(params_filename)) params_path = os.path.join(model_path, str(params_filename))
if os.path.exists(var_info_path): if os.path.exists(var_info_path):
var_dict = _load_persistable_vars(model_path, var_info_path, var_dict = _load_persistable_vars(
programs['forward'], params_filename) model_path, var_info_path, programs['forward'], params_filename
model_name = params_filename[:-len(INFER_PARAMS_SUFFIX)] )
#Load every file that meets the requirements in the directory model_path. model_name = params_filename[: -len(INFER_PARAMS_SUFFIX)]
# Load every file that meets the requirements in the directory model_path.
for file_name in os.listdir(model_path): for file_name in os.listdir(model_path):
if file_name.startswith(model_name) and file_name.endswith( if file_name.startswith(model_name) and file_name.endswith(
INFER_PARAMS_SUFFIX): INFER_PARAMS_SUFFIX
parsing_names = file_name[len(model_name ):
):-len(INFER_PARAMS_SUFFIX) + parsing_names = file_name[
1].split('.') len(model_name) : -len(INFER_PARAMS_SUFFIX) + 1
].split('.')
if len(parsing_names) == 3 and len(parsing_names[1]) > 0: if len(parsing_names) == 3 and len(parsing_names[1]) > 0:
func_name = parsing_names[1] func_name = parsing_names[1]
else: else:
...@@ -792,15 +851,17 @@ def _construct_params_and_buffers(model_path, ...@@ -792,15 +851,17 @@ def _construct_params_and_buffers(model_path,
continue continue
var_info_path = os.path.join(model_path, var_info_filename) var_info_path = os.path.join(model_path, var_info_filename)
var_dict.update( var_dict.update(
_load_persistable_vars(model_path, var_info_path, _load_persistable_vars(
programs[func_name], file_name)) model_path, var_info_path, programs[func_name], file_name
)
)
elif params_filename is not None and not os.path.exists(params_path): elif params_filename is not None and not os.path.exists(params_path):
# When saving XX, there is only '*.pdmodel' # When saving XX, there is only '*.pdmodel'
return dict() return dict()
else: else:
var_dict = _load_persistable_vars_by_program(model_path, var_dict = _load_persistable_vars_by_program(
programs['forward'], model_path, programs['forward'], params_filename
params_filename) )
if not append_suffix: if not append_suffix:
var_dict = _remove_varname_suffix(var_dict, programs['forward']) var_dict = _remove_varname_suffix(var_dict, programs['forward'])
...@@ -813,13 +874,23 @@ def _valid_vars(vars): ...@@ -813,13 +874,23 @@ def _valid_vars(vars):
return vars return vars
if framework._in_eager_without_dygraph_check(): if framework._in_eager_without_dygraph_check():
return [ return [
core.eager.Tensor(core.VarDesc.VarType.FP32, [], "Fake_var", core.eager.Tensor(
core.VarDesc.VarType.RAW, False) core.VarDesc.VarType.FP32,
[],
"Fake_var",
core.VarDesc.VarType.RAW,
False,
)
] ]
else: else:
return [ return [
core.VarBase(core.VarDesc.VarType.FP32, [], "Fake_var", core.VarBase(
core.VarDesc.VarType.RAW, False) core.VarDesc.VarType.FP32,
[],
"Fake_var",
core.VarDesc.VarType.RAW,
False,
)
] ]
...@@ -831,7 +902,8 @@ def _run_dygraph(instance, input, program_holder): ...@@ -831,7 +902,8 @@ def _run_dygraph(instance, input, program_holder):
if not isinstance(value, (np.ndarray, core.VarBase, core.eager.Tensor)): if not isinstance(value, (np.ndarray, core.VarBase, core.eager.Tensor)):
raise TypeError( raise TypeError(
"The type of input in TranslatedLayer must be numpy array or Variable(VarBase), but received %s." "The type of input in TranslatedLayer must be numpy array or Variable(VarBase), but received %s."
% type(value)) % type(value)
)
# NOTE: In order to unify the API, firstly convert the input to VarBase # NOTE: In order to unify the API, firstly convert the input to VarBase
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
if framework._in_eager_without_dygraph_check(): if framework._in_eager_without_dygraph_check():
...@@ -840,13 +912,16 @@ def _run_dygraph(instance, input, program_holder): ...@@ -840,13 +912,16 @@ def _run_dygraph(instance, input, program_holder):
name=program_holder.input_descs[i].name(), name=program_holder.input_descs[i].name(),
persistable=False, persistable=False,
place=framework._current_expected_place(), place=framework._current_expected_place(),
zero_copy=True) zero_copy=True,
)
else: else:
var = core.VarBase(value=value, var = core.VarBase(
name=program_holder.input_descs[i].name(), value=value,
persistable=False, name=program_holder.input_descs[i].name(),
place=framework._current_expected_place(), persistable=False,
zero_copy=True) place=framework._current_expected_place(),
zero_copy=True,
)
else: else:
var = value var = value
# NOTE: we changed var name here, # NOTE: we changed var name here,
...@@ -868,67 +943,112 @@ def _run_dygraph(instance, input, program_holder): ...@@ -868,67 +943,112 @@ def _run_dygraph(instance, input, program_holder):
else: else:
raise ValueError( raise ValueError(
"The persistable variable %s does not exist in current TranslatedLayer." "The persistable variable %s does not exist in current TranslatedLayer."
% var_name) % var_name
)
output_vars = [] output_vars = []
for var_desc in program_holder.output_descs: for var_desc in program_holder.output_descs:
if framework._in_eager_without_dygraph_check(): if framework._in_eager_without_dygraph_check():
var = core.eager.Tensor(dtype=var_desc.dtype(), var = core.eager.Tensor(
dims=var_desc.shape(), dtype=var_desc.dtype(),
name=var_desc.name(), dims=var_desc.shape(),
type=var_desc.type(), name=var_desc.name(),
persistable=False) type=var_desc.type(),
persistable=False,
)
else: else:
var = core.VarBase(var_desc.dtype(), var_desc.shape(), var = core.VarBase(
var_desc.name(), var_desc.type(), False) var_desc.dtype(),
var_desc.shape(),
var_desc.name(),
var_desc.type(),
False,
)
output_vars.append(var) output_vars.append(var)
# hold forward variables # hold forward variables
if framework._in_eager_without_dygraph_check(): if framework._in_eager_without_dygraph_check():
tmp_scope_vec = [program_holder.scope] tmp_scope_vec = [program_holder.scope]
else: else:
tmp_scope_vec = core.VarBase(core.VarDesc.VarType.FP32, [], tmp_scope_vec = core.VarBase(
"program_out_scope", core.VarDesc.VarType.FP32,
core.VarDesc.VarType.STEP_SCOPES, True) [],
"program_out_scope",
core.VarDesc.VarType.STEP_SCOPES,
True,
)
tmp_scope_vec.value().set_scope(program_holder.scope) tmp_scope_vec.value().set_scope(program_holder.scope)
double_grad_vars = [] double_grad_vars = []
for var_desc in program_holder.double_grad_descs: for var_desc in program_holder.double_grad_descs:
if framework._in_eager_without_dygraph_check(): if framework._in_eager_without_dygraph_check():
var = core.eager.Tensor(dtype=var_desc.dtype(), var = core.eager.Tensor(
dims=var_desc.shape(), dtype=var_desc.dtype(),
name=var_desc.name(), dims=var_desc.shape(),
type=var_desc.type(), name=var_desc.name(),
persistable=False) type=var_desc.type(),
persistable=False,
)
else: else:
var = core.VarBase(var_desc.dtype(), var_desc.shape(), var = core.VarBase(
var_desc.name(), var_desc.type(), False) var_desc.dtype(),
var_desc.shape(),
var_desc.name(),
var_desc.type(),
False,
)
double_grad_vars.append(var) double_grad_vars.append(var)
# 2. run program by op # 2. run program by op
trace_program = program_holder.infer_program if instance._is_test else program_holder.train_program trace_program = (
forward_program = program_holder._infer_program_desc if instance._is_test else program_holder.forward_program program_holder.infer_program
if instance._is_test
else program_holder.train_program
)
forward_program = (
program_holder._infer_program_desc
if instance._is_test
else program_holder.forward_program
)
end_op_index = program_holder.infer_program.block(0).op_size() end_op_index = program_holder.infer_program.block(0).op_size()
attrs = [ attrs = [
'global_block', 'global_block',
trace_program.block(0), 'start_op_index', 0, 'end_op_index', trace_program.block(0),
end_op_index, 'is_test', instance._is_test, 'program_id', 'start_op_index',
_hash_with_id(trace_program, instance) 0,
'end_op_index',
end_op_index,
'is_test',
instance._is_test,
'program_id',
_hash_with_id(trace_program, instance),
] ]
use_interpretorcore = _is_enable_standalone_executor( use_interpretorcore = (
) and _is_dy2st_enable_standalone_executor() _is_enable_standalone_executor()
and _is_dy2st_enable_standalone_executor()
)
attrs.extend(('use_interpretorcore', use_interpretorcore)) attrs.extend(('use_interpretorcore', use_interpretorcore))
if use_interpretorcore: if use_interpretorcore:
attrs.extend( attrs.extend(
('forward_global_block', forward_program.block(0), (
'backward_global_block', program_holder.backward_program.block(0))) 'forward_global_block',
forward_program.block(0),
_legacy_C_ops.run_program(_valid_vars(input_vars), 'backward_global_block',
_valid_vars(persistable_vars), program_holder.backward_program.block(0),
_valid_vars(output_vars), tmp_scope_vec, )
_valid_vars(double_grad_vars), None, *attrs) )
_legacy_C_ops.run_program(
_valid_vars(input_vars),
_valid_vars(persistable_vars),
_valid_vars(output_vars),
tmp_scope_vec,
_valid_vars(double_grad_vars),
None,
*attrs
)
# NOTE: [ why need set param's gradient type here ] # NOTE: [ why need set param's gradient type here ]
# if user set sparse gradient mode, the param's gradient # if user set sparse gradient mode, the param's gradient
...@@ -946,8 +1066,6 @@ def _run_dygraph(instance, input, program_holder): ...@@ -946,8 +1066,6 @@ def _run_dygraph(instance, input, program_holder):
continue continue
persistable_var._set_grad_type(grad_var.type()) persistable_var._set_grad_type(grad_var.type())
drop_scope_if_no_grad(instance, tmp_scope_vec)
# 3. prepare output, keep same form with inputs # 3. prepare output, keep same form with inputs
outs = output_vars outs = output_vars
if len(output_vars) == 1: if len(output_vars) == 1:
...@@ -955,27 +1073,26 @@ def _run_dygraph(instance, input, program_holder): ...@@ -955,27 +1073,26 @@ def _run_dygraph(instance, input, program_holder):
return outs return outs
def drop_scope_if_no_grad(instance, scope_vec):
tracer = framework._dygraph_tracer()
scope = scope_vec.value().get_scope() if isinstance(
scope_vec, (core.VarBase)) else scope_vec[0]
if (not instance._is_test) and (not tracer._has_grad):
scope.drop_kids()
def _run_static_graph(input, program_holder, trace_program): def _run_static_graph(input, program_holder, trace_program):
main_program = framework.default_main_program() main_program = framework.default_main_program()
param_var_names = _get_persistable_var_names(trace_program) param_var_names = _get_persistable_var_names(trace_program)
_, dict_rename_var_old_new = _rename_var_program_desc( _, dict_rename_var_old_new = _rename_var_program_desc(
trace_program, exclude=param_var_names) trace_program, exclude=param_var_names
)
trace_program.flush() trace_program.flush()
output_names = [var.name() for var in program_holder.output_descs] output_names = [var.name() for var in program_holder.output_descs]
# append blocks from 'trace_program' # append blocks from 'trace_program'
_append_block(main_program, trace_program, program_holder, input, _append_block(
dict_rename_var_old_new) main_program,
trace_program,
program_holder,
input,
dict_rename_var_old_new,
)
main_program._sync_with_cpp() main_program._sync_with_cpp()
outs = _get_output_from_program(main_program, program_holder, outs = _get_output_from_program(
dict_rename_var_old_new) main_program, program_holder, dict_rename_var_old_new
)
if len(outs) == 1: if len(outs) == 1:
outs = outs[0] outs = outs[0]
return outs return outs
...@@ -984,7 +1101,7 @@ def _run_static_graph(input, program_holder, trace_program): ...@@ -984,7 +1101,7 @@ def _run_static_graph(input, program_holder, trace_program):
def _collect_current_and_parent_var(program, block_idx): def _collect_current_and_parent_var(program, block_idx):
''' '''
Get variables in current block and its parent block. Get variables in current block and its parent block.
Args: Args:
program(Program): The program containing the current block. program(Program): The program containing the current block.
block_idx(int): index of current block. block_idx(int): index of current block.
...@@ -1003,46 +1120,55 @@ def _collect_current_and_parent_var(program, block_idx): ...@@ -1003,46 +1120,55 @@ def _collect_current_and_parent_var(program, block_idx):
return vars return vars
def _append_block(dest_program, def _append_block(
src_program_desc, dest_program,
program_holder, src_program_desc,
input_variables, program_holder,
dict_rename_var_old_new=None): input_variables,
dict_rename_var_old_new=None,
):
''' '''
Append Variables and Operators in 'src_program_desc' to dest_program. Append Variables and Operators in 'src_program_desc' to dest_program.
Args: Args:
dest_program(Program): Variables and Operators are appended to it. dest_program(Program): Variables and Operators are appended to it.
src_program_desc(ProgramDesc): Variables in it will be appended to 'dest_program'. src_program_desc(ProgramDesc): Variables in it will be appended to 'dest_program'.
program_holder(_ProgramHolder): program_holder of TranslatedLayer program_holder(_ProgramHolder): program_holder of TranslatedLayer
input_variables(list): list of input variables input_variables(list): list of input variables
dict_rename_var_old_new(None|dict): When using '_rename_var_program_desc', dict_rename_var_old_new(None|dict): When using '_rename_var_program_desc',
use it to map the name of the variable before it was modified and the new name. use it to map the name of the variable before it was modified and the new name.
''' '''
origin_block_idx = dest_program.current_block_idx origin_block_idx = dest_program.current_block_idx
param_var_names = _collect_current_and_parent_var(dest_program, param_var_names = _collect_current_and_parent_var(
origin_block_idx) dest_program, origin_block_idx
append_var_from_block_desc_static(dest_program.block(origin_block_idx), )
src_program_desc.block(0), append_var_from_block_desc_static(
exclude=param_var_names) dest_program.block(origin_block_idx),
src_program_desc.block(0),
exclude=param_var_names,
)
name_inp_desc = [inp.name() for inp in program_holder.input_descs] name_inp_desc = [inp.name() for inp in program_holder.input_descs]
input_names = [inp.name for inp in input_variables] input_names = [inp.name for inp in input_variables]
if len(name_inp_desc) != len(input_names): if len(name_inp_desc) != len(input_names):
raise ValueError( raise ValueError(
"The number of input is invalid, expected {}, but received {}.". "The number of input is invalid, expected {}, but received {}.".format(
format(len(name_inp_desc), len(input_names))) len(name_inp_desc), len(input_names)
)
)
for i, out_name in enumerate(name_inp_desc): for i, out_name in enumerate(name_inp_desc):
if dict_rename_var_old_new: if dict_rename_var_old_new:
out_name = dict_rename_var_old_new[out_name] out_name = dict_rename_var_old_new[out_name]
dest_program.block(origin_block_idx).append_op( dest_program.block(origin_block_idx).append_op(
type='assign', type='assign',
inputs={'X': [input_names[i]]}, inputs={'X': [input_names[i]]},
outputs={'Out': [out_name]}) outputs={'Out': [out_name]},
)
append_ops = append_op_from_block_desc_static( append_ops = append_op_from_block_desc_static(
dest_program.block(origin_block_idx), src_program_desc.block(0)) dest_program.block(origin_block_idx), src_program_desc.block(0)
)
dest_program._sync_with_cpp() dest_program._sync_with_cpp()
offset_block_idx = dest_program.num_blocks - 1 offset_block_idx = dest_program.num_blocks - 1
...@@ -1056,11 +1182,12 @@ def _append_block(dest_program, ...@@ -1056,11 +1182,12 @@ def _append_block(dest_program,
else: else:
parent_idx = origin_block_idx parent_idx = origin_block_idx
dest_block = dest_program._create_block(parent_idx=parent_idx) dest_block = dest_program._create_block(parent_idx=parent_idx)
append_var_from_block_desc_static(dest_block, append_var_from_block_desc_static(
src_block, dest_block, src_block, exclude=param_var_names
exclude=param_var_names) )
append_ops += append_op_from_block_desc_static( append_ops += append_op_from_block_desc_static(
dest_block, src_block) dest_block, src_block
)
dest_program._sync_with_cpp() dest_program._sync_with_cpp()
for op in append_ops: for op in append_ops:
...@@ -1070,15 +1197,16 @@ def _append_block(dest_program, ...@@ -1070,15 +1197,16 @@ def _append_block(dest_program,
origin_id = sub.id origin_id = sub.id
if isinstance(sub, framework.Block): if isinstance(sub, framework.Block):
origin_id = sub.idx origin_id = sub.idx
op._set_attr('sub_block', op._set_attr(
dest_program.block(offset_block_idx + origin_id)) 'sub_block', dest_program.block(offset_block_idx + origin_id)
)
dest_program._sync_with_cpp() dest_program._sync_with_cpp()
dest_program.current_block_idx = origin_block_idx dest_program.current_block_idx = origin_block_idx
def _get_output_from_program(program, def _get_output_from_program(
program_holder, program, program_holder, dict_rename_var_old_new=None
dict_rename_var_old_new=None): ):
""" """
Get output name of 'program' according to program_holder Get output name of 'program' according to program_holder
""" """
...@@ -1127,20 +1255,21 @@ def append_op_from_desc_static(block, op_desc): ...@@ -1127,20 +1255,21 @@ def append_op_from_desc_static(block, op_desc):
op_type = op_desc.type() op_type = op_desc.type()
op_append = block.desc.append_op() op_append = block.desc.append_op()
op_append.copy_from(op_desc) op_append.copy_from(op_desc)
op = framework.Operator(block=block, op = framework.Operator(
desc=op_append, block=block,
type=op_type, desc=op_append,
inputs=None, type=op_type,
outputs=None, inputs=None,
attrs=None) outputs=None,
attrs=None,
)
block.ops.append(op) block.ops.append(op)
return op return op
def append_var_from_block_desc_static(block, def append_var_from_block_desc_static(
src_block_desc, block, src_block_desc, include=None, exclude=None
include=None, ):
exclude=None):
""" """
Append Variables of 'src_block_desc' to current block. Append Variables of 'src_block_desc' to current block.
If 'include' is not `None`,variables that are not in include are not append. If 'include' is not `None`,variables that are not in include are not append.
...@@ -1159,13 +1288,14 @@ def append_var_from_block_desc_static(block, ...@@ -1159,13 +1288,14 @@ def append_var_from_block_desc_static(block,
for var_desc in src_block_desc.all_vars(): for var_desc in src_block_desc.all_vars():
var_desc_name = var_desc.name() var_desc_name = var_desc.name()
should_append = (include is None or var_desc_name in include) and ( should_append = (include is None or var_desc_name in include) and (
exclude is None or var_desc_name not in exclude) exclude is None or var_desc_name not in exclude
)
if not block.has_var(var_desc_name) and should_append: if not block.has_var(var_desc_name) and should_append:
var_type = var_desc.type() var_type = var_desc.type()
if var_type in [ if var_type in [
core.VarDesc.VarType.SELECTED_ROWS, core.VarDesc.VarType.SELECTED_ROWS,
core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.LOD_TENSOR_ARRAY core.VarDesc.VarType.LOD_TENSOR_ARRAY,
]: ]:
data_type = var_desc.dtype() data_type = var_desc.dtype()
var_shape = var_desc.shape() var_shape = var_desc.shape()
...@@ -1173,8 +1303,8 @@ def append_var_from_block_desc_static(block, ...@@ -1173,8 +1303,8 @@ def append_var_from_block_desc_static(block,
data_type = None data_type = None
var_shape = None var_shape = None
if var_type in [ if var_type in [
core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.LOD_TENSOR_ARRAY core.VarDesc.VarType.LOD_TENSOR_ARRAY,
]: ]:
lod_level = var_desc.lod_level() lod_level = var_desc.lod_level()
else: else:
...@@ -1193,16 +1323,18 @@ def append_var_from_block_desc_static(block, ...@@ -1193,16 +1323,18 @@ def append_var_from_block_desc_static(block,
shape=var_shape, shape=var_shape,
lod_level=lod_level, lod_level=lod_level,
persistable=var_desc.persistable(), persistable=var_desc.persistable(),
set_need_check_feed=var_desc.need_check_feed())) set_need_check_feed=var_desc.need_check_feed(),
)
)
return vars_append return vars_append
class TranslatedLayer(layers.Layer): class TranslatedLayer(layers.Layer):
""" """
TranslatedLayer is a ``paddle.nn.Layer`` for holding the model TranslatedLayer is a ``paddle.nn.Layer`` for holding the model
loaded by :ref:`api_paddle_jit_load` . It can be used like a loaded by :ref:`api_paddle_jit_load` . It can be used like a
general Layer object in eval or train mode. general Layer object in eval or train mode.
.. note: .. note:
The TranslatedLayer objects should not be created by constructor, it only can be loaded and constructed by :ref:`api_paddle_jit_load` . The TranslatedLayer objects should not be created by constructor, it only can be loaded and constructed by :ref:`api_paddle_jit_load` .
...@@ -1318,8 +1450,9 @@ class TranslatedLayer(layers.Layer): ...@@ -1318,8 +1450,9 @@ class TranslatedLayer(layers.Layer):
# the TranslatedLayer object holded var names count started from 0 # the TranslatedLayer object holded var names count started from 0
with unique_name.guard(): with unique_name.guard():
for name, var in persistable_vars.items(): for name, var in persistable_vars.items():
if isinstance(var, if isinstance(
(framework.ParamBase, framework.EagerParamBase)): var, (framework.ParamBase, framework.EagerParamBase)
):
dy_name = _generate_unique_var_name(PARAMETER_NAME_PREFIX) dy_name = _generate_unique_var_name(PARAMETER_NAME_PREFIX)
self._persistable_var_name_dict[name] = dy_name self._persistable_var_name_dict[name] = dy_name
self.add_parameter(dy_name, var) self.add_parameter(dy_name, var)
...@@ -1353,7 +1486,8 @@ class TranslatedLayer(layers.Layer): ...@@ -1353,7 +1486,8 @@ class TranslatedLayer(layers.Layer):
# 2. load layer parameters & buffers # 2. load layer parameters & buffers
persistable_vars = _construct_params_and_buffers( persistable_vars = _construct_params_and_buffers(
model_path, programs, params_filename) model_path, programs, params_filename
)
# 3. construct TranslatedLayer object # 3. construct TranslatedLayer object
translated_layer = TranslatedLayer(programs, persistable_vars) translated_layer = TranslatedLayer(programs, persistable_vars)
...@@ -1365,9 +1499,12 @@ class TranslatedLayer(layers.Layer): ...@@ -1365,9 +1499,12 @@ class TranslatedLayer(layers.Layer):
ins.name() for ins in program_holder.input_descs ins.name() for ins in program_holder.input_descs
] ]
setattr( setattr(
TranslatedLayer, method_name, TranslatedLayer,
method_name,
TranslatedLayer._execution_method_creator( TranslatedLayer._execution_method_creator(
method_name, program_holder)) method_name, program_holder
),
)
# 5. set TranslatedLayer's default mode to eval # 5. set TranslatedLayer's default mode to eval
translated_layer.eval() translated_layer.eval()
...@@ -1376,7 +1513,6 @@ class TranslatedLayer(layers.Layer): ...@@ -1376,7 +1513,6 @@ class TranslatedLayer(layers.Layer):
@staticmethod @staticmethod
def _execution_method_creator(method_name, program_holder): def _execution_method_creator(method_name, program_holder):
def __i_m_p_l__(self, *input): def __i_m_p_l__(self, *input):
program_holder = self._program_holder_dict[__i_m_p_l__.__name__] program_holder = self._program_holder_dict[__i_m_p_l__.__name__]
# When using jit.save, it runs in static graph mode. # When using jit.save, it runs in static graph mode.
...@@ -1389,7 +1525,8 @@ class TranslatedLayer(layers.Layer): ...@@ -1389,7 +1525,8 @@ class TranslatedLayer(layers.Layer):
# because '_run_static_graph' modifies 'ProgramDesc', 'OpDesc.op_size()' will return a very large wrong number. # because '_run_static_graph' modifies 'ProgramDesc', 'OpDesc.op_size()' will return a very large wrong number.
# A Segmentation fault error may occur if used 'p=ProgramDesc(program_holder.infer_program)'. # A Segmentation fault error may occur if used 'p=ProgramDesc(program_holder.infer_program)'.
p = framework.Program._construct_from_desc( p = framework.Program._construct_from_desc(
core.ProgramDesc(program_holder.infer_program)) core.ProgramDesc(program_holder.infer_program)
)
return _run_static_graph(input, program_holder, p.desc) return _run_static_graph(input, program_holder, p.desc)
__i_m_p_l__.__name__ = method_name __i_m_p_l__.__name__ = method_name
...@@ -1410,13 +1547,13 @@ class TranslatedLayer(layers.Layer): ...@@ -1410,13 +1547,13 @@ class TranslatedLayer(layers.Layer):
Args: Args:
- method_name (string): mehtod name corresponding to the program - method_name (string): mehtod name corresponding to the program
to be obtained. Default: 'forward'. to be obtained. Default: 'forward'.
Returns: Returns:
Program Program
Examples: Examples:
.. code-block:: python .. code-block:: python
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
...@@ -1502,8 +1639,9 @@ class TranslatedLayer(layers.Layer): ...@@ -1502,8 +1639,9 @@ class TranslatedLayer(layers.Layer):
program_holder = self._program_holder_dict.get(method_name, None) program_holder = self._program_holder_dict.get(method_name, None)
if program_holder is None: if program_holder is None:
raise ValueError( raise ValueError(
"The method `%s` does not exist in loaded TranslatedLayer." % "The method `%s` does not exist in loaded TranslatedLayer."
method_name) % method_name
)
return program_holder return program_holder
def _input_spec(self, method_name='forward'): def _input_spec(self, method_name='forward'):
...@@ -1513,9 +1651,11 @@ class TranslatedLayer(layers.Layer): ...@@ -1513,9 +1651,11 @@ class TranslatedLayer(layers.Layer):
# 2. build input spec by input desc # 2. build input spec by input desc
input_spec = [] input_spec = []
for var_desc in program_holder.input_descs: for var_desc in program_holder.input_descs:
spec = paddle.static.InputSpec(shape=var_desc.shape(), spec = paddle.static.InputSpec(
dtype=var_desc.dtype(), shape=var_desc.shape(),
name=var_desc.name()) dtype=var_desc.dtype(),
name=var_desc.name(),
)
input_spec.append(spec) input_spec.append(spec)
return input_spec return input_spec
...@@ -1530,9 +1670,11 @@ class TranslatedLayer(layers.Layer): ...@@ -1530,9 +1670,11 @@ class TranslatedLayer(layers.Layer):
# NOTE(chenweihang): InputSpec describes a tensor, not just input. # NOTE(chenweihang): InputSpec describes a tensor, not just input.
# Maybe the name is not good enough. Here we use InputSpec to # Maybe the name is not good enough. Here we use InputSpec to
# construct the description of Output tensor # construct the description of Output tensor
spec = paddle.static.InputSpec(shape=var_desc.shape(), spec = paddle.static.InputSpec(
dtype=var_desc.dtype(), shape=var_desc.shape(),
name=var_desc.name()) dtype=var_desc.dtype(),
name=var_desc.name(),
)
output_spec.append(spec) output_spec.append(spec)
return output_spec return output_spec
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册