From b10d84bc5aaee83c2f25e077c4f38461aafe3928 Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Mon, 21 Jan 2019 03:05:53 +0000 Subject: [PATCH] fix bug when run on GPU, test=develop --- paddle/fluid/operators/detection/box_clip_op.cc | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/detection/box_clip_op.cc b/paddle/fluid/operators/detection/box_clip_op.cc index 609bd5606b2..fb94d0fbc61 100644 --- a/paddle/fluid/operators/detection/box_clip_op.cc +++ b/paddle/fluid/operators/detection/box_clip_op.cc @@ -20,7 +20,7 @@ class BoxClipOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContext *ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("InputBox"), "Input(InputBox) of BoxClipOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("ImInfo"), @@ -41,6 +41,13 @@ class BoxClipOp : public framework::OperatorWithKernel { ctx->ShareDim("InputBox", /*->*/ "OutputBox"); ctx->ShareLoD("InputBox", /*->*/ "OutputBox"); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("RpnRois")); + return framework::OpKernelType(data_type, platform::CPUPlace()); + } }; class BoxClipOpMaker : public framework::OpProtoAndCheckerMaker { -- GitLab