未验证 提交 ae1f3209 编写于 作者: C Chen Weihang 提交者: GitHub

fix prune input bug (#30384)

上级 cf786d22
......@@ -470,19 +470,22 @@ class StaticFunction(object):
cached_program_len = len(self._program_cache)
# If specific `input_spec`, apply convertion from dygraph layers into static Program.
if cached_program_len == 0:
if input_spec is None:
input_spec = self._function_spec.input_spec
elif self._function_spec.input_spec is not None:
if not input_specs_compatible(
desired_input_spec = input_spec
if self._function_spec.input_spec is not None:
if input_spec is not None and not input_specs_compatible(
flatten(input_spec),
flatten(self._function_spec.input_spec)):
raise ValueError(
"The `input_spec`: {} used to construct concrete_program is conflict with the `input_spec`: {} in `@paddle.jit.to_static`".
format(input_spec, self._function_spec.input_spec))
# NOTE(chenweihang): we should always translated program based on the `input_spec`
# decorated on forward if it is valid
desired_input_spec = self._function_spec.input_spec
has_input_spec = (input_spec is not None)
has_input_spec = (desired_input_spec is not None)
if has_input_spec:
concrete_program, _ = self.get_concrete_program(*input_spec)
concrete_program, _ = self.get_concrete_program(
*desired_input_spec)
return concrete_program
else:
raise ValueError(
......
......@@ -1222,23 +1222,27 @@ def unwrap(func):
return unwrapped_f
def input_specs_compatible(src_input_specs, other_input_specs):
def input_specs_compatible(src_input_specs, desired_input_specs):
"""
Returns True if the two input specs are compatible, otherwise False.
args:
src_input_spec (list[InputSpec]|tuple(InputSpec)): list/tuple of
paddle.static.InputSpec
other_input_spec (list[InputSpec]|tuple(InputSpec)): list/tuple of
desired_input_specs (list[InputSpec]|tuple(InputSpec)): list/tuple of
paddle.static.InputSpec
"""
len_specs = len(src_input_specs)
if len_specs != len(other_input_specs):
if len_specs != len(desired_input_specs):
# NOTE(chenweihang): if the input_spec of jit.save is a subset of
# input_spec of to_static, also compatible
for spec in src_input_specs:
if spec not in desired_input_specs:
return False
else:
for i in range(len_specs):
src_shape = src_input_specs[i].shape
other_shape = other_input_specs[i].shape
other_shape = desired_input_specs[i].shape
len_shape = len(src_shape)
if len_shape != len(other_shape):
return False
......@@ -1251,7 +1255,7 @@ def input_specs_compatible(src_input_specs, other_input_specs):
return False
src_dtype = convert_dtype(src_input_specs[i].dtype)
other_dtype = convert_dtype(other_input_specs[i].dtype)
other_dtype = convert_dtype(desired_input_specs[i].dtype)
if src_dtype != other_dtype:
return False
......
......@@ -95,6 +95,38 @@ class LinerNetWithLabel(paddle.nn.Layer):
return out, avg_loss
class LinerNetWithPruneInput(paddle.nn.Layer):
def __init__(self, in_size, out_size):
super(LinerNetWithPruneInput, self).__init__()
self._linear = Linear(in_size, out_size)
@declarative(input_spec=[
InputSpec(
shape=[None, 784], dtype='float32', name="image"), InputSpec(
shape=[None, 1], dtype='int64', name="label")
])
def forward(self, x, label):
out = self._linear(x)
loss = fluid.layers.cross_entropy(out, label)
avg_loss = fluid.layers.mean(loss)
return out
class LinerNetWithUselessInput(paddle.nn.Layer):
def __init__(self, in_size, out_size):
super(LinerNetWithUselessInput, self).__init__()
self._linear = Linear(in_size, out_size)
@declarative(input_spec=[
InputSpec(
shape=[None, 784], dtype='float32', name="image"), InputSpec(
shape=[None, 1], dtype='int64', name="label")
])
def forward(self, x, label):
out = self._linear(x)
return out
class LinearNetReturnLoss(fluid.dygraph.Layer):
def __init__(self, in_size, out_size):
super(LinearNetReturnLoss, self).__init__()
......@@ -627,16 +659,24 @@ class TestJitSaveMultiCases(unittest.TestCase):
paddle.seed(SEED)
paddle.framework.random._manual_program_seed(SEED)
def verify_inference_correctness(self, layer, model_path, with_label=False):
def verify_inference_correctness(self,
layer,
model_path,
with_label_and_loss=False,
with_label=False):
layer.eval()
loaded_layer = paddle.jit.load(model_path)
loaded_layer.eval()
# inference & compare
x = paddle.to_tensor(np.random.random((1, 784)).astype('float32'))
if with_label:
if with_label_and_loss:
y = paddle.to_tensor(np.random.random((1, 1)).astype('int64'))
pred, _ = layer(x, y)
pred = pred.numpy()
elif with_label:
y = paddle.to_tensor(np.random.random((1, 1)).astype('int64'))
pred = layer(x, y)
pred = pred.numpy()
else:
pred = layer(x).numpy()
loaded_pred = loaded_layer(x).numpy()
......@@ -714,7 +754,8 @@ class TestJitSaveMultiCases(unittest.TestCase):
],
output_spec=[out])
self.verify_inference_correctness(layer, model_path, True)
self.verify_inference_correctness(
layer, model_path, with_label_and_loss=True)
def test_prune_to_static_no_train(self):
layer = LinerNetWithLabel(784, 1)
......@@ -732,7 +773,36 @@ class TestJitSaveMultiCases(unittest.TestCase):
],
output_spec=output_spec)
self.verify_inference_correctness(layer, model_path, True)
self.verify_inference_correctness(
layer, model_path, with_label_and_loss=True)
def test_prune_input_to_static_no_train(self):
layer = LinerNetWithPruneInput(784, 1)
model_path = "test_prune_input_to_static_no_train/model"
paddle.jit.save(
layer,
model_path,
input_spec=[
InputSpec(
shape=[None, 784], dtype='float32', name="image")
])
self.verify_inference_correctness(layer, model_path, with_label=True)
def test_prune_useless_input_to_static_no_train(self):
layer = LinerNetWithUselessInput(784, 1)
model_path = "test_prune_useless_input_to_static_no_train/model"
paddle.jit.save(
layer,
model_path,
input_spec=[
InputSpec(
shape=[None, 784], dtype='float32', name="image")
])
self.verify_inference_correctness(layer, model_path, with_label=True)
def test_no_prune_input_spec_name_warning(self):
layer = LinearNetWithInputSpec(784, 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册