From 46d6080d0dc94a459b869c82d9fe4321e97deec8 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Tue, 14 Mar 2023 14:43:11 +0800 Subject: [PATCH] [dy2static] fix the speed problem introduced by #50883 (#51606) --- python/paddle/jit/dy2static/function_spec.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/python/paddle/jit/dy2static/function_spec.py b/python/paddle/jit/dy2static/function_spec.py index da82295409f..cb40c3ae7d4 100644 --- a/python/paddle/jit/dy2static/function_spec.py +++ b/python/paddle/jit/dy2static/function_spec.py @@ -373,7 +373,14 @@ def convert_to_input_spec(inputs, input_spec): ) real_spec.name = input_spec.name if spec_greater(input_spec, real_spec): - return input_spec + # change shape but keep the others (stop_gradient / dtype) . + real_spec.shape = input_spec.shape + else: + logging_utils.warn( + "input spec is not compatitable with real inputs. input_spec: {input_spec} , real_spec: {real_spec} ".format( + input_spec=input_spec, real_spec=real_spec + ) + ) return real_spec else: # NOTE(Aurelius84): Support non-Tensor type as input spec info @@ -480,8 +487,4 @@ def spec_greater(first, other): return False return True - return ( - other.stop_gradient == first.stop_gradient - and other.dtype == first.dtype - and _shape_greater(first.shape, other.shape) - ) + return _shape_greater(first.shape, other.shape) -- GitLab