diff --git a/paddle/fluid/operators/detection/box_clip_op.cc b/paddle/fluid/operators/detection/box_clip_op.cc index 609bd5606b297d0d4b15c2d2dadbea061fd47ff6..fb94d0fbc619f13e73da5608b9ed2980bb91b124 100644 --- a/paddle/fluid/operators/detection/box_clip_op.cc +++ b/paddle/fluid/operators/detection/box_clip_op.cc @@ -20,7 +20,7 @@ class BoxClipOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContext *ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("InputBox"), "Input(InputBox) of BoxClipOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("ImInfo"), @@ -41,6 +41,13 @@ class BoxClipOp : public framework::OperatorWithKernel { ctx->ShareDim("InputBox", /*->*/ "OutputBox"); ctx->ShareLoD("InputBox", /*->*/ "OutputBox"); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("RpnRois")); + return framework::OpKernelType(data_type, platform::CPUPlace()); + } }; class BoxClipOpMaker : public framework::OpProtoAndCheckerMaker {