提交 707d144c 编写于 作者: L Luo Tao

Unify Reduce functions and simplify register code

上级 c3b46d16
...@@ -285,11 +285,9 @@ REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker<float>, stanh_grad, ...@@ -285,11 +285,9 @@ REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker<float>, stanh_grad,
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \ #define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_CPU_KERNEL( \ REGISTER_OP_CPU_KERNEL( \
act_type, \ act_type, \
paddle::operators::ActivationKernel<paddle::platform::CPUPlace, \ ops::ActivationKernel<paddle::platform::CPUPlace, ops::functor<float>>); \
paddle::operators::functor<float>>); \
REGISTER_OP_CPU_KERNEL(act_type##_grad, \ REGISTER_OP_CPU_KERNEL(act_type##_grad, \
paddle::operators::ActivationGradKernel< \ ops::ActivationGradKernel<paddle::platform::CPUPlace, \
paddle::platform::CPUPlace, \ ops::grad_functor<float>>);
paddle::operators::grad_functor<float>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);
...@@ -15,14 +15,14 @@ ...@@ -15,14 +15,14 @@
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/operators/activation_op.h" #include "paddle/operators/activation_op.h"
namespace ops = paddle::operators;
#define REGISTER_ACTIVATION_GPU_KERNEL(act_type, functor, grad_functor) \ #define REGISTER_ACTIVATION_GPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_GPU_KERNEL( \ REGISTER_OP_GPU_KERNEL( \
act_type, \ act_type, \
paddle::operators::ActivationKernel<paddle::platform::GPUPlace, \ ops::ActivationKernel<paddle::platform::GPUPlace, ops::functor<float>>); \
paddle::operators::functor<float>>); \
REGISTER_OP_GPU_KERNEL(act_type##_grad, \ REGISTER_OP_GPU_KERNEL(act_type##_grad, \
paddle::operators::ActivationGradKernel< \ ops::ActivationGradKernel<paddle::platform::GPUPlace, \
paddle::platform::GPUPlace, \ ops::grad_functor<float>>);
paddle::operators::grad_functor<float>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_GPU_KERNEL); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_GPU_KERNEL);
...@@ -168,36 +168,22 @@ namespace ops = paddle::operators; ...@@ -168,36 +168,22 @@ namespace ops = paddle::operators;
REGISTER_OP(reduce_sum, ops::ReduceOp, ops::ReduceSumOpMaker, reduce_sum_grad, REGISTER_OP(reduce_sum, ops::ReduceOp, ops::ReduceSumOpMaker, reduce_sum_grad,
ops::ReduceGradOp); ops::ReduceGradOp);
REGISTER_OP_CPU_KERNEL(
reduce_sum,
ops::ReduceKernel<paddle::platform::CPUPlace, float, ops::SumFunctor>);
REGISTER_OP_CPU_KERNEL(reduce_sum_grad,
ops::ReduceGradKernel<paddle::platform::CPUPlace, float,
ops::SumGradFunctor>);
REGISTER_OP(reduce_mean, ops::ReduceOp, ops::ReduceMeanOpMaker, REGISTER_OP(reduce_mean, ops::ReduceOp, ops::ReduceMeanOpMaker,
reduce_mean_grad, ops::ReduceGradOp); reduce_mean_grad, ops::ReduceGradOp);
REGISTER_OP_CPU_KERNEL(
reduce_mean,
ops::ReduceKernel<paddle::platform::CPUPlace, float, ops::MeanFunctor>);
REGISTER_OP_CPU_KERNEL(reduce_mean_grad,
ops::ReduceGradKernel<paddle::platform::CPUPlace, float,
ops::MeanGradFunctor>);
REGISTER_OP(reduce_max, ops::ReduceOp, ops::ReduceMaxOpMaker, reduce_max_grad, REGISTER_OP(reduce_max, ops::ReduceOp, ops::ReduceMaxOpMaker, reduce_max_grad,
ops::ReduceGradOp); ops::ReduceGradOp);
REGISTER_OP_CPU_KERNEL(
reduce_max,
ops::ReduceKernel<paddle::platform::CPUPlace, float, ops::MaxFunctor>);
REGISTER_OP_CPU_KERNEL(reduce_max_grad,
ops::ReduceGradKernel<paddle::platform::CPUPlace, float,
ops::MaxOrMinGradFunctor>);
REGISTER_OP(reduce_min, ops::ReduceOp, ops::ReduceMaxOpMaker, reduce_min_grad, REGISTER_OP(reduce_min, ops::ReduceOp, ops::ReduceMaxOpMaker, reduce_min_grad,
ops::ReduceGradOp); ops::ReduceGradOp);
REGISTER_OP_CPU_KERNEL(
reduce_min, #define REGISTER_REDUCE_CPU_KERNEL(reduce_type, functor, grad_functor) \
ops::ReduceKernel<paddle::platform::CPUPlace, float, ops::MinFunctor>); REGISTER_OP_CPU_KERNEL( \
REGISTER_OP_CPU_KERNEL(reduce_min_grad, reduce_type, \
ops::ReduceGradKernel<paddle::platform::CPUPlace, float, ops::ReduceKernel<paddle::platform::CPUPlace, float, ops::functor>); \
ops::MaxOrMinGradFunctor>); REGISTER_OP_CPU_KERNEL(reduce_type##_grad, \
ops::ReduceGradKernel<paddle::platform::CPUPlace, \
float, ops::grad_functor>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_REDUCE_CPU_KERNEL);
...@@ -17,30 +17,12 @@ ...@@ -17,30 +17,12 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( #define REGISTER_REDUCE_GPU_KERNEL(reduce_type, functor, grad_functor) \
reduce_sum, REGISTER_OP_GPU_KERNEL( \
ops::ReduceKernel<paddle::platform::GPUPlace, float, ops::SumFunctor>); reduce_type, \
REGISTER_OP_GPU_KERNEL(reduce_sum_grad, ops::ReduceKernel<paddle::platform::GPUPlace, float, ops::functor>); \
ops::ReduceGradKernel<paddle::platform::GPUPlace, float, REGISTER_OP_GPU_KERNEL(reduce_type##_grad, \
ops::SumGradFunctor>); ops::ReduceGradKernel<paddle::platform::GPUPlace, \
float, ops::grad_functor>);
REGISTER_OP_GPU_KERNEL(
reduce_mean, FOR_EACH_KERNEL_FUNCTOR(REGISTER_REDUCE_GPU_KERNEL);
ops::ReduceKernel<paddle::platform::GPUPlace, float, ops::MeanFunctor>);
REGISTER_OP_GPU_KERNEL(reduce_mean_grad,
ops::ReduceGradKernel<paddle::platform::GPUPlace, float,
ops::MeanGradFunctor>);
REGISTER_OP_GPU_KERNEL(
reduce_max,
ops::ReduceKernel<paddle::platform::GPUPlace, float, ops::MaxFunctor>);
REGISTER_OP_GPU_KERNEL(reduce_max_grad,
ops::ReduceGradKernel<paddle::platform::GPUPlace, float,
ops::MaxOrMinGradFunctor>);
REGISTER_OP_GPU_KERNEL(
reduce_min,
ops::ReduceKernel<paddle::platform::GPUPlace, float, ops::MinFunctor>);
REGISTER_OP_GPU_KERNEL(reduce_min_grad,
ops::ReduceGradKernel<paddle::platform::GPUPlace, float,
ops::MaxOrMinGradFunctor>);
...@@ -198,3 +198,9 @@ class ReduceGradKernel : public framework::OpKernel<T> { ...@@ -198,3 +198,9 @@ class ReduceGradKernel : public framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#define FOR_EACH_KERNEL_FUNCTOR(__macro) \
__macro(reduce_sum, SumFunctor, SumGradFunctor); \
__macro(reduce_mean, MeanFunctor, MeanGradFunctor); \
__macro(reduce_max, MaxFunctor, MaxOrMinGradFunctor); \
__macro(reduce_min, MinFunctor, MaxOrMinGradFunctor);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册