未验证 提交 36444461 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #6229 from chengduoZH/profiling/updata_elementwise_op

[Profiling]  Update elementwise op
...@@ -19,11 +19,48 @@ ...@@ -19,11 +19,48 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T>
struct AddFunctor {
HOSTDEVICE T operator()(T a, T b) const { return a + b; }
};
template <typename Place, typename T> template <typename Place, typename T>
class ElementwiseAddKernel : public framework::OpKernel<T> { class ElementwiseAddKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseCompute<EigenAddFunctor, Place, T>(ctx); using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
TransformFunctor<AddFunctor<T>, T, Place> functor(
x, y, z, ctx.device_context(), AddFunctor<T>());
auto x_dims = x->dims();
auto y_dims = y->dims();
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input.");
if (x_dims == y_dims) {
functor.Run();
return;
}
int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
"Axis should be in range [0, x_dims)");
int pre, n, post;
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
if (post == 1) {
functor.RunRowWise(n, pre);
return;
} else {
functor.RunMidWise(n, pre, post);
return;
}
} }
}; };
......
...@@ -16,6 +16,11 @@ ...@@ -16,6 +16,11 @@
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include "paddle/platform/transform.h"
#ifdef __NVCC__
#include <thrust/iterator/iterator_adaptor.h>
#endif
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
...@@ -54,6 +59,153 @@ inline void get_mid_dims(const framework::DDim& x_dims, ...@@ -54,6 +59,153 @@ inline void get_mid_dims(const framework::DDim& x_dims,
} }
} }
template <typename T, typename Place>
class RowwiseTransformIterator;
template <typename T, typename Place>
class MidWiseTransformIterator;
template <typename T>
class RowwiseTransformIterator<T, platform::CPUPlace> {
public:
RowwiseTransformIterator(const T* ptr, int n) : ptr_(ptr), i_(0), n_(n) {}
RowwiseTransformIterator<T, platform::CPUPlace>& operator++() {
++i_;
i_ %= n_;
return *this;
}
bool operator==(
const RowwiseTransformIterator<T, platform::CPUPlace>& rhs) const {
return (ptr_ + i_) == &(*rhs);
}
bool operator!=(
const RowwiseTransformIterator<T, platform::CPUPlace>& rhs) const {
return (ptr_ + i_) != &(*rhs);
}
const T& operator*() { return ptr_[i_]; }
private:
const T* ptr_;
int i_;
int64_t n_;
};
template <typename T>
class MidWiseTransformIterator<T, platform::CPUPlace> {
public:
MidWiseTransformIterator(const T* ptr, int n, int post)
: ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}
MidWiseTransformIterator<T, platform::CPUPlace>& operator++() {
i_ = (++j_ / post_) % n_;
return *this;
}
bool operator==(
const MidWiseTransformIterator<T, platform::CPUPlace>& rhs) const {
return (ptr_ + i_) == &(*rhs);
}
bool operator!=(
const MidWiseTransformIterator<T, platform::CPUPlace>& rhs) const {
return (ptr_ + i_) != &(*rhs);
}
const T& operator*() { return ptr_[i_]; }
private:
const T* ptr_;
int i_;
int64_t j_;
int64_t n_;
int post_;
};
#ifdef __NVCC__
template <typename T>
class RowwiseTransformIterator<T, platform::GPUPlace>
: public thrust::iterator_adaptor<
RowwiseTransformIterator<T, platform::GPUPlace>, const T*> {
public:
typedef thrust::iterator_adaptor<
RowwiseTransformIterator<T, platform::GPUPlace>, const T*>
super_t;
HOSTDEVICE RowwiseTransformIterator(const T* x, int n)
: super_t(x), begin_(x), n_(n){};
friend class thrust::iterator_core_access;
private:
unsigned int n_;
const T* begin_;
HOSTDEVICE typename super_t::reference dereference() const {
return *(begin_ + (this->base() - begin_) % n_);
}
};
template <typename T>
class MidWiseTransformIterator<T, platform::GPUPlace>
: public thrust::iterator_adaptor<
MidWiseTransformIterator<T, platform::GPUPlace>, const T*> {
public:
typedef thrust::iterator_adaptor<
MidWiseTransformIterator<T, platform::GPUPlace>, const T*>
super_t;
HOSTDEVICE MidWiseTransformIterator(const T* x, int n, int post)
: super_t(x), begin_(x), n_(n), post_(post){};
friend class thrust::iterator_core_access;
private:
unsigned int post_;
unsigned int n_;
const T* begin_;
HOSTDEVICE typename super_t::reference dereference() const {
return *(begin_ + (((this->base() - begin_) / post_) % n_));
}
};
#endif
template <typename Functor, typename T, typename Place>
class TransformFunctor {
public:
TransformFunctor(const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z, const platform::DeviceContext& ctx,
Functor func)
: x_(x->data<T>()),
y_(y->data<T>()),
z_(z->mutable_data<T>(ctx.GetPlace())),
nx_(x->numel()),
ctx_(ctx),
func_(func) {}
inline void Run() const {
platform::Transform<Place> trans;
trans(ctx_, x_, x_ + nx_, y_, z_, func_);
}
inline void RunRowWise(int n, int pre) const {
platform::Transform<Place> trans;
trans(ctx_, x_, x_ + nx_, RowwiseTransformIterator<T, Place>(y_, n), z_,
func_);
}
inline void RunMidWise(int n, int pre, int post) const {
platform::Transform<Place> trans;
trans(ctx_, x_, x_ + nx_, MidWiseTransformIterator<T, Place>(y_, n, post),
z_, func_);
}
private:
const T* x_;
const T* y_;
T* z_;
int64_t nx_;
const platform::DeviceContext& ctx_;
Functor func_;
};
#define EIGEN_FUNCTOR(name, eigen_op) \ #define EIGEN_FUNCTOR(name, eigen_op) \
struct Eigen##name##Functor { \ struct Eigen##name##Functor { \
template <typename Place, typename T> \ template <typename Place, typename T> \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册