diff --git a/paddle/fluid/framework/op_registry.h b/paddle/fluid/framework/op_registry.h index 751e150845826d4ac350d440dae6165a3e417ba3..3314e41cc51d74f87be0e2cd5eba9bb260c16be7 100644 --- a/paddle/fluid/framework/op_registry.h +++ b/paddle/fluid/framework/op_registry.h @@ -146,7 +146,7 @@ struct OpKernelRegistrarFunctorEx struct OpKernelRegistrarFunctorEx { - using KERNEL_TYPE = + using Functor = typename std::tuple_element>::type; using T = @@ -154,10 +154,7 @@ struct OpKernelRegistrarFunctorEx>::type; void operator()(const char* op_type, const char* library_type) const { - RegisterKernelClass( - op_type, library_type, [](const framework::ExecutionContext& ctx) { - KERNEL_TYPE().Compute(ctx); - }); + RegisterKernelClass(op_type, library_type, Functor()); constexpr auto size = std::tuple_size>::value; @@ -238,11 +235,11 @@ struct OpKernelRegistrarFunctorEx("Out"); auto *in = ctx.Input("X"); @@ -147,7 +147,7 @@ void ReshapeKernel::Compute(const framework::ExecutionContext &ctx) const { out->Resize(out_dims); } } -void ReshapeGradKernelBase::Compute( +void ReshapeGradKernel::operator()( const framework::ExecutionContext &ctx) const { auto *d_out = ctx.Input(framework::GradVarName("Out")); auto *d_x = ctx.Output(framework::GradVarName("X")); @@ -172,10 +172,10 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp); -REGISTER_OP_CPU_KERNEL_EX(reshape, float, ops::ReshapeKernel, double, - ops::ReshapeKernel, int, ops::ReshapeKernel, int64_t, - ops::ReshapeKernel); -REGISTER_OP_CPU_KERNEL(reshape_grad, ops::ReshapeGradKernel, - ops::ReshapeGradKernel, - ops::ReshapeGradKernel, - ops::ReshapeGradKernel); +REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, + ops::ReshapeKernel, int, ops::ReshapeKernel, + int64_t, ops::ReshapeKernel); +REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, + double, ops::ReshapeGradKernel, int, + ops::ReshapeGradKernel, int64_t, + ops::ReshapeGradKernel); diff --git a/paddle/fluid/operators/reshape_op.cu.cc b/paddle/fluid/operators/reshape_op.cu.cc index 8a09321eef27c67c0ec1f722ad9bf66f2b728e51..374b2dbc6accef165701afdcae8d69c62a55e2a5 100644 --- a/paddle/fluid/operators/reshape_op.cu.cc +++ b/paddle/fluid/operators/reshape_op.cu.cc @@ -14,11 +14,11 @@ limitations under the License. */ #include "paddle/fluid/operators/reshape_op.h" namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL_EX(reshape, float, ops::ReshapeKernel, double, - ops::ReshapeKernel, int, ops::ReshapeKernel, int64_t, - ops::ReshapeKernel); -REGISTER_OP_CUDA_KERNEL(reshape_grad, - paddle::operators::ReshapeGradKernel, - paddle::operators::ReshapeGradKernel, - paddle::operators::ReshapeGradKernel, - paddle::operators::ReshapeGradKernel); + +REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, + ops::ReshapeKernel, int, ops::ReshapeKernel, + int64_t, ops::ReshapeKernel); +REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, + double, ops::ReshapeGradKernel, int, + ops::ReshapeGradKernel, int64_t, + ops::ReshapeGradKernel); diff --git a/paddle/fluid/operators/reshape_op.h b/paddle/fluid/operators/reshape_op.h index c0b57d11d392620dfeefb23a6b26a3b0a609d9a3..68e1690a53c734d94800876ff179826c039d5dfb 100644 --- a/paddle/fluid/operators/reshape_op.h +++ b/paddle/fluid/operators/reshape_op.h @@ -118,21 +118,15 @@ class ReshapeOp : public framework::OperatorWithKernel { } }; -class ReshapeKernel : public framework::OpKernelBase { +class ReshapeKernel { public: - void Compute(const framework::ExecutionContext &ctx) const final; + void operator()(const framework::ExecutionContext &ctx) const; }; -class ReshapeGradKernelBase : public framework::OpKernelBase { +class ReshapeGradKernel { public: - void Compute(const framework::ExecutionContext &ctx) const; + void operator()(const framework::ExecutionContext &ctx) const; }; -template -class ReshapeGradKernel : public ReshapeGradKernelBase { - public: - // Tell register element type. - using ELEMENT_TYPE = T; -}; } // namespace operators } // namespace paddle