diff --git a/python/paddle/jit/dy2static/function_spec.py b/python/paddle/jit/dy2static/function_spec.py index da82295409feeb507943aa0923520ba754b9bd54..cb40c3ae7d43ef4f771ae077c2b3b8856cd2f3a8 100644 --- a/python/paddle/jit/dy2static/function_spec.py +++ b/python/paddle/jit/dy2static/function_spec.py @@ -373,7 +373,14 @@ def convert_to_input_spec(inputs, input_spec): ) real_spec.name = input_spec.name if spec_greater(input_spec, real_spec): - return input_spec + # change shape but keep the others (stop_gradient / dtype) . + real_spec.shape = input_spec.shape + else: + logging_utils.warn( + "input spec is not compatitable with real inputs. input_spec: {input_spec} , real_spec: {real_spec} ".format( + input_spec=input_spec, real_spec=real_spec + ) + ) return real_spec else: # NOTE(Aurelius84): Support non-Tensor type as input spec info @@ -480,8 +487,4 @@ def spec_greater(first, other): return False return True - return ( - other.stop_gradient == first.stop_gradient - and other.dtype == first.dtype - and _shape_greater(first.shape, other.shape) - ) + return _shape_greater(first.shape, other.shape)