未验证 提交 f16981b1 编写于 作者: H Huihuang Zheng 提交者: GitHub

Make range API set its out shape when possible (#32472)

`range` API set its output shape in dygraph but not in static graph, which can cause Dy2stat error. This PR set the shape of `range` API when possible.
上级 fba46ea3
......@@ -381,7 +381,10 @@ class LayerHelperBase(object):
return self.main_program.global_block().create_parameter(
dtype=dtype, shape=shape, type=type, **attr._to_kwargs())
def create_variable_for_type_inference(self, dtype, stop_gradient=False):
def create_variable_for_type_inference(self,
dtype,
stop_gradient=False,
shape=None):
"""Create a temporary variable that should be type inferred layer.
Note:
......@@ -397,6 +400,7 @@ class LayerHelperBase(object):
name=unique_name.generate_with_ignorable_key(".".join(
[self.name, 'tmp'])),
dtype=dtype,
shape=shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=stop_gradient)
......
......@@ -14,6 +14,7 @@
from __future__ import print_function
import math
import numpy
import six
import warnings
......@@ -1373,6 +1374,11 @@ def range(start, end, step, dtype, name=None):
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
out_shape = None
if not isinstance(start, Variable) and not isinstance(
end, Variable) and not isinstance(step, Variable):
out_shape = [int(math.ceil((end - start) / step))]
if not isinstance(start, Variable):
with device_guard("cpu"):
start = fill_constant([1], dtype, start, force_cpu=True)
......@@ -1397,7 +1403,7 @@ def range(start, end, step, dtype, name=None):
check_dtype(dtype, 'dtype', ['float32', 'float64', 'int32', 'int64'],
'range/arange')
helper = LayerHelper('range', **locals())
out = helper.create_variable_for_type_inference(dtype)
out = helper.create_variable_for_type_inference(dtype, shape=out_shape)
helper.append_op(
type='range',
inputs={'Start': start,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册