未验证 提交 e52dae6e 编写于 作者: A Aurelius84 提交者: GitHub

Using input.place() in GetExpectedKernel in slice_op (#25595)

* modify GetExpectedKernelType

* use input place

* add ENFORCE check
上级 595a7197
......@@ -148,9 +148,17 @@ class SliceOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto *in_var = ctx.InputVar("Input");
if (in_var->IsType<framework::LoDTensor>()) {
auto &in_tensor = in_var->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(
in_tensor.IsInitialized(), true,
platform::errors::InvalidArgument(
"The tensor Input (Input) of Slice op is not initialized."));
return framework::OpKernelType(in_tensor.type(), in_tensor.place());
}
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册