未验证 提交 9e1ec8c9 编写于 作者: Q qingqing01 提交者: GitHub

Enable device switching automatically for serveral operators (#8684)

上级 ae2026e1
...@@ -41,6 +41,14 @@ class BipartiteMatchOp : public framework::OperatorWithKernel { ...@@ -41,6 +41,14 @@ class BipartiteMatchOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ColToRowMatchIndices", dims); ctx->SetOutputDim("ColToRowMatchIndices", dims);
ctx->SetOutputDim("ColToRowMatchDist", dims); ctx->SetOutputDim("ColToRowMatchDist", dims);
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("DistMat")->type()),
platform::CPUPlace());
}
}; };
template <typename T> template <typename T>
......
...@@ -62,7 +62,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel { ...@@ -62,7 +62,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType( framework::ToDataType(
ctx.Input<framework::LoDTensor>("Scores")->type()), ctx.Input<framework::LoDTensor>("Scores")->type()),
ctx.device_context()); platform::CPUPlace());
} }
}; };
......
...@@ -67,6 +67,14 @@ class PriorBoxOp : public framework::OperatorWithKernel { ...@@ -67,6 +67,14 @@ class PriorBoxOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Boxes", framework::make_ddim(dim_vec)); ctx->SetOutputDim("Boxes", framework::make_ddim(dim_vec));
ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec)); ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec));
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("Input")->type()),
platform::CPUPlace());
}
}; };
class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker { class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册