From e52dae6ef61762130fcfde59cd92fb687275e937 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Fri, 31 Jul 2020 10:06:50 +0800 Subject: [PATCH] Using input.place() in GetExpectedKernel in slice_op (#25595) * modify GetExpectedKernelType * use input place * add ENFORCE check --- paddle/fluid/operators/slice_op.cc | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/slice_op.cc b/paddle/fluid/operators/slice_op.cc index 13dd8980545..8f5df7b6d5d 100644 --- a/paddle/fluid/operators/slice_op.cc +++ b/paddle/fluid/operators/slice_op.cc @@ -148,9 +148,17 @@ class SliceOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { + auto *in_var = ctx.InputVar("Input"); + if (in_var->IsType()) { + auto &in_tensor = in_var->Get(); + PADDLE_ENFORCE_EQ( + in_tensor.IsInitialized(), true, + platform::errors::InvalidArgument( + "The tensor Input (Input) of Slice op is not initialized.")); + return framework::OpKernelType(in_tensor.type(), in_tensor.place()); + } return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace()); } framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const Tensor &tensor, -- GitLab