From fbbfe8b8594960934529330dc1321e1fdc6c2a6d Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 4 Dec 2017 14:21:13 +0800 Subject: [PATCH] code refine --- paddle/operators/elementwise_add_op.h | 39 +++++++- paddle/operators/elementwise_op_function.h | 108 +++++++++++++++++++++ 2 files changed, 146 insertions(+), 1 deletion(-) diff --git a/paddle/operators/elementwise_add_op.h b/paddle/operators/elementwise_add_op.h index f04fe3ec6..686d45573 100644 --- a/paddle/operators/elementwise_add_op.h +++ b/paddle/operators/elementwise_add_op.h @@ -19,11 +19,48 @@ namespace paddle { namespace operators { +template +struct AddFunctor { + HOSTDEVICE T operator()(T a, T b) const { return a + b; } +}; + template class ElementwiseAddKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - ElementwiseCompute(ctx); + using Tensor = framework::Tensor; + + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); + z->mutable_data(ctx.GetPlace()); + TransformFunctor, T, Place> functor(x, y, z, ctx, + AddFunctor()); + + auto x_dims = x->dims(); + auto y_dims = y->dims(); + PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), + "Rank of first input must >= rank of second input."); + + if (x_dims == y_dims) { + functor.Run(); + return; + } + + int axis = ctx.Attr("axis"); + axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); + PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), + "Axis should be in range [0, x_dims)"); + + int pre, n, post; + get_mid_dims(x_dims, y_dims, axis, pre, n, post); + if (post == 1) { + functor.RunRowWise(n, pre); + return; + } else { + functor.RunMidWise(n, pre, post); + return; + } } }; diff --git a/paddle/operators/elementwise_op_function.h b/paddle/operators/elementwise_op_function.h index 8aa35b2c4..22b96b931 100644 --- a/paddle/operators/elementwise_op_function.h +++ b/paddle/operators/elementwise_op_function.h @@ -16,6 +16,7 @@ #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/framework/operator.h" +#include "paddle/platform/transform.h" #include "paddle/operators/math/math_function.h" @@ -54,6 +55,113 @@ inline void get_mid_dims(const framework::DDim& x_dims, } } +template +struct RowwiseTransformIterator; +template +struct MidWiseTransformIterator; + +template +struct RowwiseTransformIterator { + RowwiseTransformIterator(const T* ptr, int n) : ptr_(ptr), i_(0), n_(n) {} + + RowwiseTransformIterator& operator++() { + ++i_; + if (i_ == n_) { + i_ = 0; + } + return *this; + } + + bool operator==( + const RowwiseTransformIterator& rhs) const { + return &(this->operator*()) == &(*rhs); + } + + bool operator!=( + const RowwiseTransformIterator& rhs) const { + return &(this->operator*()) &= &(*rhs); + } + + const T& operator*() { return ptr_[i_]; } + + const T* ptr_; + int i_; + int n_; +}; + +template +struct MidWiseTransformIterator { + MidWiseTransformIterator(const T* ptr, int n, int post) + : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {} + + MidWiseTransformIterator& operator++() { + ++j_; + if (j_ == post_) { + j_ = 0; + ++i_; + if (i_ == n_) { + i_ = 0; + } + } + return *this; + } + + bool operator==( + const MidWiseTransformIterator& rhs) const { + return &(this->operator*()) == &(*rhs); + } + + bool operator!=( + const MidWiseTransformIterator& rhs) const { + return &(this->operator*()) &= &(*rhs); + } + + const T& operator*() { return ptr_[i_]; } + + const T* ptr_; + int i_; + int j_; + int n_; + int post_; +}; + +template +struct TransformFunctor { + TransformFunctor(const framework::Tensor* x, const framework::Tensor* y, + framework::Tensor* z, const framework::ExecutionContext& ctx, + Functor func) + : x_(x->data()), + y_(y->data()), + z_(z->mutable_data(ctx.GetPlace())), + nx_(x->numel()), + ctx_(ctx), + func_(func) {} + + inline void Run() const { + platform::Transform trans; + trans(ctx_.device_context(), x_, x_ + nx_, y_, z_, func_); + } + + inline void RunRowWise(int n, int pre) const { + platform::Transform trans; + trans(ctx_.device_context(), x_, x_ + nx_, + RowwiseTransformIterator(y_, n), z_, func_); + } + + inline void RunMidWise(int n, int pre, int post) const { + platform::Transform trans; + trans(ctx_.device_context(), x_, x_ + nx_, + MidWiseTransformIterator(y_, n, post), z_, func_); + } + + const T* x_; + const T* y_; + T* z_; + int64_t nx_; + const framework::ExecutionContext& ctx_; + Functor func_; +}; + #define EIGEN_FUNCTOR(name, eigen_op) \ struct Eigen##name##Functor { \ template \ -- GitLab