提交 9e244a8c 编写于 作者: C chengduoZH

follow comments

上级 54f09620
...@@ -34,8 +34,8 @@ class ElementwiseAddKernel : public framework::OpKernel<T> { ...@@ -34,8 +34,8 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
auto* y = ctx.Input<Tensor>("Y"); auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out"); auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
TransformFunctor<AddFunctor<T>, T, Place> functor(x, y, z, ctx, TransformFunctor<AddFunctor<T>, T, Place> functor(
AddFunctor<T>()); x, y, z, ctx.device_context(), AddFunctor<T>());
auto x_dims = x->dims(); auto x_dims = x->dims();
auto y_dims = y->dims(); auto y_dims = y->dims();
......
...@@ -81,7 +81,7 @@ struct RowwiseTransformIterator<T, platform::CPUPlace> { ...@@ -81,7 +81,7 @@ struct RowwiseTransformIterator<T, platform::CPUPlace> {
bool operator!=( bool operator!=(
const RowwiseTransformIterator<T, platform::CPUPlace>& rhs) const { const RowwiseTransformIterator<T, platform::CPUPlace>& rhs) const {
return (ptr_ + i_) &= &(*rhs); return (ptr_ + i_) != &(*rhs);
} }
const T& operator*() { return ptr_[i_]; } const T& operator*() { return ptr_[i_]; }
...@@ -97,7 +97,7 @@ struct MidWiseTransformIterator<T, platform::CPUPlace> { ...@@ -97,7 +97,7 @@ struct MidWiseTransformIterator<T, platform::CPUPlace> {
: ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {} : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}
MidWiseTransformIterator<T, platform::CPUPlace>& operator++() { MidWiseTransformIterator<T, platform::CPUPlace>& operator++() {
i_ = ++j_ / post_ % n_; i_ = (++j_ / post_) % n_;
return *this; return *this;
} }
...@@ -108,7 +108,7 @@ struct MidWiseTransformIterator<T, platform::CPUPlace> { ...@@ -108,7 +108,7 @@ struct MidWiseTransformIterator<T, platform::CPUPlace> {
bool operator!=( bool operator!=(
const MidWiseTransformIterator<T, platform::CPUPlace>& rhs) const { const MidWiseTransformIterator<T, platform::CPUPlace>& rhs) const {
return (ptr_ + i_) &= &(*rhs); return (ptr_ + i_) != &(*rhs);
} }
const T& operator*() { return ptr_[i_]; } const T& operator*() { return ptr_[i_]; }
...@@ -129,14 +129,14 @@ struct RowwiseTransformIterator<T, platform::GPUPlace> ...@@ -129,14 +129,14 @@ struct RowwiseTransformIterator<T, platform::GPUPlace>
typedef thrust::iterator_adaptor< typedef thrust::iterator_adaptor<
RowwiseTransformIterator<T, platform::GPUPlace>, const T*> RowwiseTransformIterator<T, platform::GPUPlace>, const T*>
super_t; super_t;
__host__ __device__ RowwiseTransformIterator(const T* x, int n) HOSTDEVICE RowwiseTransformIterator(const T* x, int n)
: super_t(x), begin_(x), n_(n){}; : super_t(x), begin_(x), n_(n){};
friend class thrust::iterator_core_access; friend class thrust::iterator_core_access;
private: private:
unsigned int n_; unsigned int n_;
const T* begin_; const T* begin_;
__host__ __device__ typename super_t::reference dereference() const { HOSTDEVICE typename super_t::reference dereference() const {
return *(begin_ + (this->base() - begin_) % n_); return *(begin_ + (this->base() - begin_) % n_);
} }
}; };
...@@ -149,7 +149,7 @@ struct MidWiseTransformIterator<T, platform::GPUPlace> ...@@ -149,7 +149,7 @@ struct MidWiseTransformIterator<T, platform::GPUPlace>
typedef thrust::iterator_adaptor< typedef thrust::iterator_adaptor<
MidWiseTransformIterator<T, platform::GPUPlace>, const T*> MidWiseTransformIterator<T, platform::GPUPlace>, const T*>
super_t; super_t;
__host__ __device__ MidWiseTransformIterator(const T* x, int n, int post) HOSTDEVICE MidWiseTransformIterator(const T* x, int n, int post)
: super_t(x), begin_(x), n_(n), post_(post){}; : super_t(x), begin_(x), n_(n), post_(post){};
friend class thrust::iterator_core_access; friend class thrust::iterator_core_access;
...@@ -157,7 +157,7 @@ struct MidWiseTransformIterator<T, platform::GPUPlace> ...@@ -157,7 +157,7 @@ struct MidWiseTransformIterator<T, platform::GPUPlace>
unsigned int post_; unsigned int post_;
unsigned int n_; unsigned int n_;
const T* begin_; const T* begin_;
__host__ __device__ typename super_t::reference dereference() const { HOSTDEVICE typename super_t::reference dereference() const {
return *(begin_ + (((this->base() - begin_) / post_) % n_)); return *(begin_ + (((this->base() - begin_) / post_) % n_));
} }
}; };
...@@ -166,7 +166,7 @@ struct MidWiseTransformIterator<T, platform::GPUPlace> ...@@ -166,7 +166,7 @@ struct MidWiseTransformIterator<T, platform::GPUPlace>
template <typename Functor, typename T, typename Place> template <typename Functor, typename T, typename Place>
struct TransformFunctor { struct TransformFunctor {
TransformFunctor(const framework::Tensor* x, const framework::Tensor* y, TransformFunctor(const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z, const framework::ExecutionContext& ctx, framework::Tensor* z, const platform::DeviceContext& ctx,
Functor func) Functor func)
: x_(x->data<T>()), : x_(x->data<T>()),
y_(y->data<T>()), y_(y->data<T>()),
...@@ -177,26 +177,26 @@ struct TransformFunctor { ...@@ -177,26 +177,26 @@ struct TransformFunctor {
inline void Run() const { inline void Run() const {
platform::Transform<Place> trans; platform::Transform<Place> trans;
trans(ctx_.device_context(), x_, x_ + nx_, y_, z_, func_); trans(ctx_, x_, x_ + nx_, y_, z_, func_);
} }
inline void RunRowWise(int n, int pre) const { inline void RunRowWise(int n, int pre) const {
platform::Transform<Place> trans; platform::Transform<Place> trans;
trans(ctx_.device_context(), x_, x_ + nx_, trans(ctx_, x_, x_ + nx_, RowwiseTransformIterator<T, Place>(y_, n), z_,
RowwiseTransformIterator<T, Place>(y_, n), z_, func_); func_);
} }
inline void RunMidWise(int n, int pre, int post) const { inline void RunMidWise(int n, int pre, int post) const {
platform::Transform<Place> trans; platform::Transform<Place> trans;
trans(ctx_.device_context(), x_, x_ + nx_, trans(ctx_, x_, x_ + nx_, MidWiseTransformIterator<T, Place>(y_, n, post),
MidWiseTransformIterator<T, Place>(y_, n, post), z_, func_); z_, func_);
} }
const T* x_; const T* x_;
const T* y_; const T* y_;
T* z_; T* z_;
int64_t nx_; int64_t nx_;
const framework::ExecutionContext& ctx_; const platform::DeviceContext& ctx_;
Functor func_; Functor func_;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册