diff --git a/paddle/operators/elementwise_add_op.h b/paddle/operators/elementwise_add_op.h index f04fe3ec6069ab1bf227be6a3a5c10ee908e4824..3a198c167e4cb0e9c106038bf2ac64ba7a680421 100644 --- a/paddle/operators/elementwise_add_op.h +++ b/paddle/operators/elementwise_add_op.h @@ -19,11 +19,48 @@ namespace paddle { namespace operators { +template +struct AddFunctor { + HOSTDEVICE T operator()(T a, T b) const { return a + b; } +}; + template class ElementwiseAddKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - ElementwiseCompute(ctx); + using Tensor = framework::Tensor; + + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); + z->mutable_data(ctx.GetPlace()); + TransformFunctor, T, Place> functor( + x, y, z, ctx.device_context(), AddFunctor()); + + auto x_dims = x->dims(); + auto y_dims = y->dims(); + PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), + "Rank of first input must >= rank of second input."); + + if (x_dims == y_dims) { + functor.Run(); + return; + } + + int axis = ctx.Attr("axis"); + axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); + PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), + "Axis should be in range [0, x_dims)"); + + int pre, n, post; + get_mid_dims(x_dims, y_dims, axis, pre, n, post); + if (post == 1) { + functor.RunRowWise(n, pre); + return; + } else { + functor.RunMidWise(n, pre, post); + return; + } } }; diff --git a/paddle/operators/elementwise_op_function.h b/paddle/operators/elementwise_op_function.h index 8aa35b2c466785c8749739635fcd1c2b19292f3e..ec448a9e9564046aeda1e6d9c6609ac1ecc137cd 100644 --- a/paddle/operators/elementwise_op_function.h +++ b/paddle/operators/elementwise_op_function.h @@ -16,6 +16,11 @@ #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/framework/operator.h" +#include "paddle/platform/transform.h" + +#ifdef __NVCC__ +#include +#endif #include "paddle/operators/math/math_function.h" @@ -54,6 +59,153 @@ inline void get_mid_dims(const framework::DDim& x_dims, } } +template +class RowwiseTransformIterator; +template +class MidWiseTransformIterator; + +template +class RowwiseTransformIterator { + public: + RowwiseTransformIterator(const T* ptr, int n) : ptr_(ptr), i_(0), n_(n) {} + + RowwiseTransformIterator& operator++() { + ++i_; + i_ %= n_; + return *this; + } + + bool operator==( + const RowwiseTransformIterator& rhs) const { + return (ptr_ + i_) == &(*rhs); + } + + bool operator!=( + const RowwiseTransformIterator& rhs) const { + return (ptr_ + i_) != &(*rhs); + } + + const T& operator*() { return ptr_[i_]; } + + private: + const T* ptr_; + int i_; + int64_t n_; +}; + +template +class MidWiseTransformIterator { + public: + MidWiseTransformIterator(const T* ptr, int n, int post) + : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {} + + MidWiseTransformIterator& operator++() { + i_ = (++j_ / post_) % n_; + return *this; + } + + bool operator==( + const MidWiseTransformIterator& rhs) const { + return (ptr_ + i_) == &(*rhs); + } + + bool operator!=( + const MidWiseTransformIterator& rhs) const { + return (ptr_ + i_) != &(*rhs); + } + + const T& operator*() { return ptr_[i_]; } + + private: + const T* ptr_; + int i_; + int64_t j_; + int64_t n_; + int post_; +}; + +#ifdef __NVCC__ +template +class RowwiseTransformIterator + : public thrust::iterator_adaptor< + RowwiseTransformIterator, const T*> { + public: + typedef thrust::iterator_adaptor< + RowwiseTransformIterator, const T*> + super_t; + HOSTDEVICE RowwiseTransformIterator(const T* x, int n) + : super_t(x), begin_(x), n_(n){}; + friend class thrust::iterator_core_access; + + private: + unsigned int n_; + const T* begin_; + HOSTDEVICE typename super_t::reference dereference() const { + return *(begin_ + (this->base() - begin_) % n_); + } +}; + +template +class MidWiseTransformIterator + : public thrust::iterator_adaptor< + MidWiseTransformIterator, const T*> { + public: + typedef thrust::iterator_adaptor< + MidWiseTransformIterator, const T*> + super_t; + HOSTDEVICE MidWiseTransformIterator(const T* x, int n, int post) + : super_t(x), begin_(x), n_(n), post_(post){}; + friend class thrust::iterator_core_access; + + private: + unsigned int post_; + unsigned int n_; + const T* begin_; + HOSTDEVICE typename super_t::reference dereference() const { + return *(begin_ + (((this->base() - begin_) / post_) % n_)); + } +}; +#endif + +template +class TransformFunctor { + public: + TransformFunctor(const framework::Tensor* x, const framework::Tensor* y, + framework::Tensor* z, const platform::DeviceContext& ctx, + Functor func) + : x_(x->data()), + y_(y->data()), + z_(z->mutable_data(ctx.GetPlace())), + nx_(x->numel()), + ctx_(ctx), + func_(func) {} + + inline void Run() const { + platform::Transform trans; + trans(ctx_, x_, x_ + nx_, y_, z_, func_); + } + + inline void RunRowWise(int n, int pre) const { + platform::Transform trans; + trans(ctx_, x_, x_ + nx_, RowwiseTransformIterator(y_, n), z_, + func_); + } + + inline void RunMidWise(int n, int pre, int post) const { + platform::Transform trans; + trans(ctx_, x_, x_ + nx_, MidWiseTransformIterator(y_, n, post), + z_, func_); + } + + private: + const T* x_; + const T* y_; + T* z_; + int64_t nx_; + const platform::DeviceContext& ctx_; + Functor func_; +}; + #define EIGEN_FUNCTOR(name, eigen_op) \ struct Eigen##name##Functor { \ template \