未验证 提交 1ae0730f 编写于 作者: Z zyfncg 提交者: GitHub

fix bug caused by arange (#41372)

上级 ea4b56f2
...@@ -1433,10 +1433,6 @@ def range(start, end, step, dtype, name=None): ...@@ -1433,10 +1433,6 @@ def range(start, end, step, dtype, name=None):
if not isinstance(dtype, core.VarDesc.VarType): if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
if in_dygraph_mode():
return _C_ops.final_state_arange(start, end, step, dtype,
_current_expected_place())
if not isinstance(start, Variable): if not isinstance(start, Variable):
with device_guard("cpu"): with device_guard("cpu"):
start = fill_constant([1], dtype, start, force_cpu=True) start = fill_constant([1], dtype, start, force_cpu=True)
...@@ -1455,6 +1451,10 @@ def range(start, end, step, dtype, name=None): ...@@ -1455,6 +1451,10 @@ def range(start, end, step, dtype, name=None):
elif step.dtype != dtype: elif step.dtype != dtype:
step = cast(step, dtype) step = cast(step, dtype)
if in_dygraph_mode():
return _C_ops.final_state_arange(start, end, step, dtype,
_current_expected_place())
if _in_legacy_dygraph(): if _in_legacy_dygraph():
out = _C_ops.range(start, end, step) out = _C_ops.range(start, end, step)
out.stop_gradient = True out.stop_gradient = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册