diff --git a/paddle/operators/elementwise_max_op.h b/paddle/operators/elementwise_max_op.h index 92152f7cb6ed49ff6a749a98fe5d6410dccc5ded..255728e8e620665a7de225b228c19d6c510da1c8 100644 --- a/paddle/operators/elementwise_max_op.h +++ b/paddle/operators/elementwise_max_op.h @@ -28,39 +28,7 @@ template class ElementwiseMaxKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - using Tensor = framework::Tensor; - - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); - z->mutable_data(ctx.GetPlace()); - TransformFunctor, T, DeviceContext> functor( - x, y, z, ctx.template device_context(), MaxFunctor()); - - auto x_dims = x->dims(); - auto y_dims = y->dims(); - PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), - "Rank of first input must >= rank of second input."); - - if (x_dims == y_dims) { - functor.Run(); - return; - } - - int axis = ctx.Attr("axis"); - axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); - PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), - "Axis should be in range [0, x_dims)"); - - int pre, n, post; - get_mid_dims(x_dims, y_dims, axis, pre, n, post); - if (post == 1) { - functor.RunRowWise(n, pre); - return; - } else { - functor.RunMidWise(n, pre, post); - return; - } + ElementwiseComputeEx, DeviceContext, T>(ctx); } }; diff --git a/paddle/operators/elementwise_min_op.h b/paddle/operators/elementwise_min_op.h index 53b7f59fa07cd93ee75ef667de21f277df63701b..e6627a0f1bb468c8e4661b83489cb964b72dddb0 100644 --- a/paddle/operators/elementwise_min_op.h +++ b/paddle/operators/elementwise_min_op.h @@ -28,39 +28,7 @@ template class ElementwiseMinKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - using Tensor = framework::Tensor; - - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); - z->mutable_data(ctx.GetPlace()); - TransformFunctor, T, DeviceContext> functor( - x, y, z, ctx.template device_context(), MinFunctor()); - - auto x_dims = x->dims(); - auto y_dims = y->dims(); - PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), - "Rank of first input must >= rank of second input."); - - if (x_dims == y_dims) { - functor.Run(); - return; - } - - int axis = ctx.Attr("axis"); - axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); - PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), - "Axis should be in range [0, x_dims)"); - - int pre, n, post; - get_mid_dims(x_dims, y_dims, axis, pre, n, post); - if (post == 1) { - functor.RunRowWise(n, pre); - return; - } else { - functor.RunMidWise(n, pre, post); - return; - } + ElementwiseComputeEx, DeviceContext, T>(ctx); } }; diff --git a/paddle/operators/elementwise_op_function.h b/paddle/operators/elementwise_op_function.h index 0c75276b03140473cee4b57a4022ff3c6989ab4c..be11d5cc9de0c73dd5b3fa127c9f043b6b5d3972 100644 --- a/paddle/operators/elementwise_op_function.h +++ b/paddle/operators/elementwise_op_function.h @@ -356,5 +356,43 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) { return; } } + +template +void ElementwiseComputeEx(const framework::ExecutionContext& ctx) { + using Tensor = framework::Tensor; + + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); + z->mutable_data(ctx.GetPlace()); + TransformFunctor functor( + x, y, z, ctx.template device_context(), Functor()); + + auto x_dims = x->dims(); + auto y_dims = y->dims(); + PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), + "Rank of first input must >= rank of second input."); + + if (x_dims == y_dims) { + functor.Run(); + return; + } + + int axis = ctx.Attr("axis"); + axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); + PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), + "Axis should be in range [0, x_dims)"); + + int pre, n, post; + get_mid_dims(x_dims, y_dims, axis, pre, n, post); + if (post == 1) { + functor.RunRowWise(n, pre); + return; + } else { + functor.RunMidWise(n, pre, post); + return; + } +} + } // namespace operators } // namespace paddle