From 82866d4a1810f8a1c3a8a9b7e866a133c4fe5c4b Mon Sep 17 00:00:00 2001 From: yuyang18 Date: Mon, 2 Jul 2018 16:54:41 +0800 Subject: [PATCH] Add register kernel functor and shrink reshape op * Shrink reshape_op library size * User can register a standard C++ functor as a op kernel --- paddle/fluid/framework/op_registry.h | 13 +++++-------- paddle/fluid/operators/reshape_op.cc | 18 +++++++++--------- paddle/fluid/operators/reshape_op.cu.cc | 16 ++++++++-------- paddle/fluid/operators/reshape_op.h | 14 ++++---------- 4 files changed, 26 insertions(+), 35 deletions(-) diff --git a/paddle/fluid/framework/op_registry.h b/paddle/fluid/framework/op_registry.h index 751e150845..3314e41cc5 100644 --- a/paddle/fluid/framework/op_registry.h +++ b/paddle/fluid/framework/op_registry.h @@ -146,7 +146,7 @@ struct OpKernelRegistrarFunctorEx struct OpKernelRegistrarFunctorEx { - using KERNEL_TYPE = + using Functor = typename std::tuple_element>::type; using T = @@ -154,10 +154,7 @@ struct OpKernelRegistrarFunctorEx>::type; void operator()(const char* op_type, const char* library_type) const { - RegisterKernelClass( - op_type, library_type, [](const framework::ExecutionContext& ctx) { - KERNEL_TYPE().Compute(ctx); - }); + RegisterKernelClass(op_type, library_type, Functor()); constexpr auto size = std::tuple_size>::value; @@ -238,11 +235,11 @@ struct OpKernelRegistrarFunctorEx("Out"); auto *in = ctx.Input("X"); @@ -147,7 +147,7 @@ void ReshapeKernel::Compute(const framework::ExecutionContext &ctx) const { out->Resize(out_dims); } } -void ReshapeGradKernelBase::Compute( +void ReshapeGradKernel::operator()( const framework::ExecutionContext &ctx) const { auto *d_out = ctx.Input(framework::GradVarName("Out")); auto *d_x = ctx.Output(framework::GradVarName("X")); @@ -172,10 +172,10 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp); -REGISTER_OP_CPU_KERNEL_EX(reshape, float, ops::ReshapeKernel, double, - ops::ReshapeKernel, int, ops::ReshapeKernel, int64_t, - ops::ReshapeKernel); -REGISTER_OP_CPU_KERNEL(reshape_grad, ops::ReshapeGradKernel, - ops::ReshapeGradKernel, - ops::ReshapeGradKernel, - ops::ReshapeGradKernel); +REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, + ops::ReshapeKernel, int, ops::ReshapeKernel, + int64_t, ops::ReshapeKernel); +REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, + double, ops::ReshapeGradKernel, int, + ops::ReshapeGradKernel, int64_t, + ops::ReshapeGradKernel); diff --git a/paddle/fluid/operators/reshape_op.cu.cc b/paddle/fluid/operators/reshape_op.cu.cc index 8a09321eef..374b2dbc6a 100644 --- a/paddle/fluid/operators/reshape_op.cu.cc +++ b/paddle/fluid/operators/reshape_op.cu.cc @@ -14,11 +14,11 @@ limitations under the License. */ #include "paddle/fluid/operators/reshape_op.h" namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL_EX(reshape, float, ops::ReshapeKernel, double, - ops::ReshapeKernel, int, ops::ReshapeKernel, int64_t, - ops::ReshapeKernel); -REGISTER_OP_CUDA_KERNEL(reshape_grad, - paddle::operators::ReshapeGradKernel, - paddle::operators::ReshapeGradKernel, - paddle::operators::ReshapeGradKernel, - paddle::operators::ReshapeGradKernel); + +REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, + ops::ReshapeKernel, int, ops::ReshapeKernel, + int64_t, ops::ReshapeKernel); +REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, + double, ops::ReshapeGradKernel, int, + ops::ReshapeGradKernel, int64_t, + ops::ReshapeGradKernel); diff --git a/paddle/fluid/operators/reshape_op.h b/paddle/fluid/operators/reshape_op.h index c0b57d11d3..68e1690a53 100644 --- a/paddle/fluid/operators/reshape_op.h +++ b/paddle/fluid/operators/reshape_op.h @@ -118,21 +118,15 @@ class ReshapeOp : public framework::OperatorWithKernel { } }; -class ReshapeKernel : public framework::OpKernelBase { +class ReshapeKernel { public: - void Compute(const framework::ExecutionContext &ctx) const final; + void operator()(const framework::ExecutionContext &ctx) const; }; -class ReshapeGradKernelBase : public framework::OpKernelBase { +class ReshapeGradKernel { public: - void Compute(const framework::ExecutionContext &ctx) const; + void operator()(const framework::ExecutionContext &ctx) const; }; -template -class ReshapeGradKernel : public ReshapeGradKernelBase { - public: - // Tell register element type. - using ELEMENT_TYPE = T; -}; } // namespace operators } // namespace paddle -- GitLab