未验证 提交 f16981b1 编写于 作者: H Huihuang Zheng 提交者: GitHub

Make range API set its out shape when possible (#32472)

`range` API set its output shape in dygraph but not in static graph, which can cause Dy2stat error. This PR set the shape of `range` API when possible.
上级 fba46ea3
...@@ -381,7 +381,10 @@ class LayerHelperBase(object): ...@@ -381,7 +381,10 @@ class LayerHelperBase(object):
return self.main_program.global_block().create_parameter( return self.main_program.global_block().create_parameter(
dtype=dtype, shape=shape, type=type, **attr._to_kwargs()) dtype=dtype, shape=shape, type=type, **attr._to_kwargs())
def create_variable_for_type_inference(self, dtype, stop_gradient=False): def create_variable_for_type_inference(self,
dtype,
stop_gradient=False,
shape=None):
"""Create a temporary variable that should be type inferred layer. """Create a temporary variable that should be type inferred layer.
Note: Note:
...@@ -397,6 +400,7 @@ class LayerHelperBase(object): ...@@ -397,6 +400,7 @@ class LayerHelperBase(object):
name=unique_name.generate_with_ignorable_key(".".join( name=unique_name.generate_with_ignorable_key(".".join(
[self.name, 'tmp'])), [self.name, 'tmp'])),
dtype=dtype, dtype=dtype,
shape=shape,
type=core.VarDesc.VarType.LOD_TENSOR, type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False, persistable=False,
stop_gradient=stop_gradient) stop_gradient=stop_gradient)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from __future__ import print_function from __future__ import print_function
import math
import numpy import numpy
import six import six
import warnings import warnings
...@@ -1373,6 +1374,11 @@ def range(start, end, step, dtype, name=None): ...@@ -1373,6 +1374,11 @@ def range(start, end, step, dtype, name=None):
if not isinstance(dtype, core.VarDesc.VarType): if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
out_shape = None
if not isinstance(start, Variable) and not isinstance(
end, Variable) and not isinstance(step, Variable):
out_shape = [int(math.ceil((end - start) / step))]
if not isinstance(start, Variable): if not isinstance(start, Variable):
with device_guard("cpu"): with device_guard("cpu"):
start = fill_constant([1], dtype, start, force_cpu=True) start = fill_constant([1], dtype, start, force_cpu=True)
...@@ -1397,7 +1403,7 @@ def range(start, end, step, dtype, name=None): ...@@ -1397,7 +1403,7 @@ def range(start, end, step, dtype, name=None):
check_dtype(dtype, 'dtype', ['float32', 'float64', 'int32', 'int64'], check_dtype(dtype, 'dtype', ['float32', 'float64', 'int32', 'int64'],
'range/arange') 'range/arange')
helper = LayerHelper('range', **locals()) helper = LayerHelper('range', **locals())
out = helper.create_variable_for_type_inference(dtype) out = helper.create_variable_for_type_inference(dtype, shape=out_shape)
helper.append_op( helper.append_op(
type='range', type='range',
inputs={'Start': start, inputs={'Start': start,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册