diff --git a/paddle/operators/elementwise_op_function.h b/paddle/operators/elementwise_op_function.h index 22b96b93121a1d88e955aa71e909bf2f6c572447..09ab42b501b3f06de33db5b6aaf96abbbeb11d84 100644 --- a/paddle/operators/elementwise_op_function.h +++ b/paddle/operators/elementwise_op_function.h @@ -18,6 +18,10 @@ #include "paddle/framework/operator.h" #include "paddle/platform/transform.h" +#ifdef __NVCC__ +#include +#endif + #include "paddle/operators/math/math_function.h" namespace paddle { @@ -74,12 +78,12 @@ struct RowwiseTransformIterator { bool operator==( const RowwiseTransformIterator& rhs) const { - return &(this->operator*()) == &(*rhs); + return (ptr_ + i_) == &(*rhs); } bool operator!=( const RowwiseTransformIterator& rhs) const { - return &(this->operator*()) &= &(*rhs); + return (ptr_ + i_) &= &(*rhs); } const T& operator*() { return ptr_[i_]; } @@ -108,12 +112,12 @@ struct MidWiseTransformIterator { bool operator==( const MidWiseTransformIterator& rhs) const { - return &(this->operator*()) == &(*rhs); + return (ptr_ + i_) == &(*rhs); } bool operator!=( const MidWiseTransformIterator& rhs) const { - return &(this->operator*()) &= &(*rhs); + return (ptr_ + i_) &= &(*rhs); } const T& operator*() { return ptr_[i_]; } @@ -125,6 +129,49 @@ struct MidWiseTransformIterator { int post_; }; +#ifdef __NVCC__ +template +struct RowwiseTransformIterator + : public thrust::iterator_adaptor< + RowwiseTransformIterator, const T*> { + public: + typedef thrust::iterator_adaptor< + RowwiseTransformIterator, const T*> + super_t; + __host__ __device__ RowwiseTransformIterator(const T* x, int n) + : super_t(x), begin_(x), n_(n){}; + friend class thrust::iterator_core_access; + + private: + unsigned int n_; + const T* begin_; + __host__ __device__ typename super_t::reference dereference() const { + return *(begin_ + (this->base() - begin_) % n_); + } +}; + +template +struct MidWiseTransformIterator + : public thrust::iterator_adaptor< + MidWiseTransformIterator, const T*> { + public: + typedef thrust::iterator_adaptor< + MidWiseTransformIterator, const T*> + super_t; + __host__ __device__ MidWiseTransformIterator(const T* x, int n, int post) + : super_t(x), begin_(x), n_(n), post_(post){}; + friend class thrust::iterator_core_access; + + private: + unsigned int post_; + unsigned int n_; + const T* begin_; + __host__ __device__ typename super_t::reference dereference() const { + return *(begin_ + (((this->base() - begin_) / post_) % n_)); + } +}; +#endif + template struct TransformFunctor { TransformFunctor(const framework::Tensor* x, const framework::Tensor* y,