未验证 提交 46d6080d 编写于 作者: X xiongkun 提交者: GitHub

[dy2static] fix the speed problem introduced by #50883 (#51606)

上级 dca81a43
......@@ -373,7 +373,14 @@ def convert_to_input_spec(inputs, input_spec):
)
real_spec.name = input_spec.name
if spec_greater(input_spec, real_spec):
return input_spec
# change shape but keep the others (stop_gradient / dtype) .
real_spec.shape = input_spec.shape
else:
logging_utils.warn(
"input spec is not compatitable with real inputs. input_spec: {input_spec} , real_spec: {real_spec} ".format(
input_spec=input_spec, real_spec=real_spec
)
)
return real_spec
else:
# NOTE(Aurelius84): Support non-Tensor type as input spec info
......@@ -480,8 +487,4 @@ def spec_greater(first, other):
return False
return True
return (
other.stop_gradient == first.stop_gradient
and other.dtype == first.dtype
and _shape_greater(first.shape, other.shape)
)
return _shape_greater(first.shape, other.shape)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册