From d5b4570dd98bf59a1ea81fb04fc5ef62f8826e59 Mon Sep 17 00:00:00 2001 From: ShiningZhang Date: Tue, 26 Apr 2022 11:32:21 +0800 Subject: [PATCH] fix bug: arange can not return shape when enable_static (#42182) * fix bug: arange can not return shape when enable_static * fix bug: test_arange --- python/paddle/tensor/creation.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index aeec256bc1..a5a4df6571 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -827,6 +827,11 @@ def arange(start=0, end=None, step=1, dtype=None, name=None): end = start start = 0 + out_shape = None + if not isinstance(start, Variable) and not isinstance( + end, Variable) and not isinstance(step, Variable): + out_shape = [int(math.ceil((end - start) / step))] + if not isinstance(dtype, core.VarDesc.VarType): dtype = convert_np_dtype_to_dtype_(dtype) @@ -857,11 +862,6 @@ def arange(start=0, end=None, step=1, dtype=None, name=None): out.stop_gradient = True return out - out_shape = None - if not isinstance(start, Variable) and not isinstance( - end, Variable) and not isinstance(step, Variable): - out_shape = [int(math.ceil((end - start) / step))] - check_dtype(dtype, 'dtype', ['float32', 'float64', 'int32', 'int64'], 'range/arange') helper = LayerHelper('range', **locals()) @@ -873,6 +873,8 @@ def arange(start=0, end=None, step=1, dtype=None, name=None): 'Step': step}, outputs={'Out': out}) out.stop_gradient = True + if out_shape is not None: + out.desc.set_shape(out_shape) return out -- GitLab