未验证 提交 51c4028c 编写于 作者: F feifei-111 提交者: GitHub

[dy2s] fix error when using same tensor as inputs in one call & fix bugs in jit.save (#55963)

上级 752f29a1
...@@ -77,10 +77,11 @@ static void clear_unused_out_var_in_backward( ...@@ -77,10 +77,11 @@ static void clear_unused_out_var_in_backward(
static std::vector<paddle::Tensor> filter_unused_input_var_in_backward( static std::vector<paddle::Tensor> filter_unused_input_var_in_backward(
const std::vector<paddle::Tensor>& x, const std::vector<paddle::Tensor>& x,
const std::vector<std::string>& x_names,
const paddle::framework::BlockDesc* backward_block) { const paddle::framework::BlockDesc* backward_block) {
auto filter_x = std::vector<paddle::Tensor>(x); auto filter_x = std::vector<paddle::Tensor>(x);
for (size_t i = 0; i < x.size(); i++) { for (size_t i = 0; i < x.size(); i++) {
if (!backward_block->HasVar(x[i].name())) { if (!backward_block->HasVar(x_names[i])) {
auto fake = paddle::Tensor(std::make_shared<phi::DenseTensor>()); auto fake = paddle::Tensor(std::make_shared<phi::DenseTensor>());
fake.set_name(paddle::framework::kFakeVarName); fake.set_name(paddle::framework::kFakeVarName);
filter_x[i] = fake; filter_x[i] = fake;
...@@ -117,6 +118,9 @@ inline void run_program_ad_func( ...@@ -117,6 +118,9 @@ inline void run_program_ad_func(
VLOG(2) << "start run run_program grad"; VLOG(2) << "start run run_program grad";
if (require_any_grad) { if (require_any_grad) {
auto x_names =
PADDLE_GET_CONST(std::vector<std::string>, attrs.at("x_names"));
egr::EagerUtils::PassStopGradient(false, &p_autograd_outs); egr::EagerUtils::PassStopGradient(false, &p_autograd_outs);
// Create GradOpNode (1 means [out_grad], 2 means [x_grad, paramx_grad]) // Create GradOpNode (1 means [out_grad], 2 means [x_grad, paramx_grad])
auto grad_node = std::make_shared<GradNodeRunProgram>(1, 2); auto grad_node = std::make_shared<GradNodeRunProgram>(1, 2);
...@@ -130,7 +134,7 @@ inline void run_program_ad_func( ...@@ -130,7 +134,7 @@ inline void run_program_ad_func(
paddle::framework::BlockDesc*, attrs.at("backward_global_block")); paddle::framework::BlockDesc*, attrs.at("backward_global_block"));
// Clear unused x vars // Clear unused x vars
auto filter_x = auto filter_x =
filter_unused_input_var_in_backward(x, backward_global_block); filter_unused_input_var_in_backward(x, x_names, backward_global_block);
// Set TensorWrappers // Set TensorWrappers
grad_node->SetFwdX(filter_x); grad_node->SetFwdX(filter_x);
// Clear unused out vars // Clear unused out vars
...@@ -145,7 +149,7 @@ inline void run_program_ad_func( ...@@ -145,7 +149,7 @@ inline void run_program_ad_func(
std::vector<const paddle::Tensor*> x_require_grad; std::vector<const paddle::Tensor*> x_require_grad;
for (size_t i = 0; i < x.size(); ++i) { for (size_t i = 0; i < x.size(); ++i) {
auto& name = x[i].name(); auto& name = x_names[i];
if (forward_global_block->HasVar(name) || if (forward_global_block->HasVar(name) ||
backward_global_block->HasVar(name)) { backward_global_block->HasVar(name)) {
x_require_grad.push_back(&x[i]); x_require_grad.push_back(&x[i]);
......
...@@ -150,6 +150,31 @@ static void ShareTensorsIntoScope(const std::vector<Tensor> &tensors, ...@@ -150,6 +150,31 @@ static void ShareTensorsIntoScope(const std::vector<Tensor> &tensors,
} }
} }
static void ShareTensorsIntoScopeWithName(
const std::vector<Tensor> &tensors,
const std::vector<std::string> &tensor_names,
paddle::framework::Scope *scope) {
for (size_t i = 0; i < tensors.size(); ++i) {
auto name = tensor_names[i];
if (name == paddle::framework::kFakeVarName) {
continue;
}
auto *var = scope->Var(name);
CheckInputVarStatus(tensors[i]);
// share tensor
auto tensor_base = tensors[i].impl();
if (phi::DenseTensor::classof(tensor_base.get())) {
auto *dst_tensor = var->GetMutable<phi::DenseTensor>();
auto t = std::dynamic_pointer_cast<phi::DenseTensor>(tensor_base);
*dst_tensor = *t;
} else if (phi::SelectedRows::classof(tensor_base.get())) {
auto *dst_tensor = var->GetMutable<phi::SelectedRows>();
auto t = std::dynamic_pointer_cast<phi::SelectedRows>(tensor_base);
*dst_tensor = *t;
}
}
}
static void ShareTensorsFromScope( static void ShareTensorsFromScope(
const std::vector<Tensor *> &tensors, const std::vector<Tensor *> &tensors,
const paddle::framework::BlockDesc &global_block, const paddle::framework::BlockDesc &global_block,
...@@ -320,7 +345,8 @@ inline void RunProgramAPI( ...@@ -320,7 +345,8 @@ inline void RunProgramAPI(
VLOG(4) << "global_inner_scope:" << global_inner_scope; VLOG(4) << "global_inner_scope:" << global_inner_scope;
auto input_names = details::GetTensorsName(x); auto input_names =
PADDLE_GET_CONST(std::vector<std::string>, attrs.at("x_names"));
auto output_names = details::GetTensorsName(out); auto output_names = details::GetTensorsName(out);
auto param_names = details::GetTensorsName(params); auto param_names = details::GetTensorsName(params);
auto dout_names = details::GetTensorsName(dout); auto dout_names = details::GetTensorsName(dout);
...@@ -371,7 +397,7 @@ inline void RunProgramAPI( ...@@ -371,7 +397,7 @@ inline void RunProgramAPI(
"for program: " "for program: "
<< program_id; << program_id;
// Step 1. share input_vars & parameters into scope // Step 1. share input_vars & parameters into scope
details::ShareTensorsIntoScope(x, global_inner_scope); details::ShareTensorsIntoScopeWithName(x, input_names, global_inner_scope);
details::ShareTensorsIntoScope(params, global_inner_scope); details::ShareTensorsIntoScope(params, global_inner_scope);
// Step 2. create new interpretercore // Step 2. create new interpretercore
...@@ -433,7 +459,7 @@ inline void RunProgramAPI( ...@@ -433,7 +459,7 @@ inline void RunProgramAPI(
program_id, global_inner_scope, /*is_grad=*/false); program_id, global_inner_scope, /*is_grad=*/false);
interpreter_core = cached_value.core_; interpreter_core = cached_value.core_;
// Step 2. update scope for cache interpretercore // Step 2. update scope for cache interpretercore
details::ShareTensorsIntoScope(x, global_inner_scope); details::ShareTensorsIntoScopeWithName(x, input_names, global_inner_scope);
details::ShareTensorsIntoScope(params, global_inner_scope); details::ShareTensorsIntoScope(params, global_inner_scope);
if (interpreter_core->GetVariableScope()->GetMutableScope() != if (interpreter_core->GetVariableScope()->GetMutableScope() !=
global_inner_scope) { global_inner_scope) {
......
...@@ -139,6 +139,10 @@ class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -139,6 +139,10 @@ class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker {
"std::vector<std::string>" "std::vector<std::string>"
"The names of output gradients.") "The names of output gradients.")
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<std::string>>("x_names",
"std::vector<std::string>"
"The names of input tensors.")
.SetDefault({});
AddAttr<std::vector<std::string>>("x_grad_names", AddAttr<std::vector<std::string>>("x_grad_names",
"std::vector<std::string>" "std::vector<std::string>"
"The names of input gradients.") "The names of input gradients.")
......
...@@ -415,6 +415,9 @@ class _SaveLoadConfig: ...@@ -415,6 +415,9 @@ class _SaveLoadConfig:
# if True, multi `StaticFunction` will share params in one file. # if True, multi `StaticFunction` will share params in one file.
self.combine_params = False self.combine_params = False
# when need to save a prune model, use input_names_after_prune to specify the inputs left after pruning
self.input_names_after_prune = None
@property @property
def output_spec(self): def output_spec(self):
return self._output_spec return self._output_spec
...@@ -488,11 +491,12 @@ class _SaveLoadConfig: ...@@ -488,11 +491,12 @@ class _SaveLoadConfig:
def _parse_save_configs(configs): def _parse_save_configs(configs):
supported_configs = [ supported_configs = [
'output_spec', "output_spec",
"with_hook", "with_hook",
"combine_params", "combine_params",
"clip_extra", "clip_extra",
"skip_forward", "skip_forward",
"input_names_after_prune",
] ]
# input check # input check
...@@ -505,11 +509,14 @@ def _parse_save_configs(configs): ...@@ -505,11 +509,14 @@ def _parse_save_configs(configs):
# construct inner config # construct inner config
inner_config = _SaveLoadConfig() inner_config = _SaveLoadConfig()
inner_config.output_spec = configs.get('output_spec', None) inner_config.output_spec = configs.get("output_spec", None)
inner_config.with_hook = configs.get('with_hook', False) inner_config.with_hook = configs.get("with_hook", False)
inner_config.combine_params = configs.get("combine_params", False) inner_config.combine_params = configs.get("combine_params", False)
inner_config.clip_extra = configs.get("clip_extra", True) inner_config.clip_extra = configs.get("clip_extra", True)
inner_config.skip_forward = configs.get("skip_forward", False) inner_config.skip_forward = configs.get("skip_forward", False)
inner_config.input_names_after_prune = configs.get(
"input_names_after_prune", None
)
return inner_config return inner_config
...@@ -533,7 +540,7 @@ def _parse_load_config(configs): ...@@ -533,7 +540,7 @@ def _parse_load_config(configs):
return inner_config return inner_config
def _get_input_var_names(inputs, input_spec): def _get_input_var_names(inputs, input_spec, input_names_after_prune):
name_none_error = ( name_none_error = (
"The %s's name is None. " "The %s's name is None. "
"When using jit.save, please set InputSepc's name in " "When using jit.save, please set InputSepc's name in "
...@@ -546,6 +553,14 @@ def _get_input_var_names(inputs, input_spec): ...@@ -546,6 +553,14 @@ def _get_input_var_names(inputs, input_spec):
"in input_spec is the same as the name of InputSpec in " "in input_spec is the same as the name of InputSpec in "
"`to_static` decorated on the Layer.forward method." "`to_static` decorated on the Layer.forward method."
) )
if input_names_after_prune is not None:
input_spec = [
x
for x in input_spec
if isinstance(x, paddle.static.InputSpec)
and x.name in input_names_after_prune
]
result_list = [] result_list = []
input_var_names = [ input_var_names = [
var.name var.name
...@@ -1201,7 +1216,9 @@ def save(layer, path, input_spec=None, **configs): ...@@ -1201,7 +1216,9 @@ def save(layer, path, input_spec=None, **configs):
# - the input_spec length < len((concrete_program.inputs) - 1 # - the input_spec length < len((concrete_program.inputs) - 1
# - the input_spec's name should be in concrete_program.inputs # - the input_spec's name should be in concrete_program.inputs
input_var_names = _get_input_var_names( input_var_names = _get_input_var_names(
concrete_program.inputs, inner_input_spec concrete_program.inputs,
inner_input_spec,
configs.input_names_after_prune,
) )
# NOTE(chenweihang): [ Get output variables ] # NOTE(chenweihang): [ Get output variables ]
......
...@@ -150,8 +150,17 @@ class FunctionSpec: ...@@ -150,8 +150,17 @@ class FunctionSpec:
# replace argument with corresponding InputSpec. # replace argument with corresponding InputSpec.
args_with_spec = convert_to_input_spec(args, self._input_spec) args_with_spec = convert_to_input_spec(args, self._input_spec)
else: else:
args_with_spec = _replace_value_with_input_spec(args) args_with_spec = _replace_to_input_spec_with_new_name(
kwargs_with_spec = _replace_value_with_input_spec(kwargs) args, self._arg_names
)
kwarg_names = ["kwargs." + key for key in kwargs.keys()]
kwargs_list_with_spec = _replace_to_input_spec_with_new_name(
list(kwargs.values()), kwarg_names
)
kwargs_with_spec = {
key: kwargs_list_with_spec[idx]
for idx, key in enumerate(kwargs)
}
# If without specificing name in input_spec, add default name # If without specificing name in input_spec, add default name
# according to argument name from decorated function. # according to argument name from decorated function.
...@@ -302,6 +311,44 @@ def _replace_value_with_input_spec(args): ...@@ -302,6 +311,44 @@ def _replace_value_with_input_spec(args):
return args_with_spec return args_with_spec
def _replace_to_input_spec_with_new_name(args, arg_names):
assert len(args) == len(arg_names)
order_digit = len(str(len(arg_names) - 1))
args_with_spec = []
for order, (arg, name_prefix) in enumerate(zip(args, arg_names)):
index = 0
for idx, origin_input in enumerate(paddle.utils.flatten(arg)):
if isinstance(origin_input, np.ndarray):
input_var = paddle.static.InputSpec.from_numpy(origin_input)
input_var.stop_gradient = True
elif isinstance(origin_input, core.eager.Tensor):
stop_gradient = origin_input.stop_gradient
input_var = paddle.static.InputSpec.from_tensor(origin_input)
input_var.stop_gradient = stop_gradient
elif isinstance(origin_input, paddle.fluid.framework.Variable):
stop_gradient = origin_input.stop_gradient
input_var = paddle.static.InputSpec(
origin_input.shape, origin_input.dtype, origin_input.name
)
input_var.stop_gradient = stop_gradient
else:
input_var = origin_input
if isinstance(
origin_input,
(
np.ndarray,
core.eager.Tensor,
paddle.fluid.framework.Variable,
),
):
input_var.name = f"_jst.{str(order).zfill(order_digit)}.{name_prefix}.{str(index)}"
index += 1
args_with_spec.append(input_var)
args_with_spec = paddle.utils.pack_sequence_as(args, args_with_spec)
return args_with_spec
def convert_to_input_spec(inputs, input_spec): def convert_to_input_spec(inputs, input_spec):
""" """
Replaces tensor in structured `inputs` by InputSpec in `input_spec`. Replaces tensor in structured `inputs` by InputSpec in `input_spec`.
......
...@@ -25,6 +25,7 @@ from paddle.fluid.compiler import BuildStrategy ...@@ -25,6 +25,7 @@ from paddle.fluid.compiler import BuildStrategy
from paddle.fluid.data_feeder import check_type, convert_dtype from paddle.fluid.data_feeder import check_type, convert_dtype
from paddle.fluid.dygraph.base import switch_to_static_graph from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.framework import _apply_pass from paddle.fluid.framework import _apply_pass
from paddle.fluid.unique_name import guard as UniqueNameGuard
from paddle.optimizer.lr import LRScheduler from paddle.optimizer.lr import LRScheduler
from . import logging_utils from . import logging_utils
...@@ -170,12 +171,19 @@ class PartialProgramLayer: ...@@ -170,12 +171,19 @@ class PartialProgramLayer:
""" """
def __init__( def __init__(
self, main_program, inputs, outputs, parameters=None, **kwargs self,
main_program,
inputs,
outputs,
name_generator,
parameters=None,
**kwargs
): ):
super().__init__() super().__init__()
self._inputs = NestSequence(inputs) self._inputs = NestSequence(inputs)
self._outputs = NestSequence(outputs, need_check=True) self._outputs = NestSequence(outputs, need_check=True)
self._params = parameters if parameters is not None else [] self._params = parameters if parameters is not None else []
self._name_generator = name_generator
self._build_strategy = kwargs.get('build_strategy', BuildStrategy()) self._build_strategy = kwargs.get('build_strategy', BuildStrategy())
assert isinstance(self._build_strategy, BuildStrategy) assert isinstance(self._build_strategy, BuildStrategy)
...@@ -214,26 +222,35 @@ class PartialProgramLayer: ...@@ -214,26 +222,35 @@ class PartialProgramLayer:
""" """
Execute static graph by Interpreter and Return dynamic Tensors. Execute static graph by Interpreter and Return dynamic Tensors.
""" """
in_vars, out_vars = self._prepare(inputs) with UniqueNameGuard(self._name_generator):
self._cast_fp16_if_pure_fp16(in_vars) in_vars, out_vars, in_var_names, resume_name_record = self._prepare(
attrs = self._prepare_attributes() inputs
)
self._sync_lr_value_with_scheduler() self._cast_fp16_if_pure_fp16(in_vars)
attrs = self._prepare_attributes()
_legacy_C_ops.run_program( attrs.extend(["x_names", in_var_names])
self._valid_vars(in_vars),
self._valid_vars(self._params), self._sync_lr_value_with_scheduler()
self._valid_vars(out_vars),
self._create_scope_vec( _legacy_C_ops.run_program(
program_id=self.program_id, use_scope_cache=True self._valid_vars(in_vars),
), self._valid_vars(self._params),
self._double_grads, self._valid_vars(out_vars),
self._cuda_graph_vec, self._create_scope_vec(
*attrs program_id=self.program_id, use_scope_cache=True
) ),
self._update_stop_gradient(out_vars) self._double_grads,
restored_nest_out = self._restore_out(out_vars) self._cuda_graph_vec,
return self._remove_no_value(restored_nest_out) *attrs
)
for var in in_vars:
if var.name in resume_name_record:
var.name = resume_name_record[var.name]
self._update_stop_gradient(out_vars)
restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out)
def _sync_lr_value_with_scheduler(self): def _sync_lr_value_with_scheduler(self):
"""Update lr_var value with calculated by lr_scheduler.""" """Update lr_var value with calculated by lr_scheduler."""
...@@ -887,6 +904,8 @@ class PartialProgramLayer: ...@@ -887,6 +904,8 @@ class PartialProgramLayer:
flatten_inputs = paddle.utils.flatten(inputs) flatten_inputs = paddle.utils.flatten(inputs)
# Convert variable into Tensor and feed in training data. # Convert variable into Tensor and feed in training data.
input_vars = [] input_vars = []
input_var_names = []
resume_name_record = {}
expected_place = framework._current_expected_place() expected_place = framework._current_expected_place()
for i, value in enumerate(flatten_inputs): for i, value in enumerate(flatten_inputs):
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
...@@ -909,9 +928,11 @@ class PartialProgramLayer: ...@@ -909,9 +928,11 @@ class PartialProgramLayer:
var.stop_gradient = True var.stop_gradient = True
else: else:
var = value var = value
resume_name_record[self._inputs[i].desc.name()] = var.name
var.name = self._inputs[i].desc.name() var.name = self._inputs[i].desc.name()
else: else:
continue continue
input_var_names.append(self._inputs[i].desc.name())
input_vars.append(var) input_vars.append(var)
# mapping from name(string) -> Tensor # mapping from name(string) -> Tensor
...@@ -939,7 +960,7 @@ class PartialProgramLayer: ...@@ -939,7 +960,7 @@ class PartialProgramLayer:
# Create Tensor to receive output data. # Create Tensor to receive output data.
out_vars = list(map(create_out, self._outputs.var_ids)) out_vars = list(map(create_out, self._outputs.var_ids))
return input_vars, out_vars return input_vars, out_vars, input_var_names, resume_name_record
def _create_scope_vec(self, program_id=None, use_scope_cache=False): def _create_scope_vec(self, program_id=None, use_scope_cache=False):
# Hold forward variables # Hold forward variables
...@@ -1106,6 +1127,7 @@ def partial_program_from(concrete_program, from_method=False): ...@@ -1106,6 +1127,7 @@ def partial_program_from(concrete_program, from_method=False):
concrete_program.main_program, concrete_program.main_program,
inputs, inputs,
concrete_program.outputs, concrete_program.outputs,
concrete_program.name_generator,
concrete_program.parameters, concrete_program.parameters,
**concrete_program.kwargs **concrete_program.kwargs
) )
......
...@@ -27,6 +27,8 @@ from paddle.fluid.dygraph.base import ( ...@@ -27,6 +27,8 @@ from paddle.fluid.dygraph.base import (
param_guard, param_guard,
switch_to_static_graph, switch_to_static_graph,
) )
from paddle.fluid.unique_name import UniqueNameGenerator
from paddle.fluid.unique_name import guard as UniqueNameGuard
from paddle.framework import in_dynamic_mode from paddle.framework import in_dynamic_mode
from paddle.nn.layer import layers from paddle.nn.layer import layers
from paddle.utils import flatten, gast from paddle.utils import flatten, gast
...@@ -942,72 +944,68 @@ class ASTStaticFunction(StaticFunction): ...@@ -942,72 +944,68 @@ class ASTStaticFunction(StaticFunction):
# If specific `input_spec`, apply convertion from dygraph layers into static Program. # If specific `input_spec`, apply convertion from dygraph layers into static Program.
# NOTE(jiabin): is_prim_infer indicates this method called by paddle.jit.save and it is worked in prim mode # NOTE(jiabin): is_prim_infer indicates this method called by paddle.jit.save and it is worked in prim mode
if cached_program_len == 0: desired_input_spec = input_spec
desired_input_spec = input_spec if self._function_spec.input_spec is not None:
if self._function_spec.input_spec is not None: if input_spec is not None and not input_specs_compatible(
if input_spec is not None and not input_specs_compatible( flatten(input_spec), flatten(self._function_spec.input_spec)
flatten(input_spec), flatten(self._function_spec.input_spec) ):
): raise ValueError(
raise ValueError( "The `input_spec`: {} used to construct concrete_program is conflict with the `input_spec`: {} in `@paddle.jit.to_static`".format(
"The `input_spec`: {} used to construct concrete_program is conflict with the `input_spec`: {} in `@paddle.jit.to_static`".format( input_spec, self._function_spec.input_spec
input_spec, self._function_spec.input_spec )
) )
# NOTE(chenweihang): we should always translated program based on the `input_spec`
# decorated on forward if it is valid
desired_input_spec = self._function_spec.input_spec
if input_spec is not None:
logging_utils.warn(
"\n\nYou have specified `input_spec` both in function definition (higher priority) and `paddle.jit.save` (will be ignored.)\n\n\t Using: {}\n\n\t Ignore: {}\n".format(
desired_input_spec, input_spec
) )
# NOTE(chenweihang): we should always translated program based on the `input_spec` )
# decorated on forward if it is valid
desired_input_spec = self._function_spec.input_spec has_input_spec = desired_input_spec is not None
if input_spec is not None: if has_input_spec:
concrete_program, _ = self.get_concrete_program(
*desired_input_spec,
with_hook=with_hook,
is_train=self._is_train_mode(),
is_prim_infer=is_prim_infer,
)
return concrete_program
else:
if cached_program_len != 0:
logging_utils.warn(
"No input_spec is found, save cached program instead"
)
if cached_program_len > 1:
logging_utils.warn( logging_utils.warn(
"\n\nYou have specified `input_spec` both in function definition (higher priority) and `paddle.jit.save` (will be ignored.)\n\n\t Using: {}\n\n\t Ignore: {}\n".format( "Current {} has more than one cached programs: {}, the last traced progam will be return by default.".format(
desired_input_spec, input_spec self._function_spec, cached_program_len
) )
) )
has_input_spec = desired_input_spec is not None cache_key = self._program_cache._recent_cache_key
if has_input_spec:
concrete_program, _ = self.get_concrete_program( if with_hook:
*desired_input_spec, cache_key.kwargs["with_hook"] = True
with_hook=with_hook,
is_train=self._is_train_mode(), if is_prim_infer:
is_prim_infer=is_prim_infer, (
) concrete_program,
return concrete_program _,
) = self.get_concrete_program_with_cache_key(cache_key)
return concrete_program
else:
concrete_program, _ = self._program_cache[cache_key]
return concrete_program
else: else:
raise ValueError( raise ValueError(
"No valid transformed program for {}.\n\t Please specific `input_spec` in `@paddle.jit.to_static` or feed input tensor to call the decorated function at once.\n".format( "No valid transformed program for {}.\n\t Please specific `input_spec` in `@paddle.jit.to_static` or feed input tensor to call the decorated function at once.\n".format(
self._function_spec self._function_spec
) )
) )
elif with_hook:
cache_key = self._program_cache._recent_cache_key
cache_key.kwargs["with_hook"] = True
if not is_prim_infer:
concrete_program, _ = self._program_cache[cache_key]
return concrete_program
else:
concrete_program, _ = self.get_concrete_program_with_cache_key(
cache_key
)
return concrete_program
# If more than one programs have been cached, return the recent converted program by default.
elif cached_program_len > 1:
logging_utils.warn(
"Current {} has more than one cached programs: {}, the last traced progam will be return by default.".format(
self._function_spec, cached_program_len
)
)
if not is_prim_infer:
cache_key, (
concrete_program,
partial_layer,
) = self._program_cache.last()
return concrete_program
else:
cache_key = self._program_cache._recent_cache_key
concrete_program, _ = self.get_concrete_program_with_cache_key(
cache_key
)
return concrete_program
@property @property
def inputs(self): def inputs(self):
...@@ -1134,6 +1132,7 @@ class ConcreteProgram: ...@@ -1134,6 +1132,7 @@ class ConcreteProgram:
"startup_program", "startup_program",
"parameters", "parameters",
"function", "function",
"name_generator",
'kwargs', 'kwargs',
] ]
...@@ -1143,6 +1142,7 @@ class ConcreteProgram: ...@@ -1143,6 +1142,7 @@ class ConcreteProgram:
outputs, outputs,
parameters, parameters,
function, function,
name_generator,
main_program, main_program,
startup_program=None, startup_program=None,
**kwargs, **kwargs,
...@@ -1153,6 +1153,7 @@ class ConcreteProgram: ...@@ -1153,6 +1153,7 @@ class ConcreteProgram:
self.startup_program = startup_program self.startup_program = startup_program
self.parameters = parameters self.parameters = parameters
self.function = function self.function = function
self.name_generator = name_generator
self.kwargs = kwargs self.kwargs = kwargs
@staticmethod @staticmethod
...@@ -1188,8 +1189,12 @@ class ConcreteProgram: ...@@ -1188,8 +1189,12 @@ class ConcreteProgram:
framework.default_startup_program().random_seed framework.default_startup_program().random_seed
) )
new_name_generator = UniqueNameGenerator()
with framework.program_guard(main_program, startup_program): with framework.program_guard(main_program, startup_program):
with _switch_declarative_mode_guard_(is_declarative=True): with _switch_declarative_mode_guard_(
is_declarative=True
), UniqueNameGuard(new_name_generator):
# 1. Adds `paddle.static.data` layers for input if needed # 1. Adds `paddle.static.data` layers for input if needed
static_inputs = func_spec.to_static_inputs_with_spec( static_inputs = func_spec.to_static_inputs_with_spec(
input_spec, main_program input_spec, main_program
...@@ -1244,6 +1249,7 @@ class ConcreteProgram: ...@@ -1244,6 +1249,7 @@ class ConcreteProgram:
outputs=outputs, outputs=outputs,
parameters=all_parameters_and_buffers, parameters=all_parameters_and_buffers,
function=dygraph_function, function=dygraph_function,
name_generator=new_name_generator,
main_program=main_program, main_program=main_program,
startup_program=startup_program, startup_program=startup_program,
**kwargs, **kwargs,
......
...@@ -898,6 +898,7 @@ def _valid_vars(vars): ...@@ -898,6 +898,7 @@ def _valid_vars(vars):
def _run_dygraph(instance, input, program_holder): def _run_dygraph(instance, input, program_holder):
# 1. prepare inputs, outputs, attrs # 1. prepare inputs, outputs, attrs
input_vars = [] input_vars = []
input_var_names = []
for i, value in enumerate(input): for i, value in enumerate(input):
if not isinstance(value, (np.ndarray, core.eager.Tensor)): if not isinstance(value, (np.ndarray, core.eager.Tensor)):
raise TypeError( raise TypeError(
...@@ -918,6 +919,7 @@ def _run_dygraph(instance, input, program_holder): ...@@ -918,6 +919,7 @@ def _run_dygraph(instance, input, program_holder):
# NOTE: we changed var name here, # NOTE: we changed var name here,
# but it may be an important name set by user # but it may be an important name set by user
var.name = program_holder.input_descs[i].name() var.name = program_holder.input_descs[i].name()
input_var_names.append(var.name)
input_vars.append(var) input_vars.append(var)
if instance._input_args_names is None: if instance._input_args_names is None:
instance._input_args_names = [ instance._input_args_names = [
...@@ -986,6 +988,8 @@ def _run_dygraph(instance, input, program_holder): ...@@ -986,6 +988,8 @@ def _run_dygraph(instance, input, program_holder):
instance._is_test, instance._is_test,
'program_id', 'program_id',
paddle.utils._hash_with_id(trace_program, instance), paddle.utils._hash_with_id(trace_program, instance),
'x_names',
input_var_names,
] ]
if not instance._is_test: if not instance._is_test:
attrs.extend( attrs.extend(
......
...@@ -83,7 +83,7 @@ def export(layer, path, input_spec=None, opset_version=9, **configs): ...@@ -83,7 +83,7 @@ def export(layer, path, input_spec=None, opset_version=9, **configs):
... # Static and run model. ... # Static and run model.
... paddle.jit.to_static(model) ... paddle.jit.to_static(model)
... out = model(x, y, z=True) ... out = model(x, y, z=True)
... paddle.onnx.export(model, 'pruned', input_spec=[x], output_spec=[out]) ... paddle.onnx.export(model, 'pruned', input_spec=[x, y, z], output_spec=[out], input_names_after_prune=[x])
... ...
>>> export_logic() >>> export_logic()
""" """
......
...@@ -603,8 +603,12 @@ class TestLACModel(unittest.TestCase): ...@@ -603,8 +603,12 @@ class TestLACModel(unittest.TestCase):
paddle.jit.save( paddle.jit.save(
layer=model, layer=model,
path=self.model_save_prefix, path=self.model_save_prefix,
input_spec=[input_specs[0], input_specs[-1]], input_spec=input_specs,
output_spec=[crf_decode], output_spec=[crf_decode],
input_names_after_prune=[
input_specs[0].name,
input_specs[-1].name,
],
) )
else: else:
paddle.save( paddle.save(
......
...@@ -238,12 +238,25 @@ class TestMNISTWithToStatic(TestMNIST): ...@@ -238,12 +238,25 @@ class TestMNISTWithToStatic(TestMNIST):
loss_data.append(float(avg_loss)) loss_data.append(float(avg_loss))
# new save load check # new save load check
self.check_jit_save_load( self.check_jit_save_load(
mnist, [dy_x_data], [img], to_static, prediction mnist,
[dy_x_data],
[img, label],
to_static,
prediction,
[img.name],
) )
break break
return loss_data return loss_data
def check_jit_save_load(self, model, inputs, input_spec, to_static, gt_out): def check_jit_save_load(
self,
model,
inputs,
input_spec,
to_static,
gt_out,
input_names_after_prune,
):
if to_static: if to_static:
infer_model_path = os.path.join( infer_model_path = os.path.join(
self.temp_dir.name, 'test_mnist_inference_model_by_jit_save' self.temp_dir.name, 'test_mnist_inference_model_by_jit_save'
...@@ -257,6 +270,7 @@ class TestMNISTWithToStatic(TestMNIST): ...@@ -257,6 +270,7 @@ class TestMNISTWithToStatic(TestMNIST):
path=model_save_prefix, path=model_save_prefix,
input_spec=input_spec, input_spec=input_spec,
output_spec=[gt_out], output_spec=[gt_out],
input_names_after_prune=input_names_after_prune,
) )
# load in static graph mode # load in static graph mode
static_infer_out = self.jit_load_and_run_inference_static( static_infer_out = self.jit_load_and_run_inference_static(
......
...@@ -454,8 +454,9 @@ class TestSeResnet(unittest.TestCase): ...@@ -454,8 +454,9 @@ class TestSeResnet(unittest.TestCase):
paddle.jit.save( paddle.jit.save(
se_resnext, se_resnext,
self.model_save_prefix, self.model_save_prefix,
[img], [img, label],
output_spec=[pred], output_spec=[pred],
input_names_after_prune=[img.name],
) )
else: else:
paddle.save( paddle.save(
......
...@@ -47,59 +47,42 @@ class TestArgsSpecName(unittest.TestCase): ...@@ -47,59 +47,42 @@ class TestArgsSpecName(unittest.TestCase):
def test_spec_name_hash(self): def test_spec_name_hash(self):
net = Net() net = Net()
net = paddle.jit.to_static(net) net = paddle.jit.to_static(net)
# Convert into program with four input # Convert into program with four input
self.read_from_dataset() self.read_from_dataset()
self.run_test(net, [self.x, self.y, self.m, self.n], 1, [0, 1, 2, 3]) self.run_test(net, [self.x, self.y, self.m, self.n], 1, [0, 1, 2, 3])
# Convert into program with three input # Convert into program with three input
self.read_from_dataset() self.read_from_dataset()
self.run_test(net, [self.x, self.x, self.m, self.n], 2, [0, 0, 1, 2]) self.run_test(net, [self.x, self.x, self.m, self.n], 1, [0, 0, 1, 2])
# Convert into program with two input # Convert into program with two input
self.read_from_dataset() self.read_from_dataset()
self.run_test(net, [self.x, self.x, self.m, self.m], 3, [0, 0, 1, 1]) self.run_test(net, [self.x, self.x, self.m, self.m], 1, [0, 0, 1, 1])
# Use Cache Program # Use Cache Program
self.read_from_dataset() self.read_from_dataset()
self.run_test(net, [self.n, self.n, self.y, self.y], 3, [0, 0, 1, 1]) self.run_test(net, [self.n, self.n, self.y, self.y], 1, [0, 0, 1, 1])
# Convert into program with two input # Convert into program with two input
self.read_from_dataset() self.read_from_dataset()
self.run_test(net, [self.x, self.y, self.x, self.y], 4, [0, 1, 0, 1]) self.run_test(net, [self.x, self.y, self.x, self.y], 1, [0, 1, 0, 1])
# Use Cache Program # Use Cache Program
self.read_from_dataset() self.read_from_dataset()
self.run_test(net, [self.m, self.n, self.m, self.n], 4, [0, 1, 0, 1]) self.run_test(net, [self.m, self.n, self.m, self.n], 1, [0, 1, 0, 1])
# Convert into program with one input # Convert into program with one input
self.read_from_dataset() self.read_from_dataset()
self.run_test(net, [self.x, self.x, self.x, self.x], 5, [0, 0, 0, 0]) self.run_test(net, [self.x, self.x, self.x, self.x], 1, [0, 0, 0, 0])
# Use Cache Program # Use Cache Program
self.read_from_dataset() self.read_from_dataset()
self.run_test(net, [self.m, self.m, self.m, self.m], 5, [0, 0, 0, 0]) self.run_test(net, [self.m, self.m, self.m, self.m], 1, [0, 0, 0, 0])
def run_test(self, net, inputs, trace_count, mode): def run_test(self, net, inputs, trace_count, mode):
out = net(*inputs) out = net(*inputs)
self.assertEqual(net.forward.get_traced_count(), trace_count) self.assertEqual(net.forward.get_traced_count(), trace_count)
self.assert_feed_mode(net.forward.inputs, mode)
def assert_feed_mode(self, inputs, expect_mode):
assert isinstance(inputs, list)
assert isinstance(expect_mode, list)
in_names = [var.name for var in inputs]
i, name_ids = 0, {}
def to_idx(name):
nonlocal i
if name not in name_ids:
name_ids[name] = i
i += 1
return name_ids[name]
mode = [to_idx(name) for name in in_names]
self.assertEqual(mode, expect_mode)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -135,6 +135,8 @@ class TestRunProgram(unittest.TestCase): ...@@ -135,6 +135,8 @@ class TestRunProgram(unittest.TestCase):
[out.name + '@GRAD'], [out.name + '@GRAD'],
'x_grad_names', 'x_grad_names',
[x_t.name + '@GRAD', y_t.name + '@GRAD'], [x_t.name + '@GRAD', y_t.name + '@GRAD'],
'x_names',
[x_t.name, y_t.name],
] ]
use_interpretorcore = True use_interpretorcore = True
......
...@@ -295,7 +295,13 @@ class TestNetWithNonTensorSpecWithPrune(unittest.TestCase): ...@@ -295,7 +295,13 @@ class TestNetWithNonTensorSpecWithPrune(unittest.TestCase):
# jit.save and jit.load with prune y and loss # jit.save and jit.load with prune y and loss
prune_specs = [self.x_spec, True] prune_specs = [self.x_spec, True]
paddle.jit.save(net, path, prune_specs, output_spec=[st_out]) paddle.jit.save(
net,
path,
prune_specs,
output_spec=[st_out],
input_names_after_prune=[self.x_spec.name],
)
load_net = paddle.jit.load(path) load_net = paddle.jit.load(path)
load_net.eval() load_net.eval()
load_out = load_net(self.x) # no y and no loss load_out = load_net(self.x) # no y and no loss
......
...@@ -647,7 +647,13 @@ class TestSaveLoadWithInputSpec(unittest.TestCase): ...@@ -647,7 +647,13 @@ class TestSaveLoadWithInputSpec(unittest.TestCase):
self.temp_dir.name, "multi_inout1.output_spec2/model" self.temp_dir.name, "multi_inout1.output_spec2/model"
) )
output_spec = net.forward.outputs[:1] output_spec = net.forward.outputs[:1]
paddle.jit.save(net, model_path, (input_x,), output_spec=output_spec) paddle.jit.save(
net,
model_path,
net.forward.inputs,
output_spec=output_spec,
input_names_after_prune=[input_x.name],
)
# 2. load again # 2. load again
infer_layer2 = paddle.jit.load(model_path) infer_layer2 = paddle.jit.load(model_path)
# 3. predict # 3. predict
...@@ -945,9 +951,11 @@ class TestJitSaveMultiCases(unittest.TestCase): ...@@ -945,9 +951,11 @@ class TestJitSaveMultiCases(unittest.TestCase):
layer, layer,
model_path, model_path,
input_spec=[ input_spec=[
InputSpec(shape=[None, 784], dtype='float32', name="image") InputSpec(shape=[None, 784], dtype='float32', name="image"),
True,
], ],
output_spec=[out], output_spec=[out],
input_names_after_prune=["image"],
) )
self.verify_inference_correctness( self.verify_inference_correctness(
...@@ -967,9 +975,11 @@ class TestJitSaveMultiCases(unittest.TestCase): ...@@ -967,9 +975,11 @@ class TestJitSaveMultiCases(unittest.TestCase):
layer, layer,
model_path, model_path,
input_spec=[ input_spec=[
InputSpec(shape=[None, 784], dtype='float32', name="image") InputSpec(shape=[None, 784], dtype='float32', name="image"),
True,
], ],
output_spec=output_spec, output_spec=output_spec,
input_names_after_prune=["image"],
) )
self.verify_inference_correctness( self.verify_inference_correctness(
...@@ -1082,9 +1092,11 @@ class TestJitSaveMultiCases(unittest.TestCase): ...@@ -1082,9 +1092,11 @@ class TestJitSaveMultiCases(unittest.TestCase):
layer, layer,
model_path, model_path,
input_spec=[ input_spec=[
InputSpec(shape=[None, 784], dtype='float32', name="image") InputSpec(shape=[None, 784], dtype='float32', name="image"),
True,
], ],
output_spec=[out], output_spec=[out],
input_names_after_prune=["image"],
) )
......
...@@ -63,7 +63,11 @@ class TestExportPrunedGraph(unittest.TestCase): ...@@ -63,7 +63,11 @@ class TestExportPrunedGraph(unittest.TestCase):
paddle.jit.to_static(model) paddle.jit.to_static(model)
out = model(self.x, self.y, z=True) out = model(self.x, self.y, z=True)
paddle.onnx.export( paddle.onnx.export(
model, 'pruned', input_spec=[self.x], output_spec=[out] model,
'pruned',
input_spec=[self.x, self.y, True],
output_spec=[out],
input_names_after_prune=[self.x.name],
) )
......
...@@ -246,6 +246,8 @@ class RunProgramOpTest(unittest.TestCase): ...@@ -246,6 +246,8 @@ class RunProgramOpTest(unittest.TestCase):
[out.name + '@GRAD' for out in outputs['Out']], [out.name + '@GRAD' for out in outputs['Out']],
'x_grad_names', 'x_grad_names',
[p.name + '@GRAD' for p in inputs['X']], [p.name + '@GRAD' for p in inputs['X']],
'x_names',
[t.name for t in inputs['X']],
) )
) )
...@@ -297,6 +299,8 @@ class RunProgramOpTest(unittest.TestCase): ...@@ -297,6 +299,8 @@ class RunProgramOpTest(unittest.TestCase):
[out.name + '@GRAD' for out in outputs['Out']], [out.name + '@GRAD' for out in outputs['Out']],
'x_grad_names', 'x_grad_names',
[p.name + '@GRAD' for p in inputs['X']], [p.name + '@GRAD' for p in inputs['X']],
'x_names',
[t.name for t in inputs['X']],
) )
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册