From 9e244a8cbe8e0d089e0f3d402230a1d5f2ffcbb9 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 5 Dec 2017 10:59:43 +0800 Subject: [PATCH] follow comments --- paddle/operators/elementwise_add_op.h | 4 ++-- paddle/operators/elementwise_op_function.h | 28 +++++++++++----------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/paddle/operators/elementwise_add_op.h b/paddle/operators/elementwise_add_op.h index 686d45573d8..3a198c167e4 100644 --- a/paddle/operators/elementwise_add_op.h +++ b/paddle/operators/elementwise_add_op.h @@ -34,8 +34,8 @@ class ElementwiseAddKernel : public framework::OpKernel { auto* y = ctx.Input("Y"); auto* z = ctx.Output("Out"); z->mutable_data(ctx.GetPlace()); - TransformFunctor, T, Place> functor(x, y, z, ctx, - AddFunctor()); + TransformFunctor, T, Place> functor( + x, y, z, ctx.device_context(), AddFunctor()); auto x_dims = x->dims(); auto y_dims = y->dims(); diff --git a/paddle/operators/elementwise_op_function.h b/paddle/operators/elementwise_op_function.h index 14da42a786f..6d849bff499 100644 --- a/paddle/operators/elementwise_op_function.h +++ b/paddle/operators/elementwise_op_function.h @@ -81,7 +81,7 @@ struct RowwiseTransformIterator { bool operator!=( const RowwiseTransformIterator& rhs) const { - return (ptr_ + i_) &= &(*rhs); + return (ptr_ + i_) != &(*rhs); } const T& operator*() { return ptr_[i_]; } @@ -97,7 +97,7 @@ struct MidWiseTransformIterator { : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {} MidWiseTransformIterator& operator++() { - i_ = ++j_ / post_ % n_; + i_ = (++j_ / post_) % n_; return *this; } @@ -108,7 +108,7 @@ struct MidWiseTransformIterator { bool operator!=( const MidWiseTransformIterator& rhs) const { - return (ptr_ + i_) &= &(*rhs); + return (ptr_ + i_) != &(*rhs); } const T& operator*() { return ptr_[i_]; } @@ -129,14 +129,14 @@ struct RowwiseTransformIterator typedef thrust::iterator_adaptor< RowwiseTransformIterator, const T*> super_t; - __host__ __device__ RowwiseTransformIterator(const T* x, int n) + HOSTDEVICE RowwiseTransformIterator(const T* x, int n) : super_t(x), begin_(x), n_(n){}; friend class thrust::iterator_core_access; private: unsigned int n_; const T* begin_; - __host__ __device__ typename super_t::reference dereference() const { + HOSTDEVICE typename super_t::reference dereference() const { return *(begin_ + (this->base() - begin_) % n_); } }; @@ -149,7 +149,7 @@ struct MidWiseTransformIterator typedef thrust::iterator_adaptor< MidWiseTransformIterator, const T*> super_t; - __host__ __device__ MidWiseTransformIterator(const T* x, int n, int post) + HOSTDEVICE MidWiseTransformIterator(const T* x, int n, int post) : super_t(x), begin_(x), n_(n), post_(post){}; friend class thrust::iterator_core_access; @@ -157,7 +157,7 @@ struct MidWiseTransformIterator unsigned int post_; unsigned int n_; const T* begin_; - __host__ __device__ typename super_t::reference dereference() const { + HOSTDEVICE typename super_t::reference dereference() const { return *(begin_ + (((this->base() - begin_) / post_) % n_)); } }; @@ -166,7 +166,7 @@ struct MidWiseTransformIterator template struct TransformFunctor { TransformFunctor(const framework::Tensor* x, const framework::Tensor* y, - framework::Tensor* z, const framework::ExecutionContext& ctx, + framework::Tensor* z, const platform::DeviceContext& ctx, Functor func) : x_(x->data()), y_(y->data()), @@ -177,26 +177,26 @@ struct TransformFunctor { inline void Run() const { platform::Transform trans; - trans(ctx_.device_context(), x_, x_ + nx_, y_, z_, func_); + trans(ctx_, x_, x_ + nx_, y_, z_, func_); } inline void RunRowWise(int n, int pre) const { platform::Transform trans; - trans(ctx_.device_context(), x_, x_ + nx_, - RowwiseTransformIterator(y_, n), z_, func_); + trans(ctx_, x_, x_ + nx_, RowwiseTransformIterator(y_, n), z_, + func_); } inline void RunMidWise(int n, int pre, int post) const { platform::Transform trans; - trans(ctx_.device_context(), x_, x_ + nx_, - MidWiseTransformIterator(y_, n, post), z_, func_); + trans(ctx_, x_, x_ + nx_, MidWiseTransformIterator(y_, n, post), + z_, func_); } const T* x_; const T* y_; T* z_; int64_t nx_; - const framework::ExecutionContext& ctx_; + const platform::DeviceContext& ctx_; Functor func_; }; -- GitLab