提交 488908e9 编写于 作者: C chengduoZH

refine cuda

上级 fbbfe8b8
...@@ -18,6 +18,10 @@ ...@@ -18,6 +18,10 @@
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include "paddle/platform/transform.h" #include "paddle/platform/transform.h"
#ifdef __NVCC__
#include <thrust/iterator/iterator_adaptor.h>
#endif
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
namespace paddle { namespace paddle {
...@@ -74,12 +78,12 @@ struct RowwiseTransformIterator<T, platform::CPUPlace> { ...@@ -74,12 +78,12 @@ struct RowwiseTransformIterator<T, platform::CPUPlace> {
bool operator==( bool operator==(
const RowwiseTransformIterator<T, platform::CPUPlace>& rhs) const { const RowwiseTransformIterator<T, platform::CPUPlace>& rhs) const {
return &(this->operator*()) == &(*rhs); return (ptr_ + i_) == &(*rhs);
} }
bool operator!=( bool operator!=(
const RowwiseTransformIterator<T, platform::CPUPlace>& rhs) const { const RowwiseTransformIterator<T, platform::CPUPlace>& rhs) const {
return &(this->operator*()) &= &(*rhs); return (ptr_ + i_) &= &(*rhs);
} }
const T& operator*() { return ptr_[i_]; } const T& operator*() { return ptr_[i_]; }
...@@ -108,12 +112,12 @@ struct MidWiseTransformIterator<T, platform::CPUPlace> { ...@@ -108,12 +112,12 @@ struct MidWiseTransformIterator<T, platform::CPUPlace> {
bool operator==( bool operator==(
const MidWiseTransformIterator<T, platform::CPUPlace>& rhs) const { const MidWiseTransformIterator<T, platform::CPUPlace>& rhs) const {
return &(this->operator*()) == &(*rhs); return (ptr_ + i_) == &(*rhs);
} }
bool operator!=( bool operator!=(
const MidWiseTransformIterator<T, platform::CPUPlace>& rhs) const { const MidWiseTransformIterator<T, platform::CPUPlace>& rhs) const {
return &(this->operator*()) &= &(*rhs); return (ptr_ + i_) &= &(*rhs);
} }
const T& operator*() { return ptr_[i_]; } const T& operator*() { return ptr_[i_]; }
...@@ -125,6 +129,49 @@ struct MidWiseTransformIterator<T, platform::CPUPlace> { ...@@ -125,6 +129,49 @@ struct MidWiseTransformIterator<T, platform::CPUPlace> {
int post_; int post_;
}; };
#ifdef __NVCC__
template <typename T>
struct RowwiseTransformIterator<T, platform::GPUPlace>
: public thrust::iterator_adaptor<
RowwiseTransformIterator<T, platform::GPUPlace>, const T*> {
public:
typedef thrust::iterator_adaptor<
RowwiseTransformIterator<T, platform::GPUPlace>, const T*>
super_t;
__host__ __device__ RowwiseTransformIterator(const T* x, int n)
: super_t(x), begin_(x), n_(n){};
friend class thrust::iterator_core_access;
private:
unsigned int n_;
const T* begin_;
__host__ __device__ typename super_t::reference dereference() const {
return *(begin_ + (this->base() - begin_) % n_);
}
};
template <typename T>
struct MidWiseTransformIterator<T, platform::GPUPlace>
: public thrust::iterator_adaptor<
MidWiseTransformIterator<T, platform::GPUPlace>, const T*> {
public:
typedef thrust::iterator_adaptor<
MidWiseTransformIterator<T, platform::GPUPlace>, const T*>
super_t;
__host__ __device__ MidWiseTransformIterator(const T* x, int n, int post)
: super_t(x), begin_(x), n_(n), post_(post){};
friend class thrust::iterator_core_access;
private:
unsigned int post_;
unsigned int n_;
const T* begin_;
__host__ __device__ typename super_t::reference dereference() const {
return *(begin_ + (((this->base() - begin_) / post_) % n_));
}
};
#endif
template <typename Functor, typename T, typename Place> template <typename Functor, typename T, typename Place>
struct TransformFunctor { struct TransformFunctor {
TransformFunctor(const framework::Tensor* x, const framework::Tensor* y, TransformFunctor(const framework::Tensor* x, const framework::Tensor* y,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册