From 9e1ec8c919179eae3527124c4bbd278d59f9ad4e Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Thu, 1 Mar 2018 21:24:23 +0800 Subject: [PATCH] Enable device switching automatically for serveral operators (#8684) --- paddle/fluid/operators/bipartite_match_op.cc | 8 ++++++++ paddle/fluid/operators/multiclass_nms_op.cc | 2 +- paddle/fluid/operators/prior_box_op.cc | 8 ++++++++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/bipartite_match_op.cc b/paddle/fluid/operators/bipartite_match_op.cc index 2b3f26c0a8..1218d9fdc1 100644 --- a/paddle/fluid/operators/bipartite_match_op.cc +++ b/paddle/fluid/operators/bipartite_match_op.cc @@ -41,6 +41,14 @@ class BipartiteMatchOp : public framework::OperatorWithKernel { ctx->SetOutputDim("ColToRowMatchIndices", dims); ctx->SetOutputDim("ColToRowMatchDist", dims); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("DistMat")->type()), + platform::CPUPlace()); + } }; template diff --git a/paddle/fluid/operators/multiclass_nms_op.cc b/paddle/fluid/operators/multiclass_nms_op.cc index 2565e7e9ef..c4e70cde6f 100644 --- a/paddle/fluid/operators/multiclass_nms_op.cc +++ b/paddle/fluid/operators/multiclass_nms_op.cc @@ -62,7 +62,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel { return framework::OpKernelType( framework::ToDataType( ctx.Input("Scores")->type()), - ctx.device_context()); + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/prior_box_op.cc b/paddle/fluid/operators/prior_box_op.cc index 922b2bd237..be7898c221 100644 --- a/paddle/fluid/operators/prior_box_op.cc +++ b/paddle/fluid/operators/prior_box_op.cc @@ -67,6 +67,14 @@ class PriorBoxOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Boxes", framework::make_ddim(dim_vec)); ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec)); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Input")->type()), + platform::CPUPlace()); + } }; class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker { -- GitLab