diff --git a/paddle/operators/elementwise_op_function.h b/paddle/operators/elementwise_op_function.h index 6d849bff49909686adcd344ebd21dcf391ec8915..ec448a9e9564046aeda1e6d9c6609ac1ecc137cd 100644 --- a/paddle/operators/elementwise_op_function.h +++ b/paddle/operators/elementwise_op_function.h @@ -60,12 +60,13 @@ inline void get_mid_dims(const framework::DDim& x_dims, } template -struct RowwiseTransformIterator; +class RowwiseTransformIterator; template -struct MidWiseTransformIterator; +class MidWiseTransformIterator; template -struct RowwiseTransformIterator { +class RowwiseTransformIterator { + public: RowwiseTransformIterator(const T* ptr, int n) : ptr_(ptr), i_(0), n_(n) {} RowwiseTransformIterator& operator++() { @@ -86,13 +87,15 @@ struct RowwiseTransformIterator { const T& operator*() { return ptr_[i_]; } + private: const T* ptr_; int i_; int64_t n_; }; template -struct MidWiseTransformIterator { +class MidWiseTransformIterator { + public: MidWiseTransformIterator(const T* ptr, int n, int post) : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {} @@ -113,6 +116,7 @@ struct MidWiseTransformIterator { const T& operator*() { return ptr_[i_]; } + private: const T* ptr_; int i_; int64_t j_; @@ -122,7 +126,7 @@ struct MidWiseTransformIterator { #ifdef __NVCC__ template -struct RowwiseTransformIterator +class RowwiseTransformIterator : public thrust::iterator_adaptor< RowwiseTransformIterator, const T*> { public: @@ -142,7 +146,7 @@ struct RowwiseTransformIterator }; template -struct MidWiseTransformIterator +class MidWiseTransformIterator : public thrust::iterator_adaptor< MidWiseTransformIterator, const T*> { public: @@ -164,7 +168,8 @@ struct MidWiseTransformIterator #endif template -struct TransformFunctor { +class TransformFunctor { + public: TransformFunctor(const framework::Tensor* x, const framework::Tensor* y, framework::Tensor* z, const platform::DeviceContext& ctx, Functor func) @@ -192,6 +197,7 @@ struct TransformFunctor { z_, func_); } + private: const T* x_; const T* y_; T* z_;