未验证 提交 ae1f3209 编写于 作者: C Chen Weihang 提交者: GitHub

fix prune input bug (#30384)

上级 cf786d22
...@@ -470,19 +470,22 @@ class StaticFunction(object): ...@@ -470,19 +470,22 @@ class StaticFunction(object):
cached_program_len = len(self._program_cache) cached_program_len = len(self._program_cache)
# If specific `input_spec`, apply convertion from dygraph layers into static Program. # If specific `input_spec`, apply convertion from dygraph layers into static Program.
if cached_program_len == 0: if cached_program_len == 0:
if input_spec is None: desired_input_spec = input_spec
input_spec = self._function_spec.input_spec if self._function_spec.input_spec is not None:
elif self._function_spec.input_spec is not None: if input_spec is not None and not input_specs_compatible(
if not input_specs_compatible(
flatten(input_spec), flatten(input_spec),
flatten(self._function_spec.input_spec)): flatten(self._function_spec.input_spec)):
raise ValueError( raise ValueError(
"The `input_spec`: {} used to construct concrete_program is conflict with the `input_spec`: {} in `@paddle.jit.to_static`". "The `input_spec`: {} used to construct concrete_program is conflict with the `input_spec`: {} in `@paddle.jit.to_static`".
format(input_spec, self._function_spec.input_spec)) format(input_spec, self._function_spec.input_spec))
# NOTE(chenweihang): we should always translated program based on the `input_spec`
# decorated on forward if it is valid
desired_input_spec = self._function_spec.input_spec
has_input_spec = (input_spec is not None) has_input_spec = (desired_input_spec is not None)
if has_input_spec: if has_input_spec:
concrete_program, _ = self.get_concrete_program(*input_spec) concrete_program, _ = self.get_concrete_program(
*desired_input_spec)
return concrete_program return concrete_program
else: else:
raise ValueError( raise ValueError(
......
...@@ -1222,37 +1222,41 @@ def unwrap(func): ...@@ -1222,37 +1222,41 @@ def unwrap(func):
return unwrapped_f return unwrapped_f
def input_specs_compatible(src_input_specs, other_input_specs): def input_specs_compatible(src_input_specs, desired_input_specs):
""" """
Returns True if the two input specs are compatible, otherwise False. Returns True if the two input specs are compatible, otherwise False.
args: args:
src_input_spec (list[InputSpec]|tuple(InputSpec)): list/tuple of src_input_spec (list[InputSpec]|tuple(InputSpec)): list/tuple of
paddle.static.InputSpec paddle.static.InputSpec
other_input_spec (list[InputSpec]|tuple(InputSpec)): list/tuple of desired_input_specs (list[InputSpec]|tuple(InputSpec)): list/tuple of
paddle.static.InputSpec paddle.static.InputSpec
""" """
len_specs = len(src_input_specs) len_specs = len(src_input_specs)
if len_specs != len(other_input_specs): if len_specs != len(desired_input_specs):
return False # NOTE(chenweihang): if the input_spec of jit.save is a subset of
# input_spec of to_static, also compatible
for i in range(len_specs): for spec in src_input_specs:
src_shape = src_input_specs[i].shape if spec not in desired_input_specs:
other_shape = other_input_specs[i].shape return False
len_shape = len(src_shape) else:
if len_shape != len(other_shape): for i in range(len_specs):
return False src_shape = src_input_specs[i].shape
for j in range(len_shape): other_shape = desired_input_specs[i].shape
if src_shape[j] is None or src_shape[j] < 0: len_shape = len(src_shape)
continue if len_shape != len(other_shape):
if other_shape[j] is None or other_shape[j] < 0: return False
continue for j in range(len_shape):
if src_shape[j] != other_shape[j]: if src_shape[j] is None or src_shape[j] < 0:
continue
if other_shape[j] is None or other_shape[j] < 0:
continue
if src_shape[j] != other_shape[j]:
return False
src_dtype = convert_dtype(src_input_specs[i].dtype)
other_dtype = convert_dtype(desired_input_specs[i].dtype)
if src_dtype != other_dtype:
return False return False
src_dtype = convert_dtype(src_input_specs[i].dtype)
other_dtype = convert_dtype(other_input_specs[i].dtype)
if src_dtype != other_dtype:
return False
return True return True
...@@ -95,6 +95,38 @@ class LinerNetWithLabel(paddle.nn.Layer): ...@@ -95,6 +95,38 @@ class LinerNetWithLabel(paddle.nn.Layer):
return out, avg_loss return out, avg_loss
class LinerNetWithPruneInput(paddle.nn.Layer):
def __init__(self, in_size, out_size):
super(LinerNetWithPruneInput, self).__init__()
self._linear = Linear(in_size, out_size)
@declarative(input_spec=[
InputSpec(
shape=[None, 784], dtype='float32', name="image"), InputSpec(
shape=[None, 1], dtype='int64', name="label")
])
def forward(self, x, label):
out = self._linear(x)
loss = fluid.layers.cross_entropy(out, label)
avg_loss = fluid.layers.mean(loss)
return out
class LinerNetWithUselessInput(paddle.nn.Layer):
def __init__(self, in_size, out_size):
super(LinerNetWithUselessInput, self).__init__()
self._linear = Linear(in_size, out_size)
@declarative(input_spec=[
InputSpec(
shape=[None, 784], dtype='float32', name="image"), InputSpec(
shape=[None, 1], dtype='int64', name="label")
])
def forward(self, x, label):
out = self._linear(x)
return out
class LinearNetReturnLoss(fluid.dygraph.Layer): class LinearNetReturnLoss(fluid.dygraph.Layer):
def __init__(self, in_size, out_size): def __init__(self, in_size, out_size):
super(LinearNetReturnLoss, self).__init__() super(LinearNetReturnLoss, self).__init__()
...@@ -627,16 +659,24 @@ class TestJitSaveMultiCases(unittest.TestCase): ...@@ -627,16 +659,24 @@ class TestJitSaveMultiCases(unittest.TestCase):
paddle.seed(SEED) paddle.seed(SEED)
paddle.framework.random._manual_program_seed(SEED) paddle.framework.random._manual_program_seed(SEED)
def verify_inference_correctness(self, layer, model_path, with_label=False): def verify_inference_correctness(self,
layer,
model_path,
with_label_and_loss=False,
with_label=False):
layer.eval() layer.eval()
loaded_layer = paddle.jit.load(model_path) loaded_layer = paddle.jit.load(model_path)
loaded_layer.eval() loaded_layer.eval()
# inference & compare # inference & compare
x = paddle.to_tensor(np.random.random((1, 784)).astype('float32')) x = paddle.to_tensor(np.random.random((1, 784)).astype('float32'))
if with_label: if with_label_and_loss:
y = paddle.to_tensor(np.random.random((1, 1)).astype('int64')) y = paddle.to_tensor(np.random.random((1, 1)).astype('int64'))
pred, _ = layer(x, y) pred, _ = layer(x, y)
pred = pred.numpy() pred = pred.numpy()
elif with_label:
y = paddle.to_tensor(np.random.random((1, 1)).astype('int64'))
pred = layer(x, y)
pred = pred.numpy()
else: else:
pred = layer(x).numpy() pred = layer(x).numpy()
loaded_pred = loaded_layer(x).numpy() loaded_pred = loaded_layer(x).numpy()
...@@ -714,7 +754,8 @@ class TestJitSaveMultiCases(unittest.TestCase): ...@@ -714,7 +754,8 @@ class TestJitSaveMultiCases(unittest.TestCase):
], ],
output_spec=[out]) output_spec=[out])
self.verify_inference_correctness(layer, model_path, True) self.verify_inference_correctness(
layer, model_path, with_label_and_loss=True)
def test_prune_to_static_no_train(self): def test_prune_to_static_no_train(self):
layer = LinerNetWithLabel(784, 1) layer = LinerNetWithLabel(784, 1)
...@@ -732,7 +773,36 @@ class TestJitSaveMultiCases(unittest.TestCase): ...@@ -732,7 +773,36 @@ class TestJitSaveMultiCases(unittest.TestCase):
], ],
output_spec=output_spec) output_spec=output_spec)
self.verify_inference_correctness(layer, model_path, True) self.verify_inference_correctness(
layer, model_path, with_label_and_loss=True)
def test_prune_input_to_static_no_train(self):
layer = LinerNetWithPruneInput(784, 1)
model_path = "test_prune_input_to_static_no_train/model"
paddle.jit.save(
layer,
model_path,
input_spec=[
InputSpec(
shape=[None, 784], dtype='float32', name="image")
])
self.verify_inference_correctness(layer, model_path, with_label=True)
def test_prune_useless_input_to_static_no_train(self):
layer = LinerNetWithUselessInput(784, 1)
model_path = "test_prune_useless_input_to_static_no_train/model"
paddle.jit.save(
layer,
model_path,
input_spec=[
InputSpec(
shape=[None, 784], dtype='float32', name="image")
])
self.verify_inference_correctness(layer, model_path, with_label=True)
def test_no_prune_input_spec_name_warning(self): def test_no_prune_input_spec_name_warning(self):
layer = LinearNetWithInputSpec(784, 1) layer = LinearNetWithInputSpec(784, 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册