diff --git a/paddle/fluid/operators/bipartite_match_op.cc b/paddle/fluid/operators/bipartite_match_op.cc index 2b3f26c0a890c33f9b4f4c8a5a271123d7ff0b31..1218d9fdc1e6101d17bc09a4ae769f5fbf8e7b15 100644 --- a/paddle/fluid/operators/bipartite_match_op.cc +++ b/paddle/fluid/operators/bipartite_match_op.cc @@ -41,6 +41,14 @@ class BipartiteMatchOp : public framework::OperatorWithKernel { ctx->SetOutputDim("ColToRowMatchIndices", dims); ctx->SetOutputDim("ColToRowMatchDist", dims); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("DistMat")->type()), + platform::CPUPlace()); + } }; template diff --git a/paddle/fluid/operators/multiclass_nms_op.cc b/paddle/fluid/operators/multiclass_nms_op.cc index 2565e7e9efad415c5e4db2489afa9553683b7b0a..c4e70cde6f8c6bdf1f28b010b0b90091772fdffb 100644 --- a/paddle/fluid/operators/multiclass_nms_op.cc +++ b/paddle/fluid/operators/multiclass_nms_op.cc @@ -62,7 +62,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel { return framework::OpKernelType( framework::ToDataType( ctx.Input("Scores")->type()), - ctx.device_context()); + platform::CPUPlace()); } }; diff --git a/paddle/fluid/operators/prior_box_op.cc b/paddle/fluid/operators/prior_box_op.cc index 922b2bd237a1ec54aea895a05ccd78cd624e88ae..be7898c22190339e0717317807b91e038f4949f6 100644 --- a/paddle/fluid/operators/prior_box_op.cc +++ b/paddle/fluid/operators/prior_box_op.cc @@ -67,6 +67,14 @@ class PriorBoxOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Boxes", framework::make_ddim(dim_vec)); ctx->SetOutputDim("Variances", framework::make_ddim(dim_vec)); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Input")->type()), + platform::CPUPlace()); + } }; class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker {