diff --git a/paddle/fluid/operators/slice_op.cc b/paddle/fluid/operators/slice_op.cc index 13dd89805453d1bdd8a41dcbdd0ad40a18ab5cbf..8f5df7b6d5d3cb6cee6f08edaeaa4269c70be937 100644 --- a/paddle/fluid/operators/slice_op.cc +++ b/paddle/fluid/operators/slice_op.cc @@ -148,9 +148,17 @@ class SliceOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { + auto *in_var = ctx.InputVar("Input"); + if (in_var->IsType()) { + auto &in_tensor = in_var->Get(); + PADDLE_ENFORCE_EQ( + in_tensor.IsInitialized(), true, + platform::errors::InvalidArgument( + "The tensor Input (Input) of Slice op is not initialized.")); + return framework::OpKernelType(in_tensor.type(), in_tensor.place()); + } return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context()); + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace()); } framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const Tensor &tensor,