未验证 提交 f063074f 编写于 作者: A Aurelius84 提交者: GitHub

[API]Fix paddle.arange infershape always -1 (#52764)

上级 f650e901
......@@ -151,6 +151,7 @@ class TestArangeAPI(unittest.TestCase):
expected_data = np.arange(0, 5, 1).astype(np.float32)
self.assertEqual((out == expected_data).all(), True)
self.assertListEqual(list(x1.shape), [5])
class TestArangeImperative(unittest.TestCase):
......
......@@ -1293,6 +1293,14 @@ def arange(start=0, end=None, step=1, dtype=None, name=None):
end = start
start = 0
out_shape = None
if not in_dygraph_mode() and (
not isinstance(start, Variable)
and not isinstance(end, Variable)
and not isinstance(step, Variable)
):
out_shape = [int(math.ceil((end - start) / step))]
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
......@@ -1324,13 +1332,6 @@ def arange(start=0, end=None, step=1, dtype=None, name=None):
'range/arange',
)
helper = LayerHelper('range', **locals())
out_shape = None
if (
not isinstance(start, Variable)
and not isinstance(end, Variable)
and not isinstance(step, Variable)
):
out_shape = [int(math.ceil((end - start) / step))]
out = helper.create_variable_for_type_inference(dtype, shape=out_shape)
helper.append_op(
type='range',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册