提交 b10d84bc 编写于 作者: J jerrywgz

fix bug when run on GPU, test=develop

上级 5fb28565
...@@ -20,7 +20,7 @@ class BoxClipOp : public framework::OperatorWithKernel { ...@@ -20,7 +20,7 @@ class BoxClipOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("InputBox"), PADDLE_ENFORCE(ctx->HasInput("InputBox"),
"Input(InputBox) of BoxClipOp should not be null."); "Input(InputBox) of BoxClipOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("ImInfo"), PADDLE_ENFORCE(ctx->HasInput("ImInfo"),
...@@ -41,6 +41,13 @@ class BoxClipOp : public framework::OperatorWithKernel { ...@@ -41,6 +41,13 @@ class BoxClipOp : public framework::OperatorWithKernel {
ctx->ShareDim("InputBox", /*->*/ "OutputBox"); ctx->ShareDim("InputBox", /*->*/ "OutputBox");
ctx->ShareLoD("InputBox", /*->*/ "OutputBox"); ctx->ShareLoD("InputBox", /*->*/ "OutputBox");
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("RpnRois"));
return framework::OpKernelType(data_type, platform::CPUPlace());
}
}; };
class BoxClipOpMaker : public framework::OpProtoAndCheckerMaker { class BoxClipOpMaker : public framework::OpProtoAndCheckerMaker {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册