未验证 提交 51c4028c 编写于 作者: F feifei-111 提交者: GitHub

[dy2s] fix error when using same tensor as inputs in one call & fix bugs in jit.save (#55963)

上级 752f29a1
......@@ -77,10 +77,11 @@ static void clear_unused_out_var_in_backward(
static std::vector<paddle::Tensor> filter_unused_input_var_in_backward(
const std::vector<paddle::Tensor>& x,
const std::vector<std::string>& x_names,
const paddle::framework::BlockDesc* backward_block) {
auto filter_x = std::vector<paddle::Tensor>(x);
for (size_t i = 0; i < x.size(); i++) {
if (!backward_block->HasVar(x[i].name())) {
if (!backward_block->HasVar(x_names[i])) {
auto fake = paddle::Tensor(std::make_shared<phi::DenseTensor>());
fake.set_name(paddle::framework::kFakeVarName);
filter_x[i] = fake;
......@@ -117,6 +118,9 @@ inline void run_program_ad_func(
VLOG(2) << "start run run_program grad";
if (require_any_grad) {
auto x_names =
PADDLE_GET_CONST(std::vector<std::string>, attrs.at("x_names"));
egr::EagerUtils::PassStopGradient(false, &p_autograd_outs);
// Create GradOpNode (1 means [out_grad], 2 means [x_grad, paramx_grad])
auto grad_node = std::make_shared<GradNodeRunProgram>(1, 2);
......@@ -130,7 +134,7 @@ inline void run_program_ad_func(
paddle::framework::BlockDesc*, attrs.at("backward_global_block"));
// Clear unused x vars
auto filter_x =
filter_unused_input_var_in_backward(x, backward_global_block);
filter_unused_input_var_in_backward(x, x_names, backward_global_block);
// Set TensorWrappers
grad_node->SetFwdX(filter_x);
// Clear unused out vars
......@@ -145,7 +149,7 @@ inline void run_program_ad_func(
std::vector<const paddle::Tensor*> x_require_grad;
for (size_t i = 0; i < x.size(); ++i) {
auto& name = x[i].name();
auto& name = x_names[i];
if (forward_global_block->HasVar(name) ||
backward_global_block->HasVar(name)) {
x_require_grad.push_back(&x[i]);
......
......@@ -150,6 +150,31 @@ static void ShareTensorsIntoScope(const std::vector<Tensor> &tensors,
}
}
static void ShareTensorsIntoScopeWithName(
const std::vector<Tensor> &tensors,
const std::vector<std::string> &tensor_names,
paddle::framework::Scope *scope) {
for (size_t i = 0; i < tensors.size(); ++i) {
auto name = tensor_names[i];
if (name == paddle::framework::kFakeVarName) {
continue;
}
auto *var = scope->Var(name);
CheckInputVarStatus(tensors[i]);
// share tensor
auto tensor_base = tensors[i].impl();
if (phi::DenseTensor::classof(tensor_base.get())) {
auto *dst_tensor = var->GetMutable<phi::DenseTensor>();
auto t = std::dynamic_pointer_cast<phi::DenseTensor>(tensor_base);
*dst_tensor = *t;
} else if (phi::SelectedRows::classof(tensor_base.get())) {
auto *dst_tensor = var->GetMutable<phi::SelectedRows>();
auto t = std::dynamic_pointer_cast<phi::SelectedRows>(tensor_base);
*dst_tensor = *t;
}
}
}
static void ShareTensorsFromScope(
const std::vector<Tensor *> &tensors,
const paddle::framework::BlockDesc &global_block,
......@@ -320,7 +345,8 @@ inline void RunProgramAPI(
VLOG(4) << "global_inner_scope:" << global_inner_scope;
auto input_names = details::GetTensorsName(x);
auto input_names =
PADDLE_GET_CONST(std::vector<std::string>, attrs.at("x_names"));
auto output_names = details::GetTensorsName(out);
auto param_names = details::GetTensorsName(params);
auto dout_names = details::GetTensorsName(dout);
......@@ -371,7 +397,7 @@ inline void RunProgramAPI(
"for program: "
<< program_id;
// Step 1. share input_vars & parameters into scope
details::ShareTensorsIntoScope(x, global_inner_scope);
details::ShareTensorsIntoScopeWithName(x, input_names, global_inner_scope);
details::ShareTensorsIntoScope(params, global_inner_scope);
// Step 2. create new interpretercore
......@@ -433,7 +459,7 @@ inline void RunProgramAPI(
program_id, global_inner_scope, /*is_grad=*/false);
interpreter_core = cached_value.core_;
// Step 2. update scope for cache interpretercore
details::ShareTensorsIntoScope(x, global_inner_scope);
details::ShareTensorsIntoScopeWithName(x, input_names, global_inner_scope);
details::ShareTensorsIntoScope(params, global_inner_scope);
if (interpreter_core->GetVariableScope()->GetMutableScope() !=
global_inner_scope) {
......
......@@ -139,6 +139,10 @@ class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker {
"std::vector<std::string>"
"The names of output gradients.")
.SetDefault({});
AddAttr<std::vector<std::string>>("x_names",
"std::vector<std::string>"
"The names of input tensors.")
.SetDefault({});
AddAttr<std::vector<std::string>>("x_grad_names",
"std::vector<std::string>"
"The names of input gradients.")
......
......@@ -415,6 +415,9 @@ class _SaveLoadConfig:
# if True, multi `StaticFunction` will share params in one file.
self.combine_params = False
# when need to save a prune model, use input_names_after_prune to specify the inputs left after pruning
self.input_names_after_prune = None
@property
def output_spec(self):
return self._output_spec
......@@ -488,11 +491,12 @@ class _SaveLoadConfig:
def _parse_save_configs(configs):
supported_configs = [
'output_spec',
"output_spec",
"with_hook",
"combine_params",
"clip_extra",
"skip_forward",
"input_names_after_prune",
]
# input check
......@@ -505,11 +509,14 @@ def _parse_save_configs(configs):
# construct inner config
inner_config = _SaveLoadConfig()
inner_config.output_spec = configs.get('output_spec', None)
inner_config.with_hook = configs.get('with_hook', False)
inner_config.output_spec = configs.get("output_spec", None)
inner_config.with_hook = configs.get("with_hook", False)
inner_config.combine_params = configs.get("combine_params", False)
inner_config.clip_extra = configs.get("clip_extra", True)
inner_config.skip_forward = configs.get("skip_forward", False)
inner_config.input_names_after_prune = configs.get(
"input_names_after_prune", None
)
return inner_config
......@@ -533,7 +540,7 @@ def _parse_load_config(configs):
return inner_config
def _get_input_var_names(inputs, input_spec):
def _get_input_var_names(inputs, input_spec, input_names_after_prune):
name_none_error = (
"The %s's name is None. "
"When using jit.save, please set InputSepc's name in "
......@@ -546,6 +553,14 @@ def _get_input_var_names(inputs, input_spec):
"in input_spec is the same as the name of InputSpec in "
"`to_static` decorated on the Layer.forward method."
)
if input_names_after_prune is not None:
input_spec = [
x
for x in input_spec
if isinstance(x, paddle.static.InputSpec)
and x.name in input_names_after_prune
]
result_list = []
input_var_names = [
var.name
......@@ -1201,7 +1216,9 @@ def save(layer, path, input_spec=None, **configs):
# - the input_spec length < len((concrete_program.inputs) - 1
# - the input_spec's name should be in concrete_program.inputs
input_var_names = _get_input_var_names(
concrete_program.inputs, inner_input_spec
concrete_program.inputs,
inner_input_spec,
configs.input_names_after_prune,
)
# NOTE(chenweihang): [ Get output variables ]
......
......@@ -150,8 +150,17 @@ class FunctionSpec:
# replace argument with corresponding InputSpec.
args_with_spec = convert_to_input_spec(args, self._input_spec)
else:
args_with_spec = _replace_value_with_input_spec(args)
kwargs_with_spec = _replace_value_with_input_spec(kwargs)
args_with_spec = _replace_to_input_spec_with_new_name(
args, self._arg_names
)
kwarg_names = ["kwargs." + key for key in kwargs.keys()]
kwargs_list_with_spec = _replace_to_input_spec_with_new_name(
list(kwargs.values()), kwarg_names
)
kwargs_with_spec = {
key: kwargs_list_with_spec[idx]
for idx, key in enumerate(kwargs)
}
# If without specificing name in input_spec, add default name
# according to argument name from decorated function.
......@@ -302,6 +311,44 @@ def _replace_value_with_input_spec(args):
return args_with_spec
def _replace_to_input_spec_with_new_name(args, arg_names):
assert len(args) == len(arg_names)
order_digit = len(str(len(arg_names) - 1))
args_with_spec = []
for order, (arg, name_prefix) in enumerate(zip(args, arg_names)):
index = 0
for idx, origin_input in enumerate(paddle.utils.flatten(arg)):
if isinstance(origin_input, np.ndarray):
input_var = paddle.static.InputSpec.from_numpy(origin_input)
input_var.stop_gradient = True
elif isinstance(origin_input, core.eager.Tensor):
stop_gradient = origin_input.stop_gradient
input_var = paddle.static.InputSpec.from_tensor(origin_input)
input_var.stop_gradient = stop_gradient
elif isinstance(origin_input, paddle.fluid.framework.Variable):
stop_gradient = origin_input.stop_gradient
input_var = paddle.static.InputSpec(
origin_input.shape, origin_input.dtype, origin_input.name
)
input_var.stop_gradient = stop_gradient
else:
input_var = origin_input
if isinstance(
origin_input,
(
np.ndarray,
core.eager.Tensor,
paddle.fluid.framework.Variable,
),
):
input_var.name = f"_jst.{str(order).zfill(order_digit)}.{name_prefix}.{str(index)}"
index += 1
args_with_spec.append(input_var)
args_with_spec = paddle.utils.pack_sequence_as(args, args_with_spec)
return args_with_spec
def convert_to_input_spec(inputs, input_spec):
"""
Replaces tensor in structured `inputs` by InputSpec in `input_spec`.
......
......@@ -25,6 +25,7 @@ from paddle.fluid.compiler import BuildStrategy
from paddle.fluid.data_feeder import check_type, convert_dtype
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.framework import _apply_pass
from paddle.fluid.unique_name import guard as UniqueNameGuard
from paddle.optimizer.lr import LRScheduler
from . import logging_utils
......@@ -170,12 +171,19 @@ class PartialProgramLayer:
"""
def __init__(
self, main_program, inputs, outputs, parameters=None, **kwargs
self,
main_program,
inputs,
outputs,
name_generator,
parameters=None,
**kwargs
):
super().__init__()
self._inputs = NestSequence(inputs)
self._outputs = NestSequence(outputs, need_check=True)
self._params = parameters if parameters is not None else []
self._name_generator = name_generator
self._build_strategy = kwargs.get('build_strategy', BuildStrategy())
assert isinstance(self._build_strategy, BuildStrategy)
......@@ -214,9 +222,13 @@ class PartialProgramLayer:
"""
Execute static graph by Interpreter and Return dynamic Tensors.
"""
in_vars, out_vars = self._prepare(inputs)
with UniqueNameGuard(self._name_generator):
in_vars, out_vars, in_var_names, resume_name_record = self._prepare(
inputs
)
self._cast_fp16_if_pure_fp16(in_vars)
attrs = self._prepare_attributes()
attrs.extend(["x_names", in_var_names])
self._sync_lr_value_with_scheduler()
......@@ -231,6 +243,11 @@ class PartialProgramLayer:
self._cuda_graph_vec,
*attrs
)
for var in in_vars:
if var.name in resume_name_record:
var.name = resume_name_record[var.name]
self._update_stop_gradient(out_vars)
restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out)
......@@ -887,6 +904,8 @@ class PartialProgramLayer:
flatten_inputs = paddle.utils.flatten(inputs)
# Convert variable into Tensor and feed in training data.
input_vars = []
input_var_names = []
resume_name_record = {}
expected_place = framework._current_expected_place()
for i, value in enumerate(flatten_inputs):
if isinstance(value, np.ndarray):
......@@ -909,9 +928,11 @@ class PartialProgramLayer:
var.stop_gradient = True
else:
var = value
resume_name_record[self._inputs[i].desc.name()] = var.name
var.name = self._inputs[i].desc.name()
else:
continue
input_var_names.append(self._inputs[i].desc.name())
input_vars.append(var)
# mapping from name(string) -> Tensor
......@@ -939,7 +960,7 @@ class PartialProgramLayer:
# Create Tensor to receive output data.
out_vars = list(map(create_out, self._outputs.var_ids))
return input_vars, out_vars
return input_vars, out_vars, input_var_names, resume_name_record
def _create_scope_vec(self, program_id=None, use_scope_cache=False):
# Hold forward variables
......@@ -1106,6 +1127,7 @@ def partial_program_from(concrete_program, from_method=False):
concrete_program.main_program,
inputs,
concrete_program.outputs,
concrete_program.name_generator,
concrete_program.parameters,
**concrete_program.kwargs
)
......
......@@ -27,6 +27,8 @@ from paddle.fluid.dygraph.base import (
param_guard,
switch_to_static_graph,
)
from paddle.fluid.unique_name import UniqueNameGenerator
from paddle.fluid.unique_name import guard as UniqueNameGuard
from paddle.framework import in_dynamic_mode
from paddle.nn.layer import layers
from paddle.utils import flatten, gast
......@@ -942,7 +944,6 @@ class ASTStaticFunction(StaticFunction):
# If specific `input_spec`, apply convertion from dygraph layers into static Program.
# NOTE(jiabin): is_prim_infer indicates this method called by paddle.jit.save and it is worked in prim mode
if cached_program_len == 0:
desired_input_spec = input_spec
if self._function_spec.input_spec is not None:
if input_spec is not None and not input_specs_compatible(
......@@ -973,42 +974,39 @@ class ASTStaticFunction(StaticFunction):
)
return concrete_program
else:
raise ValueError(
"No valid transformed program for {}.\n\t Please specific `input_spec` in `@paddle.jit.to_static` or feed input tensor to call the decorated function at once.\n".format(
self._function_spec
)
)
elif with_hook:
cache_key = self._program_cache._recent_cache_key
cache_key.kwargs["with_hook"] = True
if not is_prim_infer:
concrete_program, _ = self._program_cache[cache_key]
return concrete_program
else:
concrete_program, _ = self.get_concrete_program_with_cache_key(
cache_key
if cached_program_len != 0:
logging_utils.warn(
"No input_spec is found, save cached program instead"
)
return concrete_program
# If more than one programs have been cached, return the recent converted program by default.
elif cached_program_len > 1:
if cached_program_len > 1:
logging_utils.warn(
"Current {} has more than one cached programs: {}, the last traced progam will be return by default.".format(
self._function_spec, cached_program_len
)
)
if not is_prim_infer:
cache_key, (
cache_key = self._program_cache._recent_cache_key
if with_hook:
cache_key.kwargs["with_hook"] = True
if is_prim_infer:
(
concrete_program,
partial_layer,
) = self._program_cache.last()
_,
) = self.get_concrete_program_with_cache_key(cache_key)
return concrete_program
else:
cache_key = self._program_cache._recent_cache_key
concrete_program, _ = self.get_concrete_program_with_cache_key(
cache_key
)
concrete_program, _ = self._program_cache[cache_key]
return concrete_program
else:
raise ValueError(
"No valid transformed program for {}.\n\t Please specific `input_spec` in `@paddle.jit.to_static` or feed input tensor to call the decorated function at once.\n".format(
self._function_spec
)
)
@property
def inputs(self):
"""
......@@ -1134,6 +1132,7 @@ class ConcreteProgram:
"startup_program",
"parameters",
"function",
"name_generator",
'kwargs',
]
......@@ -1143,6 +1142,7 @@ class ConcreteProgram:
outputs,
parameters,
function,
name_generator,
main_program,
startup_program=None,
**kwargs,
......@@ -1153,6 +1153,7 @@ class ConcreteProgram:
self.startup_program = startup_program
self.parameters = parameters
self.function = function
self.name_generator = name_generator
self.kwargs = kwargs
@staticmethod
......@@ -1188,8 +1189,12 @@ class ConcreteProgram:
framework.default_startup_program().random_seed
)
new_name_generator = UniqueNameGenerator()
with framework.program_guard(main_program, startup_program):
with _switch_declarative_mode_guard_(is_declarative=True):
with _switch_declarative_mode_guard_(
is_declarative=True
), UniqueNameGuard(new_name_generator):
# 1. Adds `paddle.static.data` layers for input if needed
static_inputs = func_spec.to_static_inputs_with_spec(
input_spec, main_program
......@@ -1244,6 +1249,7 @@ class ConcreteProgram:
outputs=outputs,
parameters=all_parameters_and_buffers,
function=dygraph_function,
name_generator=new_name_generator,
main_program=main_program,
startup_program=startup_program,
**kwargs,
......
......@@ -898,6 +898,7 @@ def _valid_vars(vars):
def _run_dygraph(instance, input, program_holder):
# 1. prepare inputs, outputs, attrs
input_vars = []
input_var_names = []
for i, value in enumerate(input):
if not isinstance(value, (np.ndarray, core.eager.Tensor)):
raise TypeError(
......@@ -918,6 +919,7 @@ def _run_dygraph(instance, input, program_holder):
# NOTE: we changed var name here,
# but it may be an important name set by user
var.name = program_holder.input_descs[i].name()
input_var_names.append(var.name)
input_vars.append(var)
if instance._input_args_names is None:
instance._input_args_names = [
......@@ -986,6 +988,8 @@ def _run_dygraph(instance, input, program_holder):
instance._is_test,
'program_id',
paddle.utils._hash_with_id(trace_program, instance),
'x_names',
input_var_names,
]
if not instance._is_test:
attrs.extend(
......
......@@ -83,7 +83,7 @@ def export(layer, path, input_spec=None, opset_version=9, **configs):
... # Static and run model.
... paddle.jit.to_static(model)
... out = model(x, y, z=True)
... paddle.onnx.export(model, 'pruned', input_spec=[x], output_spec=[out])
... paddle.onnx.export(model, 'pruned', input_spec=[x, y, z], output_spec=[out], input_names_after_prune=[x])
...
>>> export_logic()
"""
......
......@@ -603,8 +603,12 @@ class TestLACModel(unittest.TestCase):
paddle.jit.save(
layer=model,
path=self.model_save_prefix,
input_spec=[input_specs[0], input_specs[-1]],
input_spec=input_specs,
output_spec=[crf_decode],
input_names_after_prune=[
input_specs[0].name,
input_specs[-1].name,
],
)
else:
paddle.save(
......
......@@ -238,12 +238,25 @@ class TestMNISTWithToStatic(TestMNIST):
loss_data.append(float(avg_loss))
# new save load check
self.check_jit_save_load(
mnist, [dy_x_data], [img], to_static, prediction
mnist,
[dy_x_data],
[img, label],
to_static,
prediction,
[img.name],
)
break
return loss_data
def check_jit_save_load(self, model, inputs, input_spec, to_static, gt_out):
def check_jit_save_load(
self,
model,
inputs,
input_spec,
to_static,
gt_out,
input_names_after_prune,
):
if to_static:
infer_model_path = os.path.join(
self.temp_dir.name, 'test_mnist_inference_model_by_jit_save'
......@@ -257,6 +270,7 @@ class TestMNISTWithToStatic(TestMNIST):
path=model_save_prefix,
input_spec=input_spec,
output_spec=[gt_out],
input_names_after_prune=input_names_after_prune,
)
# load in static graph mode
static_infer_out = self.jit_load_and_run_inference_static(
......
......@@ -454,8 +454,9 @@ class TestSeResnet(unittest.TestCase):
paddle.jit.save(
se_resnext,
self.model_save_prefix,
[img],
[img, label],
output_spec=[pred],
input_names_after_prune=[img.name],
)
else:
paddle.save(
......
......@@ -47,59 +47,42 @@ class TestArgsSpecName(unittest.TestCase):
def test_spec_name_hash(self):
net = Net()
net = paddle.jit.to_static(net)
# Convert into program with four input
self.read_from_dataset()
self.run_test(net, [self.x, self.y, self.m, self.n], 1, [0, 1, 2, 3])
# Convert into program with three input
self.read_from_dataset()
self.run_test(net, [self.x, self.x, self.m, self.n], 2, [0, 0, 1, 2])
self.run_test(net, [self.x, self.x, self.m, self.n], 1, [0, 0, 1, 2])
# Convert into program with two input
self.read_from_dataset()
self.run_test(net, [self.x, self.x, self.m, self.m], 3, [0, 0, 1, 1])
self.run_test(net, [self.x, self.x, self.m, self.m], 1, [0, 0, 1, 1])
# Use Cache Program
self.read_from_dataset()
self.run_test(net, [self.n, self.n, self.y, self.y], 3, [0, 0, 1, 1])
self.run_test(net, [self.n, self.n, self.y, self.y], 1, [0, 0, 1, 1])
# Convert into program with two input
self.read_from_dataset()
self.run_test(net, [self.x, self.y, self.x, self.y], 4, [0, 1, 0, 1])
self.run_test(net, [self.x, self.y, self.x, self.y], 1, [0, 1, 0, 1])
# Use Cache Program
self.read_from_dataset()
self.run_test(net, [self.m, self.n, self.m, self.n], 4, [0, 1, 0, 1])
self.run_test(net, [self.m, self.n, self.m, self.n], 1, [0, 1, 0, 1])
# Convert into program with one input
self.read_from_dataset()
self.run_test(net, [self.x, self.x, self.x, self.x], 5, [0, 0, 0, 0])
self.run_test(net, [self.x, self.x, self.x, self.x], 1, [0, 0, 0, 0])
# Use Cache Program
self.read_from_dataset()
self.run_test(net, [self.m, self.m, self.m, self.m], 5, [0, 0, 0, 0])
self.run_test(net, [self.m, self.m, self.m, self.m], 1, [0, 0, 0, 0])
def run_test(self, net, inputs, trace_count, mode):
out = net(*inputs)
self.assertEqual(net.forward.get_traced_count(), trace_count)
self.assert_feed_mode(net.forward.inputs, mode)
def assert_feed_mode(self, inputs, expect_mode):
assert isinstance(inputs, list)
assert isinstance(expect_mode, list)
in_names = [var.name for var in inputs]
i, name_ids = 0, {}
def to_idx(name):
nonlocal i
if name not in name_ids:
name_ids[name] = i
i += 1
return name_ids[name]
mode = [to_idx(name) for name in in_names]
self.assertEqual(mode, expect_mode)
if __name__ == '__main__':
......
......@@ -135,6 +135,8 @@ class TestRunProgram(unittest.TestCase):
[out.name + '@GRAD'],
'x_grad_names',
[x_t.name + '@GRAD', y_t.name + '@GRAD'],
'x_names',
[x_t.name, y_t.name],
]
use_interpretorcore = True
......
......@@ -295,7 +295,13 @@ class TestNetWithNonTensorSpecWithPrune(unittest.TestCase):
# jit.save and jit.load with prune y and loss
prune_specs = [self.x_spec, True]
paddle.jit.save(net, path, prune_specs, output_spec=[st_out])
paddle.jit.save(
net,
path,
prune_specs,
output_spec=[st_out],
input_names_after_prune=[self.x_spec.name],
)
load_net = paddle.jit.load(path)
load_net.eval()
load_out = load_net(self.x) # no y and no loss
......
......@@ -647,7 +647,13 @@ class TestSaveLoadWithInputSpec(unittest.TestCase):
self.temp_dir.name, "multi_inout1.output_spec2/model"
)
output_spec = net.forward.outputs[:1]
paddle.jit.save(net, model_path, (input_x,), output_spec=output_spec)
paddle.jit.save(
net,
model_path,
net.forward.inputs,
output_spec=output_spec,
input_names_after_prune=[input_x.name],
)
# 2. load again
infer_layer2 = paddle.jit.load(model_path)
# 3. predict
......@@ -945,9 +951,11 @@ class TestJitSaveMultiCases(unittest.TestCase):
layer,
model_path,
input_spec=[
InputSpec(shape=[None, 784], dtype='float32', name="image")
InputSpec(shape=[None, 784], dtype='float32', name="image"),
True,
],
output_spec=[out],
input_names_after_prune=["image"],
)
self.verify_inference_correctness(
......@@ -967,9 +975,11 @@ class TestJitSaveMultiCases(unittest.TestCase):
layer,
model_path,
input_spec=[
InputSpec(shape=[None, 784], dtype='float32', name="image")
InputSpec(shape=[None, 784], dtype='float32', name="image"),
True,
],
output_spec=output_spec,
input_names_after_prune=["image"],
)
self.verify_inference_correctness(
......@@ -1082,9 +1092,11 @@ class TestJitSaveMultiCases(unittest.TestCase):
layer,
model_path,
input_spec=[
InputSpec(shape=[None, 784], dtype='float32', name="image")
InputSpec(shape=[None, 784], dtype='float32', name="image"),
True,
],
output_spec=[out],
input_names_after_prune=["image"],
)
......
......@@ -63,7 +63,11 @@ class TestExportPrunedGraph(unittest.TestCase):
paddle.jit.to_static(model)
out = model(self.x, self.y, z=True)
paddle.onnx.export(
model, 'pruned', input_spec=[self.x], output_spec=[out]
model,
'pruned',
input_spec=[self.x, self.y, True],
output_spec=[out],
input_names_after_prune=[self.x.name],
)
......
......@@ -246,6 +246,8 @@ class RunProgramOpTest(unittest.TestCase):
[out.name + '@GRAD' for out in outputs['Out']],
'x_grad_names',
[p.name + '@GRAD' for p in inputs['X']],
'x_names',
[t.name for t in inputs['X']],
)
)
......@@ -297,6 +299,8 @@ class RunProgramOpTest(unittest.TestCase):
[out.name + '@GRAD' for out in outputs['Out']],
'x_grad_names',
[p.name + '@GRAD' for p in inputs['X']],
'x_names',
[t.name for t in inputs['X']],
)
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册