未验证 提交 7618cbdc 编写于 作者: W WangZhen 提交者: GitHub

[Dy2St]Fix abnormal growth of memory in train mode and no_grad for Dy2St (#47398) (#47414)

* [Dy2St]Fix abnormal growth of memory in train mode and no_grad for Dy2St 
上级 c42929c5
...@@ -394,7 +394,7 @@ inline void RunProgramAPI( ...@@ -394,7 +394,7 @@ inline void RunProgramAPI(
} }
VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(out_scope_vec->front()); VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(out_scope_vec->front());
if (is_test) { if (is_test || !egr::Controller::Instance().HasGrad()) {
VLOG(4) << "is test, set this scope can reused"; VLOG(4) << "is test, set this scope can reused";
global_inner_scope->SetCanReuesd(true); global_inner_scope->SetCanReuesd(true);
details::GcScope(global_inner_scope); details::GcScope(global_inner_scope);
...@@ -470,7 +470,7 @@ inline void RunProgramAPI( ...@@ -470,7 +470,7 @@ inline void RunProgramAPI(
// Debug info: scope info when run end // Debug info: scope info when run end
VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(out_scope_vec->front()); VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(out_scope_vec->front());
// Step 5. Drop all children scopes while testing. // Step 5. Drop all children scopes while testing.
if (is_test) { if (is_test || !egr::Controller::Instance().HasGrad()) {
out_scope_vec->front()->DropKids(); out_scope_vec->front()->DropKids();
} }
VLOG(2) << "The number of sub scopes after forward: " VLOG(2) << "The number of sub scopes after forward: "
......
...@@ -30,8 +30,14 @@ from paddle.fluid.layers import nn ...@@ -30,8 +30,14 @@ from paddle.fluid.layers import nn
from paddle.fluid.layers.utils import _hash_with_id from paddle.fluid.layers.utils import _hash_with_id
from paddle.fluid.dygraph.base import switch_to_static_graph from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.framework import _non_static_mode from paddle.fluid.framework import _non_static_mode
from paddle.fluid.executor import _is_enable_standalone_executor, _is_dy2st_enable_standalone_executor from paddle.fluid.executor import (
from paddle.fluid.dygraph.dygraph_to_static.partial_program import add_build_strategy_for, LazyInitialized _is_enable_standalone_executor,
_is_dy2st_enable_standalone_executor,
)
from paddle.fluid.dygraph.dygraph_to_static.partial_program import (
add_build_strategy_for,
LazyInitialized,
)
from paddle import _C_ops, _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
__all__ = ['TranslatedLayer'] __all__ = ['TranslatedLayer']
...@@ -53,17 +59,20 @@ def _load_program_desc(model_file_path): ...@@ -53,17 +59,20 @@ def _load_program_desc(model_file_path):
program_desc = core.ProgramDesc(program_desc_str) program_desc = core.ProgramDesc(program_desc_str)
if not core._is_program_version_supported(program_desc._version()): if not core._is_program_version_supported(program_desc._version()):
raise ValueError("Unsupported program version: %d\n" % raise ValueError(
program_desc._version()) "Unsupported program version: %d\n" % program_desc._version()
)
return program_desc return program_desc
def _is_persistable(var_desc): def _is_persistable(var_desc):
if var_desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ if (
var_desc.type() == core.VarDesc.VarType.FETCH_LIST or \ var_desc.type() == core.VarDesc.VarType.FEED_MINIBATCH
var_desc.type() == core.VarDesc.VarType.READER or \ or var_desc.type() == core.VarDesc.VarType.FETCH_LIST
var_desc.type() == core.VarDesc.VarType.RAW: or var_desc.type() == core.VarDesc.VarType.READER
or var_desc.type() == core.VarDesc.VarType.RAW
):
return False return False
return var_desc.persistable() return var_desc.persistable()
...@@ -208,9 +217,11 @@ def _rename_var_program_desc(program_desc, include=None, exclude=None): ...@@ -208,9 +217,11 @@ def _rename_var_program_desc(program_desc, include=None, exclude=None):
name_old = var.name() name_old = var.name()
is_double_grad_var = "@GRAD" in name_old is_double_grad_var = "@GRAD" in name_old
has_double_grad = has_double_grad or is_double_grad_var has_double_grad = has_double_grad or is_double_grad_var
should_rename = (include is None or name_old in include) and ( should_rename = (
exclude is None (include is None or name_old in include)
or name_old not in exclude) and not is_double_grad_var and (exclude is None or name_old not in exclude)
and not is_double_grad_var
)
if should_rename: if should_rename:
temp_name = name_old.split('_') temp_name = name_old.split('_')
if len(temp_name) > 1 and temp_name[-1].isnumeric(): if len(temp_name) > 1 and temp_name[-1].isnumeric():
...@@ -219,15 +230,19 @@ def _rename_var_program_desc(program_desc, include=None, exclude=None): ...@@ -219,15 +230,19 @@ def _rename_var_program_desc(program_desc, include=None, exclude=None):
temp_name = name_old temp_name = name_old
while True: while True:
name_new = _generate_unique_var_name_sync_with_main_program( name_new = _generate_unique_var_name_sync_with_main_program(
temp_name) temp_name
if name_new not in old_names[:var_idx] + old_names[var_idx + )
1:]: if (
name_new
not in old_names[:var_idx] + old_names[var_idx + 1 :]
):
break break
else: else:
name_new = name_old name_new = name_old
if name_old != name_new: if name_old != name_new:
cur_block._rename_var(cpt.to_bytes(name_old), cur_block._rename_var(
cpt.to_bytes(name_new)) cpt.to_bytes(name_old), cpt.to_bytes(name_new)
)
if not is_double_grad_var: if not is_double_grad_var:
dict_rename_var_old_new[name_old] = name_new dict_rename_var_old_new[name_old] = name_new
dict_rename_var_new_old[name_new] = name_old dict_rename_var_new_old[name_new] = name_old
...@@ -242,13 +257,16 @@ def _rename_var_program_desc(program_desc, include=None, exclude=None): ...@@ -242,13 +257,16 @@ def _rename_var_program_desc(program_desc, include=None, exclude=None):
var_name = var.name() var_name = var.name()
if "@GRAD" in var_name and name_old in var_name: if "@GRAD" in var_name and name_old in var_name:
new_var_name = var_name.replace( new_var_name = var_name.replace(
name_old, dict_rename_var_old_new[name_old]) name_old, dict_rename_var_old_new[name_old]
)
double_grad_rename_dict[var_name] = new_var_name double_grad_rename_dict[var_name] = new_var_name
for var_name in double_grad_rename_dict: for var_name in double_grad_rename_dict:
dict_rename_var_old_new[var_name] = double_grad_rename_dict[ dict_rename_var_old_new[var_name] = double_grad_rename_dict[
var_name] var_name
]
dict_rename_var_new_old[ dict_rename_var_new_old[
double_grad_rename_dict[var_name]] = var_name double_grad_rename_dict[var_name]
] = var_name
# Rename on program desc # Rename on program desc
for b_idx in six.moves.range(program_desc.num_blocks()): for b_idx in six.moves.range(program_desc.num_blocks()):
...@@ -257,27 +275,38 @@ def _rename_var_program_desc(program_desc, include=None, exclude=None): ...@@ -257,27 +275,38 @@ def _rename_var_program_desc(program_desc, include=None, exclude=None):
op = cur_block.op(op_idx) op = cur_block.op(op_idx)
for input_arg_name in op.input_arg_names(): for input_arg_name in op.input_arg_names():
if input_arg_name in dict_rename_var_old_new: if input_arg_name in dict_rename_var_old_new:
if input_arg_name != dict_rename_var_old_new[input_arg_name]: if (
input_arg_name
!= dict_rename_var_old_new[input_arg_name]
):
op._rename_input( op._rename_input(
input_arg_name, input_arg_name,
dict_rename_var_old_new[input_arg_name]) dict_rename_var_old_new[input_arg_name],
)
if cur_block.has_var(cpt.to_bytes(input_arg_name)): if cur_block.has_var(cpt.to_bytes(input_arg_name)):
cur_block._rename_var( cur_block._rename_var(
cpt.to_bytes(input_arg_name), cpt.to_bytes(input_arg_name),
cpt.to_bytes( cpt.to_bytes(
dict_rename_var_old_new[input_arg_name])) dict_rename_var_old_new[input_arg_name]
),
)
for output_arg_name in op.output_arg_names(): for output_arg_name in op.output_arg_names():
if output_arg_name in dict_rename_var_old_new: if output_arg_name in dict_rename_var_old_new:
if output_arg_name != dict_rename_var_old_new[ if (
output_arg_name]: output_arg_name
!= dict_rename_var_old_new[output_arg_name]
):
op._rename_output( op._rename_output(
output_arg_name, output_arg_name,
dict_rename_var_old_new[output_arg_name]) dict_rename_var_old_new[output_arg_name],
)
if cur_block.has_var(cpt.to_bytes(output_arg_name)): if cur_block.has_var(cpt.to_bytes(output_arg_name)):
cur_block._rename_var( cur_block._rename_var(
cpt.to_bytes(output_arg_name), cpt.to_bytes(output_arg_name),
cpt.to_bytes( cpt.to_bytes(
dict_rename_var_old_new[output_arg_name])) dict_rename_var_old_new[output_arg_name]
),
)
program_desc.flush() program_desc.flush()
return dict_rename_var_new_old, dict_rename_var_old_new return dict_rename_var_new_old, dict_rename_var_old_new
...@@ -333,7 +362,8 @@ class _ProgramHolder(object): ...@@ -333,7 +362,8 @@ class _ProgramHolder(object):
self._infer_program_desc = self._preprocess(program_desc) self._infer_program_desc = self._preprocess(program_desc)
# forward + backward program # forward + backward program
self._train_program_desc = self._append_backward_desc( self._train_program_desc = self._append_backward_desc(
self._infer_program_desc) self._infer_program_desc
)
# forward: # forward:
@switch_to_static_graph @switch_to_static_graph
...@@ -354,11 +384,13 @@ class _ProgramHolder(object): ...@@ -354,11 +384,13 @@ class _ProgramHolder(object):
def _create_backward_train_program(self): def _create_backward_train_program(self):
whole_program = _build_program_by_desc(self._train_program_desc) whole_program = _build_program_by_desc(self._train_program_desc)
start_op_index = self._infer_program_desc.block(0).op_size() + 2 * len( start_op_index = self._infer_program_desc.block(0).op_size() + 2 * len(
self._output_descs) self._output_descs
)
end_op_index = whole_program.desc.block(0).op_size() end_op_index = whole_program.desc.block(0).op_size()
if (start_op_index < end_op_index): if start_op_index < end_op_index:
return add_build_strategy_for(whole_program, start_op_index, return add_build_strategy_for(
end_op_index) whole_program, start_op_index, end_op_index
)
else: else:
return paddle.static.Program() return paddle.static.Program()
...@@ -406,7 +438,8 @@ class _ProgramHolder(object): ...@@ -406,7 +438,8 @@ class _ProgramHolder(object):
# rename persistable variables of 'program_desc' # rename persistable variables of 'program_desc'
list_persistable_var = _get_persistable_var_names(program_desc) list_persistable_var = _get_persistable_var_names(program_desc)
rename_new_old_dict, _ = _rename_var_program_desc( rename_new_old_dict, _ = _rename_var_program_desc(
program_desc, list_persistable_var) program_desc, list_persistable_var
)
# 1. Prune original program # 1. Prune original program
# remove feed, fetch and scale-1 op, remove op_callstack attr # remove feed, fetch and scale-1 op, remove op_callstack attr
ops_to_remove = [] ops_to_remove = []
...@@ -418,14 +451,17 @@ class _ProgramHolder(object): ...@@ -418,14 +451,17 @@ class _ProgramHolder(object):
feed_var_name = cpt.to_bytes(op.input('X')[0]) feed_var_name = cpt.to_bytes(op.input('X')[0])
root_block._remove_var(feed_var_name) root_block._remove_var(feed_var_name)
self._input_descs.append( self._input_descs.append(
root_block.find_var(cpt.to_bytes(op.output('Out')[0]))) root_block.find_var(cpt.to_bytes(op.output('Out')[0]))
)
elif op.type() == 'scale' and op.output('Out')[0].startswith( elif op.type() == 'scale' and op.output('Out')[0].startswith(
'save_infer_model/scale_'): 'save_infer_model/scale_'
):
ops_to_remove.append(i) ops_to_remove.append(i)
out_var_name = cpt.to_bytes(op.output('Out')[0]) out_var_name = cpt.to_bytes(op.output('Out')[0])
root_block._remove_var(out_var_name) root_block._remove_var(out_var_name)
self._output_descs.append( self._output_descs.append(
root_block.find_var(cpt.to_bytes(op.input('X')[0]))) root_block.find_var(cpt.to_bytes(op.input('X')[0]))
)
elif op.type() == 'fetch': elif op.type() == 'fetch':
ops_to_remove.append(i) ops_to_remove.append(i)
fetch_var_name = cpt.to_bytes(op.output('Out')[0]) fetch_var_name = cpt.to_bytes(op.output('Out')[0])
...@@ -433,7 +469,8 @@ class _ProgramHolder(object): ...@@ -433,7 +469,8 @@ class _ProgramHolder(object):
# NOTE: some old pre-train models have no extra scale_op # NOTE: some old pre-train models have no extra scale_op
if not op.input('X')[0].startswith('save_infer_model/scale_'): if not op.input('X')[0].startswith('save_infer_model/scale_'):
self._output_descs.append( self._output_descs.append(
root_block.find_var(cpt.to_bytes(op.input('X')[0]))) root_block.find_var(cpt.to_bytes(op.input('X')[0]))
)
else: else:
if op.has_attr("op_callstack"): if op.has_attr("op_callstack"):
op.remove_attr("op_callstack") op.remove_attr("op_callstack")
...@@ -478,7 +515,8 @@ class _ProgramHolder(object): ...@@ -478,7 +515,8 @@ class _ProgramHolder(object):
# there will be a problem of duplicate names, so here is unified # there will be a problem of duplicate names, so here is unified
# to add the LOADED suffix to the parameters of the model loaded # to add the LOADED suffix to the parameters of the model loaded
self._suffix_varname_dict = _get_loaded_var_new_old( self._suffix_varname_dict = _get_loaded_var_new_old(
program_desc, rename_new_old_dict) program_desc, rename_new_old_dict
)
# - get persistable var # - get persistable var
self._persistable_names = _get_persistable_var_names(program_desc) self._persistable_names = _get_persistable_var_names(program_desc)
...@@ -492,9 +530,9 @@ class _ProgramHolder(object): ...@@ -492,9 +530,9 @@ class _ProgramHolder(object):
with framework.program_guard(program): with framework.program_guard(program):
for i, out in enumerate(self._output_descs): for i, out in enumerate(self._output_descs):
var = program.global_block().var(out.name()) var = program.global_block().var(out.name())
var = nn.scale(var, var = nn.scale(
1., var, 1.0, name="translated_layer/scale_{}".format(i)
name="translated_layer/scale_{}".format(i)) )
scale_output_vars.append(var) scale_output_vars.append(var)
# 2. update output names & descs # 2. update output names & descs
for i, var in enumerate(scale_output_vars): for i, var in enumerate(scale_output_vars):
...@@ -519,15 +557,19 @@ class _ProgramHolder(object): ...@@ -519,15 +557,19 @@ class _ProgramHolder(object):
block = program.block(block_idx) block = program.block(block_idx)
for op in block.ops: for op in block.ops:
if op.type == "batch_norm": if op.type == "batch_norm":
if "ReserveSpace" not in op.output_names or len( if (
op.output("ReserveSpace")) == 0: "ReserveSpace" not in op.output_names
or len(op.output("ReserveSpace")) == 0
):
reserve_space = block.create_var( reserve_space = block.create_var(
name=unique_name.generate_with_ignorable_key( name=unique_name.generate_with_ignorable_key(
".".join(["reserve_space", 'tmp'])), ".".join(["reserve_space", 'tmp'])
),
dtype=block.var(op.input("X")[0]).dtype, dtype=block.var(op.input("X")[0]).dtype,
type=core.VarDesc.VarType.LOD_TENSOR, type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False, persistable=False,
stop_gradient=True) stop_gradient=True,
)
op.desc.set_output("ReserveSpace", [reserve_space.name]) op.desc.set_output("ReserveSpace", [reserve_space.name])
return program return program
...@@ -573,9 +615,9 @@ class _ProgramHolder(object): ...@@ -573,9 +615,9 @@ class _ProgramHolder(object):
# NOTE: [compatible] deal with model saved by save_inference_model, # NOTE: [compatible] deal with model saved by save_inference_model,
# which need get var info from program desc # which need get var info from program desc
def _load_persistable_vars_by_program(model_path, def _load_persistable_vars_by_program(
program_holder, model_path, program_holder, params_filename=None
params_filename=None): ):
# make sure the path has been checked # make sure the path has been checked
persistable_vars = _get_persistable_vars(program_holder.infer_program) persistable_vars = _get_persistable_vars(program_holder.infer_program)
load_var_dict = {} load_var_dict = {}
...@@ -584,37 +626,43 @@ def _load_persistable_vars_by_program(model_path, ...@@ -584,37 +626,43 @@ def _load_persistable_vars_by_program(model_path,
if _is_parameter(each_var, program_holder.infer_program): if _is_parameter(each_var, program_holder.infer_program):
# create output varbase # create output varbase
if framework._in_eager_without_dygraph_check(): if framework._in_eager_without_dygraph_check():
new_var = framework.EagerParamBase(shape=each_var.shape(), new_var = framework.EagerParamBase(
shape=each_var.shape(),
dtype=each_var.dtype(), dtype=each_var.dtype(),
name=each_var.name(), name=each_var.name(),
type=each_var.type(), type=each_var.type(),
persistable=True) persistable=True,
)
else: else:
new_var = framework.ParamBase(shape=each_var.shape(), new_var = framework.ParamBase(
shape=each_var.shape(),
dtype=each_var.dtype(), dtype=each_var.dtype(),
name=each_var.name(), name=each_var.name(),
type=each_var.type(), type=each_var.type(),
persistable=True) persistable=True,
)
else: else:
new_var = framework._varbase_creator(type=each_var.type(), new_var = framework._varbase_creator(
type=each_var.type(),
name=each_var.name(), name=each_var.name(),
shape=each_var.shape(), shape=each_var.shape(),
dtype=each_var.dtype(), dtype=each_var.dtype(),
persistable=True) persistable=True,
)
if params_filename is None: if params_filename is None:
framework._dygraph_tracer().trace_op( framework._dygraph_tracer().trace_op(
type='load', type='load',
inputs={}, inputs={},
outputs={'Out': new_var}, outputs={'Out': new_var},
attrs={'file_path': os.path.join(model_path, orig_each_name)}) attrs={'file_path': os.path.join(model_path, orig_each_name)},
)
new_var.stop_gradient = False new_var.stop_gradient = False
load_var_dict[each_var.name()] = new_var load_var_dict[each_var.name()] = new_var
if params_filename is not None: if params_filename is not None:
load_var_list = [] load_var_list = []
dict_name_old_new = { dict_name_old_new = {
v: k v: k for k, v in program_holder._suffix_varname_dict.items()
for k, v in program_holder._suffix_varname_dict.items()
} }
for name in sorted(dict_name_old_new.keys()): for name in sorted(dict_name_old_new.keys()):
load_var_list.append(load_var_dict[dict_name_old_new[name]]) load_var_list.append(load_var_dict[dict_name_old_new[name]])
...@@ -623,7 +671,8 @@ def _load_persistable_vars_by_program(model_path, ...@@ -623,7 +671,8 @@ def _load_persistable_vars_by_program(model_path,
type='load_combine', type='load_combine',
inputs={}, inputs={},
outputs={'Out': load_var_list}, outputs={'Out': load_var_list},
attrs={'file_path': os.path.join(model_path, params_filename)}) attrs={'file_path': os.path.join(model_path, params_filename)},
)
for each_var in persistable_vars: for each_var in persistable_vars:
if not _is_parameter(each_var, program_holder.infer_program): if not _is_parameter(each_var, program_holder.infer_program):
...@@ -645,8 +694,9 @@ def _load_persistable_vars_by_program(model_path, ...@@ -645,8 +694,9 @@ def _load_persistable_vars_by_program(model_path,
return load_var_dict return load_var_dict
def _load_persistable_vars(model_path, var_info_path, program_holder, def _load_persistable_vars(
params_filename): model_path, var_info_path, program_holder, params_filename
):
# 1. load extra var info # 1. load extra var info
with open(var_info_path, 'rb') as f: with open(var_info_path, 'rb') as f:
extra_var_info = pickle.load(f) extra_var_info = pickle.load(f)
...@@ -655,8 +705,7 @@ def _load_persistable_vars(model_path, var_info_path, program_holder, ...@@ -655,8 +705,7 @@ def _load_persistable_vars(model_path, var_info_path, program_holder,
load_var_dict = dict() load_var_dict = dict()
load_var_list = [] load_var_list = []
inv_suffix_varname_dict = { inv_suffix_varname_dict = {
value: key value: key for key, value in program_holder._suffix_varname_dict.items()
for key, value in program_holder._suffix_varname_dict.items()
} }
# NOTE(chenweihang): we need load persistable vars based the program, # NOTE(chenweihang): we need load persistable vars based the program,
...@@ -667,7 +716,8 @@ def _load_persistable_vars(model_path, var_info_path, program_holder, ...@@ -667,7 +716,8 @@ def _load_persistable_vars(model_path, var_info_path, program_holder,
raise RuntimeError( raise RuntimeError(
"The model to be loaded is not complete." "The model to be loaded is not complete."
"The variable `%s` of program cannot be found in loaded model.", "The variable `%s` of program cannot be found in loaded model.",
name) name,
)
# get suffix var name, see [why need to append suffix to persistable vars] # get suffix var name, see [why need to append suffix to persistable vars]
new_name = inv_suffix_varname_dict[name] new_name = inv_suffix_varname_dict[name]
# create output varbase # create output varbase
...@@ -680,7 +730,8 @@ def _load_persistable_vars(model_path, var_info_path, program_holder, ...@@ -680,7 +730,8 @@ def _load_persistable_vars(model_path, var_info_path, program_holder,
], # only to pass check, this shape is not meaningful ], # only to pass check, this shape is not meaningful
dtype=core.VarDesc.VarType.FP32, dtype=core.VarDesc.VarType.FP32,
name=new_name, name=new_name,
persistable=True) persistable=True,
)
else: else:
new_var = framework.ParamBase( new_var = framework.ParamBase(
shape=[ shape=[
...@@ -688,10 +739,12 @@ def _load_persistable_vars(model_path, var_info_path, program_holder, ...@@ -688,10 +739,12 @@ def _load_persistable_vars(model_path, var_info_path, program_holder,
], # only to pass check, this shape is not meaningful ], # only to pass check, this shape is not meaningful
dtype=core.VarDesc.VarType.FP32, dtype=core.VarDesc.VarType.FP32,
name=new_name, name=new_name,
persistable=True) persistable=True,
)
else: else:
new_var = framework._varbase_creator(name=new_name, new_var = framework._varbase_creator(
persistable=True) name=new_name, persistable=True
)
new_var.stop_gradient = extra_var_info[name]['stop_gradient'] new_var.stop_gradient = extra_var_info[name]['stop_gradient']
load_var_dict[new_name] = new_var load_var_dict[new_name] = new_var
...@@ -704,10 +757,12 @@ def _load_persistable_vars(model_path, var_info_path, program_holder, ...@@ -704,10 +757,12 @@ def _load_persistable_vars(model_path, var_info_path, program_holder,
if len(extra_var_info) != 0: if len(extra_var_info) != 0:
raise ValueError("The model to be loaded is incomplete.") raise ValueError("The model to be loaded is incomplete.")
else: else:
framework._dygraph_tracer().trace_op(type='load_combine', framework._dygraph_tracer().trace_op(
type='load_combine',
inputs={}, inputs={},
outputs={'Out': load_var_list}, outputs={'Out': load_var_list},
attrs={'file_path': var_file_path}) attrs={'file_path': var_file_path},
)
return load_var_dict return load_var_dict
...@@ -729,17 +784,18 @@ def _construct_program_holders(model_path, model_filename=None): ...@@ -729,17 +784,18 @@ def _construct_program_holders(model_path, model_filename=None):
# [compatible] if assign model_filename, only can load one program as Layer.forward # [compatible] if assign model_filename, only can load one program as Layer.forward
model_filename = os.path.basename(model_filename) model_filename = os.path.basename(model_filename)
model_file_path = os.path.join(model_path, model_filename) model_file_path = os.path.join(model_path, model_filename)
model_name = model_filename[:-len(INFER_MODEL_SUFFIX)] model_name = model_filename[: -len(INFER_MODEL_SUFFIX)]
#Load every file that meets the requirements in the directory model_path. # Load every file that meets the requirements in the directory model_path.
for filename in os.listdir(model_path): for filename in os.listdir(model_path):
if model_filename == filename: if model_filename == filename:
func_name = 'forward' func_name = 'forward'
model_file_path = os.path.join(model_path, model_filename) model_file_path = os.path.join(model_path, model_filename)
elif filename.endswith(INFER_MODEL_SUFFIX) and filename.startswith( elif filename.endswith(INFER_MODEL_SUFFIX) and filename.startswith(
model_name): model_name
parsing_names = filename[len(model_name ):
):-len(INFER_MODEL_SUFFIX) + parsing_names = filename[
1].split('.') len(model_name) : -len(INFER_MODEL_SUFFIX) + 1
].split('.')
if len(parsing_names) == 3 and len(parsing_names[1]) > 0: if len(parsing_names) == 3 and len(parsing_names[1]) > 0:
func_name = parsing_names[1] func_name = parsing_names[1]
model_file_path = os.path.join(model_path, filename) model_file_path = os.path.join(model_path, filename)
...@@ -748,7 +804,8 @@ def _construct_program_holders(model_path, model_filename=None): ...@@ -748,7 +804,8 @@ def _construct_program_holders(model_path, model_filename=None):
else: else:
continue continue
program_holder_dict[func_name] = _ProgramHolder( program_holder_dict[func_name] = _ProgramHolder(
_load_program_desc(model_file_path)) _load_program_desc(model_file_path)
)
else: else:
for _, _, file_names in os.walk(model_path): for _, _, file_names in os.walk(model_path):
for name in file_names: for name in file_names:
...@@ -760,30 +817,32 @@ def _construct_program_holders(model_path, model_filename=None): ...@@ -760,30 +817,32 @@ def _construct_program_holders(model_path, model_filename=None):
else: else:
method_name.replace('model', '') method_name.replace('model', '')
program_holder_dict[method_name] = _ProgramHolder( program_holder_dict[method_name] = _ProgramHolder(
_load_program_desc(model_file_path)) _load_program_desc(model_file_path)
)
return program_holder_dict return program_holder_dict
def _construct_params_and_buffers(model_path, def _construct_params_and_buffers(
programs, model_path, programs, params_filename=None, append_suffix=True
params_filename=None, ):
append_suffix=True):
var_info_filename = str(params_filename) + ".info" var_info_filename = str(params_filename) + ".info"
var_info_path = os.path.join(model_path, var_info_filename) var_info_path = os.path.join(model_path, var_info_filename)
params_path = os.path.join(model_path, str(params_filename)) params_path = os.path.join(model_path, str(params_filename))
if os.path.exists(var_info_path): if os.path.exists(var_info_path):
var_dict = _load_persistable_vars(model_path, var_info_path, var_dict = _load_persistable_vars(
programs['forward'], params_filename) model_path, var_info_path, programs['forward'], params_filename
model_name = params_filename[:-len(INFER_PARAMS_SUFFIX)] )
#Load every file that meets the requirements in the directory model_path. model_name = params_filename[: -len(INFER_PARAMS_SUFFIX)]
# Load every file that meets the requirements in the directory model_path.
for file_name in os.listdir(model_path): for file_name in os.listdir(model_path):
if file_name.startswith(model_name) and file_name.endswith( if file_name.startswith(model_name) and file_name.endswith(
INFER_PARAMS_SUFFIX): INFER_PARAMS_SUFFIX
parsing_names = file_name[len(model_name ):
):-len(INFER_PARAMS_SUFFIX) + parsing_names = file_name[
1].split('.') len(model_name) : -len(INFER_PARAMS_SUFFIX) + 1
].split('.')
if len(parsing_names) == 3 and len(parsing_names[1]) > 0: if len(parsing_names) == 3 and len(parsing_names[1]) > 0:
func_name = parsing_names[1] func_name = parsing_names[1]
else: else:
...@@ -792,15 +851,17 @@ def _construct_params_and_buffers(model_path, ...@@ -792,15 +851,17 @@ def _construct_params_and_buffers(model_path,
continue continue
var_info_path = os.path.join(model_path, var_info_filename) var_info_path = os.path.join(model_path, var_info_filename)
var_dict.update( var_dict.update(
_load_persistable_vars(model_path, var_info_path, _load_persistable_vars(
programs[func_name], file_name)) model_path, var_info_path, programs[func_name], file_name
)
)
elif params_filename is not None and not os.path.exists(params_path): elif params_filename is not None and not os.path.exists(params_path):
# When saving XX, there is only '*.pdmodel' # When saving XX, there is only '*.pdmodel'
return dict() return dict()
else: else:
var_dict = _load_persistable_vars_by_program(model_path, var_dict = _load_persistable_vars_by_program(
programs['forward'], model_path, programs['forward'], params_filename
params_filename) )
if not append_suffix: if not append_suffix:
var_dict = _remove_varname_suffix(var_dict, programs['forward']) var_dict = _remove_varname_suffix(var_dict, programs['forward'])
...@@ -813,13 +874,23 @@ def _valid_vars(vars): ...@@ -813,13 +874,23 @@ def _valid_vars(vars):
return vars return vars
if framework._in_eager_without_dygraph_check(): if framework._in_eager_without_dygraph_check():
return [ return [
core.eager.Tensor(core.VarDesc.VarType.FP32, [], "Fake_var", core.eager.Tensor(
core.VarDesc.VarType.RAW, False) core.VarDesc.VarType.FP32,
[],
"Fake_var",
core.VarDesc.VarType.RAW,
False,
)
] ]
else: else:
return [ return [
core.VarBase(core.VarDesc.VarType.FP32, [], "Fake_var", core.VarBase(
core.VarDesc.VarType.RAW, False) core.VarDesc.VarType.FP32,
[],
"Fake_var",
core.VarDesc.VarType.RAW,
False,
)
] ]
...@@ -831,7 +902,8 @@ def _run_dygraph(instance, input, program_holder): ...@@ -831,7 +902,8 @@ def _run_dygraph(instance, input, program_holder):
if not isinstance(value, (np.ndarray, core.VarBase, core.eager.Tensor)): if not isinstance(value, (np.ndarray, core.VarBase, core.eager.Tensor)):
raise TypeError( raise TypeError(
"The type of input in TranslatedLayer must be numpy array or Variable(VarBase), but received %s." "The type of input in TranslatedLayer must be numpy array or Variable(VarBase), but received %s."
% type(value)) % type(value)
)
# NOTE: In order to unify the API, firstly convert the input to VarBase # NOTE: In order to unify the API, firstly convert the input to VarBase
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
if framework._in_eager_without_dygraph_check(): if framework._in_eager_without_dygraph_check():
...@@ -840,13 +912,16 @@ def _run_dygraph(instance, input, program_holder): ...@@ -840,13 +912,16 @@ def _run_dygraph(instance, input, program_holder):
name=program_holder.input_descs[i].name(), name=program_holder.input_descs[i].name(),
persistable=False, persistable=False,
place=framework._current_expected_place(), place=framework._current_expected_place(),
zero_copy=True) zero_copy=True,
)
else: else:
var = core.VarBase(value=value, var = core.VarBase(
value=value,
name=program_holder.input_descs[i].name(), name=program_holder.input_descs[i].name(),
persistable=False, persistable=False,
place=framework._current_expected_place(), place=framework._current_expected_place(),
zero_copy=True) zero_copy=True,
)
else: else:
var = value var = value
# NOTE: we changed var name here, # NOTE: we changed var name here,
...@@ -868,67 +943,112 @@ def _run_dygraph(instance, input, program_holder): ...@@ -868,67 +943,112 @@ def _run_dygraph(instance, input, program_holder):
else: else:
raise ValueError( raise ValueError(
"The persistable variable %s does not exist in current TranslatedLayer." "The persistable variable %s does not exist in current TranslatedLayer."
% var_name) % var_name
)
output_vars = [] output_vars = []
for var_desc in program_holder.output_descs: for var_desc in program_holder.output_descs:
if framework._in_eager_without_dygraph_check(): if framework._in_eager_without_dygraph_check():
var = core.eager.Tensor(dtype=var_desc.dtype(), var = core.eager.Tensor(
dtype=var_desc.dtype(),
dims=var_desc.shape(), dims=var_desc.shape(),
name=var_desc.name(), name=var_desc.name(),
type=var_desc.type(), type=var_desc.type(),
persistable=False) persistable=False,
)
else: else:
var = core.VarBase(var_desc.dtype(), var_desc.shape(), var = core.VarBase(
var_desc.name(), var_desc.type(), False) var_desc.dtype(),
var_desc.shape(),
var_desc.name(),
var_desc.type(),
False,
)
output_vars.append(var) output_vars.append(var)
# hold forward variables # hold forward variables
if framework._in_eager_without_dygraph_check(): if framework._in_eager_without_dygraph_check():
tmp_scope_vec = [program_holder.scope] tmp_scope_vec = [program_holder.scope]
else: else:
tmp_scope_vec = core.VarBase(core.VarDesc.VarType.FP32, [], tmp_scope_vec = core.VarBase(
core.VarDesc.VarType.FP32,
[],
"program_out_scope", "program_out_scope",
core.VarDesc.VarType.STEP_SCOPES, True) core.VarDesc.VarType.STEP_SCOPES,
True,
)
tmp_scope_vec.value().set_scope(program_holder.scope) tmp_scope_vec.value().set_scope(program_holder.scope)
double_grad_vars = [] double_grad_vars = []
for var_desc in program_holder.double_grad_descs: for var_desc in program_holder.double_grad_descs:
if framework._in_eager_without_dygraph_check(): if framework._in_eager_without_dygraph_check():
var = core.eager.Tensor(dtype=var_desc.dtype(), var = core.eager.Tensor(
dtype=var_desc.dtype(),
dims=var_desc.shape(), dims=var_desc.shape(),
name=var_desc.name(), name=var_desc.name(),
type=var_desc.type(), type=var_desc.type(),
persistable=False) persistable=False,
)
else: else:
var = core.VarBase(var_desc.dtype(), var_desc.shape(), var = core.VarBase(
var_desc.name(), var_desc.type(), False) var_desc.dtype(),
var_desc.shape(),
var_desc.name(),
var_desc.type(),
False,
)
double_grad_vars.append(var) double_grad_vars.append(var)
# 2. run program by op # 2. run program by op
trace_program = program_holder.infer_program if instance._is_test else program_holder.train_program trace_program = (
forward_program = program_holder._infer_program_desc if instance._is_test else program_holder.forward_program program_holder.infer_program
if instance._is_test
else program_holder.train_program
)
forward_program = (
program_holder._infer_program_desc
if instance._is_test
else program_holder.forward_program
)
end_op_index = program_holder.infer_program.block(0).op_size() end_op_index = program_holder.infer_program.block(0).op_size()
attrs = [ attrs = [
'global_block', 'global_block',
trace_program.block(0), 'start_op_index', 0, 'end_op_index', trace_program.block(0),
end_op_index, 'is_test', instance._is_test, 'program_id', 'start_op_index',
_hash_with_id(trace_program, instance) 0,
'end_op_index',
end_op_index,
'is_test',
instance._is_test,
'program_id',
_hash_with_id(trace_program, instance),
] ]
use_interpretorcore = _is_enable_standalone_executor( use_interpretorcore = (
) and _is_dy2st_enable_standalone_executor() _is_enable_standalone_executor()
and _is_dy2st_enable_standalone_executor()
)
attrs.extend(('use_interpretorcore', use_interpretorcore)) attrs.extend(('use_interpretorcore', use_interpretorcore))
if use_interpretorcore: if use_interpretorcore:
attrs.extend( attrs.extend(
('forward_global_block', forward_program.block(0), (
'backward_global_block', program_holder.backward_program.block(0))) 'forward_global_block',
forward_program.block(0),
'backward_global_block',
program_holder.backward_program.block(0),
)
)
_legacy_C_ops.run_program(_valid_vars(input_vars), _legacy_C_ops.run_program(
_valid_vars(input_vars),
_valid_vars(persistable_vars), _valid_vars(persistable_vars),
_valid_vars(output_vars), tmp_scope_vec, _valid_vars(output_vars),
_valid_vars(double_grad_vars), None, *attrs) tmp_scope_vec,
_valid_vars(double_grad_vars),
None,
*attrs
)
# NOTE: [ why need set param's gradient type here ] # NOTE: [ why need set param's gradient type here ]
# if user set sparse gradient mode, the param's gradient # if user set sparse gradient mode, the param's gradient
...@@ -946,8 +1066,6 @@ def _run_dygraph(instance, input, program_holder): ...@@ -946,8 +1066,6 @@ def _run_dygraph(instance, input, program_holder):
continue continue
persistable_var._set_grad_type(grad_var.type()) persistable_var._set_grad_type(grad_var.type())
drop_scope_if_no_grad(instance, tmp_scope_vec)
# 3. prepare output, keep same form with inputs # 3. prepare output, keep same form with inputs
outs = output_vars outs = output_vars
if len(output_vars) == 1: if len(output_vars) == 1:
...@@ -955,27 +1073,26 @@ def _run_dygraph(instance, input, program_holder): ...@@ -955,27 +1073,26 @@ def _run_dygraph(instance, input, program_holder):
return outs return outs
def drop_scope_if_no_grad(instance, scope_vec):
tracer = framework._dygraph_tracer()
scope = scope_vec.value().get_scope() if isinstance(
scope_vec, (core.VarBase)) else scope_vec[0]
if (not instance._is_test) and (not tracer._has_grad):
scope.drop_kids()
def _run_static_graph(input, program_holder, trace_program): def _run_static_graph(input, program_holder, trace_program):
main_program = framework.default_main_program() main_program = framework.default_main_program()
param_var_names = _get_persistable_var_names(trace_program) param_var_names = _get_persistable_var_names(trace_program)
_, dict_rename_var_old_new = _rename_var_program_desc( _, dict_rename_var_old_new = _rename_var_program_desc(
trace_program, exclude=param_var_names) trace_program, exclude=param_var_names
)
trace_program.flush() trace_program.flush()
output_names = [var.name() for var in program_holder.output_descs] output_names = [var.name() for var in program_holder.output_descs]
# append blocks from 'trace_program' # append blocks from 'trace_program'
_append_block(main_program, trace_program, program_holder, input, _append_block(
dict_rename_var_old_new) main_program,
trace_program,
program_holder,
input,
dict_rename_var_old_new,
)
main_program._sync_with_cpp() main_program._sync_with_cpp()
outs = _get_output_from_program(main_program, program_holder, outs = _get_output_from_program(
dict_rename_var_old_new) main_program, program_holder, dict_rename_var_old_new
)
if len(outs) == 1: if len(outs) == 1:
outs = outs[0] outs = outs[0]
return outs return outs
...@@ -1003,11 +1120,13 @@ def _collect_current_and_parent_var(program, block_idx): ...@@ -1003,11 +1120,13 @@ def _collect_current_and_parent_var(program, block_idx):
return vars return vars
def _append_block(dest_program, def _append_block(
dest_program,
src_program_desc, src_program_desc,
program_holder, program_holder,
input_variables, input_variables,
dict_rename_var_old_new=None): dict_rename_var_old_new=None,
):
''' '''
Append Variables and Operators in 'src_program_desc' to dest_program. Append Variables and Operators in 'src_program_desc' to dest_program.
...@@ -1021,28 +1140,35 @@ def _append_block(dest_program, ...@@ -1021,28 +1140,35 @@ def _append_block(dest_program,
''' '''
origin_block_idx = dest_program.current_block_idx origin_block_idx = dest_program.current_block_idx
param_var_names = _collect_current_and_parent_var(dest_program, param_var_names = _collect_current_and_parent_var(
origin_block_idx) dest_program, origin_block_idx
append_var_from_block_desc_static(dest_program.block(origin_block_idx), )
append_var_from_block_desc_static(
dest_program.block(origin_block_idx),
src_program_desc.block(0), src_program_desc.block(0),
exclude=param_var_names) exclude=param_var_names,
)
name_inp_desc = [inp.name() for inp in program_holder.input_descs] name_inp_desc = [inp.name() for inp in program_holder.input_descs]
input_names = [inp.name for inp in input_variables] input_names = [inp.name for inp in input_variables]
if len(name_inp_desc) != len(input_names): if len(name_inp_desc) != len(input_names):
raise ValueError( raise ValueError(
"The number of input is invalid, expected {}, but received {}.". "The number of input is invalid, expected {}, but received {}.".format(
format(len(name_inp_desc), len(input_names))) len(name_inp_desc), len(input_names)
)
)
for i, out_name in enumerate(name_inp_desc): for i, out_name in enumerate(name_inp_desc):
if dict_rename_var_old_new: if dict_rename_var_old_new:
out_name = dict_rename_var_old_new[out_name] out_name = dict_rename_var_old_new[out_name]
dest_program.block(origin_block_idx).append_op( dest_program.block(origin_block_idx).append_op(
type='assign', type='assign',
inputs={'X': [input_names[i]]}, inputs={'X': [input_names[i]]},
outputs={'Out': [out_name]}) outputs={'Out': [out_name]},
)
append_ops = append_op_from_block_desc_static( append_ops = append_op_from_block_desc_static(
dest_program.block(origin_block_idx), src_program_desc.block(0)) dest_program.block(origin_block_idx), src_program_desc.block(0)
)
dest_program._sync_with_cpp() dest_program._sync_with_cpp()
offset_block_idx = dest_program.num_blocks - 1 offset_block_idx = dest_program.num_blocks - 1
...@@ -1056,11 +1182,12 @@ def _append_block(dest_program, ...@@ -1056,11 +1182,12 @@ def _append_block(dest_program,
else: else:
parent_idx = origin_block_idx parent_idx = origin_block_idx
dest_block = dest_program._create_block(parent_idx=parent_idx) dest_block = dest_program._create_block(parent_idx=parent_idx)
append_var_from_block_desc_static(dest_block, append_var_from_block_desc_static(
src_block, dest_block, src_block, exclude=param_var_names
exclude=param_var_names) )
append_ops += append_op_from_block_desc_static( append_ops += append_op_from_block_desc_static(
dest_block, src_block) dest_block, src_block
)
dest_program._sync_with_cpp() dest_program._sync_with_cpp()
for op in append_ops: for op in append_ops:
...@@ -1070,15 +1197,16 @@ def _append_block(dest_program, ...@@ -1070,15 +1197,16 @@ def _append_block(dest_program,
origin_id = sub.id origin_id = sub.id
if isinstance(sub, framework.Block): if isinstance(sub, framework.Block):
origin_id = sub.idx origin_id = sub.idx
op._set_attr('sub_block', op._set_attr(
dest_program.block(offset_block_idx + origin_id)) 'sub_block', dest_program.block(offset_block_idx + origin_id)
)
dest_program._sync_with_cpp() dest_program._sync_with_cpp()
dest_program.current_block_idx = origin_block_idx dest_program.current_block_idx = origin_block_idx
def _get_output_from_program(program, def _get_output_from_program(
program_holder, program, program_holder, dict_rename_var_old_new=None
dict_rename_var_old_new=None): ):
""" """
Get output name of 'program' according to program_holder Get output name of 'program' according to program_holder
""" """
...@@ -1127,20 +1255,21 @@ def append_op_from_desc_static(block, op_desc): ...@@ -1127,20 +1255,21 @@ def append_op_from_desc_static(block, op_desc):
op_type = op_desc.type() op_type = op_desc.type()
op_append = block.desc.append_op() op_append = block.desc.append_op()
op_append.copy_from(op_desc) op_append.copy_from(op_desc)
op = framework.Operator(block=block, op = framework.Operator(
block=block,
desc=op_append, desc=op_append,
type=op_type, type=op_type,
inputs=None, inputs=None,
outputs=None, outputs=None,
attrs=None) attrs=None,
)
block.ops.append(op) block.ops.append(op)
return op return op
def append_var_from_block_desc_static(block, def append_var_from_block_desc_static(
src_block_desc, block, src_block_desc, include=None, exclude=None
include=None, ):
exclude=None):
""" """
Append Variables of 'src_block_desc' to current block. Append Variables of 'src_block_desc' to current block.
If 'include' is not `None`,variables that are not in include are not append. If 'include' is not `None`,variables that are not in include are not append.
...@@ -1159,13 +1288,14 @@ def append_var_from_block_desc_static(block, ...@@ -1159,13 +1288,14 @@ def append_var_from_block_desc_static(block,
for var_desc in src_block_desc.all_vars(): for var_desc in src_block_desc.all_vars():
var_desc_name = var_desc.name() var_desc_name = var_desc.name()
should_append = (include is None or var_desc_name in include) and ( should_append = (include is None or var_desc_name in include) and (
exclude is None or var_desc_name not in exclude) exclude is None or var_desc_name not in exclude
)
if not block.has_var(var_desc_name) and should_append: if not block.has_var(var_desc_name) and should_append:
var_type = var_desc.type() var_type = var_desc.type()
if var_type in [ if var_type in [
core.VarDesc.VarType.SELECTED_ROWS, core.VarDesc.VarType.SELECTED_ROWS,
core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.LOD_TENSOR_ARRAY core.VarDesc.VarType.LOD_TENSOR_ARRAY,
]: ]:
data_type = var_desc.dtype() data_type = var_desc.dtype()
var_shape = var_desc.shape() var_shape = var_desc.shape()
...@@ -1174,7 +1304,7 @@ def append_var_from_block_desc_static(block, ...@@ -1174,7 +1304,7 @@ def append_var_from_block_desc_static(block,
var_shape = None var_shape = None
if var_type in [ if var_type in [
core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.LOD_TENSOR_ARRAY core.VarDesc.VarType.LOD_TENSOR_ARRAY,
]: ]:
lod_level = var_desc.lod_level() lod_level = var_desc.lod_level()
else: else:
...@@ -1193,7 +1323,9 @@ def append_var_from_block_desc_static(block, ...@@ -1193,7 +1323,9 @@ def append_var_from_block_desc_static(block,
shape=var_shape, shape=var_shape,
lod_level=lod_level, lod_level=lod_level,
persistable=var_desc.persistable(), persistable=var_desc.persistable(),
set_need_check_feed=var_desc.need_check_feed())) set_need_check_feed=var_desc.need_check_feed(),
)
)
return vars_append return vars_append
...@@ -1318,8 +1450,9 @@ class TranslatedLayer(layers.Layer): ...@@ -1318,8 +1450,9 @@ class TranslatedLayer(layers.Layer):
# the TranslatedLayer object holded var names count started from 0 # the TranslatedLayer object holded var names count started from 0
with unique_name.guard(): with unique_name.guard():
for name, var in persistable_vars.items(): for name, var in persistable_vars.items():
if isinstance(var, if isinstance(
(framework.ParamBase, framework.EagerParamBase)): var, (framework.ParamBase, framework.EagerParamBase)
):
dy_name = _generate_unique_var_name(PARAMETER_NAME_PREFIX) dy_name = _generate_unique_var_name(PARAMETER_NAME_PREFIX)
self._persistable_var_name_dict[name] = dy_name self._persistable_var_name_dict[name] = dy_name
self.add_parameter(dy_name, var) self.add_parameter(dy_name, var)
...@@ -1353,7 +1486,8 @@ class TranslatedLayer(layers.Layer): ...@@ -1353,7 +1486,8 @@ class TranslatedLayer(layers.Layer):
# 2. load layer parameters & buffers # 2. load layer parameters & buffers
persistable_vars = _construct_params_and_buffers( persistable_vars = _construct_params_and_buffers(
model_path, programs, params_filename) model_path, programs, params_filename
)
# 3. construct TranslatedLayer object # 3. construct TranslatedLayer object
translated_layer = TranslatedLayer(programs, persistable_vars) translated_layer = TranslatedLayer(programs, persistable_vars)
...@@ -1365,9 +1499,12 @@ class TranslatedLayer(layers.Layer): ...@@ -1365,9 +1499,12 @@ class TranslatedLayer(layers.Layer):
ins.name() for ins in program_holder.input_descs ins.name() for ins in program_holder.input_descs
] ]
setattr( setattr(
TranslatedLayer, method_name, TranslatedLayer,
method_name,
TranslatedLayer._execution_method_creator( TranslatedLayer._execution_method_creator(
method_name, program_holder)) method_name, program_holder
),
)
# 5. set TranslatedLayer's default mode to eval # 5. set TranslatedLayer's default mode to eval
translated_layer.eval() translated_layer.eval()
...@@ -1376,7 +1513,6 @@ class TranslatedLayer(layers.Layer): ...@@ -1376,7 +1513,6 @@ class TranslatedLayer(layers.Layer):
@staticmethod @staticmethod
def _execution_method_creator(method_name, program_holder): def _execution_method_creator(method_name, program_holder):
def __i_m_p_l__(self, *input): def __i_m_p_l__(self, *input):
program_holder = self._program_holder_dict[__i_m_p_l__.__name__] program_holder = self._program_holder_dict[__i_m_p_l__.__name__]
# When using jit.save, it runs in static graph mode. # When using jit.save, it runs in static graph mode.
...@@ -1389,7 +1525,8 @@ class TranslatedLayer(layers.Layer): ...@@ -1389,7 +1525,8 @@ class TranslatedLayer(layers.Layer):
# because '_run_static_graph' modifies 'ProgramDesc', 'OpDesc.op_size()' will return a very large wrong number. # because '_run_static_graph' modifies 'ProgramDesc', 'OpDesc.op_size()' will return a very large wrong number.
# A Segmentation fault error may occur if used 'p=ProgramDesc(program_holder.infer_program)'. # A Segmentation fault error may occur if used 'p=ProgramDesc(program_holder.infer_program)'.
p = framework.Program._construct_from_desc( p = framework.Program._construct_from_desc(
core.ProgramDesc(program_holder.infer_program)) core.ProgramDesc(program_holder.infer_program)
)
return _run_static_graph(input, program_holder, p.desc) return _run_static_graph(input, program_holder, p.desc)
__i_m_p_l__.__name__ = method_name __i_m_p_l__.__name__ = method_name
...@@ -1502,8 +1639,9 @@ class TranslatedLayer(layers.Layer): ...@@ -1502,8 +1639,9 @@ class TranslatedLayer(layers.Layer):
program_holder = self._program_holder_dict.get(method_name, None) program_holder = self._program_holder_dict.get(method_name, None)
if program_holder is None: if program_holder is None:
raise ValueError( raise ValueError(
"The method `%s` does not exist in loaded TranslatedLayer." % "The method `%s` does not exist in loaded TranslatedLayer."
method_name) % method_name
)
return program_holder return program_holder
def _input_spec(self, method_name='forward'): def _input_spec(self, method_name='forward'):
...@@ -1513,9 +1651,11 @@ class TranslatedLayer(layers.Layer): ...@@ -1513,9 +1651,11 @@ class TranslatedLayer(layers.Layer):
# 2. build input spec by input desc # 2. build input spec by input desc
input_spec = [] input_spec = []
for var_desc in program_holder.input_descs: for var_desc in program_holder.input_descs:
spec = paddle.static.InputSpec(shape=var_desc.shape(), spec = paddle.static.InputSpec(
shape=var_desc.shape(),
dtype=var_desc.dtype(), dtype=var_desc.dtype(),
name=var_desc.name()) name=var_desc.name(),
)
input_spec.append(spec) input_spec.append(spec)
return input_spec return input_spec
...@@ -1530,9 +1670,11 @@ class TranslatedLayer(layers.Layer): ...@@ -1530,9 +1670,11 @@ class TranslatedLayer(layers.Layer):
# NOTE(chenweihang): InputSpec describes a tensor, not just input. # NOTE(chenweihang): InputSpec describes a tensor, not just input.
# Maybe the name is not good enough. Here we use InputSpec to # Maybe the name is not good enough. Here we use InputSpec to
# construct the description of Output tensor # construct the description of Output tensor
spec = paddle.static.InputSpec(shape=var_desc.shape(), spec = paddle.static.InputSpec(
shape=var_desc.shape(),
dtype=var_desc.dtype(), dtype=var_desc.dtype(),
name=var_desc.name()) name=var_desc.name(),
)
output_spec.append(spec) output_spec.append(spec)
return output_spec return output_spec
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册