From ead7059bf93e362978c8463350ad9551db43b2b9 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 16 Jan 2018 14:35:26 +0800 Subject: [PATCH] Refine code --- paddle/operators/elementwise_max_op.h | 34 +------------------ paddle/operators/elementwise_min_op.h | 34 +------------------ paddle/operators/elementwise_op_function.h | 38 ++++++++++++++++++++++ 3 files changed, 40 insertions(+), 66 deletions(-) diff --git a/paddle/operators/elementwise_max_op.h b/paddle/operators/elementwise_max_op.h index 92152f7cb6e..255728e8e62 100644 --- a/paddle/operators/elementwise_max_op.h +++ b/paddle/operators/elementwise_max_op.h @@ -28,39 +28,7 @@ template class ElementwiseMaxKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - using Tensor = framework::Tensor; - - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); - z->mutable_data(ctx.GetPlace()); - TransformFunctor, T, DeviceContext> functor( - x, y, z, ctx.template device_context(), MaxFunctor()); - - auto x_dims = x->dims(); - auto y_dims = y->dims(); - PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), - "Rank of first input must >= rank of second input."); - - if (x_dims == y_dims) { - functor.Run(); - return; - } - - int axis = ctx.Attr("axis"); - axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); - PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), - "Axis should be in range [0, x_dims)"); - - int pre, n, post; - get_mid_dims(x_dims, y_dims, axis, pre, n, post); - if (post == 1) { - functor.RunRowWise(n, pre); - return; - } else { - functor.RunMidWise(n, pre, post); - return; - } + ElementwiseComputeEx, DeviceContext, T>(ctx); } }; diff --git a/paddle/operators/elementwise_min_op.h b/paddle/operators/elementwise_min_op.h index 53b7f59fa07..e6627a0f1bb 100644 --- a/paddle/operators/elementwise_min_op.h +++ b/paddle/operators/elementwise_min_op.h @@ -28,39 +28,7 @@ template class ElementwiseMinKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - using Tensor = framework::Tensor; - - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); - z->mutable_data(ctx.GetPlace()); - TransformFunctor, T, DeviceContext> functor( - x, y, z, ctx.template device_context(), MinFunctor()); - - auto x_dims = x->dims(); - auto y_dims = y->dims(); - PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), - "Rank of first input must >= rank of second input."); - - if (x_dims == y_dims) { - functor.Run(); - return; - } - - int axis = ctx.Attr("axis"); - axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); - PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), - "Axis should be in range [0, x_dims)"); - - int pre, n, post; - get_mid_dims(x_dims, y_dims, axis, pre, n, post); - if (post == 1) { - functor.RunRowWise(n, pre); - return; - } else { - functor.RunMidWise(n, pre, post); - return; - } + ElementwiseComputeEx, DeviceContext, T>(ctx); } }; diff --git a/paddle/operators/elementwise_op_function.h b/paddle/operators/elementwise_op_function.h index 0c75276b031..be11d5cc9de 100644 --- a/paddle/operators/elementwise_op_function.h +++ b/paddle/operators/elementwise_op_function.h @@ -356,5 +356,43 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) { return; } } + +template +void ElementwiseComputeEx(const framework::ExecutionContext& ctx) { + using Tensor = framework::Tensor; + + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); + z->mutable_data(ctx.GetPlace()); + TransformFunctor functor( + x, y, z, ctx.template device_context(), Functor()); + + auto x_dims = x->dims(); + auto y_dims = y->dims(); + PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), + "Rank of first input must >= rank of second input."); + + if (x_dims == y_dims) { + functor.Run(); + return; + } + + int axis = ctx.Attr("axis"); + axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); + PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), + "Axis should be in range [0, x_dims)"); + + int pre, n, post; + get_mid_dims(x_dims, y_dims, axis, pre, n, post); + if (post == 1) { + functor.RunRowWise(n, pre); + return; + } else { + functor.RunMidWise(n, pre, post); + return; + } +} + } // namespace operators } // namespace paddle -- GitLab