未验证 提交 81030125 编写于 作者: M mapingshuo 提交者: GitHub

convert input vars' dtype for range op (#22028)

* convert dtype of vars for range op, test=develop
上级 95d79b6d
......@@ -1158,7 +1158,7 @@ def range(start, end, step, dtype):
it is a 1-D Tensor with shape [1].
step(float32 | float64 | int32 | int64 | Variable): Spacing between values. For any output out, this is the
distance between two adjacent values, out[i+1] - out[i].
dtype(str): the data type of the output tensor, can be float32, float64, int32, int64.
dtype(str|core.VarDesc.VarType): the data type of the output tensor, can be float32, float64, int32, int64.
Returns: a 1-D Tensor which is evenly spaced values within a given interval. Its data type is set by dtype.
......@@ -1174,12 +1174,26 @@ def range(start, end, step, dtype):
"""
helper = LayerHelper("range", **locals())
check_dtype(dtype, 'create data type',
['float32', 'float64', 'int32', 'int64'], 'range')
dtype = convert_dtype(dtype)
if not isinstance(start, Variable):
start = fill_constant([1], dtype, start)
elif convert_dtype(start.dtype) != dtype:
# make sure that start, end, step has the same dtype as
# `dtype`
start = cast(x=start, dtype=dtype)
if not isinstance(end, Variable):
end = fill_constant([1], dtype, end)
elif convert_dtype(end.dtype) != dtype:
end = cast(x=end, dtype=dtype)
if not isinstance(step, Variable):
step = fill_constant([1], dtype, step)
elif convert_dtype(step.dtype) != dtype:
step = cast(x=step, dtype=dtype)
out = helper.create_variable_for_type_inference(dtype=start.dtype)
......
......@@ -2561,7 +2561,12 @@ class TestBook(LayerTest):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
layers.range(0, 10, 2, 'int32')
y = layers.range(0.1, 10.0, 0.2, 'float32')
layers.range(0.1, 10.0, 0.2, 'float32')
layers.range(0.1, 10.0, 0.2, 'float64')
start = layers.fill_constant(shape=[1], value=0.1, dtype="float32")
end = layers.fill_constant(shape=[1], value=10.0, dtype="float32")
step = layers.fill_constant(shape=[1], value=0.2, dtype="float32")
y = layers.range(start, end, step, 'float64')
return y
def make_spectral_norm(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册