From 76fb95fe769f991685818059324664da3d1d1af4 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Tue, 22 Sep 2020 09:06:10 -0700 Subject: [PATCH] avoid data transform for linspace OP (#27444) --- paddle/fluid/operators/linspace_op.cc | 11 +++++++++-- python/paddle/fluid/layers/tensor.py | 9 ++++++--- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/linspace_op.cc b/paddle/fluid/operators/linspace_op.cc index 667c6e89295..7cc07383bfa 100644 --- a/paddle/fluid/operators/linspace_op.cc +++ b/paddle/fluid/operators/linspace_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/linspace_op.h" +#include namespace paddle { namespace operators { @@ -21,7 +22,7 @@ class LinspaceOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { OP_INOUT_CHECK(ctx->HasInput("Start"), "Input", "Start", "linspace"); OP_INOUT_CHECK(ctx->HasInput("Stop"), "Input", "Stop", "linspace"); OP_INOUT_CHECK(ctx->HasInput("Num"), "Input", "Num", "linspace"); @@ -50,11 +51,17 @@ class LinspaceOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { + const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( framework::proto::VarType::Type(ctx.Attr("dtype")), ctx.GetPlace()); } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + return expected_kernel_type; + } }; class LinspaceOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index cf52f3b00fb..2fba578ec07 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -1453,11 +1453,14 @@ def linspace(start, stop, num, dtype=None, name=None): if not isinstance(dtype, core.VarDesc.VarType): dtype = convert_np_dtype_to_dtype_(dtype) if not isinstance(start, Variable): - tensor_start = fill_constant([1], dtype, start) + with device_guard("cpu"): + tensor_start = fill_constant([1], dtype, start) if not isinstance(stop, Variable): - tensor_stop = fill_constant([1], dtype, stop) + with device_guard("cpu"): + tensor_stop = fill_constant([1], dtype, stop) if not isinstance(num, Variable): - tensor_num = fill_constant([1], 'int32', num) + with device_guard("cpu"): + tensor_num = fill_constant([1], 'int32', num) if in_dygraph_mode(): return core.ops.linspace(tensor_start, tensor_stop, tensor_num, 'dtype', dtype) -- GitLab