From 2d24f56a7acaec53450e98c0266f76fa780bdf4f Mon Sep 17 00:00:00 2001 From: Zhang Ting <709968123@qq.com> Date: Fri, 31 Jul 2020 20:20:38 +0800 Subject: [PATCH] avoid data transfer, test=develop (#25810) --- paddle/fluid/operators/range_op.cc | 8 ++++++++ python/paddle/fluid/layers/tensor.py | 11 +++++++---- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/range_op.cc b/paddle/fluid/operators/range_op.cc index 31ef777e5f..8585ecd2f9 100644 --- a/paddle/fluid/operators/range_op.cc +++ b/paddle/fluid/operators/range_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/range_op.h" +#include namespace paddle { namespace operators { @@ -65,6 +66,13 @@ class RangeOp : public framework::OperatorWithKernel { } ctx->SetOutputDim("Out", {-1}); } + + protected: + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + return expected_kernel_type; + } }; class RangeOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 34b847f0e2..bba0baac01 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -18,7 +18,7 @@ from six.moves import reduce from ..layer_helper import LayerHelper from ..param_attr import ParamAttr from ..initializer import Initializer -from ..framework import convert_np_dtype_to_dtype_, in_dygraph_mode, _varbase_creator +from ..framework import convert_np_dtype_to_dtype_, in_dygraph_mode, _varbase_creator, device_guard from ..framework import Variable from ..initializer import Constant from ..core import VarDesc @@ -1394,17 +1394,20 @@ def range(start, end, step, dtype, name=None): dtype = convert_np_dtype_to_dtype_(dtype) if not isinstance(start, Variable): - start = fill_constant([1], dtype, start) + with device_guard("cpu"): + start = fill_constant([1], dtype, start) elif start.dtype != dtype: start = cast(start, dtype) if not isinstance(end, Variable): - end = fill_constant([1], dtype, end) + with device_guard("cpu"): + end = fill_constant([1], dtype, end) elif end.dtype != dtype: end = cast(end, dtype) if not isinstance(step, Variable): - step = fill_constant([1], dtype, step) + with device_guard("cpu"): + step = fill_constant([1], dtype, step) elif step.dtype != dtype: step = cast(step, dtype) -- GitLab