未验证 提交 3cdc7a01 编写于 作者: S ShiningZhang 提交者: GitHub

range can not return shape when enable_static (#42275)

上级 eb64983a
...@@ -1470,6 +1470,11 @@ def range(start, end, step, dtype, name=None): ...@@ -1470,6 +1470,11 @@ def range(start, end, step, dtype, name=None):
# [3, 4, 5, 6] # [3, 4, 5, 6]
""" """
out_shape = None
if not isinstance(start, Variable) and not isinstance(
end, Variable) and not isinstance(step, Variable):
out_shape = [int(math.ceil((end - start) / step))]
if not isinstance(dtype, core.VarDesc.VarType): if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
...@@ -1500,11 +1505,6 @@ def range(start, end, step, dtype, name=None): ...@@ -1500,11 +1505,6 @@ def range(start, end, step, dtype, name=None):
out.stop_gradient = True out.stop_gradient = True
return out return out
out_shape = None
if not isinstance(start, Variable) and not isinstance(
end, Variable) and not isinstance(step, Variable):
out_shape = [int(math.ceil((end - start) / step))]
check_dtype(dtype, 'dtype', ['float32', 'float64', 'int32', 'int64'], check_dtype(dtype, 'dtype', ['float32', 'float64', 'int32', 'int64'],
'range/arange') 'range/arange')
helper = LayerHelper('range', **locals()) helper = LayerHelper('range', **locals())
...@@ -1516,6 +1516,8 @@ def range(start, end, step, dtype, name=None): ...@@ -1516,6 +1516,8 @@ def range(start, end, step, dtype, name=None):
'Step': step}, 'Step': step},
outputs={'Out': out}) outputs={'Out': out})
out.stop_gradient = True out.stop_gradient = True
if out_shape is not None:
out.desc.set_shape(out_shape)
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册