提交 37671ac0 编写于 作者: C chengduoZH

follow comments

上级 9e244a8c
...@@ -60,12 +60,13 @@ inline void get_mid_dims(const framework::DDim& x_dims, ...@@ -60,12 +60,13 @@ inline void get_mid_dims(const framework::DDim& x_dims,
} }
template <typename T, typename Place> template <typename T, typename Place>
struct RowwiseTransformIterator; class RowwiseTransformIterator;
template <typename T, typename Place> template <typename T, typename Place>
struct MidWiseTransformIterator; class MidWiseTransformIterator;
template <typename T> template <typename T>
struct RowwiseTransformIterator<T, platform::CPUPlace> { class RowwiseTransformIterator<T, platform::CPUPlace> {
public:
RowwiseTransformIterator(const T* ptr, int n) : ptr_(ptr), i_(0), n_(n) {} RowwiseTransformIterator(const T* ptr, int n) : ptr_(ptr), i_(0), n_(n) {}
RowwiseTransformIterator<T, platform::CPUPlace>& operator++() { RowwiseTransformIterator<T, platform::CPUPlace>& operator++() {
...@@ -86,13 +87,15 @@ struct RowwiseTransformIterator<T, platform::CPUPlace> { ...@@ -86,13 +87,15 @@ struct RowwiseTransformIterator<T, platform::CPUPlace> {
const T& operator*() { return ptr_[i_]; } const T& operator*() { return ptr_[i_]; }
private:
const T* ptr_; const T* ptr_;
int i_; int i_;
int64_t n_; int64_t n_;
}; };
template <typename T> template <typename T>
struct MidWiseTransformIterator<T, platform::CPUPlace> { class MidWiseTransformIterator<T, platform::CPUPlace> {
public:
MidWiseTransformIterator(const T* ptr, int n, int post) MidWiseTransformIterator(const T* ptr, int n, int post)
: ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {} : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}
...@@ -113,6 +116,7 @@ struct MidWiseTransformIterator<T, platform::CPUPlace> { ...@@ -113,6 +116,7 @@ struct MidWiseTransformIterator<T, platform::CPUPlace> {
const T& operator*() { return ptr_[i_]; } const T& operator*() { return ptr_[i_]; }
private:
const T* ptr_; const T* ptr_;
int i_; int i_;
int64_t j_; int64_t j_;
...@@ -122,7 +126,7 @@ struct MidWiseTransformIterator<T, platform::CPUPlace> { ...@@ -122,7 +126,7 @@ struct MidWiseTransformIterator<T, platform::CPUPlace> {
#ifdef __NVCC__ #ifdef __NVCC__
template <typename T> template <typename T>
struct RowwiseTransformIterator<T, platform::GPUPlace> class RowwiseTransformIterator<T, platform::GPUPlace>
: public thrust::iterator_adaptor< : public thrust::iterator_adaptor<
RowwiseTransformIterator<T, platform::GPUPlace>, const T*> { RowwiseTransformIterator<T, platform::GPUPlace>, const T*> {
public: public:
...@@ -142,7 +146,7 @@ struct RowwiseTransformIterator<T, platform::GPUPlace> ...@@ -142,7 +146,7 @@ struct RowwiseTransformIterator<T, platform::GPUPlace>
}; };
template <typename T> template <typename T>
struct MidWiseTransformIterator<T, platform::GPUPlace> class MidWiseTransformIterator<T, platform::GPUPlace>
: public thrust::iterator_adaptor< : public thrust::iterator_adaptor<
MidWiseTransformIterator<T, platform::GPUPlace>, const T*> { MidWiseTransformIterator<T, platform::GPUPlace>, const T*> {
public: public:
...@@ -164,7 +168,8 @@ struct MidWiseTransformIterator<T, platform::GPUPlace> ...@@ -164,7 +168,8 @@ struct MidWiseTransformIterator<T, platform::GPUPlace>
#endif #endif
template <typename Functor, typename T, typename Place> template <typename Functor, typename T, typename Place>
struct TransformFunctor { class TransformFunctor {
public:
TransformFunctor(const framework::Tensor* x, const framework::Tensor* y, TransformFunctor(const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z, const platform::DeviceContext& ctx, framework::Tensor* z, const platform::DeviceContext& ctx,
Functor func) Functor func)
...@@ -192,6 +197,7 @@ struct TransformFunctor { ...@@ -192,6 +197,7 @@ struct TransformFunctor {
z_, func_); z_, func_);
} }
private:
const T* x_; const T* x_;
const T* y_; const T* y_;
T* z_; T* z_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册