diff --git a/paddle/fluid/operators/range_op.cc b/paddle/fluid/operators/range_op.cc index 31ef777e5f041c6bedf17095a1302dd976923726..8585ecd2f94cc86c4d130b47b14c7c7f68620237 100644 --- a/paddle/fluid/operators/range_op.cc +++ b/paddle/fluid/operators/range_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/range_op.h" +#include namespace paddle { namespace operators { @@ -65,6 +66,13 @@ class RangeOp : public framework::OperatorWithKernel { } ctx->SetOutputDim("Out", {-1}); } + + protected: + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + return expected_kernel_type; + } }; class RangeOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 34b847f0e2bfa3eaae9bc6775dcce50a2e30ae4c..bba0baac016e86aa25260a7687a235c6cf5baaf4 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -18,7 +18,7 @@ from six.moves import reduce from ..layer_helper import LayerHelper from ..param_attr import ParamAttr from ..initializer import Initializer -from ..framework import convert_np_dtype_to_dtype_, in_dygraph_mode, _varbase_creator +from ..framework import convert_np_dtype_to_dtype_, in_dygraph_mode, _varbase_creator, device_guard from ..framework import Variable from ..initializer import Constant from ..core import VarDesc @@ -1394,17 +1394,20 @@ def range(start, end, step, dtype, name=None): dtype = convert_np_dtype_to_dtype_(dtype) if not isinstance(start, Variable): - start = fill_constant([1], dtype, start) + with device_guard("cpu"): + start = fill_constant([1], dtype, start) elif start.dtype != dtype: start = cast(start, dtype) if not isinstance(end, Variable): - end = fill_constant([1], dtype, end) + with device_guard("cpu"): + end = fill_constant([1], dtype, end) elif end.dtype != dtype: end = cast(end, dtype) if not isinstance(step, Variable): - step = fill_constant([1], dtype, step) + with device_guard("cpu"): + step = fill_constant([1], dtype, step) elif step.dtype != dtype: step = cast(step, dtype)