未验证 提交 46d6080d 编写于 作者: X xiongkun 提交者: GitHub

[dy2static] fix the speed problem introduced by #50883 (#51606)

上级 dca81a43
...@@ -373,7 +373,14 @@ def convert_to_input_spec(inputs, input_spec): ...@@ -373,7 +373,14 @@ def convert_to_input_spec(inputs, input_spec):
) )
real_spec.name = input_spec.name real_spec.name = input_spec.name
if spec_greater(input_spec, real_spec): if spec_greater(input_spec, real_spec):
return input_spec # change shape but keep the others (stop_gradient / dtype) .
real_spec.shape = input_spec.shape
else:
logging_utils.warn(
"input spec is not compatitable with real inputs. input_spec: {input_spec} , real_spec: {real_spec} ".format(
input_spec=input_spec, real_spec=real_spec
)
)
return real_spec return real_spec
else: else:
# NOTE(Aurelius84): Support non-Tensor type as input spec info # NOTE(Aurelius84): Support non-Tensor type as input spec info
...@@ -480,8 +487,4 @@ def spec_greater(first, other): ...@@ -480,8 +487,4 @@ def spec_greater(first, other):
return False return False
return True return True
return ( return _shape_greater(first.shape, other.shape)
other.stop_gradient == first.stop_gradient
and other.dtype == first.dtype
and _shape_greater(first.shape, other.shape)
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册