未验证 提交 3bd1d22a 编写于 作者: C chengduo 提交者: GitHub

Enhance fused_elementwise_activation_op (#12837)

* Enhance the function of fused_elementwise_activation_op

* enhance unit test

* Clean Code And Add Doc

* Add compound functors

* Fix doc and enhance unit test

* define Dx and Dy for d_binary_func

* add mul_scale

* add mul_scale

* add elementwise_mul

* code refine

* code refine

* add doc

* add  AsIntermediate
上级 a615ad46
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <glog/logging.h> #include <glog/logging.h>
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
...@@ -46,9 +47,9 @@ namespace operators { ...@@ -46,9 +47,9 @@ namespace operators {
* pre=2*3, n=4*5, post=1 * pre=2*3, n=4*5, post=1
* x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1) * x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1)
*/ */
inline void get_mid_dims(const framework::DDim& x_dims, inline void get_mid_dims(const framework::DDim &x_dims,
const framework::DDim& y_dims, const int axis, const framework::DDim &y_dims, const int axis,
int* pre, int* n, int* post) { int *pre, int *n, int *post) {
*pre = 1; *pre = 1;
*n = 1; *n = 1;
*post = 1; *post = 1;
...@@ -68,7 +69,7 @@ inline void get_mid_dims(const framework::DDim& x_dims, ...@@ -68,7 +69,7 @@ inline void get_mid_dims(const framework::DDim& x_dims,
} }
inline framework::DDim trim_trailing_singular_dims( inline framework::DDim trim_trailing_singular_dims(
const framework::DDim& dims) { const framework::DDim &dims) {
// Remove trailing dimensions of size 1 for y // Remove trailing dimensions of size 1 for y
auto actual_dims_size = dims.size(); auto actual_dims_size = dims.size();
for (; actual_dims_size != 0; --actual_dims_size) { for (; actual_dims_size != 0; --actual_dims_size) {
...@@ -89,15 +90,16 @@ inline framework::DDim trim_trailing_singular_dims( ...@@ -89,15 +90,16 @@ inline framework::DDim trim_trailing_singular_dims(
template <typename T, typename DeviceContext> template <typename T, typename DeviceContext>
class RowwiseTransformIterator; class RowwiseTransformIterator;
template <typename T, typename DeviceContext> template <typename T, typename DeviceContext>
class MidWiseTransformIterator; class MidWiseTransformIterator;
template <typename T> template <typename T>
class RowwiseTransformIterator<T, platform::CPUDeviceContext> { class RowwiseTransformIterator<T, platform::CPUDeviceContext> {
public: public:
RowwiseTransformIterator(const T* ptr, int n) : ptr_(ptr), i_(0), n_(n) {} RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {}
RowwiseTransformIterator<T, platform::CPUDeviceContext>& operator++() { RowwiseTransformIterator<T, platform::CPUDeviceContext> &operator++() {
++i_; ++i_;
if (UNLIKELY(i_ == n_)) { if (UNLIKELY(i_ == n_)) {
i_ = 0; i_ = 0;
...@@ -105,20 +107,20 @@ class RowwiseTransformIterator<T, platform::CPUDeviceContext> { ...@@ -105,20 +107,20 @@ class RowwiseTransformIterator<T, platform::CPUDeviceContext> {
return *this; return *this;
} }
bool operator==(const RowwiseTransformIterator<T, platform::CPUDeviceContext>& bool operator==(const RowwiseTransformIterator<T, platform::CPUDeviceContext>
rhs) const { &rhs) const {
return (ptr_ + i_) == &(*rhs); return (ptr_ + i_) == &(*rhs);
} }
bool operator!=(const RowwiseTransformIterator<T, platform::CPUDeviceContext>& bool operator!=(const RowwiseTransformIterator<T, platform::CPUDeviceContext>
rhs) const { &rhs) const {
return (ptr_ + i_) != &(*rhs); return (ptr_ + i_) != &(*rhs);
} }
const T& operator*() { return ptr_[i_]; } const T &operator*() { return ptr_[i_]; }
private: private:
const T* ptr_; const T *ptr_;
int i_; int i_;
int64_t n_; int64_t n_;
}; };
...@@ -126,10 +128,10 @@ class RowwiseTransformIterator<T, platform::CPUDeviceContext> { ...@@ -126,10 +128,10 @@ class RowwiseTransformIterator<T, platform::CPUDeviceContext> {
template <typename T> template <typename T>
class MidWiseTransformIterator<T, platform::CPUDeviceContext> { class MidWiseTransformIterator<T, platform::CPUDeviceContext> {
public: public:
MidWiseTransformIterator(const T* ptr, int n, int post) MidWiseTransformIterator(const T *ptr, int n, int post)
: ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {} : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}
MidWiseTransformIterator<T, platform::CPUDeviceContext>& operator++() { MidWiseTransformIterator<T, platform::CPUDeviceContext> &operator++() {
++j_; ++j_;
if (UNLIKELY(j_ == post_)) { if (UNLIKELY(j_ == post_)) {
++i_; ++i_;
...@@ -141,20 +143,20 @@ class MidWiseTransformIterator<T, platform::CPUDeviceContext> { ...@@ -141,20 +143,20 @@ class MidWiseTransformIterator<T, platform::CPUDeviceContext> {
return *this; return *this;
} }
bool operator==(const MidWiseTransformIterator<T, platform::CPUDeviceContext>& bool operator==(const MidWiseTransformIterator<T, platform::CPUDeviceContext>
rhs) const { &rhs) const {
return (ptr_ + i_) == &(*rhs); return (ptr_ + i_) == &(*rhs);
} }
bool operator!=(const MidWiseTransformIterator<T, platform::CPUDeviceContext>& bool operator!=(const MidWiseTransformIterator<T, platform::CPUDeviceContext>
rhs) const { &rhs) const {
return (ptr_ + i_) != &(*rhs); return (ptr_ + i_) != &(*rhs);
} }
const T& operator*() { return ptr_[i_]; } const T &operator*() { return ptr_[i_]; }
private: private:
const T* ptr_; const T *ptr_;
int64_t i_; int64_t i_;
int64_t j_; int64_t j_;
int64_t n_; int64_t n_;
...@@ -165,18 +167,18 @@ class MidWiseTransformIterator<T, platform::CPUDeviceContext> { ...@@ -165,18 +167,18 @@ class MidWiseTransformIterator<T, platform::CPUDeviceContext> {
template <typename T> template <typename T>
class RowwiseTransformIterator<T, platform::CUDADeviceContext> class RowwiseTransformIterator<T, platform::CUDADeviceContext>
: public thrust::iterator_adaptor< : public thrust::iterator_adaptor<
RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T*> { RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T *> {
public: public:
typedef thrust::iterator_adaptor< typedef thrust::iterator_adaptor<
RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T*> RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T *>
super_t; super_t;
HOSTDEVICE RowwiseTransformIterator(const T* x, int n) HOSTDEVICE RowwiseTransformIterator(const T *x, int n)
: super_t(x), begin_(x), n_(n) {} : super_t(x), begin_(x), n_(n) {}
friend class thrust::iterator_core_access; friend class thrust::iterator_core_access;
private: private:
unsigned int n_; unsigned int n_;
const T* begin_; const T *begin_;
HOSTDEVICE typename super_t::reference dereference() const { HOSTDEVICE typename super_t::reference dereference() const {
return *(begin_ + (this->base() - begin_) % n_); return *(begin_ + (this->base() - begin_) % n_);
} }
...@@ -185,19 +187,19 @@ class RowwiseTransformIterator<T, platform::CUDADeviceContext> ...@@ -185,19 +187,19 @@ class RowwiseTransformIterator<T, platform::CUDADeviceContext>
template <typename T> template <typename T>
class MidWiseTransformIterator<T, platform::CUDADeviceContext> class MidWiseTransformIterator<T, platform::CUDADeviceContext>
: public thrust::iterator_adaptor< : public thrust::iterator_adaptor<
MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T*> { MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T *> {
public: public:
typedef thrust::iterator_adaptor< typedef thrust::iterator_adaptor<
MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T*> MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T *>
super_t; super_t;
HOSTDEVICE MidWiseTransformIterator(const T* x, int n, int post) HOSTDEVICE MidWiseTransformIterator(const T *x, int n, int post)
: super_t(x), begin_(x), n_(n), post_(post) {} : super_t(x), begin_(x), n_(n), post_(post) {}
friend class thrust::iterator_core_access; friend class thrust::iterator_core_access;
private: private:
unsigned int post_; unsigned int post_;
unsigned int n_; unsigned int n_;
const T* begin_; const T *begin_;
HOSTDEVICE typename super_t::reference dereference() const { HOSTDEVICE typename super_t::reference dereference() const {
return *(begin_ + (((this->base() - begin_) / post_) % n_)); return *(begin_ + (((this->base() - begin_) / post_) % n_));
} }
...@@ -208,8 +210,8 @@ template <typename Functor, typename T, typename DeviceContext, ...@@ -208,8 +210,8 @@ template <typename Functor, typename T, typename DeviceContext,
typename OutType = T> typename OutType = T>
class TransformFunctor { class TransformFunctor {
public: public:
TransformFunctor(const framework::Tensor* x, const framework::Tensor* y, TransformFunctor(const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor* z, const DeviceContext& ctx, Functor func) framework::Tensor *z, const DeviceContext &ctx, Functor func)
: x_(x->data<T>()), : x_(x->data<T>()),
y_(y->data<T>()), y_(y->data<T>()),
z_(z->mutable_data<OutType>(ctx.GetPlace())), z_(z->mutable_data<OutType>(ctx.GetPlace())),
...@@ -235,20 +237,20 @@ class TransformFunctor { ...@@ -235,20 +237,20 @@ class TransformFunctor {
} }
private: private:
const T* x_; const T *x_;
const T* y_; const T *y_;
OutType* z_; OutType *z_;
int64_t nx_; int64_t nx_;
const DeviceContext& ctx_; const DeviceContext &ctx_;
Functor func_; Functor func_;
}; };
#define EIGEN_FUNCTOR(name, eigen_op) \ #define EIGEN_FUNCTOR(name, eigen_op) \
struct Eigen##name##Functor { \ struct Eigen##name##Functor { \
template <typename DeviceContext, typename T> \ template <typename DeviceContext, typename T> \
inline void Run(const framework::Tensor* x, const framework::Tensor* y, \ inline void Run(const framework::Tensor *x, const framework::Tensor *y, \
framework::Tensor* z, \ framework::Tensor *z, \
const framework::ExecutionContext& ctx) { \ const framework::ExecutionContext &ctx) { \
auto x_e = framework::EigenVector<T>::Flatten(*x); \ auto x_e = framework::EigenVector<T>::Flatten(*x); \
auto y_e = framework::EigenVector<T>::Flatten(*y); \ auto y_e = framework::EigenVector<T>::Flatten(*y); \
auto z_e = framework::EigenVector<T>::Flatten(*z); \ auto z_e = framework::EigenVector<T>::Flatten(*z); \
...@@ -257,9 +259,9 @@ class TransformFunctor { ...@@ -257,9 +259,9 @@ class TransformFunctor {
eigen_op(x_e, y_e); \ eigen_op(x_e, y_e); \
} \ } \
template <typename DeviceContext, typename T> \ template <typename DeviceContext, typename T> \
inline void RunBroadCast(const framework::Tensor* x, \ inline void RunBroadCast(const framework::Tensor *x, \
const framework::Tensor* y, framework::Tensor* z, \ const framework::Tensor *y, framework::Tensor *z, \
const framework::ExecutionContext& ctx, int pre, \ const framework::ExecutionContext &ctx, int pre, \
int n) { \ int n) { \
auto x_e = framework::EigenVector<T>::Flatten(*x); \ auto x_e = framework::EigenVector<T>::Flatten(*x); \
auto y_e = framework::EigenVector<T>::Flatten(*y); \ auto y_e = framework::EigenVector<T>::Flatten(*y); \
...@@ -272,10 +274,10 @@ class TransformFunctor { ...@@ -272,10 +274,10 @@ class TransformFunctor {
eigen_op(x_e, y_bcast); \ eigen_op(x_e, y_bcast); \
} \ } \
template <typename DeviceContext, typename T> \ template <typename DeviceContext, typename T> \
inline void RunBroadCast2(const framework::Tensor* x, \ inline void RunBroadCast2(const framework::Tensor *x, \
const framework::Tensor* y, \ const framework::Tensor *y, \
framework::Tensor* z, \ framework::Tensor *z, \
const framework::ExecutionContext& ctx, int pre, \ const framework::ExecutionContext &ctx, int pre, \
int n, int post) { \ int n, int post) { \
auto x_e = framework::EigenVector<T>::Flatten(*x); \ auto x_e = framework::EigenVector<T>::Flatten(*x); \
auto y_e = framework::EigenVector<T>::Flatten(*y); \ auto y_e = framework::EigenVector<T>::Flatten(*y); \
...@@ -290,23 +292,27 @@ class TransformFunctor { ...@@ -290,23 +292,27 @@ class TransformFunctor {
} }
#define EIGEN_ADD(x, y) ((x) + (y)) #define EIGEN_ADD(x, y) ((x) + (y))
EIGEN_FUNCTOR(Add, EIGEN_ADD); EIGEN_FUNCTOR(Add, EIGEN_ADD);
#define EIGEN_SUB(x, y) ((x) - (y)) #define EIGEN_SUB(x, y) ((x) - (y))
EIGEN_FUNCTOR(Sub, EIGEN_SUB); EIGEN_FUNCTOR(Sub, EIGEN_SUB);
#define EIGEN_MUL(x, y) ((x) * (y)) #define EIGEN_MUL(x, y) ((x) * (y))
EIGEN_FUNCTOR(Mul, EIGEN_MUL); EIGEN_FUNCTOR(Mul, EIGEN_MUL);
#define EIGEN_DIV(x, y) ((x) / (y)) #define EIGEN_DIV(x, y) ((x) / (y))
EIGEN_FUNCTOR(Div, EIGEN_DIV); EIGEN_FUNCTOR(Div, EIGEN_DIV);
template <typename T, typename DX_OP, typename DY_OP> template <typename T, typename DX_OP, typename DY_OP>
struct ElemwiseGradNoBroadcast { struct ElemwiseGradNoBroadcast {
const T* x_; const T *x_;
const T* y_; const T *y_;
const T* out_; const T *out_;
const T* dout_; const T *dout_;
HOSTDEVICE void operator()(size_t i) { HOSTDEVICE void operator()(size_t i) {
if (dx_ != nullptr) { if (dx_ != nullptr) {
...@@ -319,14 +325,14 @@ struct ElemwiseGradNoBroadcast { ...@@ -319,14 +325,14 @@ struct ElemwiseGradNoBroadcast {
DX_OP dx_op_; DX_OP dx_op_;
DY_OP dy_op_; DY_OP dy_op_;
T* dx_; T *dx_;
T* dy_; T *dy_;
}; };
template <typename T, typename DX_OP, typename DY_OP> template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out, static void ElemwiseGradBroadcast1CPU(const T *x, const T *y, const T *out,
const T* dout, int h, int w, DX_OP dx_op, const T *dout, int h, int w, DX_OP dx_op,
DY_OP dy_op, T* dx, T* dy) { DY_OP dy_op, T *dx, T *dy) {
for (int i = 0; i < h; ++i) { for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) { for (int j = 0; j < w; ++j) {
int x_offset = i * w + j; int x_offset = i * w + j;
...@@ -348,8 +354,8 @@ static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out, ...@@ -348,8 +354,8 @@ static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out,
#ifdef __NVCC__ #ifdef __NVCC__
template <typename T, typename DX_OP, typename DY_OP> template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast1CUDAKernel( static __global__ void ElemwiseGradBroadcast1CUDAKernel(
const T* x, const T* y, const T* out, const T* dout, int h, int w, const T *x, const T *y, const T *out, const T *dout, int h, int w,
DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) { DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
int j = blockIdx.x; int j = blockIdx.x;
int i = threadIdx.x; int i = threadIdx.x;
int tid = threadIdx.x; int tid = threadIdx.x;
...@@ -376,10 +382,10 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel( ...@@ -376,10 +382,10 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
} }
template <typename T, typename DX_OP, typename DY_OP> template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T* x, static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T *x,
const T* y, const T* out, const T* dout, const T *y, const T *out, const T *dout,
int h, int w, DX_OP dx_op, DY_OP dy_op, int h, int w, DX_OP dx_op, DY_OP dy_op,
T* dx, T* dy) { T *dx, T *dy) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
int gird_size = w; int gird_size = w;
ElemwiseGradBroadcast1CUDAKernel<<<gird_size, block_size, 0, stream>>>( ElemwiseGradBroadcast1CUDAKernel<<<gird_size, block_size, 0, stream>>>(
...@@ -389,9 +395,9 @@ static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T* x, ...@@ -389,9 +395,9 @@ static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T* x,
#endif #endif
template <typename T, typename DX_OP, typename DY_OP> template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast2CPU(const T* x, const T* y, const T* out, static void ElemwiseGradBroadcast2CPU(const T *x, const T *y, const T *out,
const T* dout, int pre, int n, int post, const T *dout, int pre, int n, int post,
DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) { DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
for (int i = 0; i < pre; ++i) { for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) { for (int j = 0; j < n; ++j) {
for (int k = 0; k < post; ++k) { for (int k = 0; k < post; ++k) {
...@@ -416,8 +422,8 @@ static void ElemwiseGradBroadcast2CPU(const T* x, const T* y, const T* out, ...@@ -416,8 +422,8 @@ static void ElemwiseGradBroadcast2CPU(const T* x, const T* y, const T* out,
#ifdef __NVCC__ #ifdef __NVCC__
template <typename T, typename DX_OP, typename DY_OP> template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast2CUDAKernel( static __global__ void ElemwiseGradBroadcast2CUDAKernel(
const T* x, const T* y, const T* out, const T* dout, int pre, int n, const T *x, const T *y, const T *out, const T *dout, int pre, int n,
int post, DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) { int post, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
int tid = threadIdx.x; int tid = threadIdx.x;
int j = blockIdx.x; int j = blockIdx.x;
...@@ -453,10 +459,10 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel( ...@@ -453,10 +459,10 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
} }
template <typename T, typename DX_OP, typename DY_OP> template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T* x, static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T *x,
const T* y, const T* out, const T* dout, const T *y, const T *out, const T *dout,
int pre, int n, int post, DX_OP dx_op, int pre, int n, int post, DX_OP dx_op,
DY_OP dy_op, T* dx, T* dy) { DY_OP dy_op, T *dx, T *dy) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post); int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
int gird_size = n; int gird_size = n;
ElemwiseGradBroadcast2CUDAKernel<<<gird_size, block_size, 0, stream>>>( ElemwiseGradBroadcast2CUDAKernel<<<gird_size, block_size, 0, stream>>>(
...@@ -467,11 +473,11 @@ static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T* x, ...@@ -467,11 +473,11 @@ static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T* x,
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP> template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradComputeNoBroadcast( void ElemwiseGradComputeNoBroadcast(
const framework::ExecutionContext& ctx, const framework::DDim& x_dim, const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
const framework::DDim& y_dim, const framework::Tensor& x, const framework::DDim &y_dim, const framework::Tensor &x,
const framework::Tensor& y, const framework::Tensor& out, const framework::Tensor &y, const framework::Tensor &out,
const framework::Tensor& dout, int axis, framework::Tensor* dx, const framework::Tensor &dout, int axis, framework::Tensor *dx,
framework::Tensor* dy, DX_OP dx_op, DY_OP dy_op) { framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
size_t N = static_cast<size_t>(framework::product(x_dim)); size_t N = static_cast<size_t>(framework::product(x_dim));
platform::ForRange<DeviceContext> for_range( platform::ForRange<DeviceContext> for_range(
ctx.template device_context<DeviceContext>(), N); ctx.template device_context<DeviceContext>(), N);
...@@ -483,11 +489,11 @@ void ElemwiseGradComputeNoBroadcast( ...@@ -483,11 +489,11 @@ void ElemwiseGradComputeNoBroadcast(
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP> template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradComputeWithBroadcast( void ElemwiseGradComputeWithBroadcast(
const framework::ExecutionContext& ctx, const framework::DDim& x_dim, const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
const framework::DDim& y_dim_untrimed, const framework::Tensor& x, const framework::DDim &y_dim_untrimed, const framework::Tensor &x,
const framework::Tensor& y, const framework::Tensor& out, const framework::Tensor &y, const framework::Tensor &out,
const framework::Tensor& dout, int axis, framework::Tensor* dx, const framework::Tensor &dout, int axis, framework::Tensor *dx,
framework::Tensor* dy, DX_OP dx_op, DY_OP dy_op) { framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis); axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis);
auto y_dim = trim_trailing_singular_dims(y_dim_untrimed); auto y_dim = trim_trailing_singular_dims(y_dim_untrimed);
axis = (y_dim.size() == 0) ? x_dim.size() : axis; axis = (y_dim.size() == 0) ? x_dim.size() : axis;
...@@ -531,14 +537,14 @@ void ElemwiseGradComputeWithBroadcast( ...@@ -531,14 +537,14 @@ void ElemwiseGradComputeWithBroadcast(
} }
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP> template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradCompute(const framework::ExecutionContext& ctx, void ElemwiseGradCompute(const framework::ExecutionContext &ctx,
const framework::Tensor& x, const framework::Tensor& y, const framework::Tensor &x, const framework::Tensor &y,
const framework::Tensor& out, const framework::Tensor &out,
const framework::Tensor& dout, int axis, const framework::Tensor &dout, int axis,
framework::Tensor* dx, framework::Tensor* dy, framework::Tensor *dx, framework::Tensor *dy,
DX_OP dx_op, DY_OP dy_op) { DX_OP dx_op, DY_OP dy_op) {
const framework::DDim& x_dim = x.dims(); const framework::DDim &x_dim = x.dims();
const framework::DDim& y_dim = y.dims(); const framework::DDim &y_dim = y.dims();
if (x.dims() == y.dims()) { if (x.dims() == y.dims()) {
ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>( ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op); ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
...@@ -553,27 +559,27 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx, ...@@ -553,27 +559,27 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
// In elementwise_add, elementwise_sub, we use dout as fake X, Y, Out to reuse // In elementwise_add, elementwise_sub, we use dout as fake X, Y, Out to reuse
// elementwise code. // elementwise code.
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP> template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseExplicitGradCompute(const framework::ExecutionContext& ctx, void ElemwiseExplicitGradCompute(const framework::ExecutionContext &ctx,
const framework::Tensor& x, const framework::Tensor &x,
const framework::Tensor& y, const framework::Tensor &y,
const framework::Tensor& out, const framework::Tensor &out,
const framework::Tensor& dout, int axis, const framework::Tensor &dout, int axis,
framework::Tensor* dx, framework::Tensor* dy, framework::Tensor *dx, framework::Tensor *dy,
DX_OP dx_op, DY_OP dy_op) { DX_OP dx_op, DY_OP dy_op) {
if (dy == nullptr) { if (dy == nullptr) {
const framework::DDim& dx_dims = dout.dims(); const framework::DDim &dx_dims = dout.dims();
auto dy_dims = dx_dims; auto dy_dims = dx_dims;
ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>( ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op); ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
} else { } else {
if (dout.dims() == dy->dims()) { if (dout.dims() == dy->dims()) {
const framework::DDim& dx_dims = dout.dims(); const framework::DDim &dx_dims = dout.dims();
const framework::DDim& dy_dims = dy->dims(); const framework::DDim &dy_dims = dy->dims();
ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>( ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op); ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
} else { // Y is a scalar } else { // Y is a scalar
auto dx_dims = dout.dims(); auto dx_dims = dout.dims();
const framework::DDim& dy_dims = dy->dims(); const framework::DDim &dy_dims = dy->dims();
ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>( ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>(
ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op); ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
} }
...@@ -583,13 +589,13 @@ void ElemwiseExplicitGradCompute(const framework::ExecutionContext& ctx, ...@@ -583,13 +589,13 @@ void ElemwiseExplicitGradCompute(const framework::ExecutionContext& ctx,
// Deprecated // Deprecated
template <typename DeviceContext, typename T, typename functor, template <typename DeviceContext, typename T, typename functor,
typename broadcastfunctor, typename broadcast2functor> typename broadcastfunctor, typename broadcast2functor>
void ElementwiseGradCompute(const framework::ExecutionContext& ctx, void ElementwiseGradCompute(const framework::ExecutionContext &ctx,
const framework::Tensor* x, const framework::Tensor *x,
const framework::Tensor* y, const framework::Tensor *y,
const framework::Tensor* out, const framework::Tensor *out,
const framework::Tensor* dout, int axis, const framework::Tensor *dout, int axis,
framework::Tensor* dx, framework::Tensor* dy) { framework::Tensor *dx, framework::Tensor *dy) {
auto& place = *ctx.template device_context<DeviceContext>().eigen_device(); auto &place = *ctx.template device_context<DeviceContext>().eigen_device();
auto x_dims = x->dims(); auto x_dims = x->dims();
auto y_dims = y->dims(); auto y_dims = y->dims();
...@@ -627,10 +633,10 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx, ...@@ -627,10 +633,10 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
template <typename Functor, typename DeviceContext, typename T, template <typename Functor, typename DeviceContext, typename T,
typename OutType = T> typename OutType = T>
void ElementwiseComputeEx(const framework::ExecutionContext& ctx, void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
const framework::Tensor* x, const framework::Tensor *x,
const framework::Tensor* y, int axis, Functor func, const framework::Tensor *y, int axis, Functor func,
framework::Tensor* z) { framework::Tensor *z) {
TransformFunctor<Functor, T, DeviceContext, OutType> functor( TransformFunctor<Functor, T, DeviceContext, OutType> functor(
x, y, z, ctx.template device_context<DeviceContext>(), func); x, y, z, ctx.template device_context<DeviceContext>(), func);
...@@ -661,5 +667,823 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx, ...@@ -661,5 +667,823 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
} }
} }
// FusedElemwiseAndAct
// --- forward
template <typename T, typename CompoundFunctor, bool KeepIntermediateOut>
struct FusedElemwiseAndActNoBroadcast {
HOSTDEVICE void operator()(size_t i) {
T y_val = y_[i];
T x_val = x_[i];
if (KeepIntermediateOut) {
T intermeidiate_out = compound_functor_.GetIntermediateOut(x_val, y_val);
intermediate_out_[i] = intermeidiate_out;
out_[i] =
compound_functor_.GetOutUseIntermediateOut(x_val, intermeidiate_out);
} else {
out_[i] = compound_functor_.GetOut(x_val, y_val);
}
}
const T *x_;
const T *y_;
CompoundFunctor compound_functor_;
T *out_;
T *intermediate_out_;
};
// FusedElemwiseAndActBroadcast1:
// In this case, X and Y can be reshaped to a matrix.
// For example shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5) and axis = -1 or 2,
// X can be reshaped to (6, 20) and Y can be reshaped to (1, 20)
template <typename T, typename CompoundFunctor, bool BcastY,
bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActBroadcast1CPU(const T *x, const T *y,
CompoundFunctor compound_functor,
int h, int w, T *out,
T *intermediate_out) {
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
int offset = i * w + j;
T y_val = BcastY ? y[j] : y[offset];
T x_val = BcastY ? x[offset] : x[j];
int64_t intermediate_out_offset;
if (KeepIntermediateOut) {
T intermeidiate_out = compound_functor.GetIntermediateOut(x_val, y_val);
if (SameShapeOfIntermediateOutAndOut) {
// for the case of f1(f2(x, y))
intermediate_out_offset = offset;
} else if (BcastY) {
intermediate_out_offset = j;
} else {
intermediate_out_offset = offset;
}
intermediate_out[intermediate_out_offset] = intermeidiate_out;
out[offset] =
compound_functor.GetOutUseIntermediateOut(x_val, intermeidiate_out);
} else {
out[offset] = compound_functor.GetOut(x_val, y_val);
}
}
}
}
// FusedElemwiseAndActBroadcast2
// In this case, X and Y can be reshaped to a matrix.
// For example shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4) and axis = 1,
// X can be reshaped to (2, 12, 5) and Y can be reshaped to (1, 12, 1)
// pre = 2, n = 12, post = 5
template <typename T, typename CompoundFunctor, bool BcastY,
bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActBroadcast2CPU(const T *x, const T *y, int pre,
int n, int post,
CompoundFunctor compound_functor,
T *out, T *intermediate_out) {
for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) {
for (int k = 0; k < post; ++k) {
int offset = i * n * post + j * post + k;
T y_val = BcastY ? y[j] : y[offset];
T x_val = BcastY ? x[offset] : x[j];
int64_t intermediate_out_offset;
if (KeepIntermediateOut) {
T intermeidiate_out =
compound_functor.GetIntermediateOut(x_val, y_val);
if (SameShapeOfIntermediateOutAndOut) {
// for the case of f1(f2(x, y))
intermediate_out_offset = offset;
} else if (BcastY) {
intermediate_out_offset = j;
} else {
intermediate_out_offset = offset;
}
intermediate_out[intermediate_out_offset] = intermeidiate_out;
out[offset] = compound_functor.GetOutUseIntermediateOut(
x_val, intermeidiate_out);
} else {
out[offset] = compound_functor.GetOut(x_val, y_val);
}
}
}
}
}
#ifdef __NVCC__
template <typename T, typename CompoundFunctor, bool BcastY,
bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static __global__ void FusedElemwiseAndActBroadcast1CUDAKernel(
const T *x, const T *y, int h, int w, CompoundFunctor compound_functor,
T *out, T *intermediate_out) {
int j = blockIdx.x;
int i = threadIdx.x;
while (i < h) {
int offset = i * w + j;
T y_val = BcastY ? y[j] : y[offset];
T x_val = BcastY ? x[offset] : x[j];
int64_t intermediate_out_offset;
if (KeepIntermediateOut) {
T intermeidiate_out = compound_functor.GetIntermediateOut(x_val, y_val);
if (SameShapeOfIntermediateOutAndOut) {
// for the case of f1(f2(x, y))
intermediate_out_offset = offset;
} else if (BcastY) {
intermediate_out_offset = j;
} else {
intermediate_out_offset = offset;
}
intermediate_out[intermediate_out_offset] = intermeidiate_out;
out[offset] =
compound_functor.GetOutUseIntermediateOut(x_val, intermeidiate_out);
} else {
out[offset] = compound_functor.GetOut(x_val, y_val);
}
i += ELEMWISE_MAX_BLOCK_DIM;
}
}
template <typename T, typename CompoundFunctor, bool BcastY,
bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActBroadcast1CUDA(cudaStream_t stream, const T *x,
const T *y,
CompoundFunctor compound_functor,
int h, int w, T *out,
T *intermediate_out) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
int gird_size = w;
FusedElemwiseAndActBroadcast1CUDAKernel<
T, CompoundFunctor, BcastY, KeepIntermediateOut,
SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
x, y, h, w, compound_functor, out, intermediate_out);
}
template <typename T, typename CompoundFunctor, bool BcastY,
bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static __global__ void FusedElemwiseAndActBroadcast2CUDAKernel(
const T *x, const T *y, CompoundFunctor compound_functor, int pre, int n,
int post, T *out, T *intermediate_out) {
int tid = threadIdx.x;
int j = blockIdx.x;
while (true) {
int i = tid / post;
int k = tid % post;
if (i >= pre) break;
int offset = i * n * post + j * post + k;
T y_val = BcastY ? y[j] : y[offset];
T x_val = BcastY ? x[offset] : x[j];
int64_t intermediate_out_offset;
if (KeepIntermediateOut) {
T intermeidiate_out = compound_functor.GetIntermediateOut(x_val, y_val);
if (SameShapeOfIntermediateOutAndOut) {
// for the case of f1(f2(x, y))
intermediate_out_offset = offset;
} else if (BcastY) {
intermediate_out_offset = j;
} else {
intermediate_out_offset = offset;
}
intermediate_out[intermediate_out_offset] = intermeidiate_out;
out[offset] =
compound_functor.GetOutUseIntermediateOut(x_val, intermeidiate_out);
} else {
out[offset] = compound_functor.GetOut(x_val, y_val);
}
tid += ELEMWISE_MAX_BLOCK_DIM;
}
}
template <typename T, typename CompoundFunctor, bool BcastY,
bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActBroadcast2CUDA(cudaStream_t stream, const T *x,
const T *y, int pre, int n,
int post,
CompoundFunctor compound_functor,
T *out, T *intermediate_out) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
int gird_size = n;
FusedElemwiseAndActBroadcast2CUDAKernel<
T, CompoundFunctor, BcastY, KeepIntermediateOut,
SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
x, y, compound_functor, pre, n, post, out, intermediate_out);
}
#endif
template <typename DeviceContext, typename T, typename CompoundFunctor,
bool KeepIntermediateOut>
void FusedElemwiseAndActComputeNoBroadcast(
const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
const framework::Tensor &x, const framework::Tensor &y,
CompoundFunctor compound_functor, framework::Tensor *out,
framework::Tensor *intermediate_out) {
size_t N = static_cast<size_t>(framework::product(x_dim));
platform::ForRange<DeviceContext> for_range(
ctx.template device_context<DeviceContext>(), N);
for_range(
FusedElemwiseAndActNoBroadcast<T, CompoundFunctor, KeepIntermediateOut>{
x.data<T>(), y.data<T>(), compound_functor,
out->mutable_data<T>(ctx.GetPlace()),
intermediate_out == nullptr
? nullptr
: intermediate_out->mutable_data<T>(ctx.GetPlace())});
}
template <typename DeviceContext, typename T, typename CompoundFunctor,
bool BcastY, bool KeepIntermediateOut,
bool SameShapeOfIntermediateOutAndOut>
void FusedElemwiseAndActComputeWithBroadcast(
const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
const framework::DDim &y_dim_untrimed, const framework::Tensor &x,
const framework::Tensor &y, CompoundFunctor compound_functor, int axis,
framework::Tensor *out, framework::Tensor *intermediate_out) {
axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis);
auto y_dim = trim_trailing_singular_dims(y_dim_untrimed);
axis = (y_dim.size() == 0) ? x_dim.size() : axis;
int pre, n, post;
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
if (post == 1) {
int h = pre;
int w = n;
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
FusedElemwiseAndActBroadcast1CUDA<T, CompoundFunctor, BcastY,
KeepIntermediateOut,
SameShapeOfIntermediateOutAndOut>(
ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
y.data<T>(), compound_functor, h, w,
out->mutable_data<T>(ctx.GetPlace()),
intermediate_out == nullptr
? nullptr
: intermediate_out->mutable_data<T>(ctx.GetPlace()));
#endif
} else {
FusedElemwiseAndActBroadcast1CPU<T, CompoundFunctor, BcastY,
KeepIntermediateOut,
SameShapeOfIntermediateOutAndOut>(
x.data<T>(), y.data<T>(), compound_functor, h, w,
out->mutable_data<T>(ctx.GetPlace()),
intermediate_out == nullptr
? nullptr
: intermediate_out->mutable_data<T>(ctx.GetPlace()));
}
} else {
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
FusedElemwiseAndActBroadcast2CUDA<T, CompoundFunctor, BcastY,
KeepIntermediateOut,
SameShapeOfIntermediateOutAndOut>(
ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
y.data<T>(), pre, n, post, compound_functor,
out->mutable_data<T>(ctx.GetPlace()),
intermediate_out == nullptr
? nullptr
: intermediate_out->mutable_data<T>(ctx.GetPlace()));
#endif
} else {
FusedElemwiseAndActBroadcast2CPU<T, CompoundFunctor, BcastY,
KeepIntermediateOut,
SameShapeOfIntermediateOutAndOut>(
x.data<T>(), y.data<T>(), pre, n, post, compound_functor,
out->mutable_data<T>(ctx.GetPlace()),
intermediate_out == nullptr
? nullptr
: intermediate_out->mutable_data<T>(ctx.GetPlace()));
}
}
}
// --- backward
template <typename T, typename DX_OP, typename DY_OP, bool UseIntermediateOut>
struct FusedElemwiseAndActGradNoBroadcast {
HOSTDEVICE void operator()(size_t i) {
if (dx_ != nullptr) {
dx_[i] = UseIntermediateOut ? dx_op_(x_[i], y_[i], intermediate_out_[i],
out_[i], dout_[i])
: dx_op_(x_[i], y_[i], out_[i], dout_[i]);
}
if (dy_ != nullptr) {
dy_[i] = UseIntermediateOut ? dy_op_(x_[i], y_[i], intermediate_out_[i],
out_[i], dout_[i])
: dy_op_(x_[i], y_[i], out_[i], dout_[i]);
}
}
const T *x_;
const T *y_;
const T *intermediate_out_;
const T *out_;
const T *dout_;
DX_OP dx_op_;
DY_OP dy_op_;
T *dx_;
T *dy_;
};
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
bool UseIntermediateOut>
void FusedElemwiseAndActGradComputeNoBroadcast(
const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
const framework::DDim &y_dim, const framework::Tensor *x,
const framework::Tensor *y, const framework::Tensor *intermediate_out,
const framework::Tensor *out, const framework::Tensor *dout, int axis,
framework::Tensor *dx, framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
size_t N = static_cast<size_t>(framework::product(x_dim));
platform::ForRange<DeviceContext> for_range(
ctx.template device_context<DeviceContext>(), N);
for_range(
FusedElemwiseAndActGradNoBroadcast<T, DX_OP, DY_OP, UseIntermediateOut>{
x->data<T>(), y->data<T>(),
intermediate_out ? intermediate_out->data<T>() : nullptr,
out->data<T>(), dout->data<T>(), dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())});
}
template <typename T, typename DX_OP, typename DY_OP, bool UseIntermediateOut,
bool BcastY, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast1CPU(const T *x, const T *y,
const T *intermediate_out,
const T *out, const T *dout,
int h, int w, DX_OP dx_op,
DY_OP dy_op, T *dx, T *dy) {
int64_t tmp_out_idx, x_idx, y_idx;
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
int offset = i * w + j;
tmp_out_idx = BcastY ? j : offset;
y_idx = BcastY ? j : offset;
x_idx = BcastY ? offset : j;
if (SameShapeOfIntermediateOutAndOut) {
tmp_out_idx = offset;
}
if (dx != nullptr) {
T tmp = UseIntermediateOut
? dx_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dx_op(x[x_idx], y[y_idx], out[offset], dout[offset]);
if (BcastY) {
dx[x_idx] = tmp;
} else {
if (i == 0) {
dx[x_idx] = tmp;
} else {
dx[x_idx] += tmp;
}
}
}
if (dy != nullptr) {
T tmp = UseIntermediateOut
? dy_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dy_op(x[x_idx], y[y_idx], out[offset], dout[offset]);
if (BcastY) {
if (i == 0) {
dy[y_idx] = tmp;
} else {
dy[y_idx] += tmp;
}
} else {
dy[y_idx] = tmp;
}
}
}
}
}
template <typename T, typename DX_OP, typename DY_OP, bool UseIntermediateOut,
bool BcastY, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast2CPU(const T *x, const T *y,
const T *intermediate_out,
const T *out, const T *dout,
int pre, int n, int post,
DX_OP dx_op, DY_OP dy_op,
T *dx, T *dy) {
int64_t tmp_out_idx, x_idx, y_idx;
for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) {
for (int k = 0; k < post; ++k) {
int offset = i * n * post + j * post + k;
tmp_out_idx = BcastY ? j : offset;
y_idx = BcastY ? j : offset;
x_idx = BcastY ? offset : j;
if (SameShapeOfIntermediateOutAndOut) {
tmp_out_idx = offset;
}
if (dx != nullptr) {
T tmp = UseIntermediateOut
? dx_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dx_op(x[x_idx], y[y_idx], out[offset], dout[offset]);
if (BcastY) {
dx[x_idx] = tmp;
} else {
if (i == 0 && k == 0) {
dx[x_idx] = tmp;
} else {
dx[x_idx] += tmp;
}
}
}
if (dy != nullptr) {
T tmp = UseIntermediateOut
? dy_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dy_op(x[x_idx], y[y_idx], out[offset], dout[offset]);
if (BcastY) {
if (i == 0 && k == 0) {
dy[y_idx] = tmp;
} else {
dy[y_idx] += tmp;
}
} else {
dy[y_idx] = tmp;
}
}
}
}
}
}
#ifdef __NVCC__
template <typename T, typename DX_OP, typename DY_OP, bool UseIntermediateOut,
bool BcastY, bool SameShapeOfIntermediateOutAndOut>
static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
const T *x, const T *y, const T *intermediate_out, const T *out,
const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
int j = blockIdx.x;
int i = threadIdx.x;
int tid = threadIdx.x;
T val(0);
int64_t tmp_out_idx, x_idx, y_idx;
do {
int offset = i * w + j;
tmp_out_idx = BcastY ? j : offset;
y_idx = BcastY ? j : offset;
x_idx = BcastY ? offset : j;
if (SameShapeOfIntermediateOutAndOut) {
tmp_out_idx = offset;
}
if (dx != nullptr) {
T tmp = UseIntermediateOut
? dx_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dx_op(x[x_idx], y[y_idx], out[offset], dout[offset]);
if (BcastY) {
dx[x_idx] = tmp;
} else {
val += tmp;
}
}
if (dy != nullptr) {
T tmp = UseIntermediateOut
? dy_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dy_op(x[x_idx], y[y_idx], out[offset], dout[offset]);
if (BcastY) {
val += tmp;
} else {
dy[y_idx] = tmp;
}
}
i += ELEMWISE_MAX_BLOCK_DIM;
} while (i < h);
if (BcastY) {
if (dy) {
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dy[j] = val;
}
}
} else {
if (dx) {
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dx[j] = val;
}
}
}
}
template <typename T, typename DX_OP, typename DY_OP, bool UseIntermediateOut,
bool BcastY, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast1CUDA(cudaStream_t stream,
const T *x, const T *y,
const T *intermediate_out,
const T *out, const T *dout,
int h, int w, DX_OP dx_op,
DY_OP dy_op, T *dx, T *dy) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
int gird_size = w;
FusedElemwiseAndActGradBroadcast1CUDAKernel<
T, DX_OP, DY_OP, UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
x, y, intermediate_out, out, dout, h, w, dx_op, dy_op, dx, dy);
}
template <typename T, typename DX_OP, typename DY_OP, bool UseIntermediateOut,
bool BcastY, bool SameShapeOfIntermediateOutAndOut>
static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel(
const T *x, const T *y, const T *intermediate_out, const T *out,
const T *dout, int pre, int n, int post, DX_OP dx_op, DY_OP dy_op, T *dx,
T *dy) {
int tid = threadIdx.x;
int j = blockIdx.x;
T val(0);
int ttid = tid;
int64_t tmp_out_idx, x_idx, y_idx;
while (true) {
int i = ttid / post;
int k = ttid % post;
if (i >= pre) break;
int offset = i * n * post + j * post + k;
tmp_out_idx = BcastY ? j : offset;
y_idx = BcastY ? j : offset;
x_idx = BcastY ? offset : j;
if (SameShapeOfIntermediateOutAndOut) {
tmp_out_idx = offset;
}
if (dx != nullptr) {
T tmp = UseIntermediateOut
? dx_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dx_op(x[x_idx], y[y_idx], out[offset], dout[offset]);
if (BcastY) {
dx[x_idx] = tmp;
} else {
val += tmp;
}
}
if (dy != nullptr) {
T tmp = UseIntermediateOut
? dy_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dy_op(x[x_idx], y[y_idx], out[offset], dout[offset]);
if (BcastY) {
val += tmp;
} else {
dy[y_idx] = tmp;
}
}
ttid += ELEMWISE_MAX_BLOCK_DIM;
}
if (BcastY) {
if (dy) {
int h = pre * post;
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dy[j] = val;
}
}
} else {
if (dx) {
int h = pre * post;
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dx[j] = val;
}
}
}
}
template <typename T, typename DX_OP, typename DY_OP, bool UseIntermediateOut,
bool BcastY, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast2CUDA(
cudaStream_t stream, const T *x, const T *y, const T *intermediate_out,
const T *out, const T *dout, int pre, int n, int post, DX_OP dx_op,
DY_OP dy_op, T *dx, T *dy) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
int gird_size = n;
FusedElemwiseAndActGradBroadcast2CUDAKernel<
T, DX_OP, DY_OP, UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
x, y, intermediate_out, out, dout, pre, n, post, dx_op, dy_op, dx, dy);
}
#endif
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
bool UseIntermediateOut, bool BcastY,
bool SameShapeOfIntermediateOutAndOut>
void FusedElemwiseAndActGradComputeWithBroadcast(
const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
const framework::DDim &y_dim_untrimed, const framework::Tensor *x,
const framework::Tensor *y, const framework::Tensor *intermediate_out,
const framework::Tensor *out, const framework::Tensor *dout, int axis,
framework::Tensor *dx, framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis);
auto y_dim = trim_trailing_singular_dims(y_dim_untrimed);
axis = (y_dim.size() == 0) ? x_dim.size() : axis;
int pre, n, post;
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
if (post == 1) {
int h = pre;
int w = n;
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
FusedElemwiseAndActGradBroadcast1CUDA<T, DX_OP, DY_OP, UseIntermediateOut,
BcastY,
SameShapeOfIntermediateOutAndOut>(
ctx.template device_context<DeviceContext>().stream(), x->data<T>(),
y->data<T>(),
intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
out->data<T>(), dout->data<T>(), h, w, dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
} else {
FusedElemwiseAndActGradBroadcast1CPU<T, DX_OP, DY_OP, UseIntermediateOut,
BcastY,
SameShapeOfIntermediateOutAndOut>(
x->data<T>(), y->data<T>(),
intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
out->data<T>(), dout->data<T>(), h, w, dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
}
} else {
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
FusedElemwiseAndActGradBroadcast2CUDA<T, DX_OP, DY_OP, UseIntermediateOut,
BcastY,
SameShapeOfIntermediateOutAndOut>(
ctx.template device_context<DeviceContext>().stream(), x->data<T>(),
y->data<T>(),
intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
out->data<T>(), dout->data<T>(), pre, n, post, dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
} else {
FusedElemwiseAndActGradBroadcast2CPU<T, DX_OP, DY_OP, UseIntermediateOut,
BcastY,
SameShapeOfIntermediateOutAndOut>(
x->data<T>(), y->data<T>(),
intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
out->data<T>(), dout->data<T>(), pre, n, post, dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
}
}
}
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
bool UseIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
void FusedElemwiseAndActGradComputeEx(
const framework::ExecutionContext &ctx, const framework::Tensor *x,
const framework::Tensor *y, const framework::Tensor *out,
const framework::Tensor *intermediate_out, const framework::Tensor *dout,
int axis, framework::Tensor *dx, framework::Tensor *dy, DX_OP dx_op,
DY_OP dy_op) {
const framework::DDim &x_dim = x->dims();
const framework::DDim &y_dim = y->dims();
if (UseIntermediateOut) {
PADDLE_ENFORCE(intermediate_out, "intermediate_out should not be nullptr");
}
if (x_dim == y_dim) {
FusedElemwiseAndActGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP,
UseIntermediateOut>(
ctx, x_dim, y_dim, x, y, intermediate_out, out, dout, axis, dx, dy,
dx_op, dy_op);
} else { // Y is a scalar
bool bcast_y = x_dim.size() >= y_dim.size();
if (x_dim.size() == y_dim.size()) {
for (int i = 0; i < x_dim.size(); ++i) {
if (x_dim[i] < y_dim[i]) {
bcast_y = false;
break;
}
}
}
// z = f1(x, f2(y))
// z = f1(f2(x, y))
if (bcast_y) { // Y should be broadcast.
FusedElemwiseAndActGradComputeWithBroadcast<
DeviceContext, T, DX_OP, DY_OP, UseIntermediateOut, true /*BcastY*/,
SameShapeOfIntermediateOutAndOut>(ctx, x_dim, y_dim, x, y,
intermediate_out, out, dout, axis,
dx, dy, dx_op, dy_op);
} else {
FusedElemwiseAndActGradComputeWithBroadcast<
DeviceContext, T, DX_OP, DY_OP, UseIntermediateOut, false /*BcastY*/,
SameShapeOfIntermediateOutAndOut>(ctx, y_dim, x_dim, x, y,
intermediate_out, out, dout, axis,
dx, dy, dx_op, dy_op);
}
}
}
template <typename DeviceContext, typename T, typename CompoundFunctor,
bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
void FusedElemwiseAndActComputeEx(const framework::ExecutionContext &ctx,
const framework::Tensor &x,
const framework::Tensor &y, int axis,
CompoundFunctor compound_functor,
framework::Tensor *out,
framework::Tensor *intermediate_out) {
if (KeepIntermediateOut) {
PADDLE_ENFORCE(intermediate_out,
"The keep_intermediate_value is opened, "
"intermediate_out should not be nullptr.");
}
const framework::DDim &x_dim = x.dims();
const framework::DDim &y_dim = y.dims();
if (x.dims() == y.dims()) {
FusedElemwiseAndActComputeNoBroadcast<DeviceContext, T, CompoundFunctor,
KeepIntermediateOut>(
ctx, x_dim, x, y, compound_functor, out, intermediate_out);
} else {
// Whether the shape of Y is a continuous subsequence of X,
// For more information please refer to the op's introduction.
bool bcast_y = x.dims().size() >= y.dims().size();
if (x.dims().size() == y.dims().size()) {
for (int i = 0; i < x.dims().size(); ++i) {
if (x.dims()[i] < y.dims()[i]) {
bcast_y = false;
break;
}
}
}
// z = f1(x, f2(y))
// z = f1(f2(x, y))
if (bcast_y) { // Y should be broadcast.
// In this case,
// for 'f2(y)', the shape of intermediate_out should be equal to the shape
// of Y.
// for 'f2(x, y)', the shape of intermediate_out should be equal to the
// shape of Out.
// the shape of Out should be equal to the shape of X.
FusedElemwiseAndActComputeWithBroadcast<
DeviceContext, T, CompoundFunctor, true /*BcastY*/,
KeepIntermediateOut, SameShapeOfIntermediateOutAndOut>(
ctx, x_dim /*OutShape*/, y_dim, x, y, compound_functor, axis, out,
intermediate_out);
} else {
// In this case,
// for 'f2(y)', the shape of intermediate_out should be equal to the shape
// of Out.
// for 'f2(x, y)', the shape of intermediate_out should be equal to the
// shape of Out.
// the shape of Out should be equal to the shape of Y.
FusedElemwiseAndActComputeWithBroadcast<
DeviceContext, T, CompoundFunctor, false /*BcastY*/,
KeepIntermediateOut, SameShapeOfIntermediateOutAndOut>(
ctx, y_dim /*OutShape*/, x_dim, x, y, compound_functor, axis, out,
intermediate_out);
}
}
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -12,14 +12,60 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,14 +12,60 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/fused_elemwise_activation_op.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/operators/fused_elemwise_activation_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
/*
* Whether the compound function is Unary(Binary(X, Y)).
* For Unary(Binary(X, Y)), the intermediate_out's shape is the same the final
* out.
*/
static bool IsUnaryCompound(const std::vector<std::string> &functor_list) {
PADDLE_ENFORCE_EQ(functor_list.size(), 2);
static std::unordered_set<std::string> binary_fun = {
"elementwise_add", "elementwise_mul", "elementwise_add_grad",
"elementwise_mul_grad"};
return binary_fun.count(functor_list[1]) != 0;
}
/*
* Whether the Input(X) could be absent.
*/
static bool InputXCanBeAbsent(const std::vector<std::string> &functor_list) {
PADDLE_ENFORCE_EQ(functor_list.size(), 2);
static std::unordered_set<std::string> binary_fun = {"elementwise_add_grad"};
return binary_fun.count(functor_list[0]) != 0 ||
binary_fun.count(functor_list[1]) != 0;
}
/*
* Whether the compound function is supported.
* For Unary(Binary(X, Y)), the intermediate_out's shape is the same the final
* out.
*/
static bool IsSupportedCompound(const std::vector<std::string> &functors) {
static std::unordered_set<std::string> unary_fun = {"scale", "relu"};
static std::unordered_set<std::string> binary_fun = {"elementwise_add",
"elementwise_mul"};
std::string unary_fun_str;
if (binary_fun.count(functors[0])) {
unary_fun_str = functors[1];
} else if (binary_fun.count(functors[1])) {
unary_fun_str = functors[0];
} else {
PADDLE_THROW("%s and %s are not included in fused_list.", functors[0],
functors[1]);
}
PADDLE_ENFORCE_EQ(unary_fun.count(unary_fun_str), 1,
"%s is not included in fused_list.", unary_fun_str);
return true;
}
class FusedElemwiseActivationOp : public framework::OperatorWithKernel { class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -37,11 +83,44 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel { ...@@ -37,11 +83,44 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
auto x_dim = ctx->GetInputDim("X"); auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputDim("Y"); auto y_dim = ctx->GetInputDim("Y");
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
"Rank of first input must >= rank of second input.");
ctx->SetOutputDim("Out", x_dim); // Whether the shape of Y is a continuous subsequence of X,
ctx->ShareLoD("X", /*->*/ "Out"); // For more information please refer to the op's introduction.
bool bcast_y = x_dim.size() >= y_dim.size();
if (x_dim.size() == y_dim.size()) {
for (int i = 0; i < x_dim.size(); ++i) {
if (x_dim[i] < y_dim[i]) {
bcast_y = false;
break;
}
}
}
auto &out_dim = bcast_y ? x_dim : y_dim;
std::string out_lod = bcast_y ? "X" : "Y";
if (ctx->Attrs().Get<bool>("keep_intermediate_value")) {
PADDLE_ENFORCE(ctx->HasOutput("IntermediateOut"),
"Output(IntermediateOut) of FusedElemwiseActivationOp "
"should not be null.");
if (IsUnaryCompound(
ctx->Attrs().Get<std::vector<std::string>>("functor_list"))) {
// for Unary(Binary(X, Y)), the shape and lod of out and
// intermediate_out are the same.
ctx->SetOutputDim("IntermediateOut", out_dim);
// set the lod of intermediate_out
ctx->ShareLoD(out_lod, /*->*/ "IntermediateOut");
} else {
// for Binary(X, Unary(Y)), the shape and lod of Y and
// intermediate_out are the same.
ctx->SetOutputDim("IntermediateOut", y_dim);
// set the lod of intermediate_out
ctx->ShareLoD("Y", /*->*/ "IntermediateOut");
}
}
ctx->SetOutputDim("Out", out_dim);
ctx->ShareLoD(out_lod, /*->*/ "Out");
} }
protected: protected:
...@@ -59,29 +138,42 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel { ...@@ -59,29 +138,42 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
class FusedElemwiseActivationMaker : public framework::OpProtoAndCheckerMaker { class FusedElemwiseActivationMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "(vector<Tensor>)"); AddInput(
AddInput("Y", "(vector<Tensor>)"); "X",
AddOutput("Out", "vector<Tensor>"); "(Tensor) The input tensor of fused_elemwise_activation operator.");
AddInput(
"Y",
"(Tensor) The input tensor of fused_elemwise_activation operator.");
AddOutput("Out",
"vector<Tensor> The output tensor of fused_elemwise_activation "
"operator.");
AddOutput("IntermediateOut",
"Tensor The IntermediateOut tensor of fused_elemwise_activation "
"operator.")
.AsIntermediate();
AddAttr<int>("axis", AddAttr<int>("axis",
"axis is used by elementwise_op, the default value is -1.") "axis is used by elementwise_op, the default value is -1.")
.SetDefault(-1); .SetDefault(-1);
AddAttr<float>("scale", AddAttr<float>("scale",
"scale is used by scale_op, the default value is 0.0.") "scale is used by scale_op, the default value is 0.0.")
.SetDefault(0.0); .SetDefault(0.0);
AddAttr<bool>("recomputation", AddAttr<bool>(
"recomputation",
"Whether to recompute the Out." "Whether to recompute the Out."
"fused_elemwise_activation_grad has two methods to get the " "The computation of fused_elemwise_activation_grad has two methods to "
"dx and dy, one " "get the dx and dy, one is to use the 'Out', and the other is not. "
"is to use the 'Out', and the other is not to use it. " "The former method will save the time of recomputing the 'Out', but it "
"The former method will save the time of recomputing the " "must occupy the memory to store the 'out'. While, the later method "
"'Out', but it must occupy the memory to store the 'out'. " "can avoid occupying the memory, but it must recompute the 'Out'. "
"While, the later method can avoid occupying the memory, " "It is useful for Unary(Binary(X, Y)). The default value is true.")
"but it must recompute the 'Out'. The default value is true.")
.SetDefault(true); .SetDefault(true);
AddAttr<bool>("keep_intermediate_value",
"Whether to save the intermediate_out.")
.SetDefault(false);
AddAttr<std::vector<std::string>>("functor_list", AddAttr<std::vector<std::string>>("functor_list",
"The functors that should be fused.") "The functors that should be fused.")
.AddCustomChecker([&](const std::vector<std::string> &functor_list) { .AddCustomChecker([&](const std::vector<std::string> &functor_list) {
PADDLE_ENFORCE(ValidCheck(functor_list)); PADDLE_ENFORCE(IsSupportedCompound(functor_list));
}); });
AddComment(R"DOC( AddComment(R"DOC(
...@@ -93,30 +185,38 @@ operators (elementwise_op and activation_op): ...@@ -93,30 +185,38 @@ operators (elementwise_op and activation_op):
Z = Binary(X, Unary(Y)) Z = Binary(X, Unary(Y))
Z = Unary(Binary(X, Y)) Z = Unary(Binary(X, Y))
The attributions of activation_op can be get from fused_elemwise_activation_op's There are two cases for this operator:
attributions. functor_list records the functors to be fused, for example
"scale,elementwise_add".
)DOC"); 1. The shape of $Y$ and $X$ is the same.
} 2. The shape of $Y$ is a continuous subsequence of $X$ or the shape of $X$ is a continuous subsequence of $Y$.
private: For case 2 (assume that the shape of $Y$ is a continuous subsequence of $X$ ):
bool ValidCheck(const std::vector<std::string> &functors) {
std::unordered_set<std::string> unary_fun = {"scale", "relu"};
std::unordered_set<std::string> binary_fun = {"elementwise_add"};
std::string unary_fun_str; 1. Broadcast $Y$ to match the shape of $X$, where $axis$ is the start dimension index
if (binary_fun.count(functors[0])) { for broadcasting $Y$ onto $X$.
unary_fun_str = functors[1]; 2. If $axis$ is -1 (default), $axis = rank(X) - rank(Y)$.
} else if (binary_fun.count(functors[1])) { 3. The trailing dimensions of size 1 for $Y$ will be ignored for the consideration of
unary_fun_str = functors[0]; subsequence, such as shape(Y) = (2, 1) => (2).
} else {
PADDLE_THROW("%s and %s are not included in fused_list.", functors[0], For example:
functors[1]);
} .. code-block:: python
PADDLE_ENFORCE_EQ(unary_fun.count(unary_fun_str), 1,
"%s is not included in fused_list.", unary_fun_str); shape(X) = (2, 3, 4, 5), shape(Y) = (,)
return true; shape(X) = (2, 3, 4, 5), shape(Y) = (5,)
shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5), with axis=-1(default) or axis=2
shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0
shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0
The inputs $X$ and $Y$ can carry the different LoD information.
But the output only shares the LoD information with the one whose shape is the same with Out.
The attributions of activation_op can be get from fused_elemwise_activation_op's.
The functor_list records the functions to be fused, for example
["scale", "elementwise_add"].
)DOC");
} }
}; };
...@@ -141,6 +241,7 @@ class FusedElemwiseActivationGradMaker ...@@ -141,6 +241,7 @@ class FusedElemwiseActivationGradMaker
op_desc_ptr->SetInput(framework::GradVarName(output_param), op_desc_ptr->SetInput(framework::GradVarName(output_param),
this->OutputGrad(output_param)); this->OutputGrad(output_param));
} }
op_desc_ptr->SetAttrMap(this->Attrs()); op_desc_ptr->SetAttrMap(this->Attrs());
std::vector<std::string> functor_names = std::vector<std::string> functor_names =
...@@ -158,40 +259,59 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel { ...@@ -158,40 +259,59 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null"); "Input(Out@Grad) should not be null");
if (ctx->Attrs().Get<bool>("keep_intermediate_value")) {
auto x_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE(ctx->HasInput("IntermediateOut"),
auto y_dims = ctx->GetInputDim("Y"); "Input(IntermediateOut) should not be null");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); } else {
PADDLE_ENFORCE_EQ(ctx->Inputs(framework::GradVarName("Out")).size(), 1);
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), }
"Rank of first input must >= rank of second input.");
auto funtor_list =
ctx->Attrs().Get<std::vector<std::string>>("functor_list");
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y"); auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) { if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims); if (ctx->HasInputs("X")) {
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
ctx->ShareLoD("X", x_grad_name);
} else {
// Node: If "X" is absence, the shape of Y should be a continuous
// subsequence of X, if not, we could not infer the shape of dx.
// Currently, only when Binary is elementwise_add or elementwise_sub,
// the "X" could be absent.
PADDLE_ENFORCE(InputXCanBeAbsent(funtor_list),
"Only when BinaryFunctor is elementwise_add, the 'X' "
"could be absent.");
// For Unary(Binary(X, Y)), IntermediateOut should not be empty.
if (IsUnaryCompound(funtor_list)) {
PADDLE_ENFORCE(
ctx->HasInputs("IntermediateOut"),
"If the compound_functor is Unary(Binary(X, Y)) and Binary "
"is elementwise_add, the intermediate_out must be not absent.");
}
ctx->SetOutputDim(x_grad_name,
ctx->GetInputDim(framework::GradVarName("Out")));
ctx->ShareLoD(framework::GradVarName("Out"), x_grad_name);
}
} }
if (ctx->HasOutput(y_grad_name)) { if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
ctx->SetOutputDim(y_grad_name, ctx->GetInputDim("Y"));
ctx->ShareLoD("Y", y_grad_name);
} }
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type_index = ctx.Input<framework::Tensor>("X")->type(); // PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE_EQ(input_data_type_index, auto input_data_type_index = ctx.Input<framework::Tensor>("Y")->type();
ctx.Input<framework::Tensor>("Y")->type(),
"The element's type of input should be the same.");
PADDLE_ENFORCE_EQ(
input_data_type_index,
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
"The element's type of input should be the same.");
auto input_data_type = framework::ToDataType(input_data_type_index); auto input_data_type = framework::ToDataType(input_data_type_index);
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
......
...@@ -20,208 +20,114 @@ limitations under the License. */ ...@@ -20,208 +20,114 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise_op_function.h"
#include "paddle/fluid/operators/math/compound_functors.h"
#include "paddle/fluid/operators/math/functors.h" #include "paddle/fluid/operators/math/functors.h"
namespace math = paddle::operators::math;
namespace paddle { namespace paddle {
namespace operators { namespace operators {
// CompoundFunctors
// For example: Z = Binary(X, Unary(Y))
template <typename T, typename BinaryFun, typename UnaryFun>
struct BinaryCompoundFunctor {
BinaryCompoundFunctor(const BinaryFun &binary_fun, const UnaryFun &unary_fun)
: binary_fun_(binary_fun), unary_fun_(unary_fun) {}
inline HOSTDEVICE T operator()(T x, T y) {
return binary_fun_(x, unary_fun_(y));
}
private:
BinaryFun binary_fun_;
UnaryFun unary_fun_;
};
// For example: Z = Unary(Binary(X, Y))
template <typename T, typename UnaryFun, typename BinaryFun>
struct UnaryCompoundFunctor {
UnaryCompoundFunctor(const UnaryFun &unary_fun, const BinaryFun &binary_fun)
: unary_fun_(unary_fun), binary_fun_(binary_fun) {}
inline HOSTDEVICE T operator()(T x, T y) {
return unary_fun_(binary_fun_(x, y));
}
private:
UnaryFun unary_fun_;
BinaryFun binary_fun_;
};
// FIXME(zcd): DBinaryFun and DUnaryFun have to method to get
// the dx, one is to use the 'out', and the other is not to use it.
// the former method will save the time of recomputing the
// 'out', but it must occupy the memory to store the 'out'.
// While the later method can avoid occupying this memory,
// but it must recompute the 'out'.
template <typename T, typename DBinaryFun, typename UnaryFun,
bool Recomputation = true>
struct BinaryCompoundGradDxFunctor {
BinaryCompoundGradDxFunctor(const DBinaryFun &d_binary_fun,
const UnaryFun &unary_fun)
: d_binary_fun_(d_binary_fun), unary_fun_(unary_fun) {}
inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
if (Recomputation) {
return dout * d_binary_fun_(x, unary_fun_(y));
} else {
return dout * d_binary_fun_(x, unary_fun_(y), out);
}
}
private:
DBinaryFun d_binary_fun_;
UnaryFun unary_fun_;
};
template <typename T, typename DBinaryFun, typename UnaryFun,
typename DUnaryFun, bool Recomputation = true>
struct BinaryCompoundGradDyFunctor {
BinaryCompoundGradDyFunctor(const DBinaryFun &d_binary_fun,
const UnaryFun &unary_fun,
const DUnaryFun &d_unary_fun)
: d_binary_fun_(d_binary_fun),
unary_fun_(unary_fun),
d_unary_fun_(d_unary_fun) {}
inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
if (Recomputation) {
return dout * d_binary_fun_(unary_fun_(y), x) * d_unary_fun_(y);
} else {
return dout * d_binary_fun_(unary_fun_(y), x, out) * d_unary_fun_(y);
}
}
private:
DBinaryFun d_binary_fun_;
UnaryFun unary_fun_;
DUnaryFun d_unary_fun_;
};
template <typename T, typename DUnaryFun, typename BinaryFun,
typename DBinaryFun, bool Recomputation = true>
struct UnaryCompoundGradDxFunctor {
UnaryCompoundGradDxFunctor(const DUnaryFun &d_unary_fun,
const BinaryFun &binary_fun,
const DBinaryFun &d_binary_fun)
: d_unary_fun_(d_unary_fun),
binary_fun_(binary_fun),
d_binary_fun_(d_binary_fun) {}
inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
T base;
if (Recomputation) {
base = dout * d_unary_fun_(binary_fun_(x, y));
} else {
base = dout * d_unary_fun_(binary_fun_(x, y), out);
}
return base * d_binary_fun_(x, y);
}
private:
DUnaryFun d_unary_fun_;
BinaryFun binary_fun_;
DBinaryFun d_binary_fun_;
};
template <typename T, typename DUnaryFun, typename BinaryFun,
typename DBinaryFun, bool Recomputation = true>
struct UnaryCompoundGradDyFunctor {
UnaryCompoundGradDyFunctor(const DUnaryFun &d_unary_fun,
const BinaryFun &binary_fun,
const DBinaryFun &d_binary_fun)
: d_unary_fun_(d_unary_fun),
binary_fun_(binary_fun),
d_binary_fun_(d_binary_fun) {}
inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
T base;
if (Recomputation) {
base = dout * d_unary_fun_(binary_fun_(x, y));
} else {
base = dout * d_unary_fun_(binary_fun_(x, y), out);
}
return base * d_binary_fun_(y, x);
}
private:
DUnaryFun d_unary_fun_;
BinaryFun binary_fun_;
DBinaryFun d_binary_fun_;
};
template <typename DeviceContext, typename T, typename BinaryFunctor, template <typename DeviceContext, typename T, typename BinaryFunctor,
typename UnaryFunctor> typename UnaryFunctor>
static void RunBinaryCompoundFunctor(const framework::ExecutionContext &ctx, static void RunBinaryCompoundFunctor(
const BinaryFunctor &binary_functor, const framework::ExecutionContext &ctx, const BinaryFunctor &binary_functor,
const UnaryFunctor &unary_functor, const UnaryFunctor &unary_functor, const framework::Tensor &in_x,
const framework::Tensor *in_x, const framework::Tensor &in_y, std::vector<framework::Tensor *> *outputs) {
const framework::Tensor *in_y, // Z = Binary(X, Unary(Y))
framework::Tensor *output) { // intermediate_out = Unary(Y)
// out = Binary(X, Unary(Y))
// In this case, the shape of intermediate_out and out are different.
paddle::operators::math::BinaryCompoundFunctor<T, BinaryFunctor, UnaryFunctor>
compound_func(binary_functor, unary_functor);
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
using BinaryCompoundFunctor = if (ctx.Attr<bool>("keep_intermediate_value")) {
BinaryCompoundFunctor<T, BinaryFunctor, UnaryFunctor>; FusedElemwiseAndActComputeEx<DeviceContext, T,
paddle::operators::math::BinaryCompoundFunctor<
ElementwiseComputeEx<BinaryCompoundFunctor, DeviceContext, T>( T, BinaryFunctor, UnaryFunctor>,
ctx, in_x, in_y, axis, true /*KeepIntermediateValue*/,
BinaryCompoundFunctor(binary_functor, unary_functor), output); false /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
} else {
FusedElemwiseAndActComputeEx<DeviceContext, T,
paddle::operators::math::BinaryCompoundFunctor<
T, BinaryFunctor, UnaryFunctor>,
false /*KeepIntermediateValue*/,
false /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
}
} }
template <typename DeviceContext, typename T, typename UnaryFunctor, template <typename DeviceContext, typename T, typename UnaryFunctor,
typename BinaryFunctor> typename BinaryFunctor>
static void RunUnaryCompoundFunctors(const framework::ExecutionContext &ctx, static void RunUnaryCompoundFunctors(
const UnaryFunctor &unary_functor, const framework::ExecutionContext &ctx, const UnaryFunctor &unary_functor,
const BinaryFunctor &binary_functor, const BinaryFunctor &binary_functor, const framework::Tensor &in_x,
const framework::Tensor *in_x, const framework::Tensor &in_y, std::vector<framework::Tensor *> *outputs) {
const framework::Tensor *in_y, // Z = Unary(Binary(X, Y))
framework::Tensor *output) { // intermediate_out = Binary(X, Y)
// out = Unary(Binary(X, Y))
// In this case, the shape of intermediate_out and out are the same.
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
using UnaryCompoundFunctor = paddle::operators::math::UnaryCompoundFunctor<T, UnaryFunctor, BinaryFunctor>
UnaryCompoundFunctor<T, UnaryFunctor, BinaryFunctor>; compound_func(unary_functor, binary_functor);
ElementwiseComputeEx<UnaryCompoundFunctor, DeviceContext, T>( if (ctx.Attr<bool>("keep_intermediate_value")) {
ctx, in_x, in_y, axis, FusedElemwiseAndActComputeEx<DeviceContext, T,
UnaryCompoundFunctor(unary_functor, binary_functor), output); paddle::operators::math::UnaryCompoundFunctor<
T, UnaryFunctor, BinaryFunctor>,
true /*KeepIntermediateValue*/,
true /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
} else {
FusedElemwiseAndActComputeEx<DeviceContext, T,
paddle::operators::math::UnaryCompoundFunctor<
T, UnaryFunctor, BinaryFunctor>,
false /*KeepIntermediateValue*/,
true /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
}
} }
template <typename DeviceContext, typename T, typename BinaryGradFunctor, template <typename DeviceContext, typename T, typename BinaryGradFunctor,
typename UnaryFunctor, typename UnaryGradFunctor, typename UnaryFunctor, typename UnaryGradFunctor>
bool Recomputation = true>
static void RunBinaryCompoundGradFunctors( static void RunBinaryCompoundGradFunctors(
const framework::ExecutionContext &ctx, const framework::ExecutionContext &ctx,
const BinaryGradFunctor &binary_grad_functor, const BinaryGradFunctor &binary_grad_functor,
const UnaryFunctor &unary_functor, const UnaryFunctor &unary_functor,
const UnaryGradFunctor &unary_grad_functor, const framework::Tensor *in_x, const UnaryGradFunctor &unary_grad_functor, const framework::Tensor *in_x,
const framework::Tensor *in_y, const framework::Tensor *in_out, const framework::Tensor *in_y, const framework::Tensor *in_out,
const framework::Tensor *in_intermediate_out,
const framework::Tensor *in_out_grad, framework::Tensor *x_grad, const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
framework::Tensor *y_grad) { framework::Tensor *y_grad) {
// Z = Binary(X, Unary(Y))
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
using BinaryCompoundDxFunctor = using BinaryCompoundDxFunctor =
BinaryCompoundGradDxFunctor<T, BinaryGradFunctor, UnaryFunctor, paddle::operators::math::BinaryCompoundGradDxFunctor<T, BinaryGradFunctor,
Recomputation>; UnaryFunctor>;
using BinaryCompoundDyFunctor = using BinaryCompoundDyFunctor =
BinaryCompoundGradDyFunctor<T, BinaryGradFunctor, UnaryFunctor, paddle::operators::math::BinaryCompoundGradDyFunctor<
UnaryGradFunctor, Recomputation>; T, BinaryGradFunctor, UnaryFunctor, UnaryGradFunctor>;
ElemwiseGradCompute<DeviceContext, T, BinaryCompoundDxFunctor, if (in_intermediate_out) {
BinaryCompoundDyFunctor>( FusedElemwiseAndActGradComputeEx<
ctx, *in_x, *in_y, *in_out, *in_out_grad, axis, x_grad, y_grad, DeviceContext, T, BinaryCompoundDxFunctor, BinaryCompoundDyFunctor,
BinaryCompoundDxFunctor(binary_grad_functor, unary_functor), true /*UseIntermediateOut*/,
false /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
y_grad, BinaryCompoundDxFunctor(binary_grad_functor, unary_functor),
BinaryCompoundDyFunctor(binary_grad_functor, unary_functor,
unary_grad_functor));
} else {
FusedElemwiseAndActGradComputeEx<
DeviceContext, T, BinaryCompoundDxFunctor, BinaryCompoundDyFunctor,
false /*UseIntermediateOut*/,
false /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
y_grad, BinaryCompoundDxFunctor(binary_grad_functor, unary_functor),
BinaryCompoundDyFunctor(binary_grad_functor, unary_functor, BinaryCompoundDyFunctor(binary_grad_functor, unary_functor,
unary_grad_functor)); unary_grad_functor));
}
} }
template <typename DeviceContext, typename T, typename UnaryGradFunctor, template <typename DeviceContext, typename T, typename UnaryGradFunctor,
...@@ -233,143 +139,159 @@ static void RunUnaryCompoundGradFunctors( ...@@ -233,143 +139,159 @@ static void RunUnaryCompoundGradFunctors(
const BinaryFunctor &binary_functor, const BinaryFunctor &binary_functor,
const BinaryGradFunctor &binary_grad_functor, const framework::Tensor *in_x, const BinaryGradFunctor &binary_grad_functor, const framework::Tensor *in_x,
const framework::Tensor *in_y, const framework::Tensor *in_out, const framework::Tensor *in_y, const framework::Tensor *in_out,
const framework::Tensor *in_intermediate_out,
const framework::Tensor *in_out_grad, framework::Tensor *x_grad, const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
framework::Tensor *y_grad) { framework::Tensor *y_grad) {
// Z = Unary(Binary(X, Y))
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
using UnaryCompoundDxFunctor = using UnaryCompoundDxFunctor =
UnaryCompoundGradDxFunctor<T, UnaryGradFunctor, BinaryFunctor, paddle::operators::math::UnaryCompoundGradDxFunctor<
BinaryGradFunctor, Recomputation>; T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, Recomputation>;
using UnaryCompoundDyFunctor = using UnaryCompoundDyFunctor =
UnaryCompoundGradDyFunctor<T, UnaryGradFunctor, BinaryFunctor, paddle::operators::math::UnaryCompoundGradDyFunctor<
BinaryGradFunctor, Recomputation>; T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, Recomputation>;
ElemwiseGradCompute<DeviceContext, T, UnaryCompoundDxFunctor, if (in_intermediate_out) {
UnaryCompoundDyFunctor>( FusedElemwiseAndActGradComputeEx<
ctx, *in_x, *in_y, *in_out, *in_out_grad, axis, x_grad, y_grad, DeviceContext, T, UnaryCompoundDxFunctor, UnaryCompoundDyFunctor,
UnaryCompoundDxFunctor(unary_grad_functor, binary_functor, true /*UseIntermediateOut*/, true /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
y_grad, UnaryCompoundDxFunctor(unary_grad_functor, binary_functor,
binary_grad_functor),
UnaryCompoundDyFunctor(unary_grad_functor, binary_functor,
binary_grad_functor));
} else {
FusedElemwiseAndActGradComputeEx<DeviceContext, T, UnaryCompoundDxFunctor,
UnaryCompoundDyFunctor,
false /*UseIntermediateOut*/,
true /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
y_grad, UnaryCompoundDxFunctor(unary_grad_functor, binary_functor,
binary_grad_functor), binary_grad_functor),
UnaryCompoundDyFunctor(unary_grad_functor, binary_functor, UnaryCompoundDyFunctor(unary_grad_functor, binary_functor,
binary_grad_functor)); binary_grad_functor));
}
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
static void RunFunctors(const framework::ExecutionContext &ctx, static void RunFunctors(const framework::ExecutionContext &ctx,
const framework::Tensor *in_x, const framework::Tensor &in_x,
const framework::Tensor *in_y, const framework::Tensor &in_y,
framework::Tensor *output) { std::vector<framework::Tensor *> *outputs) {
auto &functors = ctx.Attr<std::vector<std::string>>("functor_list"); auto &functors = ctx.Attr<std::vector<std::string>>("functor_list");
auto funcs_str = functors[0] + "," + functors[1];
// TODO(zcd): The following code can be refined. // TODO(zcd): The following code can be refined.
auto funcs_str = functors[0] + "," + functors[1];
if (funcs_str == "elementwise_add,scale") { if (funcs_str == "elementwise_add,scale") {
// Z = Binary(X, Unary(Y)) // Z = Binary(X, Unary(Y))
T scale = static_cast<T>(ctx.Attr<float>("scale")); T scale = static_cast<T>(ctx.Attr<float>("scale"));
RunBinaryCompoundFunctor<DeviceContext, T, math::AddFunctor<T>, RunBinaryCompoundFunctor<DeviceContext, T,
math::ScaleFunctor<T>>( paddle::operators::math::AddFunctor<T>,
ctx, math::AddFunctor<T>(), math::ScaleFunctor<T>(scale), in_x, in_y, paddle::operators::math::ScaleFunctor<T>>(
output); ctx, paddle::operators::math::AddFunctor<T>(),
paddle::operators::math::ScaleFunctor<T>(scale), in_x, in_y, outputs);
} else if (funcs_str == "scale,elementwise_add") { } else if (funcs_str == "scale,elementwise_add") {
// Z = Unary(Binary(X, Y)) // Z = Unary(Binary(X, Y))
T scale = static_cast<T>(ctx.Attr<float>("scale")); T scale = static_cast<T>(ctx.Attr<float>("scale"));
RunUnaryCompoundFunctors<DeviceContext, T, math::ScaleFunctor<T>, RunUnaryCompoundFunctors<DeviceContext, T,
math::AddFunctor<T>>( paddle::operators::math::ScaleFunctor<T>,
ctx, math::ScaleFunctor<T>(scale), math::AddFunctor<T>(), in_x, in_y, paddle::operators::math::AddFunctor<T>>(
output); ctx, paddle::operators::math::ScaleFunctor<T>(scale),
paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
} else if (funcs_str == "elementwise_add,relu") { } else if (funcs_str == "elementwise_add,relu") {
RunBinaryCompoundFunctor<DeviceContext, T, math::AddFunctor<T>, // Z = Binary(X, Unary(Y))
math::ReluFunctor<T>>( RunBinaryCompoundFunctor<DeviceContext, T,
ctx, math::AddFunctor<T>(), math::ReluFunctor<T>(), in_x, in_y, output); paddle::operators::math::AddFunctor<T>,
paddle::operators::math::ReluFunctor<T>>(
ctx, paddle::operators::math::AddFunctor<T>(),
paddle::operators::math::ReluFunctor<T>(), in_x, in_y, outputs);
} else if (funcs_str == "relu,elementwise_add") { } else if (funcs_str == "relu,elementwise_add") {
RunUnaryCompoundFunctors<DeviceContext, T, math::ReluFunctor<T>, // Z = Unary(Binary(X, Y))
math::AddFunctor<T>>( RunUnaryCompoundFunctors<DeviceContext, T,
ctx, math::ReluFunctor<T>(), math::AddFunctor<T>(), in_x, in_y, output); paddle::operators::math::ReluFunctor<T>,
paddle::operators::math::AddFunctor<T>>(
ctx, paddle::operators::math::ReluFunctor<T>(),
paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
} else if (funcs_str == "elementwise_mul,scale") {
// Z = Binary(X, Unary(Y))
T scale = static_cast<T>(ctx.Attr<float>("scale"));
RunBinaryCompoundFunctor<DeviceContext, T,
paddle::operators::math::MulFunctor<T>,
paddle::operators::math::ScaleFunctor<T>>(
ctx, paddle::operators::math::MulFunctor<T>(),
paddle::operators::math::ScaleFunctor<T>(scale), in_x, in_y, outputs);
} else { } else {
PADDLE_THROW("%s has not been implemented.", funcs_str); PADDLE_THROW("%s has not been implemented.", funcs_str);
} }
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T, bool ReComputation>
static void RunGradFunctors(const framework::ExecutionContext &ctx, static void RunGradFunctors(const framework::ExecutionContext &ctx,
const framework::Tensor *in_x, const framework::Tensor *in_x,
const framework::Tensor *in_y, const framework::Tensor *in_y,
const framework::Tensor *in_out, const framework::Tensor *in_out,
const framework::Tensor *in_intermediate_out,
const framework::Tensor *in_out_grad, const framework::Tensor *in_out_grad,
framework::Tensor *x_grad, framework::Tensor *x_grad,
framework::Tensor *y_grad) { framework::Tensor *y_grad) {
auto &functors = ctx.Attr<std::vector<std::string>>("functor_list"); auto &functors = ctx.Attr<std::vector<std::string>>("functor_list");
auto funcs_str = functors[0] + "," + functors[1]; auto funcs_str = functors[0] + "," + functors[1];
bool recomputation = ctx.Attr<bool>("recomputation"); // TODO(zcd): The following code can be refined. for example, use registrition
// TODO(zcd): The following code can be refined. for example, use registion
if (funcs_str == "elementwise_add_grad,scale_grad") { if (funcs_str == "elementwise_add_grad,scale_grad") {
// The backward of Z = Binary(X, Unary(Y)) // The backward of Z = Binary(X, Unary(Y))
T scale = static_cast<T>(ctx.Attr<float>("scale")); T scale = static_cast<T>(ctx.Attr<float>("scale"));
if (recomputation) { RunBinaryCompoundGradFunctors<DeviceContext, T,
RunBinaryCompoundGradFunctors<DeviceContext, T, math::AddGradFunctor<T>, paddle::operators::math::AddGradFunctor<T>,
math::ScaleFunctor<T>, paddle::operators::math::ScaleFunctor<T>,
math::ScaleGradFunctor<T>, true>( paddle::operators::math::ScaleGradFunctor<T>>(
ctx, math::AddGradFunctor<T>(), math::ScaleFunctor<T>(scale), ctx, paddle::operators::math::AddGradFunctor<T>(),
math::ScaleGradFunctor<T>(scale), in_x, in_y, in_out, in_out_grad, paddle::operators::math::ScaleFunctor<T>(scale),
x_grad, y_grad); paddle::operators::math::ScaleGradFunctor<T>(scale), in_x, in_y, in_out,
} else { in_intermediate_out, in_out_grad, x_grad, y_grad);
RunBinaryCompoundGradFunctors<DeviceContext, T, math::AddGradFunctor<T>,
math::ScaleFunctor<T>,
math::ScaleGradFunctor<T>, false>(
ctx, math::AddGradFunctor<T>(), math::ScaleFunctor<T>(scale),
math::ScaleGradFunctor<T>(scale), in_x, in_y, in_out, in_out_grad,
x_grad, y_grad);
}
} else if (funcs_str == "scale_grad,elementwise_add_grad") { } else if (funcs_str == "scale_grad,elementwise_add_grad") {
// The backward of Z = Unary(Binary(X, Y)) // The backward of Z = Unary(Binary(X, Y))
T scale = static_cast<T>(ctx.Attr<float>("scale")); T scale = static_cast<T>(ctx.Attr<float>("scale"));
if (recomputation) { RunUnaryCompoundGradFunctors<DeviceContext, T,
RunUnaryCompoundGradFunctors<DeviceContext, T, math::ScaleGradFunctor<T>, paddle::operators::math::ScaleGradFunctor<T>,
math::AddFunctor<T>, math::AddGradFunctor<T>, paddle::operators::math::AddFunctor<T>,
true>(ctx, math::ScaleGradFunctor<T>(scale), paddle::operators::math::AddGradFunctor<T>,
math::AddFunctor<T>(), ReComputation /*Recomputation*/>(
math::AddGradFunctor<T>(), in_x, in_y, ctx, paddle::operators::math::ScaleGradFunctor<T>(scale),
in_out, in_out_grad, x_grad, y_grad); paddle::operators::math::AddFunctor<T>(),
} else { paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
RunUnaryCompoundGradFunctors<DeviceContext, T, math::ScaleGradFunctor<T>, in_intermediate_out, in_out_grad, x_grad, y_grad);
math::AddFunctor<T>, math::AddGradFunctor<T>,
false>(ctx, math::ScaleGradFunctor<T>(scale),
math::AddFunctor<T>(),
math::AddGradFunctor<T>(), in_x, in_y,
in_out, in_out_grad, x_grad, y_grad);
}
} else if (funcs_str == "elementwise_add_grad,relu_grad") { } else if (funcs_str == "elementwise_add_grad,relu_grad") {
if (recomputation) { RunBinaryCompoundGradFunctors<DeviceContext, T,
RunBinaryCompoundGradFunctors<DeviceContext, T, math::AddGradFunctor<T>, paddle::operators::math::AddGradFunctor<T>,
math::ReluFunctor<T>, paddle::operators::math::ReluFunctor<T>,
math::ReluGradFunctor<T>, true>( paddle::operators::math::ReluGradFunctor<T>>(
ctx, math::AddGradFunctor<T>(), math::ReluFunctor<T>(), ctx, paddle::operators::math::AddGradFunctor<T>(),
math::ReluGradFunctor<T>(), in_x, in_y, in_out, in_out_grad, x_grad, paddle::operators::math::ReluFunctor<T>(),
y_grad); paddle::operators::math::ReluGradFunctor<T>(), in_x, in_y, in_out,
} else { in_intermediate_out, in_out_grad, x_grad, y_grad);
RunBinaryCompoundGradFunctors<DeviceContext, T, math::AddGradFunctor<T>,
math::ReluFunctor<T>,
math::ReluGradFunctor<T>, false>(
ctx, math::AddGradFunctor<T>(), math::ReluFunctor<T>(),
math::ReluGradFunctor<T>(), in_x, in_y, in_out, in_out_grad, x_grad,
y_grad);
}
} else if (funcs_str == "relu_grad,elementwise_add_grad") { } else if (funcs_str == "relu_grad,elementwise_add_grad") {
if (recomputation) { RunUnaryCompoundGradFunctors<DeviceContext, T,
RunUnaryCompoundGradFunctors<DeviceContext, T, math::ReluGradFunctor<T>, paddle::operators::math::ReluGradFunctor<T>,
math::AddFunctor<T>, math::AddGradFunctor<T>, paddle::operators::math::AddFunctor<T>,
true>(ctx, math::ReluGradFunctor<T>(), paddle::operators::math::AddGradFunctor<T>,
math::AddFunctor<T>(), ReComputation /*Recomputation*/>(
math::AddGradFunctor<T>(), in_x, in_y, ctx, paddle::operators::math::ReluGradFunctor<T>(),
in_out, in_out_grad, x_grad, y_grad); paddle::operators::math::AddFunctor<T>(),
} else { paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
RunUnaryCompoundGradFunctors<DeviceContext, T, math::ReluGradFunctor<T>, in_intermediate_out, in_out_grad, x_grad, y_grad);
math::AddFunctor<T>, math::AddGradFunctor<T>, } else if (funcs_str == "elementwise_mul_grad,scale_grad") {
false>(ctx, math::ReluGradFunctor<T>(), // The backward of Z = Binary(X, Unary(Y))
math::AddFunctor<T>(), T scale = static_cast<T>(ctx.Attr<float>("scale"));
math::AddGradFunctor<T>(), in_x, in_y, RunBinaryCompoundGradFunctors<DeviceContext, T,
in_out, in_out_grad, x_grad, y_grad); paddle::operators::math::MulGradFunctor<T>,
} paddle::operators::math::ScaleFunctor<T>,
paddle::operators::math::ScaleGradFunctor<T>>(
ctx, paddle::operators::math::MulGradFunctor<T>(),
paddle::operators::math::ScaleFunctor<T>(scale),
paddle::operators::math::ScaleGradFunctor<T>(scale), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad);
} else { } else {
PADDLE_THROW("%s has not been implemented.", funcs_str); PADDLE_THROW("%s has not been implemented.", funcs_str);
} }
...@@ -385,11 +307,23 @@ class FusedElemwiseActivationKernel : public framework::OpKernel<T> { ...@@ -385,11 +307,23 @@ class FusedElemwiseActivationKernel : public framework::OpKernel<T> {
auto &in_y = detail::Ref(ctx.Input<framework::Tensor>("Y"), auto &in_y = detail::Ref(ctx.Input<framework::Tensor>("Y"),
"Cannot get input tensor %s, variable name = %s", "Cannot get input tensor %s, variable name = %s",
"Y", ctx.op().Input("Y")); "Y", ctx.op().Input("Y"));
auto &output = detail::Ref(ctx.Output<framework::Tensor>("Out"), PADDLE_ENFORCE(ctx.HasOutput("Out"), "The output(Out) should not be empty");
"Cannot get input tensor %s, variable name = %s", auto output = ctx.Output<framework::Tensor>("Out");
"Out", ctx.op().Output("Out"));
std::vector<framework::Tensor *> outputs;
outputs.emplace_back(output);
if (ctx.Attr<bool>("keep_intermediate_value")) {
PADDLE_ENFORCE(ctx.HasOutput("IntermediateOut"),
"The keep_intermediate_value is enable, so the "
"IntermediateOut should not be empty.");
auto intermediate_out = ctx.Output<framework::Tensor>("IntermediateOut");
outputs.emplace_back(intermediate_out);
} else {
outputs.emplace_back(nullptr);
}
RunFunctors<DeviceContext, T>(ctx, &in_x, &in_y, &output); RunFunctors<DeviceContext, T>(ctx, in_x, in_y, &outputs);
} }
}; };
...@@ -397,28 +331,66 @@ template <typename DeviceContext, typename T> ...@@ -397,28 +331,66 @@ template <typename DeviceContext, typename T>
class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> { class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto &in_x = detail::Ref(ctx.Input<framework::Tensor>("X"), auto x = ctx.Input<framework::Tensor>("X");
"Cannot get input tensor %s, variable name = %s", auto y = ctx.Input<framework::Tensor>("Y");
"X", ctx.op().Input("X"));
auto &in_y = detail::Ref(ctx.Input<framework::Tensor>("Y"), auto in_out = ctx.Input<framework::Tensor>("Out");
"Cannot get input tensor %s, variable name = %s", auto in_out_grad =
"Y", ctx.op().Input("Y")); ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto &in_out = detail::Ref(ctx.Input<framework::Tensor>("Out"),
"Cannot get input tensor %s, variable name = %s",
"Out", ctx.op().Input("Out"));
auto &in_out_grad =
detail::Ref(ctx.Input<framework::Tensor>(framework::GradVarName("Out")),
"Cannot get input tensor %s, variable name = %s",
framework::GradVarName("Out"),
ctx.op().Input(framework::GradVarName("Out")));
framework::Tensor *x_grad = framework::Tensor *x_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("X")); ctx.Output<framework::Tensor>(framework::GradVarName("X"));
framework::Tensor *y_grad = framework::Tensor *y_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("Y")); ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
RunGradFunctors<DeviceContext, T>(ctx, &in_x, &in_y, &in_out, &in_out_grad, PADDLE_ENFORCE(y != nullptr, "Input(Y) should not be nullptr.");
x_grad, y_grad);
if (ctx.Attr<bool>("recomputation")) {
PADDLE_ENFORCE(
x != nullptr,
"The recomputation is opened, so Input(X) should not be absent.");
} else {
PADDLE_ENFORCE(in_out != nullptr,
"The recomputation is disabled, so the Input('Out') "
"should not be empty.");
}
framework::Tensor *in_x;
auto functor_list = ctx.Attr<std::vector<std::string>>("functor_list");
// If functor_list contains elementwise_add, the backward doesn't use
// in_x, and in_outs.
if (x == nullptr) {
PADDLE_ENFORCE(functor_list[0] == "elementwise_add_grad" ||
functor_list[1] == "elementwise_add_grad",
"Only when the compoundfunctor contains "
"elementwise_add_grad, the 'X' could be absent.");
in_x = const_cast<framework::Tensor *>(in_out_grad);
in_out = const_cast<framework::Tensor *>(in_out_grad);
} else {
in_x = const_cast<framework::Tensor *>(x);
}
framework::Tensor *in_intermediate_out;
if (ctx.Attr<bool>("keep_intermediate_value")) {
in_intermediate_out = const_cast<framework::Tensor *>(
ctx.Input<framework::Tensor>("IntermediateOut"));
PADDLE_ENFORCE(in_intermediate_out != nullptr,
"The option of 'keep_intermediate_value' is opened, "
"so the number of 'Out' should be two.");
} else {
in_intermediate_out = nullptr;
}
if (ctx.Attr<bool>("recomputation")) {
RunGradFunctors<DeviceContext, T, true /*Recomputation*/>(
ctx, in_x, y, in_out, in_intermediate_out, in_out_grad, x_grad,
y_grad);
} else {
RunGradFunctors<DeviceContext, T, false /*Recomputation*/>(
ctx, in_x, y, in_out, in_intermediate_out, in_out_grad, x_grad,
y_grad);
}
} }
}; };
} // namespace operators } // namespace operators
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <unordered_set>
#include <vector>
namespace paddle {
namespace operators {
namespace math {
template <typename T, typename BinaryFunctor, typename UnaryFunctor>
struct BinaryCompoundFunctor {
BinaryCompoundFunctor(const BinaryFunctor func1, const UnaryFunctor func2)
: func1_(func1), func2_(func2) {}
// Z = BinaryFunctor(X, UnaryFunctor(Y))
inline HOSTDEVICE T GetOut(T x, T y) { return func1_(x, func2_(y)); }
inline HOSTDEVICE T GetOutUseIntermediateOut(T x, T intermediat_out) {
return func1_(x, intermediat_out);
}
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return func2_(y); }
BinaryFunctor func1_;
UnaryFunctor func2_;
};
template <typename T, typename UnaryFunctor, typename BinaryFunctor>
struct UnaryCompoundFunctor {
UnaryCompoundFunctor(const UnaryFunctor func1, const BinaryFunctor func2)
: func1_(func1), func2_(func2) {}
// Z = UnaryFunctor(BinaryFunctor(X, Y))
inline HOSTDEVICE T GetOut(T x, T y) { return func1_(func2_(x, y)); }
inline HOSTDEVICE T GetOutUseIntermediateOut(T x, T intermediat_out) {
return func1_(intermediat_out);
}
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return func2_(x, y); }
UnaryFunctor func1_;
BinaryFunctor func2_;
};
// FIXME(zcd): DBinaryFun and DUnaryFun have to method to get
// the dx, one is to use the 'out', and the other is not to use it.
// the former method will save the time of recomputing the
// 'out', but it must occupy the memory to store the 'out'.
// While the later method can avoid occupying this memory,
// but it must recompute the 'out'.
template <typename T, typename DBinaryFun, typename UnaryFun>
struct BinaryCompoundGradDxFunctor {
BinaryCompoundGradDxFunctor(const DBinaryFun &d_binary_fun,
const UnaryFun &unary_fun)
: d_binary_fun_(d_binary_fun), unary_fun_(unary_fun) {}
inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
return dout * d_binary_fun_.Dx(x, unary_fun_(y));
}
inline HOSTDEVICE T operator()(T x, T y, T intermediate_out, T out, T dout) {
return dout * d_binary_fun_.Dx(x, intermediate_out);
}
private:
DBinaryFun d_binary_fun_;
UnaryFun unary_fun_;
};
template <typename T, typename DBinaryFun, typename UnaryFun,
typename DUnaryFun>
struct BinaryCompoundGradDyFunctor {
BinaryCompoundGradDyFunctor(const DBinaryFun &d_binary_fun,
const UnaryFun &unary_fun,
const DUnaryFun &d_unary_fun)
: d_binary_fun_(d_binary_fun),
unary_fun_(unary_fun),
d_unary_fun_(d_unary_fun) {}
inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
return dout * d_binary_fun_.Dy(x, unary_fun_(y)) * d_unary_fun_(y);
}
inline HOSTDEVICE T operator()(T x, T y, T intermediate_out, T out, T dout) {
return dout * d_binary_fun_.Dy(x, intermediate_out) *
d_unary_fun_(y, intermediate_out);
}
private:
DBinaryFun d_binary_fun_;
UnaryFun unary_fun_;
DUnaryFun d_unary_fun_;
};
template <typename T, typename DUnaryFun, typename BinaryFun,
typename DBinaryFun, bool Recomputation = true>
struct UnaryCompoundGradDxFunctor {
UnaryCompoundGradDxFunctor(const DUnaryFun &d_unary_fun,
const BinaryFun &binary_fun,
const DBinaryFun &d_binary_fun)
: d_unary_fun_(d_unary_fun),
binary_fun_(binary_fun),
d_binary_fun_(d_binary_fun) {}
inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
T base;
if (Recomputation) {
base = dout * d_unary_fun_(binary_fun_(x, y));
} else {
base = dout * d_unary_fun_(binary_fun_(x, y), out);
}
return base * d_binary_fun_.Dx(x, y);
}
inline HOSTDEVICE T operator()(T x, T y, T intermediate_out, T out, T dout) {
T base;
if (Recomputation) {
base = dout * d_unary_fun_(intermediate_out);
} else {
base = dout * d_unary_fun_(intermediate_out, out);
}
return base * d_binary_fun_.Dx(x, y);
}
private:
DUnaryFun d_unary_fun_;
BinaryFun binary_fun_;
DBinaryFun d_binary_fun_;
};
template <typename T, typename DUnaryFun, typename BinaryFun,
typename DBinaryFun, bool Recomputation = true>
struct UnaryCompoundGradDyFunctor {
UnaryCompoundGradDyFunctor(const DUnaryFun &d_unary_fun,
const BinaryFun &binary_fun,
const DBinaryFun &d_binary_fun)
: d_unary_fun_(d_unary_fun),
binary_fun_(binary_fun),
d_binary_fun_(d_binary_fun) {}
inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
T base;
if (Recomputation) {
base = dout * d_unary_fun_(binary_fun_(x, y));
} else {
base = dout * d_unary_fun_(binary_fun_(x, y), out);
}
return base * d_binary_fun_.Dy(x, y);
}
inline HOSTDEVICE T operator()(T x, T y, T intermediate_out, T out, T dout) {
T base;
if (Recomputation) {
base = dout * d_unary_fun_(intermediate_out);
} else {
base = dout * d_unary_fun_(intermediate_out, out);
}
return base * d_binary_fun_.Dy(x, y);
}
private:
DUnaryFun d_unary_fun_;
BinaryFun binary_fun_;
DBinaryFun d_binary_fun_;
};
} // namespace math
} // namespace operators
} // namespace paddle
...@@ -18,6 +18,19 @@ namespace paddle { ...@@ -18,6 +18,19 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
// MulFunctor
template <typename T>
struct MulFunctor {
// out = x * y;
inline HOSTDEVICE T operator()(T x, T y) { return x * y; }
};
template <typename T>
struct MulGradFunctor {
inline HOSTDEVICE T Dx(T x, T y) { return y; }
inline HOSTDEVICE T Dy(T x, T y) { return x; }
};
// AddFunctor // AddFunctor
template <typename T> template <typename T>
struct AddFunctor { struct AddFunctor {
...@@ -27,9 +40,8 @@ struct AddFunctor { ...@@ -27,9 +40,8 @@ struct AddFunctor {
template <typename T> template <typename T>
struct AddGradFunctor { struct AddGradFunctor {
inline HOSTDEVICE T operator()(T x, T y) { return 1; } inline HOSTDEVICE T Dx(T x, T y) { return 1; }
inline HOSTDEVICE T Dy(T x, T y) { return 1; }
inline HOSTDEVICE T operator()(T x, T y, T out) const { return 1; }
}; };
template <typename T> template <typename T>
......
...@@ -47,7 +47,8 @@ def get_numeric_gradient(place, ...@@ -47,7 +47,8 @@ def get_numeric_gradient(place,
input_to_check, input_to_check,
output_names, output_names,
delta=0.005, delta=0.005,
in_place=False): in_place=False,
sum_outputs=None):
# FIXME: change this method by compile time concepts # FIXME: change this method by compile time concepts
set_input(scope, op, inputs, place) set_input(scope, op, inputs, place)
...@@ -58,9 +59,11 @@ def get_numeric_gradient(place, ...@@ -58,9 +59,11 @@ def get_numeric_gradient(place,
sum = [] sum = []
op.run(scope, place) op.run(scope, place)
for output_name in output_names: for output_name in output_names:
if sum_outputs and output_name not in sum_outputs:
continue
sum.append( sum.append(
np.array(scope.find_var(output_name).get_tensor()).mean()) np.array(scope.find_var(output_name).get_tensor()).mean())
return np.array(sum).mean() return np.array(sum).sum() / len(output_names)
tensor_to_check = scope.find_var(input_to_check).get_tensor() tensor_to_check = scope.find_var(input_to_check).get_tensor()
tensor_size = product(tensor_to_check.shape()) tensor_size = product(tensor_to_check.shape())
...@@ -396,13 +399,14 @@ class OpTest(unittest.TestCase): ...@@ -396,13 +399,14 @@ class OpTest(unittest.TestCase):
numeric_grad_delta=0.005, numeric_grad_delta=0.005,
in_place=False, in_place=False,
max_relative_error=0.005, max_relative_error=0.005,
user_defined_grads=None): user_defined_grads=None,
sum_outputs=None):
places = self._get_places() places = self._get_places()
for place in places: for place in places:
self.check_grad_with_place(place, inputs_to_check, output_names, self.check_grad_with_place(place, inputs_to_check, output_names,
no_grad_set, numeric_grad_delta, no_grad_set, numeric_grad_delta,
in_place, max_relative_error, in_place, max_relative_error,
user_defined_grads) user_defined_grads, sum_outputs)
def check_grad_with_place(self, def check_grad_with_place(self,
place, place,
...@@ -412,7 +416,8 @@ class OpTest(unittest.TestCase): ...@@ -412,7 +416,8 @@ class OpTest(unittest.TestCase):
numeric_grad_delta=0.005, numeric_grad_delta=0.005,
in_place=False, in_place=False,
max_relative_error=0.005, max_relative_error=0.005,
user_defined_grads=None): user_defined_grads=None,
sum_outputs=None):
self.scope = core.Scope() self.scope = core.Scope()
op_inputs = self.inputs if hasattr(self, "inputs") else dict() op_inputs = self.inputs if hasattr(self, "inputs") else dict()
op_outputs = self.outputs if hasattr(self, "outputs") else dict() op_outputs = self.outputs if hasattr(self, "outputs") else dict()
...@@ -435,7 +440,8 @@ class OpTest(unittest.TestCase): ...@@ -435,7 +440,8 @@ class OpTest(unittest.TestCase):
input_to_check, input_to_check,
output_names, output_names,
delta=numeric_grad_delta, delta=numeric_grad_delta,
in_place=in_place) for input_to_check in inputs_to_check in_place=in_place,
sum_outputs=sum_outputs) for input_to_check in inputs_to_check
] ]
analytic_grads = self._get_gradient(inputs_to_check, place, analytic_grads = self._get_gradient(inputs_to_check, place,
output_names, no_grad_set) output_names, no_grad_set)
......
...@@ -15,32 +15,31 @@ ...@@ -15,32 +15,31 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from functools import partial
import paddle.fluid.core as core import paddle.fluid.core as core
from op_test import OpTest from op_test import OpTest
# scale + add # TestFusedElementwiseActivationOp
# TestElementwiseAddOp # TestFusedElementwiseActivationOp_scalar
# TestFusedOperatorsOp_scalar # TestFusedElementwiseActivationOp_scalar2
# TestFusedOperatorsOp_scalar2 # TestFusedElementwiseActivationOp_Vector
# TestFusedOperatorsOp_Vector # TestFusedElementwiseActivationOp_broadcast_0
# TestFusedOperatorsOp_broadcast_0 # TestFusedElementwiseActivationOp_broadcast_1
# TestFusedOperatorsOp_broadcast_1 # TestFusedElementwiseActivationOp_broadcast_2
# TestFusedOperatorsOp_broadcast_2 # TestFusedElementwiseActivationOp_broadcast_3
# TestFusedOperatorsOp_broadcast_3 # TestFusedElementwiseActivationOp_broadcast_4
# TestFusedOperatorsOp_broadcast_4 # TestFusedElementwiseActivationOp_rowwise_add_0
# TestFusedOperatorsOp_rowwise_add_0 # TestFusedElementwiseActivationOp_rowwise_add_1
# TestFusedOperatorsOp_rowwise_add_1 # TestFusedElementwiseActivationOp_channelwise_add
# TestFusedOperatorsOp_channelwise_add
def create_test_class(test_case, callback, attrs):
class TestElementwiseAddOp(OpTest): class TestFusedElementwiseActivationOp_base(OpTest):
def setUp(self): def setUp(self):
self.op_type = "fused_elemwise_activation" self.op_type = "fused_elemwise_activation"
self.dtype = np.float32 self.dtype = np.float32
self.axis = -1 self.axis = -1
self.init_axis()
self.init_dtype()
self.init_input() self.init_input()
self.init_output() self.init_output()
self.init_attr() self.init_attr()
...@@ -49,772 +48,294 @@ class TestElementwiseAddOp(OpTest): ...@@ -49,772 +48,294 @@ class TestElementwiseAddOp(OpTest):
'X': OpTest.np_dtype_to_fluid_dtype(self.x), 'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y) 'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
} }
if self.attrs["keep_intermediate_value"]:
self.outputs = {
'Out': self.out,
"IntermediateOut": self.intermediate_out
}
else:
self.outputs = {'Out': self.out} self.outputs = {'Out': self.out}
def init_input(self): def init_input(self):
self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype)
self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype)
self.axis = -1
def init_output(self): def init_output(self):
self.scale = 0.1 self.x, self.y, self.intermediate_out, self.out = \
self.out = (self.x + self.y) * self.scale callback(self.x, self.y, self.x, self.y)
def init_attr(self): def init_attr(self):
self.attrs = { self.attrs = {'axis': self.axis, }
'axis': self.axis, for key in attrs.keys():
'scale': self.scale, self.attrs[key] = attrs[key]
'functor_list': ["scale", "elementwise_add"]
}
def init_dtype(self):
pass
def init_axis(self):
pass
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.005) if self.attrs["keep_intermediate_value"]:
self.check_grad(
['X', 'Y'], ['Out', 'IntermediateOut'],
max_relative_error=0.005,
sum_outputs=['Out'])
else:
self.check_grad(['X', 'Y'], ['Out'], max_relative_error=0.005)
def test_check_grad_ingore_x(self): def test_check_grad_ingore_x(self):
if self.attrs["keep_intermediate_value"]:
self.check_grad(
['Y'], ['Out', 'IntermediateOut'],
max_relative_error=0.005,
no_grad_set=set("X"),
sum_outputs=['Out'])
else:
self.check_grad( self.check_grad(
['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X")) ['Y'], ['Out'],
max_relative_error=0.005,
no_grad_set=set("X"))
def test_check_grad_ingore_y(self): def test_check_grad_ingore_y(self):
if self.attrs["keep_intermediate_value"]:
self.check_grad( self.check_grad(
['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')) ['X'], ['Out', 'IntermediateOut'],
max_relative_error=0.005,
no_grad_set=set("Y"),
sum_outputs=['Out'])
else:
self.check_grad(
['X'], ['Out'],
max_relative_error=0.005,
no_grad_set=set("Y"))
class TestFusedOperatorsOp_scalar(TestElementwiseAddOp): class TestFusedElementwiseActivationOp_scalar(
TestFusedElementwiseActivationOp_base):
def init_input(self): def init_input(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype) self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(1).astype(self.dtype) self.y = np.random.rand(1).astype(self.dtype)
def init_output(self): class TestFusedElementwiseActivationOp_scalar2(
self.scale = 0.1 TestFusedElementwiseActivationOp_base):
self.out = (self.x + self.y) * self.scale
class TestFusedOperatorsOp_scalar2(TestElementwiseAddOp):
def init_input(self): def init_input(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype) self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(1, 1).astype(self.dtype) self.y = np.random.rand(1, 1).astype(self.dtype)
def init_output(self): class TestFusedElementwiseActivationOp_Vector(
self.scale = 0.1 TestFusedElementwiseActivationOp_base):
self.out = (self.x + self.y) * self.scale
class TestFusedOperatorsOp_Vector(TestElementwiseAddOp):
def init_input(self): def init_input(self):
self.x = np.random.random((32, )).astype(self.dtype) self.x = np.random.random((32, )).astype(self.dtype)
self.y = np.random.random((32, )).astype(self.dtype) self.y = np.random.random((32, )).astype(self.dtype)
def init_output(self): class TestFusedElementwiseActivationOp_broadcast_0(
self.scale = 0.1 TestFusedElementwiseActivationOp_base):
self.out = (self.x + self.y) * self.scale
class TestFusedOperatorsOp_broadcast_0(TestElementwiseAddOp):
def init_input(self): def init_input(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype) self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(2).astype(self.dtype) self.y = np.random.rand(2).astype(self.dtype)
def init_axis(self):
self.axis = 0 self.axis = 0
def init_output(self): def init_output(self):
self.scale = 0.1 self.x, self.y, self.intermediate_out, self.out = \
self.out = (self.x + self.y.reshape(2, 1, 1)) * self.scale callback(self.x, self.y, self.x, self.y.reshape(2, 1, 1))
class TestFusedElementwiseActivationOp_broadcast_1(
class TestFusedOperatorsOp_broadcast_1(TestElementwiseAddOp): TestFusedElementwiseActivationOp_base):
def init_input(self): def init_input(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype) self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(3).astype(self.dtype) self.y = np.random.rand(3).astype(self.dtype)
def init_axis(self):
self.axis = 1 self.axis = 1
def init_output(self): def init_output(self):
self.scale = 0.1 self.x, self.y, self.intermediate_out, self.out = \
self.out = (self.x + self.y.reshape(1, 3, 1)) * self.scale callback(self.x, self.y, self.x, self.y.reshape(1, 3, 1))
class TestFusedOperatorsOp_broadcast_2(TestElementwiseAddOp): class TestFusedElementwiseActivationOp_broadcast_2(
TestFusedElementwiseActivationOp_base):
def init_input(self): def init_input(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype) self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(4).astype(self.dtype) self.y = np.random.rand(4).astype(self.dtype)
def init_output(self): def init_output(self):
self.scale = 0.1 self.x, self.y, self.intermediate_out, self.out = \
self.out = (self.x + self.y.reshape(1, 1, 4)) * self.scale callback(self.x, self.y, self.x, self.y.reshape(1, 1, 4))
class TestFusedElementwiseActivationOp_broadcast_3(
class TestFusedOperatorsOp_broadcast_3(TestElementwiseAddOp): TestFusedElementwiseActivationOp_base):
def init_input(self): def init_input(self):
self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype) self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype)
self.y = np.random.rand(3, 4).astype(self.dtype) self.y = np.random.rand(3, 4).astype(self.dtype)
def init_axis(self):
self.axis = 1 self.axis = 1
def init_output(self): def init_output(self):
self.scale = 0.1 self.x, self.y, self.intermediate_out, self.out = \
self.out = (self.x + self.y.reshape(1, 3, 4, 1)) * self.scale callback(self.x, self.y, self.x, self.y.reshape(1, 3, 4, 1))
class TestFusedElementwiseActivationOp_broadcast_4(
class TestFusedOperatorsOp_broadcast_4(TestElementwiseAddOp): TestFusedElementwiseActivationOp_base):
def init_input(self): def init_input(self):
self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype) self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype)
self.y = np.random.rand(2, 1).astype(self.dtype) self.y = np.random.rand(2, 1).astype(self.dtype)
def init_axis(self):
self.axis = 0 self.axis = 0
def init_output(self): def init_output(self):
self.scale = 0.1 self.x, self.y, self.intermediate_out, self.out = \
self.out = (self.x + self.y.reshape(2, 1, 1, 1)) * self.scale callback(self.x, self.y, self.x, self.y.reshape(2, 1, 1, 1))
class TestFusedOperatorsOp_rowwise_add_0(TestElementwiseAddOp): class TestFusedElementwiseActivationOp_rowwise_add_0(
TestFusedElementwiseActivationOp_base):
def init_input(self): def init_input(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype) self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(3, 4).astype(self.dtype) self.y = np.random.rand(3, 4).astype(self.dtype)
def init_axis(self):
self.axis = 1 self.axis = 1
def init_output(self): def init_output(self):
self.scale = 0.1 self.x, self.y, self.intermediate_out, self.out = \
self.out = (self.x + self.y.reshape(1, 3, 4)) * self.scale callback(self.x, self.y, self.x, self.y.reshape(1, 3, 4))
class TestFusedOperatorsOp_rowwise_add_1(TestElementwiseAddOp): class TestFusedElementwiseActivationOp_rowwise_add_1(
TestFusedElementwiseActivationOp_base):
def init_input(self): def init_input(self):
self.x = np.random.rand(2, 1).astype(self.dtype) self.x = np.random.rand(2, 1).astype(self.dtype)
self.y = np.random.rand(1).astype(self.dtype) self.y = np.random.rand(1).astype(self.dtype)
def init_axis(self):
self.axis = 1 self.axis = 1
def init_output(self): def init_output(self):
self.scale = 0.1 self.x, self.y, self.intermediate_out, self.out = \
self.out = (self.x + self.y.reshape(1, 1)) * self.scale callback(self.x, self.y, self.x, self.y.reshape(1, 1))
class TestFusedOperatorsOp_channelwise_add(TestElementwiseAddOp): class TestFusedElementwiseActivationOp_channelwise_add(
TestFusedElementwiseActivationOp_base):
def init_input(self): def init_input(self):
self.x = np.random.rand(3, 20, 20).astype(self.dtype) self.x = np.random.rand(3, 20, 20).astype(self.dtype)
self.y = np.random.rand(3, 1, 1).astype(self.dtype) self.y = np.random.rand(3, 1, 1).astype(self.dtype)
def init_axis(self): TestFusedElementwiseActivationOp_base.__name__ = test_case + "_base"
self.axis = -1 TestFusedElementwiseActivationOp_scalar.__name__ = test_case + "_scalar"
TestFusedElementwiseActivationOp_scalar2.__name__ = test_case + "_scalar2"
def init_output(self): TestFusedElementwiseActivationOp_Vector.__name__ = test_case + "_Vector"
self.scale = 0.1 TestFusedElementwiseActivationOp_broadcast_0.__name__ = test_case + "_broadcast_0"
self.out = (self.x + self.y) * self.scale TestFusedElementwiseActivationOp_broadcast_1.__name__ = test_case + "_broadcast_1"
TestFusedElementwiseActivationOp_broadcast_2.__name__ = test_case + "_broadcast_2"
TestFusedElementwiseActivationOp_broadcast_3.__name__ = test_case + "_broadcast_3"
# add + scale TestFusedElementwiseActivationOp_broadcast_4.__name__ = test_case + "_broadcast_4"
# TestElementwiseAddOp_f_add_scale TestFusedElementwiseActivationOp_rowwise_add_0.__name__ = test_case + "_rowwise_add_0"
# TestFusedOperatorsOp_scalar_f_add_scale TestFusedElementwiseActivationOp_rowwise_add_1.__name__ = test_case + "_rowwise_add_1"
# TestFusedOperatorsOp_scalar2_f_add_scale TestFusedElementwiseActivationOp_channelwise_add.__name__ = test_case + "_channelwise_add"
# TestFusedOperatorsOp_Vector_f_add_scale
# TestFusedOperatorsOp_broadcast_0_f_add_scale globals()[test_case + "_base"] = TestFusedElementwiseActivationOp_base
# TestFusedOperatorsOp_broadcast_1_f_add_scale globals()[test_case + "_scalar"] = TestFusedElementwiseActivationOp_scalar
# TestFusedOperatorsOp_broadcast_2_f_add_scale globals()[test_case + "_scalar2"] = TestFusedElementwiseActivationOp_scalar2
# TestFusedOperatorsOp_broadcast_3_f_add_scale globals()[test_case + "_Vector"] = TestFusedElementwiseActivationOp_Vector
# TestFusedOperatorsOp_broadcast_4_f_add_scale globals()[test_case +
# TestFusedOperatorsOp_rowwise_add_0_f_add_scale "_broadcast_0"] = TestFusedElementwiseActivationOp_broadcast_0
# TestFusedOperatorsOp_rowwise_add_1_f_add_scale globals()[test_case +
# TestFusedOperatorsOp_channelwise_add_f_add_scale "_broadcast_1"] = TestFusedElementwiseActivationOp_broadcast_1
globals()[test_case +
"_broadcast_2"] = TestFusedElementwiseActivationOp_broadcast_2
class TestFusedOperatorsOp_f_add_scale(TestElementwiseAddOp): globals()[test_case +
def init_output(self): "_broadcast_3"] = TestFusedElementwiseActivationOp_broadcast_3
self.scale = 0.1 globals()[test_case +
self.out = self.x + self.y * self.scale "_broadcast_4"] = TestFusedElementwiseActivationOp_broadcast_4
globals()[test_case +
def init_attr(self): "_rowwise_add_0"] = TestFusedElementwiseActivationOp_rowwise_add_0
self.attrs = { globals()[test_case +
'axis': self.axis, "_rowwise_add_1"] = TestFusedElementwiseActivationOp_rowwise_add_1
'scale': self.scale, globals(
'functor_list': ["elementwise_add", "scale"] )[test_case +
} "_channelwise_add"] = TestFusedElementwiseActivationOp_channelwise_add
class TestFusedOperatorsOp_scalar_f_add_scale(TestFusedOperatorsOp_scalar): def scale_add_func(x, y, x_bcast, y_bcast, scale, mode=0):
def init_output(self): if mode == 0:
self.scale = 0.1 return x, y, (x_bcast + y_bcast), (x_bcast + y_bcast) * scale
self.out = self.x + self.y * self.scale else:
return y, x, (x_bcast + y_bcast), (x_bcast + y_bcast) * scale
def init_attr(self):
self.attrs = {
'axis': self.axis, def add_scale_func(x, y, x_bcast, y_bcast, scale, mode=0):
'scale': self.scale, if mode == 0:
'functor_list': ["elementwise_add", "scale"] return x, y, y * scale, x_bcast + y_bcast * scale
} else:
return y, x, x * scale, y_bcast + x_bcast * scale
class TestFusedOperatorsOp_scalar2_f_add_scale(TestFusedOperatorsOp_scalar2):
def init_output(self): def add_relu_func(x, y, x_bcast, y_bcast, mode=0):
self.scale = 0.1
self.out = self.x + self.y * self.scale
def init_attr(self):
self.attrs = {
'axis': self.axis,
'scale': self.scale,
'functor_list': ["elementwise_add", "scale"]
}
class TestFusedOperatorsOp_Vector_f_add_scale(TestFusedOperatorsOp_Vector):
def init_output(self):
self.scale = 0.1
self.out = self.x + self.y * self.scale
def init_attr(self):
self.attrs = {
'axis': self.axis,
'scale': self.scale,
'functor_list': ["elementwise_add", "scale"]
}
class TestFusedOperatorsOp_broadcast_0_f_add_scale(
TestFusedOperatorsOp_broadcast_0):
def init_axis(self):
self.axis = 0
def init_output(self):
self.scale = 0.1
self.out = self.x + self.y.reshape(2, 1, 1) * self.scale
def init_attr(self):
self.attrs = {
'axis': self.axis,
'scale': self.scale,
'functor_list': ["elementwise_add", "scale"]
}
class TestFusedOperatorsOp_broadcast_1_f_add_scale(
TestFusedOperatorsOp_broadcast_1):
def init_axis(self):
self.axis = 1
def init_output(self):
self.scale = 0.1
self.out = self.x + self.y.reshape(1, 3, 1) * self.scale
def init_attr(self):
self.attrs = {
'axis': self.axis,
'scale': self.scale,
'functor_list': ["elementwise_add", "scale"]
}
class TestFusedOperatorsOp_broadcast_2_f_add_scale(
TestFusedOperatorsOp_broadcast_2):
def init_output(self):
self.scale = 0.1
self.out = self.x + self.y.reshape(1, 1, 4) * self.scale
def init_attr(self):
self.attrs = {
'axis': self.axis,
'scale': self.scale,
'functor_list': ["elementwise_add", "scale"]
}
class TestFusedOperatorsOp_broadcast_3_f_add_scale(
TestFusedOperatorsOp_broadcast_3):
def init_axis(self):
self.axis = 1
def init_output(self):
self.scale = 0.1
self.out = self.x + self.y.reshape(1, 3, 4, 1) * self.scale
def init_attr(self):
self.attrs = {
'axis': self.axis,
'scale': self.scale,
'functor_list': ["elementwise_add", "scale"]
}
class TestFusedOperatorsOp_broadcast_4_f_add_scale(
TestFusedOperatorsOp_broadcast_4):
def init_axis(self):
self.axis = 0
def init_output(self):
self.scale = 0.2
self.out = self.x + self.y.reshape(2, 1, 1, 1) * self.scale
def init_attr(self):
self.attrs = {
'axis': self.axis,
'scale': self.scale,
'functor_list': ["elementwise_add", "scale"]
}
class TestFusedOperatorsOp_rowwise_add_0_f_add_scale(
TestFusedOperatorsOp_rowwise_add_0):
def init_axis(self):
self.axis = 1
def init_output(self):
self.scale = 0.1
self.out = self.x + self.y.reshape(1, 3, 4) * self.scale
def init_attr(self):
self.attrs = {
'axis': self.axis,
'scale': self.scale,
'functor_list': ["elementwise_add", "scale"]
}
class TestFusedOperatorsOp_rowwise_add_1_f_add_scale(
TestFusedOperatorsOp_rowwise_add_1):
def init_axis(self):
self.axis = 1
def init_output(self):
self.scale = 0.2
self.out = self.x + self.y.reshape(1, 1) * self.scale
def init_attr(self):
self.attrs = {
'axis': self.axis,
'scale': self.scale,
'functor_list': ["elementwise_add", "scale"]
}
class TestFusedOperatorsOp_channelwise_add_f_add_scale(
TestFusedOperatorsOp_channelwise_add):
def init_axis(self):
self.axis = -1
def init_output(self):
self.scale = 0.2
self.out = self.x + self.y * self.scale
def init_attr(self):
self.attrs = {
'axis': self.axis,
'scale': self.scale,
'functor_list': ["elementwise_add", "scale"]
}
# add + relu
# TestElementwiseAddOp_f_add_relu
# TestFusedOperatorsOp_scalar_f_add_relu
# TestFusedOperatorsOp_scalar2_f_add_relu
# TestFusedOperatorsOp_Vector_f_add_relu
# TestFusedOperatorsOp_broadcast_0_f_add_relu
# TestFusedOperatorsOp_broadcast_1_f_add_relu
# TestFusedOperatorsOp_broadcast_2_f_add_relu
# TestFusedOperatorsOp_broadcast_3_f_add_relu
# TestFusedOperatorsOp_broadcast_4_f_add_relu
# TestFusedOperatorsOp_rowwise_add_0_f_add_relu
# TestFusedOperatorsOp_rowwise_add_1_f_add_relu
# TestFusedOperatorsOp_channelwise_add_f_add_relu
class TestFusedOperatorsOp_f_add_relu(TestElementwiseAddOp):
def init_output(self):
# Copy from test_activation_op.py
# Because we set delta = 0.005 in calculating numeric gradient,
# if x is too small, such as 0.002, x_neg will be -0.003
# x_pos will be 0.007, so the numeric gradient is inaccurate.
# we should avoid this
self.y[np.abs(self.y) < 0.005] = 0.02
self.out = self.x + np.maximum(self.y, 0)
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["elementwise_add", "relu"]
}
class TestFusedOperatorsOp_scalar_f_add_relu(TestFusedOperatorsOp_scalar):
def init_output(self):
self.y[np.abs(self.y) < 0.005] = 0.02
self.out = self.x + np.maximum(self.y, 0)
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["elementwise_add", "relu"]
}
class TestFusedOperatorsOp_scalar2_f_add_relu(TestFusedOperatorsOp_scalar2):
def init_output(self):
self.y[np.abs(self.y) < 0.005] = 0.02
self.out = self.x + np.maximum(self.y, 0)
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["elementwise_add", "relu"]
}
class TestFusedOperatorsOp_Vector_f_add_relu(TestFusedOperatorsOp_Vector):
def init_output(self):
self.y[np.abs(self.y) < 0.005] = 0.02
self.out = self.x + np.maximum(self.y, 0)
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["elementwise_add", "relu"]
}
class TestFusedOperatorsOp_broadcast_0_f_add_relu(
TestFusedOperatorsOp_broadcast_0):
def init_axis(self):
self.axis = 0
def init_output(self):
self.y[np.abs(self.y) < 0.005] = 0.02
self.out = self.x + np.maximum(self.y.reshape(2, 1, 1), 0)
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["elementwise_add", "relu"]
}
class TestFusedOperatorsOp_broadcast_1_f_add_relu(
TestFusedOperatorsOp_broadcast_1):
def init_axis(self):
self.axis = 1
def init_output(self):
self.y[np.abs(self.y) < 0.005] = 0.02
self.out = self.x + np.maximum(self.y.reshape(1, 3, 1), 0)
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["elementwise_add", "relu"]
}
class TestFusedOperatorsOp_broadcast_2_f_add_relu(
TestFusedOperatorsOp_broadcast_2):
def init_output(self):
self.y[np.abs(self.y) < 0.005] = 0.02
self.out = self.x + np.maximum(self.y.reshape(1, 1, 4), 0)
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["elementwise_add", "relu"]
}
class TestFusedOperatorsOp_broadcast_3_f_add_relu(
TestFusedOperatorsOp_broadcast_3):
def init_axis(self):
self.axis = 1
def init_output(self):
self.y[np.abs(self.y) < 0.005] = 0.02
self.out = self.x + np.maximum(self.y.reshape(1, 3, 4, 1), 0)
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["elementwise_add", "relu"]
}
class TestFusedOperatorsOp_broadcast_4_f_add_relu(
TestFusedOperatorsOp_broadcast_4):
def init_axis(self):
self.axis = 0
def init_output(self):
self.y[np.abs(self.y) < 0.005] = 0.02
self.out = self.x + np.maximum(self.y.reshape(2, 1, 1, 1), 0)
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["elementwise_add", "relu"]
}
class TestFusedOperatorsOp_rowwise_add_0_f_add_relu(
TestFusedOperatorsOp_rowwise_add_0):
def init_axis(self):
self.axis = 1
def init_output(self):
self.y[np.abs(self.y) < 0.005] = 0.02
self.out = self.x + np.maximum(self.y.reshape(1, 3, 4), 0)
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["elementwise_add", "relu"]
}
class TestFusedOperatorsOp_rowwise_add_1_f_add_relu(
TestFusedOperatorsOp_rowwise_add_1):
def init_axis(self):
self.axis = 1
def init_output(self):
self.y[np.abs(self.y) < 0.005] = 0.02
self.out = self.x + np.maximum(self.y.reshape(1, 1), 0)
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["elementwise_add", "relu"]
}
class TestFusedOperatorsOp_channelwise_add_f_add_relu(
TestFusedOperatorsOp_channelwise_add):
def init_axis(self):
self.axis = -1
def init_output(self):
self.y[np.abs(self.y) < 0.005] = 0.02
self.out = self.x + np.maximum(self.y, 0)
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["elementwise_add", "relu"]
}
# relu + add
# TestElementwiseAddOp_f_relu_add
# TestFusedOperatorsOp_scalar_f_relu_add
# TestFusedOperatorsOp_scalar2_f_relu_add
# TestFusedOperatorsOp_Vector_f_relu_add
# TestFusedOperatorsOp_broadcast_0_f_relu_add
# TestFusedOperatorsOp_broadcast_1_f_relu_add
# TestFusedOperatorsOp_broadcast_2_f_relu_add
# TestFusedOperatorsOp_broadcast_3_f_relu_add
# TestFusedOperatorsOp_broadcast_4_f_relu_add
# TestFusedOperatorsOp_rowwise_add_0_f_relu_add
# TestFusedOperatorsOp_rowwise_add_1_f_relu_add
# TestFusedOperatorsOp_channelwise_add_f_relu_add
class TestFusedOperatorsOp_f_relu_add(TestElementwiseAddOp):
def init_output(self):
# Copy from test_activation_op.py # Copy from test_activation_op.py
# Because we set delta = 0.005 in calculating numeric gradient, # Because we set delta = 0.005 in calculating numeric gradient,
# if x is too small, such as 0.002, x_neg will be -0.003 # if x is too small, such as 0.002, x_neg will be -0.003
# x_pos will be 0.007, so the numeric gradient is inaccurate. # x_pos will be 0.007, so the numeric gradient is inaccurate.
# we should avoid this # we should avoid this
self.out = self.x + self.y if mode == 0:
self.out = np.maximum(self.out, 0) y[np.abs(y) < 0.005] = 0.02
self.out[np.abs(self.out) < 0.005] = 0.02 y_bcast[np.abs(y_bcast) < 0.005] = 0.02
return x, y, np.maximum(y, 0), x_bcast + np.maximum(y_bcast, 0)
def init_attr(self): else:
self.attrs = { x[np.abs(x) < 0.005] = 0.02
'axis': self.axis, x_bcast[np.abs(x_bcast) < 0.005] = 0.02
'functor_list': ["relu", "elementwise_add"] return y, x, np.maximum(x, 0), y_bcast + np.maximum(x_bcast, 0)
}
def relu_add_func(x, y, x_bcast, y_bcast, mode=0):
class TestFusedOperatorsOp_scalar_f_relu_add(TestFusedOperatorsOp_scalar): intermediate_out = x_bcast + y_bcast
def init_output(self): out = np.maximum(intermediate_out, 0)
self.out = self.x + self.y out[np.abs(out) < 0.005] = 0.02
self.out = np.maximum(self.out, 0) if mode == 0:
self.out[np.abs(self.out) < 0.005] = 0.02 return x, y, intermediate_out, out
else:
def init_attr(self): return y, x, intermediate_out, out
self.attrs = {
'axis': self.axis,
'functor_list': ["relu", "elementwise_add"] def mul_scale_func(x, y, x_bcast, y_bcast, scale, mode=0):
} if mode == 0:
return x, y, y * scale, x_bcast * (y_bcast * scale)
else:
class TestFusedOperatorsOp_scalar2_f_relu_add(TestFusedOperatorsOp_scalar2): return y, x, x * scale, y_bcast * (x_bcast * scale)
def init_output(self):
self.out = self.x + self.y
self.out = np.maximum(self.out, 0) scale = 0.1
self.out[np.abs(self.out) < 0.005] = 0.02 scale_add_func = partial(scale_add_func, scale=scale)
add_scale_func = partial(add_scale_func, scale=scale)
def init_attr(self): mul_scale_func = partial(mul_scale_func, scale=scale)
self.attrs = {
'axis': self.axis, for mode in {0, 1}:
'functor_list': ["relu", "elementwise_add"] scale_add_func = partial(scale_add_func, mode=mode)
} add_scale_func = partial(add_scale_func, mode=mode)
mul_scale_func = partial(mul_scale_func, mode=mode)
relu_add_func = partial(relu_add_func, mode=mode)
class TestFusedOperatorsOp_Vector_f_relu_add(TestFusedOperatorsOp_Vector): add_relu_func = partial(add_relu_func, mode=mode)
def init_output(self):
self.out = self.x + self.y for recomputation in {True, False}:
self.out = np.maximum(self.out, 0) for keep_intermediate_value in {True, False}:
self.out[np.abs(self.out) < 0.005] = 0.02 suffix = ("_keep_intermediate_value" if keep_intermediate_value else "") \
+ ("_recomputation" if recomputation else "") \
def init_attr(self): + ("_mode_"+ str(mode))
self.attrs = { create_test_class('scale_add' + suffix, scale_add_func, {
'axis': self.axis, 'scale': scale,
'functor_list': ["relu", "elementwise_add"] 'functor_list': ["scale", "elementwise_add"],
} 'keep_intermediate_value': keep_intermediate_value,
'recomputation': recomputation
})
class TestFusedOperatorsOp_broadcast_0_f_relu_add( create_test_class('add_scale' + suffix, add_scale_func, {
TestFusedOperatorsOp_broadcast_0): 'scale': scale,
def init_axis(self): 'functor_list': ["elementwise_add", "scale"],
self.axis = 0 'keep_intermediate_value': keep_intermediate_value,
'recomputation': recomputation
def init_output(self): })
self.out = self.x + self.y.reshape(2, 1, 1) create_test_class('add_relu' + suffix, add_relu_func, {
self.out = np.maximum(self.out, 0) 'functor_list': ["elementwise_add", "relu"],
self.out[np.abs(self.out) < 0.005] = 0.02 'keep_intermediate_value': keep_intermediate_value,
'recomputation': recomputation
def init_attr(self): })
self.attrs = { create_test_class('relu_add' + suffix, relu_add_func, {
'axis': self.axis, 'functor_list': ["relu", "elementwise_add"],
'functor_list': ["relu", "elementwise_add"] 'keep_intermediate_value': keep_intermediate_value,
} 'recomputation': recomputation
})
create_test_class('mul_scale' + suffix, mul_scale_func, {
class TestFusedOperatorsOp_broadcast_1_f_relu_add( 'scale': scale,
TestFusedOperatorsOp_broadcast_1): 'functor_list': ["elementwise_mul", "scale"],
def init_axis(self): 'keep_intermediate_value': keep_intermediate_value,
self.axis = 1 'recomputation': recomputation
})
def init_output(self):
self.out = self.x + self.y.reshape(1, 3, 1)
self.out = np.maximum(self.out, 0)
self.out[np.abs(self.out) < 0.005] = 0.02
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["relu", "elementwise_add"]
}
class TestFusedOperatorsOp_broadcast_2_f_relu_add(
TestFusedOperatorsOp_broadcast_2):
def init_output(self):
self.out = self.x + self.y.reshape(1, 1, 4)
self.out = np.maximum(self.out, 0)
self.out[np.abs(self.out) < 0.005] = 0.02
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["relu", "elementwise_add"]
}
class TestFusedOperatorsOp_broadcast_3_f_relu_add(
TestFusedOperatorsOp_broadcast_3):
def init_axis(self):
self.axis = 1
def init_output(self):
self.out = self.x + self.y.reshape(1, 3, 4, 1)
self.out = np.maximum(self.out, 0)
self.out[np.abs(self.out) < 0.005] = 0.02
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["relu", "elementwise_add"]
}
class TestFusedOperatorsOp_broadcast_4_f_relu_add(
TestFusedOperatorsOp_broadcast_4):
def init_axis(self):
self.axis = 0
def init_output(self):
self.out = self.x + self.y.reshape(2, 1, 1, 1)
self.out = np.maximum(self.out, 0)
self.out[np.abs(self.out) < 0.005] = 0.02
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["relu", "elementwise_add"]
}
class TestFusedOperatorsOp_rowwise_add_0_f_relu_add(
TestFusedOperatorsOp_rowwise_add_0):
def init_axis(self):
self.axis = 1
def init_output(self):
self.out = self.x + self.y.reshape(1, 3, 4)
self.out = np.maximum(self.out, 0)
self.out[np.abs(self.out) < 0.005] = 0.02
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["relu", "elementwise_add"]
}
class TestFusedOperatorsOp_rowwise_add_1_f_relu_add(
TestFusedOperatorsOp_rowwise_add_1):
def init_axis(self):
self.axis = 1
def init_output(self):
self.out = self.x + self.y.reshape(1, 1)
self.out = np.maximum(self.out, 0)
self.out[np.abs(self.out) < 0.005] = 0.02
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["relu", "elementwise_add"]
}
class TestFusedOperatorsOp_channelwise_add_f_relu_add(
TestFusedOperatorsOp_channelwise_add):
def init_axis(self):
self.axis = -1
def init_output(self):
self.out = self.x + self.y
self.out = np.maximum(self.out, 0)
self.out[np.abs(self.out) < 0.005] = 0.02
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["relu", "elementwise_add"]
}
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册