diff --git a/paddle/fluid/operators/psroi_pool_op.cc b/paddle/fluid/operators/psroi_pool_op.cc index 6978d9c5dc5993e64793f420a63dcca020f47868..78989582b7a0da5b7ff326cea1606df9993bed4c 100644 --- a/paddle/fluid/operators/psroi_pool_op.cc +++ b/paddle/fluid/operators/psroi_pool_op.cc @@ -129,9 +129,8 @@ class PSROIPoolOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; @@ -150,9 +149,8 @@ class PSROIPoolGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } };