diff --git a/paddle/fluid/operators/elementwise_op_function.h b/paddle/fluid/operators/elementwise_op_function.h index f90dcdc156590b776f817a4933d5a9b45868ba98..d5b9b2dac085e7abc31ef243be82eaa815d387ba 100644 --- a/paddle/fluid/operators/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise_op_function.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once + #include #include #include @@ -46,9 +47,9 @@ namespace operators { * pre=2*3, n=4*5, post=1 * x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1) */ -inline void get_mid_dims(const framework::DDim& x_dims, - const framework::DDim& y_dims, const int axis, - int* pre, int* n, int* post) { +inline void get_mid_dims(const framework::DDim &x_dims, + const framework::DDim &y_dims, const int axis, + int *pre, int *n, int *post) { *pre = 1; *n = 1; *post = 1; @@ -68,7 +69,7 @@ inline void get_mid_dims(const framework::DDim& x_dims, } inline framework::DDim trim_trailing_singular_dims( - const framework::DDim& dims) { + const framework::DDim &dims) { // Remove trailing dimensions of size 1 for y auto actual_dims_size = dims.size(); for (; actual_dims_size != 0; --actual_dims_size) { @@ -89,15 +90,16 @@ inline framework::DDim trim_trailing_singular_dims( template class RowwiseTransformIterator; + template class MidWiseTransformIterator; template class RowwiseTransformIterator { public: - RowwiseTransformIterator(const T* ptr, int n) : ptr_(ptr), i_(0), n_(n) {} + RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {} - RowwiseTransformIterator& operator++() { + RowwiseTransformIterator &operator++() { ++i_; if (UNLIKELY(i_ == n_)) { i_ = 0; @@ -105,20 +107,20 @@ class RowwiseTransformIterator { return *this; } - bool operator==(const RowwiseTransformIterator& - rhs) const { + bool operator==(const RowwiseTransformIterator + &rhs) const { return (ptr_ + i_) == &(*rhs); } - bool operator!=(const RowwiseTransformIterator& - rhs) const { + bool operator!=(const RowwiseTransformIterator + &rhs) const { return (ptr_ + i_) != &(*rhs); } - const T& operator*() { return ptr_[i_]; } + const T &operator*() { return ptr_[i_]; } private: - const T* ptr_; + const T *ptr_; int i_; int64_t n_; }; @@ -126,10 +128,10 @@ class RowwiseTransformIterator { template class MidWiseTransformIterator { public: - MidWiseTransformIterator(const T* ptr, int n, int post) + MidWiseTransformIterator(const T *ptr, int n, int post) : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {} - MidWiseTransformIterator& operator++() { + MidWiseTransformIterator &operator++() { ++j_; if (UNLIKELY(j_ == post_)) { ++i_; @@ -141,20 +143,20 @@ class MidWiseTransformIterator { return *this; } - bool operator==(const MidWiseTransformIterator& - rhs) const { + bool operator==(const MidWiseTransformIterator + &rhs) const { return (ptr_ + i_) == &(*rhs); } - bool operator!=(const MidWiseTransformIterator& - rhs) const { + bool operator!=(const MidWiseTransformIterator + &rhs) const { return (ptr_ + i_) != &(*rhs); } - const T& operator*() { return ptr_[i_]; } + const T &operator*() { return ptr_[i_]; } private: - const T* ptr_; + const T *ptr_; int64_t i_; int64_t j_; int64_t n_; @@ -165,18 +167,18 @@ class MidWiseTransformIterator { template class RowwiseTransformIterator : public thrust::iterator_adaptor< - RowwiseTransformIterator, const T*> { + RowwiseTransformIterator, const T *> { public: typedef thrust::iterator_adaptor< - RowwiseTransformIterator, const T*> + RowwiseTransformIterator, const T *> super_t; - HOSTDEVICE RowwiseTransformIterator(const T* x, int n) + HOSTDEVICE RowwiseTransformIterator(const T *x, int n) : super_t(x), begin_(x), n_(n) {} friend class thrust::iterator_core_access; private: unsigned int n_; - const T* begin_; + const T *begin_; HOSTDEVICE typename super_t::reference dereference() const { return *(begin_ + (this->base() - begin_) % n_); } @@ -185,19 +187,19 @@ class RowwiseTransformIterator template class MidWiseTransformIterator : public thrust::iterator_adaptor< - MidWiseTransformIterator, const T*> { + MidWiseTransformIterator, const T *> { public: typedef thrust::iterator_adaptor< - MidWiseTransformIterator, const T*> + MidWiseTransformIterator, const T *> super_t; - HOSTDEVICE MidWiseTransformIterator(const T* x, int n, int post) + HOSTDEVICE MidWiseTransformIterator(const T *x, int n, int post) : super_t(x), begin_(x), n_(n), post_(post) {} friend class thrust::iterator_core_access; private: unsigned int post_; unsigned int n_; - const T* begin_; + const T *begin_; HOSTDEVICE typename super_t::reference dereference() const { return *(begin_ + (((this->base() - begin_) / post_) % n_)); } @@ -208,8 +210,8 @@ template class TransformFunctor { public: - TransformFunctor(const framework::Tensor* x, const framework::Tensor* y, - framework::Tensor* z, const DeviceContext& ctx, Functor func) + TransformFunctor(const framework::Tensor *x, const framework::Tensor *y, + framework::Tensor *z, const DeviceContext &ctx, Functor func) : x_(x->data()), y_(y->data()), z_(z->mutable_data(ctx.GetPlace())), @@ -235,20 +237,20 @@ class TransformFunctor { } private: - const T* x_; - const T* y_; - OutType* z_; + const T *x_; + const T *y_; + OutType *z_; int64_t nx_; - const DeviceContext& ctx_; + const DeviceContext &ctx_; Functor func_; }; #define EIGEN_FUNCTOR(name, eigen_op) \ struct Eigen##name##Functor { \ template \ - inline void Run(const framework::Tensor* x, const framework::Tensor* y, \ - framework::Tensor* z, \ - const framework::ExecutionContext& ctx) { \ + inline void Run(const framework::Tensor *x, const framework::Tensor *y, \ + framework::Tensor *z, \ + const framework::ExecutionContext &ctx) { \ auto x_e = framework::EigenVector::Flatten(*x); \ auto y_e = framework::EigenVector::Flatten(*y); \ auto z_e = framework::EigenVector::Flatten(*z); \ @@ -257,9 +259,9 @@ class TransformFunctor { eigen_op(x_e, y_e); \ } \ template \ - inline void RunBroadCast(const framework::Tensor* x, \ - const framework::Tensor* y, framework::Tensor* z, \ - const framework::ExecutionContext& ctx, int pre, \ + inline void RunBroadCast(const framework::Tensor *x, \ + const framework::Tensor *y, framework::Tensor *z, \ + const framework::ExecutionContext &ctx, int pre, \ int n) { \ auto x_e = framework::EigenVector::Flatten(*x); \ auto y_e = framework::EigenVector::Flatten(*y); \ @@ -272,10 +274,10 @@ class TransformFunctor { eigen_op(x_e, y_bcast); \ } \ template \ - inline void RunBroadCast2(const framework::Tensor* x, \ - const framework::Tensor* y, \ - framework::Tensor* z, \ - const framework::ExecutionContext& ctx, int pre, \ + inline void RunBroadCast2(const framework::Tensor *x, \ + const framework::Tensor *y, \ + framework::Tensor *z, \ + const framework::ExecutionContext &ctx, int pre, \ int n, int post) { \ auto x_e = framework::EigenVector::Flatten(*x); \ auto y_e = framework::EigenVector::Flatten(*y); \ @@ -290,23 +292,27 @@ class TransformFunctor { } #define EIGEN_ADD(x, y) ((x) + (y)) + EIGEN_FUNCTOR(Add, EIGEN_ADD); #define EIGEN_SUB(x, y) ((x) - (y)) + EIGEN_FUNCTOR(Sub, EIGEN_SUB); #define EIGEN_MUL(x, y) ((x) * (y)) + EIGEN_FUNCTOR(Mul, EIGEN_MUL); #define EIGEN_DIV(x, y) ((x) / (y)) + EIGEN_FUNCTOR(Div, EIGEN_DIV); template struct ElemwiseGradNoBroadcast { - const T* x_; - const T* y_; - const T* out_; - const T* dout_; + const T *x_; + const T *y_; + const T *out_; + const T *dout_; HOSTDEVICE void operator()(size_t i) { if (dx_ != nullptr) { @@ -319,14 +325,14 @@ struct ElemwiseGradNoBroadcast { DX_OP dx_op_; DY_OP dy_op_; - T* dx_; - T* dy_; + T *dx_; + T *dy_; }; template -static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out, - const T* dout, int h, int w, DX_OP dx_op, - DY_OP dy_op, T* dx, T* dy) { +static void ElemwiseGradBroadcast1CPU(const T *x, const T *y, const T *out, + const T *dout, int h, int w, DX_OP dx_op, + DY_OP dy_op, T *dx, T *dy) { for (int i = 0; i < h; ++i) { for (int j = 0; j < w; ++j) { int x_offset = i * w + j; @@ -348,8 +354,8 @@ static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out, #ifdef __NVCC__ template static __global__ void ElemwiseGradBroadcast1CUDAKernel( - const T* x, const T* y, const T* out, const T* dout, int h, int w, - DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) { + const T *x, const T *y, const T *out, const T *dout, int h, int w, + DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) { int j = blockIdx.x; int i = threadIdx.x; int tid = threadIdx.x; @@ -376,10 +382,10 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel( } template -static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T* x, - const T* y, const T* out, const T* dout, +static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T *x, + const T *y, const T *out, const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op, - T* dx, T* dy) { + T *dx, T *dy) { int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); int gird_size = w; ElemwiseGradBroadcast1CUDAKernel<<>>( @@ -389,9 +395,9 @@ static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T* x, #endif template -static void ElemwiseGradBroadcast2CPU(const T* x, const T* y, const T* out, - const T* dout, int pre, int n, int post, - DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) { +static void ElemwiseGradBroadcast2CPU(const T *x, const T *y, const T *out, + const T *dout, int pre, int n, int post, + DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) { for (int i = 0; i < pre; ++i) { for (int j = 0; j < n; ++j) { for (int k = 0; k < post; ++k) { @@ -416,8 +422,8 @@ static void ElemwiseGradBroadcast2CPU(const T* x, const T* y, const T* out, #ifdef __NVCC__ template static __global__ void ElemwiseGradBroadcast2CUDAKernel( - const T* x, const T* y, const T* out, const T* dout, int pre, int n, - int post, DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) { + const T *x, const T *y, const T *out, const T *dout, int pre, int n, + int post, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) { int tid = threadIdx.x; int j = blockIdx.x; @@ -453,10 +459,10 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel( } template -static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T* x, - const T* y, const T* out, const T* dout, +static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T *x, + const T *y, const T *out, const T *dout, int pre, int n, int post, DX_OP dx_op, - DY_OP dy_op, T* dx, T* dy) { + DY_OP dy_op, T *dx, T *dy) { int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post); int gird_size = n; ElemwiseGradBroadcast2CUDAKernel<<>>( @@ -467,11 +473,11 @@ static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T* x, template void ElemwiseGradComputeNoBroadcast( - const framework::ExecutionContext& ctx, const framework::DDim& x_dim, - const framework::DDim& y_dim, const framework::Tensor& x, - const framework::Tensor& y, const framework::Tensor& out, - const framework::Tensor& dout, int axis, framework::Tensor* dx, - framework::Tensor* dy, DX_OP dx_op, DY_OP dy_op) { + const framework::ExecutionContext &ctx, const framework::DDim &x_dim, + const framework::DDim &y_dim, const framework::Tensor &x, + const framework::Tensor &y, const framework::Tensor &out, + const framework::Tensor &dout, int axis, framework::Tensor *dx, + framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) { size_t N = static_cast(framework::product(x_dim)); platform::ForRange for_range( ctx.template device_context(), N); @@ -483,11 +489,11 @@ void ElemwiseGradComputeNoBroadcast( template void ElemwiseGradComputeWithBroadcast( - const framework::ExecutionContext& ctx, const framework::DDim& x_dim, - const framework::DDim& y_dim_untrimed, const framework::Tensor& x, - const framework::Tensor& y, const framework::Tensor& out, - const framework::Tensor& dout, int axis, framework::Tensor* dx, - framework::Tensor* dy, DX_OP dx_op, DY_OP dy_op) { + const framework::ExecutionContext &ctx, const framework::DDim &x_dim, + const framework::DDim &y_dim_untrimed, const framework::Tensor &x, + const framework::Tensor &y, const framework::Tensor &out, + const framework::Tensor &dout, int axis, framework::Tensor *dx, + framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) { axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis); auto y_dim = trim_trailing_singular_dims(y_dim_untrimed); axis = (y_dim.size() == 0) ? x_dim.size() : axis; @@ -531,14 +537,14 @@ void ElemwiseGradComputeWithBroadcast( } template -void ElemwiseGradCompute(const framework::ExecutionContext& ctx, - const framework::Tensor& x, const framework::Tensor& y, - const framework::Tensor& out, - const framework::Tensor& dout, int axis, - framework::Tensor* dx, framework::Tensor* dy, +void ElemwiseGradCompute(const framework::ExecutionContext &ctx, + const framework::Tensor &x, const framework::Tensor &y, + const framework::Tensor &out, + const framework::Tensor &dout, int axis, + framework::Tensor *dx, framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) { - const framework::DDim& x_dim = x.dims(); - const framework::DDim& y_dim = y.dims(); + const framework::DDim &x_dim = x.dims(); + const framework::DDim &y_dim = y.dims(); if (x.dims() == y.dims()) { ElemwiseGradComputeNoBroadcast( ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op); @@ -553,27 +559,27 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx, // In elementwise_add, elementwise_sub, we use dout as fake X, Y, Out to reuse // elementwise code. template -void ElemwiseExplicitGradCompute(const framework::ExecutionContext& ctx, - const framework::Tensor& x, - const framework::Tensor& y, - const framework::Tensor& out, - const framework::Tensor& dout, int axis, - framework::Tensor* dx, framework::Tensor* dy, +void ElemwiseExplicitGradCompute(const framework::ExecutionContext &ctx, + const framework::Tensor &x, + const framework::Tensor &y, + const framework::Tensor &out, + const framework::Tensor &dout, int axis, + framework::Tensor *dx, framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) { if (dy == nullptr) { - const framework::DDim& dx_dims = dout.dims(); + const framework::DDim &dx_dims = dout.dims(); auto dy_dims = dx_dims; ElemwiseGradComputeNoBroadcast( ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op); } else { if (dout.dims() == dy->dims()) { - const framework::DDim& dx_dims = dout.dims(); - const framework::DDim& dy_dims = dy->dims(); + const framework::DDim &dx_dims = dout.dims(); + const framework::DDim &dy_dims = dy->dims(); ElemwiseGradComputeNoBroadcast( ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op); } else { // Y is a scalar auto dx_dims = dout.dims(); - const framework::DDim& dy_dims = dy->dims(); + const framework::DDim &dy_dims = dy->dims(); ElemwiseGradComputeWithBroadcast( ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op); } @@ -583,13 +589,13 @@ void ElemwiseExplicitGradCompute(const framework::ExecutionContext& ctx, // Deprecated template -void ElementwiseGradCompute(const framework::ExecutionContext& ctx, - const framework::Tensor* x, - const framework::Tensor* y, - const framework::Tensor* out, - const framework::Tensor* dout, int axis, - framework::Tensor* dx, framework::Tensor* dy) { - auto& place = *ctx.template device_context().eigen_device(); +void ElementwiseGradCompute(const framework::ExecutionContext &ctx, + const framework::Tensor *x, + const framework::Tensor *y, + const framework::Tensor *out, + const framework::Tensor *dout, int axis, + framework::Tensor *dx, framework::Tensor *dy) { + auto &place = *ctx.template device_context().eigen_device(); auto x_dims = x->dims(); auto y_dims = y->dims(); @@ -627,10 +633,10 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx, template -void ElementwiseComputeEx(const framework::ExecutionContext& ctx, - const framework::Tensor* x, - const framework::Tensor* y, int axis, Functor func, - framework::Tensor* z) { +void ElementwiseComputeEx(const framework::ExecutionContext &ctx, + const framework::Tensor *x, + const framework::Tensor *y, int axis, Functor func, + framework::Tensor *z) { TransformFunctor functor( x, y, z, ctx.template device_context(), func); @@ -661,5 +667,823 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx, } } +// FusedElemwiseAndAct +// --- forward +template +struct FusedElemwiseAndActNoBroadcast { + HOSTDEVICE void operator()(size_t i) { + T y_val = y_[i]; + T x_val = x_[i]; + if (KeepIntermediateOut) { + T intermeidiate_out = compound_functor_.GetIntermediateOut(x_val, y_val); + intermediate_out_[i] = intermeidiate_out; + out_[i] = + compound_functor_.GetOutUseIntermediateOut(x_val, intermeidiate_out); + } else { + out_[i] = compound_functor_.GetOut(x_val, y_val); + } + } + + const T *x_; + const T *y_; + CompoundFunctor compound_functor_; + T *out_; + T *intermediate_out_; +}; + +// FusedElemwiseAndActBroadcast1: +// In this case, X and Y can be reshaped to a matrix. +// For example shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5) and axis = -1 or 2, +// X can be reshaped to (6, 20) and Y can be reshaped to (1, 20) +template +static void FusedElemwiseAndActBroadcast1CPU(const T *x, const T *y, + CompoundFunctor compound_functor, + int h, int w, T *out, + T *intermediate_out) { + for (int i = 0; i < h; ++i) { + for (int j = 0; j < w; ++j) { + int offset = i * w + j; + + T y_val = BcastY ? y[j] : y[offset]; + T x_val = BcastY ? x[offset] : x[j]; + int64_t intermediate_out_offset; + if (KeepIntermediateOut) { + T intermeidiate_out = compound_functor.GetIntermediateOut(x_val, y_val); + + if (SameShapeOfIntermediateOutAndOut) { + // for the case of f1(f2(x, y)) + intermediate_out_offset = offset; + } else if (BcastY) { + intermediate_out_offset = j; + } else { + intermediate_out_offset = offset; + } + + intermediate_out[intermediate_out_offset] = intermeidiate_out; + out[offset] = + compound_functor.GetOutUseIntermediateOut(x_val, intermeidiate_out); + } else { + out[offset] = compound_functor.GetOut(x_val, y_val); + } + } + } +} + +// FusedElemwiseAndActBroadcast2 +// In this case, X and Y can be reshaped to a matrix. +// For example shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4) and axis = 1, +// X can be reshaped to (2, 12, 5) and Y can be reshaped to (1, 12, 1) +// pre = 2, n = 12, post = 5 +template +static void FusedElemwiseAndActBroadcast2CPU(const T *x, const T *y, int pre, + int n, int post, + CompoundFunctor compound_functor, + T *out, T *intermediate_out) { + for (int i = 0; i < pre; ++i) { + for (int j = 0; j < n; ++j) { + for (int k = 0; k < post; ++k) { + int offset = i * n * post + j * post + k; + + T y_val = BcastY ? y[j] : y[offset]; + T x_val = BcastY ? x[offset] : x[j]; + int64_t intermediate_out_offset; + + if (KeepIntermediateOut) { + T intermeidiate_out = + compound_functor.GetIntermediateOut(x_val, y_val); + + if (SameShapeOfIntermediateOutAndOut) { + // for the case of f1(f2(x, y)) + intermediate_out_offset = offset; + } else if (BcastY) { + intermediate_out_offset = j; + } else { + intermediate_out_offset = offset; + } + + intermediate_out[intermediate_out_offset] = intermeidiate_out; + out[offset] = compound_functor.GetOutUseIntermediateOut( + x_val, intermeidiate_out); + } else { + out[offset] = compound_functor.GetOut(x_val, y_val); + } + } + } + } +} + +#ifdef __NVCC__ +template +static __global__ void FusedElemwiseAndActBroadcast1CUDAKernel( + const T *x, const T *y, int h, int w, CompoundFunctor compound_functor, + T *out, T *intermediate_out) { + int j = blockIdx.x; + int i = threadIdx.x; + + while (i < h) { + int offset = i * w + j; + + T y_val = BcastY ? y[j] : y[offset]; + T x_val = BcastY ? x[offset] : x[j]; + int64_t intermediate_out_offset; + + if (KeepIntermediateOut) { + T intermeidiate_out = compound_functor.GetIntermediateOut(x_val, y_val); + + if (SameShapeOfIntermediateOutAndOut) { + // for the case of f1(f2(x, y)) + intermediate_out_offset = offset; + } else if (BcastY) { + intermediate_out_offset = j; + } else { + intermediate_out_offset = offset; + } + + intermediate_out[intermediate_out_offset] = intermeidiate_out; + out[offset] = + compound_functor.GetOutUseIntermediateOut(x_val, intermeidiate_out); + } else { + out[offset] = compound_functor.GetOut(x_val, y_val); + } + + i += ELEMWISE_MAX_BLOCK_DIM; + } +} + +template +static void FusedElemwiseAndActBroadcast1CUDA(cudaStream_t stream, const T *x, + const T *y, + CompoundFunctor compound_functor, + int h, int w, T *out, + T *intermediate_out) { + int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); + int gird_size = w; + FusedElemwiseAndActBroadcast1CUDAKernel< + T, CompoundFunctor, BcastY, KeepIntermediateOut, + SameShapeOfIntermediateOutAndOut><<>>( + x, y, h, w, compound_functor, out, intermediate_out); +} + +template +static __global__ void FusedElemwiseAndActBroadcast2CUDAKernel( + const T *x, const T *y, CompoundFunctor compound_functor, int pre, int n, + int post, T *out, T *intermediate_out) { + int tid = threadIdx.x; + int j = blockIdx.x; + + while (true) { + int i = tid / post; + int k = tid % post; + if (i >= pre) break; + + int offset = i * n * post + j * post + k; + + T y_val = BcastY ? y[j] : y[offset]; + T x_val = BcastY ? x[offset] : x[j]; + int64_t intermediate_out_offset; + + if (KeepIntermediateOut) { + T intermeidiate_out = compound_functor.GetIntermediateOut(x_val, y_val); + + if (SameShapeOfIntermediateOutAndOut) { + // for the case of f1(f2(x, y)) + intermediate_out_offset = offset; + } else if (BcastY) { + intermediate_out_offset = j; + } else { + intermediate_out_offset = offset; + } + + intermediate_out[intermediate_out_offset] = intermeidiate_out; + out[offset] = + compound_functor.GetOutUseIntermediateOut(x_val, intermeidiate_out); + } else { + out[offset] = compound_functor.GetOut(x_val, y_val); + } + + tid += ELEMWISE_MAX_BLOCK_DIM; + } +} + +template +static void FusedElemwiseAndActBroadcast2CUDA(cudaStream_t stream, const T *x, + const T *y, int pre, int n, + int post, + CompoundFunctor compound_functor, + T *out, T *intermediate_out) { + int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post); + int gird_size = n; + + FusedElemwiseAndActBroadcast2CUDAKernel< + T, CompoundFunctor, BcastY, KeepIntermediateOut, + SameShapeOfIntermediateOutAndOut><<>>( + x, y, compound_functor, pre, n, post, out, intermediate_out); +} + +#endif + +template +void FusedElemwiseAndActComputeNoBroadcast( + const framework::ExecutionContext &ctx, const framework::DDim &x_dim, + const framework::Tensor &x, const framework::Tensor &y, + CompoundFunctor compound_functor, framework::Tensor *out, + framework::Tensor *intermediate_out) { + size_t N = static_cast(framework::product(x_dim)); + + platform::ForRange for_range( + ctx.template device_context(), N); + + for_range( + FusedElemwiseAndActNoBroadcast{ + x.data(), y.data(), compound_functor, + out->mutable_data(ctx.GetPlace()), + intermediate_out == nullptr + ? nullptr + : intermediate_out->mutable_data(ctx.GetPlace())}); +} + +template +void FusedElemwiseAndActComputeWithBroadcast( + const framework::ExecutionContext &ctx, const framework::DDim &x_dim, + const framework::DDim &y_dim_untrimed, const framework::Tensor &x, + const framework::Tensor &y, CompoundFunctor compound_functor, int axis, + framework::Tensor *out, framework::Tensor *intermediate_out) { + axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis); + auto y_dim = trim_trailing_singular_dims(y_dim_untrimed); + axis = (y_dim.size() == 0) ? x_dim.size() : axis; + + int pre, n, post; + get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post); + + if (post == 1) { + int h = pre; + int w = n; + if (platform::is_gpu_place(ctx.GetPlace())) { +#ifdef __NVCC__ + FusedElemwiseAndActBroadcast1CUDA( + ctx.template device_context().stream(), x.data(), + y.data(), compound_functor, h, w, + out->mutable_data(ctx.GetPlace()), + intermediate_out == nullptr + ? nullptr + : intermediate_out->mutable_data(ctx.GetPlace())); +#endif + } else { + FusedElemwiseAndActBroadcast1CPU( + x.data(), y.data(), compound_functor, h, w, + out->mutable_data(ctx.GetPlace()), + intermediate_out == nullptr + ? nullptr + : intermediate_out->mutable_data(ctx.GetPlace())); + } + } else { + if (platform::is_gpu_place(ctx.GetPlace())) { +#ifdef __NVCC__ + FusedElemwiseAndActBroadcast2CUDA( + ctx.template device_context().stream(), x.data(), + y.data(), pre, n, post, compound_functor, + out->mutable_data(ctx.GetPlace()), + intermediate_out == nullptr + ? nullptr + : intermediate_out->mutable_data(ctx.GetPlace())); +#endif + } else { + FusedElemwiseAndActBroadcast2CPU( + x.data(), y.data(), pre, n, post, compound_functor, + out->mutable_data(ctx.GetPlace()), + intermediate_out == nullptr + ? nullptr + : intermediate_out->mutable_data(ctx.GetPlace())); + } + } +} + +// --- backward +template +struct FusedElemwiseAndActGradNoBroadcast { + HOSTDEVICE void operator()(size_t i) { + if (dx_ != nullptr) { + dx_[i] = UseIntermediateOut ? dx_op_(x_[i], y_[i], intermediate_out_[i], + out_[i], dout_[i]) + : dx_op_(x_[i], y_[i], out_[i], dout_[i]); + } + if (dy_ != nullptr) { + dy_[i] = UseIntermediateOut ? dy_op_(x_[i], y_[i], intermediate_out_[i], + out_[i], dout_[i]) + : dy_op_(x_[i], y_[i], out_[i], dout_[i]); + } + } + + const T *x_; + const T *y_; + const T *intermediate_out_; + const T *out_; + const T *dout_; + DX_OP dx_op_; + DY_OP dy_op_; + T *dx_; + T *dy_; +}; + +template +void FusedElemwiseAndActGradComputeNoBroadcast( + const framework::ExecutionContext &ctx, const framework::DDim &x_dim, + const framework::DDim &y_dim, const framework::Tensor *x, + const framework::Tensor *y, const framework::Tensor *intermediate_out, + const framework::Tensor *out, const framework::Tensor *dout, int axis, + framework::Tensor *dx, framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) { + size_t N = static_cast(framework::product(x_dim)); + platform::ForRange for_range( + ctx.template device_context(), N); + for_range( + FusedElemwiseAndActGradNoBroadcast{ + x->data(), y->data(), + intermediate_out ? intermediate_out->data() : nullptr, + out->data(), dout->data(), dx_op, dy_op, + dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), + dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())}); +} + +template +static void FusedElemwiseAndActGradBroadcast1CPU(const T *x, const T *y, + const T *intermediate_out, + const T *out, const T *dout, + int h, int w, DX_OP dx_op, + DY_OP dy_op, T *dx, T *dy) { + int64_t tmp_out_idx, x_idx, y_idx; + for (int i = 0; i < h; ++i) { + for (int j = 0; j < w; ++j) { + int offset = i * w + j; + + tmp_out_idx = BcastY ? j : offset; + y_idx = BcastY ? j : offset; + x_idx = BcastY ? offset : j; + + if (SameShapeOfIntermediateOutAndOut) { + tmp_out_idx = offset; + } + + if (dx != nullptr) { + T tmp = UseIntermediateOut + ? dx_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx], + out[offset], dout[offset]) + : dx_op(x[x_idx], y[y_idx], out[offset], dout[offset]); + + if (BcastY) { + dx[x_idx] = tmp; + } else { + if (i == 0) { + dx[x_idx] = tmp; + } else { + dx[x_idx] += tmp; + } + } + } + if (dy != nullptr) { + T tmp = UseIntermediateOut + ? dy_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx], + out[offset], dout[offset]) + : dy_op(x[x_idx], y[y_idx], out[offset], dout[offset]); + if (BcastY) { + if (i == 0) { + dy[y_idx] = tmp; + } else { + dy[y_idx] += tmp; + } + } else { + dy[y_idx] = tmp; + } + } + } + } +} + +template +static void FusedElemwiseAndActGradBroadcast2CPU(const T *x, const T *y, + const T *intermediate_out, + const T *out, const T *dout, + int pre, int n, int post, + DX_OP dx_op, DY_OP dy_op, + T *dx, T *dy) { + int64_t tmp_out_idx, x_idx, y_idx; + for (int i = 0; i < pre; ++i) { + for (int j = 0; j < n; ++j) { + for (int k = 0; k < post; ++k) { + int offset = i * n * post + j * post + k; + + tmp_out_idx = BcastY ? j : offset; + y_idx = BcastY ? j : offset; + x_idx = BcastY ? offset : j; + + if (SameShapeOfIntermediateOutAndOut) { + tmp_out_idx = offset; + } + + if (dx != nullptr) { + T tmp = UseIntermediateOut + ? dx_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx], + out[offset], dout[offset]) + : dx_op(x[x_idx], y[y_idx], out[offset], dout[offset]); + + if (BcastY) { + dx[x_idx] = tmp; + } else { + if (i == 0 && k == 0) { + dx[x_idx] = tmp; + } else { + dx[x_idx] += tmp; + } + } + } + if (dy != nullptr) { + T tmp = UseIntermediateOut + ? dy_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx], + out[offset], dout[offset]) + : dy_op(x[x_idx], y[y_idx], out[offset], dout[offset]); + if (BcastY) { + if (i == 0 && k == 0) { + dy[y_idx] = tmp; + } else { + dy[y_idx] += tmp; + } + } else { + dy[y_idx] = tmp; + } + } + } + } + } +} + +#ifdef __NVCC__ +template +static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel( + const T *x, const T *y, const T *intermediate_out, const T *out, + const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) { + int j = blockIdx.x; + int i = threadIdx.x; + int tid = threadIdx.x; + T val(0); + int64_t tmp_out_idx, x_idx, y_idx; + + do { + int offset = i * w + j; + + tmp_out_idx = BcastY ? j : offset; + y_idx = BcastY ? j : offset; + x_idx = BcastY ? offset : j; + + if (SameShapeOfIntermediateOutAndOut) { + tmp_out_idx = offset; + } + + if (dx != nullptr) { + T tmp = UseIntermediateOut + ? dx_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx], + out[offset], dout[offset]) + : dx_op(x[x_idx], y[y_idx], out[offset], dout[offset]); + + if (BcastY) { + dx[x_idx] = tmp; + } else { + val += tmp; + } + } + if (dy != nullptr) { + T tmp = UseIntermediateOut + ? dy_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx], + out[offset], dout[offset]) + : dy_op(x[x_idx], y[y_idx], out[offset], dout[offset]); + if (BcastY) { + val += tmp; + } else { + dy[y_idx] = tmp; + } + } + + i += ELEMWISE_MAX_BLOCK_DIM; + } while (i < h); + + if (BcastY) { + if (dy) { + h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; + val = paddle::platform::reduceSum(val, tid, h); + if (threadIdx.x == 0) { + dy[j] = val; + } + } + } else { + if (dx) { + h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; + val = paddle::platform::reduceSum(val, tid, h); + if (threadIdx.x == 0) { + dx[j] = val; + } + } + } +} + +template +static void FusedElemwiseAndActGradBroadcast1CUDA(cudaStream_t stream, + const T *x, const T *y, + const T *intermediate_out, + const T *out, const T *dout, + int h, int w, DX_OP dx_op, + DY_OP dy_op, T *dx, T *dy) { + int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); + int gird_size = w; + FusedElemwiseAndActGradBroadcast1CUDAKernel< + T, DX_OP, DY_OP, UseIntermediateOut, BcastY, + SameShapeOfIntermediateOutAndOut><<>>( + x, y, intermediate_out, out, dout, h, w, dx_op, dy_op, dx, dy); +} + +template +static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel( + const T *x, const T *y, const T *intermediate_out, const T *out, + const T *dout, int pre, int n, int post, DX_OP dx_op, DY_OP dy_op, T *dx, + T *dy) { + int tid = threadIdx.x; + int j = blockIdx.x; + + T val(0); + int ttid = tid; + int64_t tmp_out_idx, x_idx, y_idx; + while (true) { + int i = ttid / post; + int k = ttid % post; + if (i >= pre) break; + + int offset = i * n * post + j * post + k; + + tmp_out_idx = BcastY ? j : offset; + y_idx = BcastY ? j : offset; + x_idx = BcastY ? offset : j; + + if (SameShapeOfIntermediateOutAndOut) { + tmp_out_idx = offset; + } + + if (dx != nullptr) { + T tmp = UseIntermediateOut + ? dx_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx], + out[offset], dout[offset]) + : dx_op(x[x_idx], y[y_idx], out[offset], dout[offset]); + + if (BcastY) { + dx[x_idx] = tmp; + } else { + val += tmp; + } + } + if (dy != nullptr) { + T tmp = UseIntermediateOut + ? dy_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx], + out[offset], dout[offset]) + : dy_op(x[x_idx], y[y_idx], out[offset], dout[offset]); + if (BcastY) { + val += tmp; + } else { + dy[y_idx] = tmp; + } + } + + ttid += ELEMWISE_MAX_BLOCK_DIM; + } + + if (BcastY) { + if (dy) { + int h = pre * post; + h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; + val = paddle::platform::reduceSum(val, tid, h); + if (threadIdx.x == 0) { + dy[j] = val; + } + } + } else { + if (dx) { + int h = pre * post; + h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; + val = paddle::platform::reduceSum(val, tid, h); + if (threadIdx.x == 0) { + dx[j] = val; + } + } + } +} + +template +static void FusedElemwiseAndActGradBroadcast2CUDA( + cudaStream_t stream, const T *x, const T *y, const T *intermediate_out, + const T *out, const T *dout, int pre, int n, int post, DX_OP dx_op, + DY_OP dy_op, T *dx, T *dy) { + int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post); + int gird_size = n; + FusedElemwiseAndActGradBroadcast2CUDAKernel< + T, DX_OP, DY_OP, UseIntermediateOut, BcastY, + SameShapeOfIntermediateOutAndOut><<>>( + x, y, intermediate_out, out, dout, pre, n, post, dx_op, dy_op, dx, dy); +} +#endif + +template +void FusedElemwiseAndActGradComputeWithBroadcast( + const framework::ExecutionContext &ctx, const framework::DDim &x_dim, + const framework::DDim &y_dim_untrimed, const framework::Tensor *x, + const framework::Tensor *y, const framework::Tensor *intermediate_out, + const framework::Tensor *out, const framework::Tensor *dout, int axis, + framework::Tensor *dx, framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) { + axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis); + auto y_dim = trim_trailing_singular_dims(y_dim_untrimed); + axis = (y_dim.size() == 0) ? x_dim.size() : axis; + + int pre, n, post; + get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post); + if (post == 1) { + int h = pre; + int w = n; + if (platform::is_gpu_place(ctx.GetPlace())) { +#ifdef __NVCC__ + FusedElemwiseAndActGradBroadcast1CUDA( + ctx.template device_context().stream(), x->data(), + y->data(), + intermediate_out == nullptr ? nullptr : intermediate_out->data(), + out->data(), dout->data(), h, w, dx_op, dy_op, + dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), + dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); +#endif + } else { + FusedElemwiseAndActGradBroadcast1CPU( + x->data(), y->data(), + intermediate_out == nullptr ? nullptr : intermediate_out->data(), + out->data(), dout->data(), h, w, dx_op, dy_op, + dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), + dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); + } + } else { + if (platform::is_gpu_place(ctx.GetPlace())) { +#ifdef __NVCC__ + FusedElemwiseAndActGradBroadcast2CUDA( + ctx.template device_context().stream(), x->data(), + y->data(), + intermediate_out == nullptr ? nullptr : intermediate_out->data(), + out->data(), dout->data(), pre, n, post, dx_op, dy_op, + dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), + dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); +#endif + } else { + FusedElemwiseAndActGradBroadcast2CPU( + x->data(), y->data(), + intermediate_out == nullptr ? nullptr : intermediate_out->data(), + out->data(), dout->data(), pre, n, post, dx_op, dy_op, + dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), + dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); + } + } +} + +template +void FusedElemwiseAndActGradComputeEx( + const framework::ExecutionContext &ctx, const framework::Tensor *x, + const framework::Tensor *y, const framework::Tensor *out, + const framework::Tensor *intermediate_out, const framework::Tensor *dout, + int axis, framework::Tensor *dx, framework::Tensor *dy, DX_OP dx_op, + DY_OP dy_op) { + const framework::DDim &x_dim = x->dims(); + const framework::DDim &y_dim = y->dims(); + if (UseIntermediateOut) { + PADDLE_ENFORCE(intermediate_out, "intermediate_out should not be nullptr"); + } + if (x_dim == y_dim) { + FusedElemwiseAndActGradComputeNoBroadcast( + ctx, x_dim, y_dim, x, y, intermediate_out, out, dout, axis, dx, dy, + dx_op, dy_op); + } else { // Y is a scalar + bool bcast_y = x_dim.size() >= y_dim.size(); + if (x_dim.size() == y_dim.size()) { + for (int i = 0; i < x_dim.size(); ++i) { + if (x_dim[i] < y_dim[i]) { + bcast_y = false; + break; + } + } + } + + // z = f1(x, f2(y)) + // z = f1(f2(x, y)) + if (bcast_y) { // Y should be broadcast. + FusedElemwiseAndActGradComputeWithBroadcast< + DeviceContext, T, DX_OP, DY_OP, UseIntermediateOut, true /*BcastY*/, + SameShapeOfIntermediateOutAndOut>(ctx, x_dim, y_dim, x, y, + intermediate_out, out, dout, axis, + dx, dy, dx_op, dy_op); + } else { + FusedElemwiseAndActGradComputeWithBroadcast< + DeviceContext, T, DX_OP, DY_OP, UseIntermediateOut, false /*BcastY*/, + SameShapeOfIntermediateOutAndOut>(ctx, y_dim, x_dim, x, y, + intermediate_out, out, dout, axis, + dx, dy, dx_op, dy_op); + } + } +} + +template +void FusedElemwiseAndActComputeEx(const framework::ExecutionContext &ctx, + const framework::Tensor &x, + const framework::Tensor &y, int axis, + CompoundFunctor compound_functor, + framework::Tensor *out, + framework::Tensor *intermediate_out) { + if (KeepIntermediateOut) { + PADDLE_ENFORCE(intermediate_out, + "The keep_intermediate_value is opened, " + "intermediate_out should not be nullptr."); + } + + const framework::DDim &x_dim = x.dims(); + const framework::DDim &y_dim = y.dims(); + if (x.dims() == y.dims()) { + FusedElemwiseAndActComputeNoBroadcast( + ctx, x_dim, x, y, compound_functor, out, intermediate_out); + } else { + // Whether the shape of Y is a continuous subsequence of X, + // For more information please refer to the op's introduction. + bool bcast_y = x.dims().size() >= y.dims().size(); + if (x.dims().size() == y.dims().size()) { + for (int i = 0; i < x.dims().size(); ++i) { + if (x.dims()[i] < y.dims()[i]) { + bcast_y = false; + break; + } + } + } + + // z = f1(x, f2(y)) + // z = f1(f2(x, y)) + if (bcast_y) { // Y should be broadcast. + // In this case, + // for 'f2(y)', the shape of intermediate_out should be equal to the shape + // of Y. + // for 'f2(x, y)', the shape of intermediate_out should be equal to the + // shape of Out. + // the shape of Out should be equal to the shape of X. + FusedElemwiseAndActComputeWithBroadcast< + DeviceContext, T, CompoundFunctor, true /*BcastY*/, + KeepIntermediateOut, SameShapeOfIntermediateOutAndOut>( + ctx, x_dim /*OutShape*/, y_dim, x, y, compound_functor, axis, out, + intermediate_out); + } else { + // In this case, + // for 'f2(y)', the shape of intermediate_out should be equal to the shape + // of Out. + // for 'f2(x, y)', the shape of intermediate_out should be equal to the + // shape of Out. + // the shape of Out should be equal to the shape of Y. + FusedElemwiseAndActComputeWithBroadcast< + DeviceContext, T, CompoundFunctor, false /*BcastY*/, + KeepIntermediateOut, SameShapeOfIntermediateOutAndOut>( + ctx, y_dim /*OutShape*/, x_dim, x, y, compound_functor, axis, out, + intermediate_out); + } + } +} } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/fused_elemwise_activation_op.cc b/paddle/fluid/operators/fused_elemwise_activation_op.cc index a6fd0aeb021dce40339c32251af130d5984dccd2..b54f0091b3fe21222b4690f4dcff1c081d4799e7 100644 --- a/paddle/fluid/operators/fused_elemwise_activation_op.cc +++ b/paddle/fluid/operators/fused_elemwise_activation_op.cc @@ -12,14 +12,60 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/operators/fused_elemwise_activation_op.h" #include #include -#include "paddle/fluid/operators/fused_elemwise_activation_op.h" - namespace paddle { namespace operators { +/* + * Whether the compound function is Unary(Binary(X, Y)). + * For Unary(Binary(X, Y)), the intermediate_out's shape is the same the final + * out. + */ +static bool IsUnaryCompound(const std::vector &functor_list) { + PADDLE_ENFORCE_EQ(functor_list.size(), 2); + static std::unordered_set binary_fun = { + "elementwise_add", "elementwise_mul", "elementwise_add_grad", + "elementwise_mul_grad"}; + return binary_fun.count(functor_list[1]) != 0; +} + +/* + * Whether the Input(X) could be absent. + */ +static bool InputXCanBeAbsent(const std::vector &functor_list) { + PADDLE_ENFORCE_EQ(functor_list.size(), 2); + static std::unordered_set binary_fun = {"elementwise_add_grad"}; + return binary_fun.count(functor_list[0]) != 0 || + binary_fun.count(functor_list[1]) != 0; +} + +/* + * Whether the compound function is supported. + * For Unary(Binary(X, Y)), the intermediate_out's shape is the same the final + * out. + */ +static bool IsSupportedCompound(const std::vector &functors) { + static std::unordered_set unary_fun = {"scale", "relu"}; + static std::unordered_set binary_fun = {"elementwise_add", + "elementwise_mul"}; + + std::string unary_fun_str; + if (binary_fun.count(functors[0])) { + unary_fun_str = functors[1]; + } else if (binary_fun.count(functors[1])) { + unary_fun_str = functors[0]; + } else { + PADDLE_THROW("%s and %s are not included in fused_list.", functors[0], + functors[1]); + } + PADDLE_ENFORCE_EQ(unary_fun.count(unary_fun_str), 1, + "%s is not included in fused_list.", unary_fun_str); + return true; +} + class FusedElemwiseActivationOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -37,11 +83,44 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel { auto x_dim = ctx->GetInputDim("X"); auto y_dim = ctx->GetInputDim("Y"); - PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(), - "Rank of first input must >= rank of second input."); - ctx->SetOutputDim("Out", x_dim); - ctx->ShareLoD("X", /*->*/ "Out"); + // Whether the shape of Y is a continuous subsequence of X, + // For more information please refer to the op's introduction. + bool bcast_y = x_dim.size() >= y_dim.size(); + if (x_dim.size() == y_dim.size()) { + for (int i = 0; i < x_dim.size(); ++i) { + if (x_dim[i] < y_dim[i]) { + bcast_y = false; + break; + } + } + } + + auto &out_dim = bcast_y ? x_dim : y_dim; + std::string out_lod = bcast_y ? "X" : "Y"; + + if (ctx->Attrs().Get("keep_intermediate_value")) { + PADDLE_ENFORCE(ctx->HasOutput("IntermediateOut"), + "Output(IntermediateOut) of FusedElemwiseActivationOp " + "should not be null."); + + if (IsUnaryCompound( + ctx->Attrs().Get>("functor_list"))) { + // for Unary(Binary(X, Y)), the shape and lod of out and + // intermediate_out are the same. + ctx->SetOutputDim("IntermediateOut", out_dim); + // set the lod of intermediate_out + ctx->ShareLoD(out_lod, /*->*/ "IntermediateOut"); + } else { + // for Binary(X, Unary(Y)), the shape and lod of Y and + // intermediate_out are the same. + ctx->SetOutputDim("IntermediateOut", y_dim); + // set the lod of intermediate_out + ctx->ShareLoD("Y", /*->*/ "IntermediateOut"); + } + } + ctx->SetOutputDim("Out", out_dim); + ctx->ShareLoD(out_lod, /*->*/ "Out"); } protected: @@ -59,29 +138,42 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel { class FusedElemwiseActivationMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("X", "(vector)"); - AddInput("Y", "(vector)"); - AddOutput("Out", "vector"); + AddInput( + "X", + "(Tensor) The input tensor of fused_elemwise_activation operator."); + AddInput( + "Y", + "(Tensor) The input tensor of fused_elemwise_activation operator."); + AddOutput("Out", + "vector The output tensor of fused_elemwise_activation " + "operator."); + AddOutput("IntermediateOut", + "Tensor The IntermediateOut tensor of fused_elemwise_activation " + "operator.") + .AsIntermediate(); AddAttr("axis", "axis is used by elementwise_op, the default value is -1.") .SetDefault(-1); AddAttr("scale", "scale is used by scale_op, the default value is 0.0.") .SetDefault(0.0); - AddAttr("recomputation", - "Whether to recompute the Out." - "fused_elemwise_activation_grad has two methods to get the " - "dx and dy, one " - "is to use the 'Out', and the other is not to use it. " - "The former method will save the time of recomputing the " - "'Out', but it must occupy the memory to store the 'out'. " - "While, the later method can avoid occupying the memory, " - "but it must recompute the 'Out'. The default value is true.") + AddAttr( + "recomputation", + "Whether to recompute the Out." + "The computation of fused_elemwise_activation_grad has two methods to " + "get the dx and dy, one is to use the 'Out', and the other is not. " + "The former method will save the time of recomputing the 'Out', but it " + "must occupy the memory to store the 'out'. While, the later method " + "can avoid occupying the memory, but it must recompute the 'Out'. " + "It is useful for Unary(Binary(X, Y)). The default value is true.") .SetDefault(true); + AddAttr("keep_intermediate_value", + "Whether to save the intermediate_out.") + .SetDefault(false); AddAttr>("functor_list", "The functors that should be fused.") .AddCustomChecker([&](const std::vector &functor_list) { - PADDLE_ENFORCE(ValidCheck(functor_list)); + PADDLE_ENFORCE(IsSupportedCompound(functor_list)); }); AddComment(R"DOC( @@ -93,30 +185,38 @@ operators (elementwise_op and activation_op): Z = Binary(X, Unary(Y)) Z = Unary(Binary(X, Y)) -The attributions of activation_op can be get from fused_elemwise_activation_op's -attributions. functor_list records the functors to be fused, for example -"scale,elementwise_add". +There are two cases for this operator: -)DOC"); - } +1. The shape of $Y$ and $X$ is the same. +2. The shape of $Y$ is a continuous subsequence of $X$ or the shape of $X$ is a continuous subsequence of $Y$. - private: - bool ValidCheck(const std::vector &functors) { - std::unordered_set unary_fun = {"scale", "relu"}; - std::unordered_set binary_fun = {"elementwise_add"}; +For case 2 (assume that the shape of $Y$ is a continuous subsequence of $X$ ): - std::string unary_fun_str; - if (binary_fun.count(functors[0])) { - unary_fun_str = functors[1]; - } else if (binary_fun.count(functors[1])) { - unary_fun_str = functors[0]; - } else { - PADDLE_THROW("%s and %s are not included in fused_list.", functors[0], - functors[1]); - } - PADDLE_ENFORCE_EQ(unary_fun.count(unary_fun_str), 1, - "%s is not included in fused_list.", unary_fun_str); - return true; +1. Broadcast $Y$ to match the shape of $X$, where $axis$ is the start dimension index + for broadcasting $Y$ onto $X$. +2. If $axis$ is -1 (default), $axis = rank(X) - rank(Y)$. +3. The trailing dimensions of size 1 for $Y$ will be ignored for the consideration of + subsequence, such as shape(Y) = (2, 1) => (2). + +For example: + + .. code-block:: python + + shape(X) = (2, 3, 4, 5), shape(Y) = (,) + shape(X) = (2, 3, 4, 5), shape(Y) = (5,) + shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5), with axis=-1(default) or axis=2 + shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1 + shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0 + shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0 + + +The inputs $X$ and $Y$ can carry the different LoD information. +But the output only shares the LoD information with the one whose shape is the same with Out. +The attributions of activation_op can be get from fused_elemwise_activation_op's. +The functor_list records the functions to be fused, for example +["scale", "elementwise_add"]. + +)DOC"); } }; @@ -141,6 +241,7 @@ class FusedElemwiseActivationGradMaker op_desc_ptr->SetInput(framework::GradVarName(output_param), this->OutputGrad(output_param)); } + op_desc_ptr->SetAttrMap(this->Attrs()); std::vector functor_names = @@ -158,40 +259,59 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); - PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null"); - - auto x_dims = ctx->GetInputDim("X"); - auto y_dims = ctx->GetInputDim("Y"); - auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); - - PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), - "Rank of first input must >= rank of second input."); + "Input(Out@Grad) should not be null"); + if (ctx->Attrs().Get("keep_intermediate_value")) { + PADDLE_ENFORCE(ctx->HasInput("IntermediateOut"), + "Input(IntermediateOut) should not be null"); + } else { + PADDLE_ENFORCE_EQ(ctx->Inputs(framework::GradVarName("Out")).size(), 1); + } + auto funtor_list = + ctx->Attrs().Get>("functor_list"); auto x_grad_name = framework::GradVarName("X"); auto y_grad_name = framework::GradVarName("Y"); + if (ctx->HasOutput(x_grad_name)) { - ctx->SetOutputDim(x_grad_name, x_dims); + if (ctx->HasInputs("X")) { + ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X")); + ctx->ShareLoD("X", x_grad_name); + } else { + // Node: If "X" is absence, the shape of Y should be a continuous + // subsequence of X, if not, we could not infer the shape of dx. + + // Currently, only when Binary is elementwise_add or elementwise_sub, + // the "X" could be absent. + PADDLE_ENFORCE(InputXCanBeAbsent(funtor_list), + "Only when BinaryFunctor is elementwise_add, the 'X' " + "could be absent."); + + // For Unary(Binary(X, Y)), IntermediateOut should not be empty. + if (IsUnaryCompound(funtor_list)) { + PADDLE_ENFORCE( + ctx->HasInputs("IntermediateOut"), + "If the compound_functor is Unary(Binary(X, Y)) and Binary " + "is elementwise_add, the intermediate_out must be not absent."); + } + + ctx->SetOutputDim(x_grad_name, + ctx->GetInputDim(framework::GradVarName("Out"))); + ctx->ShareLoD(framework::GradVarName("Out"), x_grad_name); + } } if (ctx->HasOutput(y_grad_name)) { - ctx->SetOutputDim(y_grad_name, y_dims); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); + ctx->SetOutputDim(y_grad_name, ctx->GetInputDim("Y")); + ctx->ShareLoD("Y", y_grad_name); } } protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - auto input_data_type_index = ctx.Input("X")->type(); - PADDLE_ENFORCE_EQ(input_data_type_index, - ctx.Input("Y")->type(), - "The element's type of input should be the same."); - PADDLE_ENFORCE_EQ( - input_data_type_index, - ctx.Input(framework::GradVarName("Out"))->type(), - "The element's type of input should be the same."); - + // PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); + auto input_data_type_index = ctx.Input("Y")->type(); auto input_data_type = framework::ToDataType(input_data_type_index); return framework::OpKernelType(input_data_type, ctx.GetPlace()); } diff --git a/paddle/fluid/operators/fused_elemwise_activation_op.h b/paddle/fluid/operators/fused_elemwise_activation_op.h index fe0017b824532b1210d0ae3e51983d63d081f12a..6321541aab7e31cd703289bb8951245215ecb3e2 100644 --- a/paddle/fluid/operators/fused_elemwise_activation_op.h +++ b/paddle/fluid/operators/fused_elemwise_activation_op.h @@ -20,208 +20,114 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/elementwise_op_function.h" +#include "paddle/fluid/operators/math/compound_functors.h" #include "paddle/fluid/operators/math/functors.h" -namespace math = paddle::operators::math; - namespace paddle { namespace operators { -// CompoundFunctors -// For example: Z = Binary(X, Unary(Y)) -template -struct BinaryCompoundFunctor { - BinaryCompoundFunctor(const BinaryFun &binary_fun, const UnaryFun &unary_fun) - : binary_fun_(binary_fun), unary_fun_(unary_fun) {} - - inline HOSTDEVICE T operator()(T x, T y) { - return binary_fun_(x, unary_fun_(y)); - } - - private: - BinaryFun binary_fun_; - UnaryFun unary_fun_; -}; - -// For example: Z = Unary(Binary(X, Y)) -template -struct UnaryCompoundFunctor { - UnaryCompoundFunctor(const UnaryFun &unary_fun, const BinaryFun &binary_fun) - : unary_fun_(unary_fun), binary_fun_(binary_fun) {} - - inline HOSTDEVICE T operator()(T x, T y) { - return unary_fun_(binary_fun_(x, y)); - } - - private: - UnaryFun unary_fun_; - BinaryFun binary_fun_; -}; - -// FIXME(zcd): DBinaryFun and DUnaryFun have to method to get -// the dx, one is to use the 'out', and the other is not to use it. -// the former method will save the time of recomputing the -// 'out', but it must occupy the memory to store the 'out'. -// While the later method can avoid occupying this memory, -// but it must recompute the 'out'. - -template -struct BinaryCompoundGradDxFunctor { - BinaryCompoundGradDxFunctor(const DBinaryFun &d_binary_fun, - const UnaryFun &unary_fun) - : d_binary_fun_(d_binary_fun), unary_fun_(unary_fun) {} - - inline HOSTDEVICE T operator()(T x, T y, T out, T dout) { - if (Recomputation) { - return dout * d_binary_fun_(x, unary_fun_(y)); - } else { - return dout * d_binary_fun_(x, unary_fun_(y), out); - } - } - - private: - DBinaryFun d_binary_fun_; - UnaryFun unary_fun_; -}; - -template -struct BinaryCompoundGradDyFunctor { - BinaryCompoundGradDyFunctor(const DBinaryFun &d_binary_fun, - const UnaryFun &unary_fun, - const DUnaryFun &d_unary_fun) - : d_binary_fun_(d_binary_fun), - unary_fun_(unary_fun), - d_unary_fun_(d_unary_fun) {} - - inline HOSTDEVICE T operator()(T x, T y, T out, T dout) { - if (Recomputation) { - return dout * d_binary_fun_(unary_fun_(y), x) * d_unary_fun_(y); - } else { - return dout * d_binary_fun_(unary_fun_(y), x, out) * d_unary_fun_(y); - } - } - - private: - DBinaryFun d_binary_fun_; - UnaryFun unary_fun_; - DUnaryFun d_unary_fun_; -}; - -template -struct UnaryCompoundGradDxFunctor { - UnaryCompoundGradDxFunctor(const DUnaryFun &d_unary_fun, - const BinaryFun &binary_fun, - const DBinaryFun &d_binary_fun) - : d_unary_fun_(d_unary_fun), - binary_fun_(binary_fun), - d_binary_fun_(d_binary_fun) {} - - inline HOSTDEVICE T operator()(T x, T y, T out, T dout) { - T base; - if (Recomputation) { - base = dout * d_unary_fun_(binary_fun_(x, y)); - } else { - base = dout * d_unary_fun_(binary_fun_(x, y), out); - } - return base * d_binary_fun_(x, y); - } - - private: - DUnaryFun d_unary_fun_; - BinaryFun binary_fun_; - DBinaryFun d_binary_fun_; -}; - -template -struct UnaryCompoundGradDyFunctor { - UnaryCompoundGradDyFunctor(const DUnaryFun &d_unary_fun, - const BinaryFun &binary_fun, - const DBinaryFun &d_binary_fun) - : d_unary_fun_(d_unary_fun), - binary_fun_(binary_fun), - d_binary_fun_(d_binary_fun) {} - - inline HOSTDEVICE T operator()(T x, T y, T out, T dout) { - T base; - if (Recomputation) { - base = dout * d_unary_fun_(binary_fun_(x, y)); - } else { - base = dout * d_unary_fun_(binary_fun_(x, y), out); - } - return base * d_binary_fun_(y, x); - } - - private: - DUnaryFun d_unary_fun_; - BinaryFun binary_fun_; - DBinaryFun d_binary_fun_; -}; - template -static void RunBinaryCompoundFunctor(const framework::ExecutionContext &ctx, - const BinaryFunctor &binary_functor, - const UnaryFunctor &unary_functor, - const framework::Tensor *in_x, - const framework::Tensor *in_y, - framework::Tensor *output) { +static void RunBinaryCompoundFunctor( + const framework::ExecutionContext &ctx, const BinaryFunctor &binary_functor, + const UnaryFunctor &unary_functor, const framework::Tensor &in_x, + const framework::Tensor &in_y, std::vector *outputs) { + // Z = Binary(X, Unary(Y)) + // intermediate_out = Unary(Y) + // out = Binary(X, Unary(Y)) + // In this case, the shape of intermediate_out and out are different. + paddle::operators::math::BinaryCompoundFunctor + compound_func(binary_functor, unary_functor); int axis = ctx.Attr("axis"); - using BinaryCompoundFunctor = - BinaryCompoundFunctor; - - ElementwiseComputeEx( - ctx, in_x, in_y, axis, - BinaryCompoundFunctor(binary_functor, unary_functor), output); + if (ctx.Attr("keep_intermediate_value")) { + FusedElemwiseAndActComputeEx, + true /*KeepIntermediateValue*/, + false /*SameShapeOfIntermediateOutAndOut*/>( + ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]); + } else { + FusedElemwiseAndActComputeEx, + false /*KeepIntermediateValue*/, + false /*SameShapeOfIntermediateOutAndOut*/>( + ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]); + } } template -static void RunUnaryCompoundFunctors(const framework::ExecutionContext &ctx, - const UnaryFunctor &unary_functor, - const BinaryFunctor &binary_functor, - const framework::Tensor *in_x, - const framework::Tensor *in_y, - framework::Tensor *output) { +static void RunUnaryCompoundFunctors( + const framework::ExecutionContext &ctx, const UnaryFunctor &unary_functor, + const BinaryFunctor &binary_functor, const framework::Tensor &in_x, + const framework::Tensor &in_y, std::vector *outputs) { + // Z = Unary(Binary(X, Y)) + // intermediate_out = Binary(X, Y) + // out = Unary(Binary(X, Y)) + // In this case, the shape of intermediate_out and out are the same. int axis = ctx.Attr("axis"); - using UnaryCompoundFunctor = - UnaryCompoundFunctor; + paddle::operators::math::UnaryCompoundFunctor + compound_func(unary_functor, binary_functor); - ElementwiseComputeEx( - ctx, in_x, in_y, axis, - UnaryCompoundFunctor(unary_functor, binary_functor), output); + if (ctx.Attr("keep_intermediate_value")) { + FusedElemwiseAndActComputeEx, + true /*KeepIntermediateValue*/, + true /*SameShapeOfIntermediateOutAndOut*/>( + ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]); + } else { + FusedElemwiseAndActComputeEx, + false /*KeepIntermediateValue*/, + true /*SameShapeOfIntermediateOutAndOut*/>( + ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]); + } } template + typename UnaryFunctor, typename UnaryGradFunctor> static void RunBinaryCompoundGradFunctors( const framework::ExecutionContext &ctx, const BinaryGradFunctor &binary_grad_functor, const UnaryFunctor &unary_functor, const UnaryGradFunctor &unary_grad_functor, const framework::Tensor *in_x, const framework::Tensor *in_y, const framework::Tensor *in_out, + const framework::Tensor *in_intermediate_out, const framework::Tensor *in_out_grad, framework::Tensor *x_grad, framework::Tensor *y_grad) { + // Z = Binary(X, Unary(Y)) int axis = ctx.Attr("axis"); using BinaryCompoundDxFunctor = - BinaryCompoundGradDxFunctor; + paddle::operators::math::BinaryCompoundGradDxFunctor; using BinaryCompoundDyFunctor = - BinaryCompoundGradDyFunctor; - - ElemwiseGradCompute( - ctx, *in_x, *in_y, *in_out, *in_out_grad, axis, x_grad, y_grad, - BinaryCompoundDxFunctor(binary_grad_functor, unary_functor), - BinaryCompoundDyFunctor(binary_grad_functor, unary_functor, - unary_grad_functor)); + paddle::operators::math::BinaryCompoundGradDyFunctor< + T, BinaryGradFunctor, UnaryFunctor, UnaryGradFunctor>; + + if (in_intermediate_out) { + FusedElemwiseAndActGradComputeEx< + DeviceContext, T, BinaryCompoundDxFunctor, BinaryCompoundDyFunctor, + true /*UseIntermediateOut*/, + false /*SameShapeOfIntermediateOutAndOut*/>( + ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad, + y_grad, BinaryCompoundDxFunctor(binary_grad_functor, unary_functor), + BinaryCompoundDyFunctor(binary_grad_functor, unary_functor, + unary_grad_functor)); + } else { + FusedElemwiseAndActGradComputeEx< + DeviceContext, T, BinaryCompoundDxFunctor, BinaryCompoundDyFunctor, + false /*UseIntermediateOut*/, + false /*SameShapeOfIntermediateOutAndOut*/>( + ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad, + y_grad, BinaryCompoundDxFunctor(binary_grad_functor, unary_functor), + BinaryCompoundDyFunctor(binary_grad_functor, unary_functor, + unary_grad_functor)); + } } template ("axis"); using UnaryCompoundDxFunctor = - UnaryCompoundGradDxFunctor; + paddle::operators::math::UnaryCompoundGradDxFunctor< + T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, Recomputation>; using UnaryCompoundDyFunctor = - UnaryCompoundGradDyFunctor; - - ElemwiseGradCompute( - ctx, *in_x, *in_y, *in_out, *in_out_grad, axis, x_grad, y_grad, - UnaryCompoundDxFunctor(unary_grad_functor, binary_functor, - binary_grad_functor), - UnaryCompoundDyFunctor(unary_grad_functor, binary_functor, - binary_grad_functor)); + paddle::operators::math::UnaryCompoundGradDyFunctor< + T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, Recomputation>; + + if (in_intermediate_out) { + FusedElemwiseAndActGradComputeEx< + DeviceContext, T, UnaryCompoundDxFunctor, UnaryCompoundDyFunctor, + true /*UseIntermediateOut*/, true /*SameShapeOfIntermediateOutAndOut*/>( + ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad, + y_grad, UnaryCompoundDxFunctor(unary_grad_functor, binary_functor, + binary_grad_functor), + UnaryCompoundDyFunctor(unary_grad_functor, binary_functor, + binary_grad_functor)); + } else { + FusedElemwiseAndActGradComputeEx( + ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad, + y_grad, UnaryCompoundDxFunctor(unary_grad_functor, binary_functor, + binary_grad_functor), + UnaryCompoundDyFunctor(unary_grad_functor, binary_functor, + binary_grad_functor)); + } } template static void RunFunctors(const framework::ExecutionContext &ctx, - const framework::Tensor *in_x, - const framework::Tensor *in_y, - framework::Tensor *output) { + const framework::Tensor &in_x, + const framework::Tensor &in_y, + std::vector *outputs) { auto &functors = ctx.Attr>("functor_list"); - auto funcs_str = functors[0] + "," + functors[1]; + // TODO(zcd): The following code can be refined. + auto funcs_str = functors[0] + "," + functors[1]; if (funcs_str == "elementwise_add,scale") { // Z = Binary(X, Unary(Y)) T scale = static_cast(ctx.Attr("scale")); - RunBinaryCompoundFunctor, - math::ScaleFunctor>( - ctx, math::AddFunctor(), math::ScaleFunctor(scale), in_x, in_y, - output); + RunBinaryCompoundFunctor, + paddle::operators::math::ScaleFunctor>( + ctx, paddle::operators::math::AddFunctor(), + paddle::operators::math::ScaleFunctor(scale), in_x, in_y, outputs); } else if (funcs_str == "scale,elementwise_add") { // Z = Unary(Binary(X, Y)) T scale = static_cast(ctx.Attr("scale")); - RunUnaryCompoundFunctors, - math::AddFunctor>( - ctx, math::ScaleFunctor(scale), math::AddFunctor(), in_x, in_y, - output); + RunUnaryCompoundFunctors, + paddle::operators::math::AddFunctor>( + ctx, paddle::operators::math::ScaleFunctor(scale), + paddle::operators::math::AddFunctor(), in_x, in_y, outputs); } else if (funcs_str == "elementwise_add,relu") { - RunBinaryCompoundFunctor, - math::ReluFunctor>( - ctx, math::AddFunctor(), math::ReluFunctor(), in_x, in_y, output); + // Z = Binary(X, Unary(Y)) + RunBinaryCompoundFunctor, + paddle::operators::math::ReluFunctor>( + ctx, paddle::operators::math::AddFunctor(), + paddle::operators::math::ReluFunctor(), in_x, in_y, outputs); } else if (funcs_str == "relu,elementwise_add") { - RunUnaryCompoundFunctors, - math::AddFunctor>( - ctx, math::ReluFunctor(), math::AddFunctor(), in_x, in_y, output); + // Z = Unary(Binary(X, Y)) + RunUnaryCompoundFunctors, + paddle::operators::math::AddFunctor>( + ctx, paddle::operators::math::ReluFunctor(), + paddle::operators::math::AddFunctor(), in_x, in_y, outputs); + } else if (funcs_str == "elementwise_mul,scale") { + // Z = Binary(X, Unary(Y)) + T scale = static_cast(ctx.Attr("scale")); + RunBinaryCompoundFunctor, + paddle::operators::math::ScaleFunctor>( + ctx, paddle::operators::math::MulFunctor(), + paddle::operators::math::ScaleFunctor(scale), in_x, in_y, outputs); } else { PADDLE_THROW("%s has not been implemented.", funcs_str); } } -template +template static void RunGradFunctors(const framework::ExecutionContext &ctx, const framework::Tensor *in_x, const framework::Tensor *in_y, const framework::Tensor *in_out, + const framework::Tensor *in_intermediate_out, const framework::Tensor *in_out_grad, framework::Tensor *x_grad, framework::Tensor *y_grad) { auto &functors = ctx.Attr>("functor_list"); auto funcs_str = functors[0] + "," + functors[1]; - bool recomputation = ctx.Attr("recomputation"); - - // TODO(zcd): The following code can be refined. for example, use registion + // TODO(zcd): The following code can be refined. for example, use registrition if (funcs_str == "elementwise_add_grad,scale_grad") { // The backward of Z = Binary(X, Unary(Y)) T scale = static_cast(ctx.Attr("scale")); - if (recomputation) { - RunBinaryCompoundGradFunctors, - math::ScaleFunctor, - math::ScaleGradFunctor, true>( - ctx, math::AddGradFunctor(), math::ScaleFunctor(scale), - math::ScaleGradFunctor(scale), in_x, in_y, in_out, in_out_grad, - x_grad, y_grad); - } else { - RunBinaryCompoundGradFunctors, - math::ScaleFunctor, - math::ScaleGradFunctor, false>( - ctx, math::AddGradFunctor(), math::ScaleFunctor(scale), - math::ScaleGradFunctor(scale), in_x, in_y, in_out, in_out_grad, - x_grad, y_grad); - } + RunBinaryCompoundGradFunctors, + paddle::operators::math::ScaleFunctor, + paddle::operators::math::ScaleGradFunctor>( + ctx, paddle::operators::math::AddGradFunctor(), + paddle::operators::math::ScaleFunctor(scale), + paddle::operators::math::ScaleGradFunctor(scale), in_x, in_y, in_out, + in_intermediate_out, in_out_grad, x_grad, y_grad); } else if (funcs_str == "scale_grad,elementwise_add_grad") { // The backward of Z = Unary(Binary(X, Y)) T scale = static_cast(ctx.Attr("scale")); - if (recomputation) { - RunUnaryCompoundGradFunctors, - math::AddFunctor, math::AddGradFunctor, - true>(ctx, math::ScaleGradFunctor(scale), - math::AddFunctor(), - math::AddGradFunctor(), in_x, in_y, - in_out, in_out_grad, x_grad, y_grad); - } else { - RunUnaryCompoundGradFunctors, - math::AddFunctor, math::AddGradFunctor, - false>(ctx, math::ScaleGradFunctor(scale), - math::AddFunctor(), - math::AddGradFunctor(), in_x, in_y, - in_out, in_out_grad, x_grad, y_grad); - } + RunUnaryCompoundGradFunctors, + paddle::operators::math::AddFunctor, + paddle::operators::math::AddGradFunctor, + ReComputation /*Recomputation*/>( + ctx, paddle::operators::math::ScaleGradFunctor(scale), + paddle::operators::math::AddFunctor(), + paddle::operators::math::AddGradFunctor(), in_x, in_y, in_out, + in_intermediate_out, in_out_grad, x_grad, y_grad); } else if (funcs_str == "elementwise_add_grad,relu_grad") { - if (recomputation) { - RunBinaryCompoundGradFunctors, - math::ReluFunctor, - math::ReluGradFunctor, true>( - ctx, math::AddGradFunctor(), math::ReluFunctor(), - math::ReluGradFunctor(), in_x, in_y, in_out, in_out_grad, x_grad, - y_grad); - } else { - RunBinaryCompoundGradFunctors, - math::ReluFunctor, - math::ReluGradFunctor, false>( - ctx, math::AddGradFunctor(), math::ReluFunctor(), - math::ReluGradFunctor(), in_x, in_y, in_out, in_out_grad, x_grad, - y_grad); - } + RunBinaryCompoundGradFunctors, + paddle::operators::math::ReluFunctor, + paddle::operators::math::ReluGradFunctor>( + ctx, paddle::operators::math::AddGradFunctor(), + paddle::operators::math::ReluFunctor(), + paddle::operators::math::ReluGradFunctor(), in_x, in_y, in_out, + in_intermediate_out, in_out_grad, x_grad, y_grad); } else if (funcs_str == "relu_grad,elementwise_add_grad") { - if (recomputation) { - RunUnaryCompoundGradFunctors, - math::AddFunctor, math::AddGradFunctor, - true>(ctx, math::ReluGradFunctor(), - math::AddFunctor(), - math::AddGradFunctor(), in_x, in_y, - in_out, in_out_grad, x_grad, y_grad); - } else { - RunUnaryCompoundGradFunctors, - math::AddFunctor, math::AddGradFunctor, - false>(ctx, math::ReluGradFunctor(), - math::AddFunctor(), - math::AddGradFunctor(), in_x, in_y, - in_out, in_out_grad, x_grad, y_grad); - } + RunUnaryCompoundGradFunctors, + paddle::operators::math::AddFunctor, + paddle::operators::math::AddGradFunctor, + ReComputation /*Recomputation*/>( + ctx, paddle::operators::math::ReluGradFunctor(), + paddle::operators::math::AddFunctor(), + paddle::operators::math::AddGradFunctor(), in_x, in_y, in_out, + in_intermediate_out, in_out_grad, x_grad, y_grad); + } else if (funcs_str == "elementwise_mul_grad,scale_grad") { + // The backward of Z = Binary(X, Unary(Y)) + T scale = static_cast(ctx.Attr("scale")); + RunBinaryCompoundGradFunctors, + paddle::operators::math::ScaleFunctor, + paddle::operators::math::ScaleGradFunctor>( + ctx, paddle::operators::math::MulGradFunctor(), + paddle::operators::math::ScaleFunctor(scale), + paddle::operators::math::ScaleGradFunctor(scale), in_x, in_y, in_out, + in_intermediate_out, in_out_grad, x_grad, y_grad); } else { PADDLE_THROW("%s has not been implemented.", funcs_str); } @@ -385,11 +307,23 @@ class FusedElemwiseActivationKernel : public framework::OpKernel { auto &in_y = detail::Ref(ctx.Input("Y"), "Cannot get input tensor %s, variable name = %s", "Y", ctx.op().Input("Y")); - auto &output = detail::Ref(ctx.Output("Out"), - "Cannot get input tensor %s, variable name = %s", - "Out", ctx.op().Output("Out")); + PADDLE_ENFORCE(ctx.HasOutput("Out"), "The output(Out) should not be empty"); + auto output = ctx.Output("Out"); + + std::vector outputs; + outputs.emplace_back(output); + + if (ctx.Attr("keep_intermediate_value")) { + PADDLE_ENFORCE(ctx.HasOutput("IntermediateOut"), + "The keep_intermediate_value is enable, so the " + "IntermediateOut should not be empty."); + auto intermediate_out = ctx.Output("IntermediateOut"); + outputs.emplace_back(intermediate_out); + } else { + outputs.emplace_back(nullptr); + } - RunFunctors(ctx, &in_x, &in_y, &output); + RunFunctors(ctx, in_x, in_y, &outputs); } }; @@ -397,28 +331,66 @@ template class FusedElemwiseActivationGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { - auto &in_x = detail::Ref(ctx.Input("X"), - "Cannot get input tensor %s, variable name = %s", - "X", ctx.op().Input("X")); - auto &in_y = detail::Ref(ctx.Input("Y"), - "Cannot get input tensor %s, variable name = %s", - "Y", ctx.op().Input("Y")); - auto &in_out = detail::Ref(ctx.Input("Out"), - "Cannot get input tensor %s, variable name = %s", - "Out", ctx.op().Input("Out")); - auto &in_out_grad = - detail::Ref(ctx.Input(framework::GradVarName("Out")), - "Cannot get input tensor %s, variable name = %s", - framework::GradVarName("Out"), - ctx.op().Input(framework::GradVarName("Out"))); + auto x = ctx.Input("X"); + auto y = ctx.Input("Y"); + + auto in_out = ctx.Input("Out"); + auto in_out_grad = + ctx.Input(framework::GradVarName("Out")); framework::Tensor *x_grad = ctx.Output(framework::GradVarName("X")); framework::Tensor *y_grad = ctx.Output(framework::GradVarName("Y")); - RunGradFunctors(ctx, &in_x, &in_y, &in_out, &in_out_grad, - x_grad, y_grad); + PADDLE_ENFORCE(y != nullptr, "Input(Y) should not be nullptr."); + + if (ctx.Attr("recomputation")) { + PADDLE_ENFORCE( + x != nullptr, + "The recomputation is opened, so Input(X) should not be absent."); + } else { + PADDLE_ENFORCE(in_out != nullptr, + "The recomputation is disabled, so the Input('Out') " + "should not be empty."); + } + + framework::Tensor *in_x; + auto functor_list = ctx.Attr>("functor_list"); + + // If functor_list contains elementwise_add, the backward doesn't use + // in_x, and in_outs. + if (x == nullptr) { + PADDLE_ENFORCE(functor_list[0] == "elementwise_add_grad" || + functor_list[1] == "elementwise_add_grad", + "Only when the compoundfunctor contains " + "elementwise_add_grad, the 'X' could be absent."); + in_x = const_cast(in_out_grad); + in_out = const_cast(in_out_grad); + } else { + in_x = const_cast(x); + } + + framework::Tensor *in_intermediate_out; + if (ctx.Attr("keep_intermediate_value")) { + in_intermediate_out = const_cast( + ctx.Input("IntermediateOut")); + PADDLE_ENFORCE(in_intermediate_out != nullptr, + "The option of 'keep_intermediate_value' is opened, " + "so the number of 'Out' should be two."); + } else { + in_intermediate_out = nullptr; + } + + if (ctx.Attr("recomputation")) { + RunGradFunctors( + ctx, in_x, y, in_out, in_intermediate_out, in_out_grad, x_grad, + y_grad); + } else { + RunGradFunctors( + ctx, in_x, y, in_out, in_intermediate_out, in_out_grad, x_grad, + y_grad); + } } }; } // namespace operators diff --git a/paddle/fluid/operators/math/compound_functors.h b/paddle/fluid/operators/math/compound_functors.h new file mode 100644 index 0000000000000000000000000000000000000000..1d32a9585b08a9d27730076d9f7baa6056270a42 --- /dev/null +++ b/paddle/fluid/operators/math/compound_functors.h @@ -0,0 +1,185 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +namespace paddle { +namespace operators { +namespace math { + +template +struct BinaryCompoundFunctor { + BinaryCompoundFunctor(const BinaryFunctor func1, const UnaryFunctor func2) + : func1_(func1), func2_(func2) {} + // Z = BinaryFunctor(X, UnaryFunctor(Y)) + + inline HOSTDEVICE T GetOut(T x, T y) { return func1_(x, func2_(y)); } + + inline HOSTDEVICE T GetOutUseIntermediateOut(T x, T intermediat_out) { + return func1_(x, intermediat_out); + } + + inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return func2_(y); } + + BinaryFunctor func1_; + UnaryFunctor func2_; +}; + +template +struct UnaryCompoundFunctor { + UnaryCompoundFunctor(const UnaryFunctor func1, const BinaryFunctor func2) + : func1_(func1), func2_(func2) {} + // Z = UnaryFunctor(BinaryFunctor(X, Y)) + + inline HOSTDEVICE T GetOut(T x, T y) { return func1_(func2_(x, y)); } + + inline HOSTDEVICE T GetOutUseIntermediateOut(T x, T intermediat_out) { + return func1_(intermediat_out); + } + + inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return func2_(x, y); } + + UnaryFunctor func1_; + BinaryFunctor func2_; +}; + +// FIXME(zcd): DBinaryFun and DUnaryFun have to method to get +// the dx, one is to use the 'out', and the other is not to use it. +// the former method will save the time of recomputing the +// 'out', but it must occupy the memory to store the 'out'. +// While the later method can avoid occupying this memory, +// but it must recompute the 'out'. +template +struct BinaryCompoundGradDxFunctor { + BinaryCompoundGradDxFunctor(const DBinaryFun &d_binary_fun, + const UnaryFun &unary_fun) + : d_binary_fun_(d_binary_fun), unary_fun_(unary_fun) {} + + inline HOSTDEVICE T operator()(T x, T y, T out, T dout) { + return dout * d_binary_fun_.Dx(x, unary_fun_(y)); + } + + inline HOSTDEVICE T operator()(T x, T y, T intermediate_out, T out, T dout) { + return dout * d_binary_fun_.Dx(x, intermediate_out); + } + + private: + DBinaryFun d_binary_fun_; + UnaryFun unary_fun_; +}; + +template +struct BinaryCompoundGradDyFunctor { + BinaryCompoundGradDyFunctor(const DBinaryFun &d_binary_fun, + const UnaryFun &unary_fun, + const DUnaryFun &d_unary_fun) + : d_binary_fun_(d_binary_fun), + unary_fun_(unary_fun), + d_unary_fun_(d_unary_fun) {} + + inline HOSTDEVICE T operator()(T x, T y, T out, T dout) { + return dout * d_binary_fun_.Dy(x, unary_fun_(y)) * d_unary_fun_(y); + } + + inline HOSTDEVICE T operator()(T x, T y, T intermediate_out, T out, T dout) { + return dout * d_binary_fun_.Dy(x, intermediate_out) * + d_unary_fun_(y, intermediate_out); + } + + private: + DBinaryFun d_binary_fun_; + UnaryFun unary_fun_; + DUnaryFun d_unary_fun_; +}; + +template +struct UnaryCompoundGradDxFunctor { + UnaryCompoundGradDxFunctor(const DUnaryFun &d_unary_fun, + const BinaryFun &binary_fun, + const DBinaryFun &d_binary_fun) + : d_unary_fun_(d_unary_fun), + binary_fun_(binary_fun), + d_binary_fun_(d_binary_fun) {} + + inline HOSTDEVICE T operator()(T x, T y, T out, T dout) { + T base; + if (Recomputation) { + base = dout * d_unary_fun_(binary_fun_(x, y)); + } else { + base = dout * d_unary_fun_(binary_fun_(x, y), out); + } + return base * d_binary_fun_.Dx(x, y); + } + + inline HOSTDEVICE T operator()(T x, T y, T intermediate_out, T out, T dout) { + T base; + if (Recomputation) { + base = dout * d_unary_fun_(intermediate_out); + } else { + base = dout * d_unary_fun_(intermediate_out, out); + } + return base * d_binary_fun_.Dx(x, y); + } + + private: + DUnaryFun d_unary_fun_; + BinaryFun binary_fun_; + DBinaryFun d_binary_fun_; +}; + +template +struct UnaryCompoundGradDyFunctor { + UnaryCompoundGradDyFunctor(const DUnaryFun &d_unary_fun, + const BinaryFun &binary_fun, + const DBinaryFun &d_binary_fun) + : d_unary_fun_(d_unary_fun), + binary_fun_(binary_fun), + d_binary_fun_(d_binary_fun) {} + + inline HOSTDEVICE T operator()(T x, T y, T out, T dout) { + T base; + if (Recomputation) { + base = dout * d_unary_fun_(binary_fun_(x, y)); + } else { + base = dout * d_unary_fun_(binary_fun_(x, y), out); + } + return base * d_binary_fun_.Dy(x, y); + } + + inline HOSTDEVICE T operator()(T x, T y, T intermediate_out, T out, T dout) { + T base; + if (Recomputation) { + base = dout * d_unary_fun_(intermediate_out); + } else { + base = dout * d_unary_fun_(intermediate_out, out); + } + return base * d_binary_fun_.Dy(x, y); + } + + private: + DUnaryFun d_unary_fun_; + BinaryFun binary_fun_; + DBinaryFun d_binary_fun_; +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/functors.h b/paddle/fluid/operators/math/functors.h index ad2f49ccbf5ff37d33cc9e71c1a683571f4f8137..ddb01cdfc084f5ba2e9e573be461389f46fbe03f 100644 --- a/paddle/fluid/operators/math/functors.h +++ b/paddle/fluid/operators/math/functors.h @@ -18,6 +18,19 @@ namespace paddle { namespace operators { namespace math { +// MulFunctor +template +struct MulFunctor { + // out = x * y; + inline HOSTDEVICE T operator()(T x, T y) { return x * y; } +}; + +template +struct MulGradFunctor { + inline HOSTDEVICE T Dx(T x, T y) { return y; } + inline HOSTDEVICE T Dy(T x, T y) { return x; } +}; + // AddFunctor template struct AddFunctor { @@ -27,9 +40,8 @@ struct AddFunctor { template struct AddGradFunctor { - inline HOSTDEVICE T operator()(T x, T y) { return 1; } - - inline HOSTDEVICE T operator()(T x, T y, T out) const { return 1; } + inline HOSTDEVICE T Dx(T x, T y) { return 1; } + inline HOSTDEVICE T Dy(T x, T y) { return 1; } }; template diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index 44cd073379f293a1114c2c77fa80d35d112d4fb8..20f1a37a426e9697048d636bf738c9056213e5f6 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -47,7 +47,8 @@ def get_numeric_gradient(place, input_to_check, output_names, delta=0.005, - in_place=False): + in_place=False, + sum_outputs=None): # FIXME: change this method by compile time concepts set_input(scope, op, inputs, place) @@ -58,9 +59,11 @@ def get_numeric_gradient(place, sum = [] op.run(scope, place) for output_name in output_names: + if sum_outputs and output_name not in sum_outputs: + continue sum.append( np.array(scope.find_var(output_name).get_tensor()).mean()) - return np.array(sum).mean() + return np.array(sum).sum() / len(output_names) tensor_to_check = scope.find_var(input_to_check).get_tensor() tensor_size = product(tensor_to_check.shape()) @@ -396,13 +399,14 @@ class OpTest(unittest.TestCase): numeric_grad_delta=0.005, in_place=False, max_relative_error=0.005, - user_defined_grads=None): + user_defined_grads=None, + sum_outputs=None): places = self._get_places() for place in places: self.check_grad_with_place(place, inputs_to_check, output_names, no_grad_set, numeric_grad_delta, in_place, max_relative_error, - user_defined_grads) + user_defined_grads, sum_outputs) def check_grad_with_place(self, place, @@ -412,7 +416,8 @@ class OpTest(unittest.TestCase): numeric_grad_delta=0.005, in_place=False, max_relative_error=0.005, - user_defined_grads=None): + user_defined_grads=None, + sum_outputs=None): self.scope = core.Scope() op_inputs = self.inputs if hasattr(self, "inputs") else dict() op_outputs = self.outputs if hasattr(self, "outputs") else dict() @@ -435,7 +440,8 @@ class OpTest(unittest.TestCase): input_to_check, output_names, delta=numeric_grad_delta, - in_place=in_place) for input_to_check in inputs_to_check + in_place=in_place, + sum_outputs=sum_outputs) for input_to_check in inputs_to_check ] analytic_grads = self._get_gradient(inputs_to_check, place, output_names, no_grad_set) diff --git a/python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py b/python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py index 97e1b9061afb738dd9e5f8b3b6a9c9a123c6aac6..4a213c29113e5e23af2caf7fbcb807be3d0166d2 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py @@ -15,806 +15,327 @@ from __future__ import print_function import unittest import numpy as np +from functools import partial import paddle.fluid.core as core from op_test import OpTest -# scale + add -# TestElementwiseAddOp -# TestFusedOperatorsOp_scalar -# TestFusedOperatorsOp_scalar2 -# TestFusedOperatorsOp_Vector -# TestFusedOperatorsOp_broadcast_0 -# TestFusedOperatorsOp_broadcast_1 -# TestFusedOperatorsOp_broadcast_2 -# TestFusedOperatorsOp_broadcast_3 -# TestFusedOperatorsOp_broadcast_4 -# TestFusedOperatorsOp_rowwise_add_0 -# TestFusedOperatorsOp_rowwise_add_1 -# TestFusedOperatorsOp_channelwise_add - - -class TestElementwiseAddOp(OpTest): - def setUp(self): - self.op_type = "fused_elemwise_activation" - self.dtype = np.float32 - self.axis = -1 - - self.init_axis() - self.init_dtype() - self.init_input() - self.init_output() - self.init_attr() - - self.inputs = { - 'X': OpTest.np_dtype_to_fluid_dtype(self.x), - 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) - } - self.outputs = {'Out': self.out} - - def init_input(self): - self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) - self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) - - def init_output(self): - self.scale = 0.1 - self.out = (self.x + self.y) * self.scale - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'scale': self.scale, - 'functor_list': ["scale", "elementwise_add"] - } - - def init_dtype(self): - pass - - def init_axis(self): - pass - - def test_check_output(self): - self.check_output() - - def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.005) - - def test_check_grad_ingore_x(self): - self.check_grad( - ['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X")) - - def test_check_grad_ingore_y(self): - self.check_grad( - ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')) - - -class TestFusedOperatorsOp_scalar(TestElementwiseAddOp): - def init_input(self): - self.x = np.random.rand(2, 3, 4).astype(self.dtype) - self.y = np.random.rand(1).astype(self.dtype) - - def init_output(self): - self.scale = 0.1 - self.out = (self.x + self.y) * self.scale - - -class TestFusedOperatorsOp_scalar2(TestElementwiseAddOp): - def init_input(self): - self.x = np.random.rand(2, 3, 4).astype(self.dtype) - self.y = np.random.rand(1, 1).astype(self.dtype) - - def init_output(self): - self.scale = 0.1 - self.out = (self.x + self.y) * self.scale - - -class TestFusedOperatorsOp_Vector(TestElementwiseAddOp): - def init_input(self): - self.x = np.random.random((32, )).astype(self.dtype) - self.y = np.random.random((32, )).astype(self.dtype) - - def init_output(self): - self.scale = 0.1 - self.out = (self.x + self.y) * self.scale - - -class TestFusedOperatorsOp_broadcast_0(TestElementwiseAddOp): - def init_input(self): - self.x = np.random.rand(2, 3, 4).astype(self.dtype) - self.y = np.random.rand(2).astype(self.dtype) - - def init_axis(self): - self.axis = 0 - - def init_output(self): - self.scale = 0.1 - self.out = (self.x + self.y.reshape(2, 1, 1)) * self.scale - - -class TestFusedOperatorsOp_broadcast_1(TestElementwiseAddOp): - def init_input(self): - self.x = np.random.rand(2, 3, 4).astype(self.dtype) - self.y = np.random.rand(3).astype(self.dtype) - - def init_axis(self): - self.axis = 1 - - def init_output(self): - self.scale = 0.1 - self.out = (self.x + self.y.reshape(1, 3, 1)) * self.scale - - -class TestFusedOperatorsOp_broadcast_2(TestElementwiseAddOp): - def init_input(self): - self.x = np.random.rand(2, 3, 4).astype(self.dtype) - self.y = np.random.rand(4).astype(self.dtype) - - def init_output(self): - self.scale = 0.1 - self.out = (self.x + self.y.reshape(1, 1, 4)) * self.scale - - -class TestFusedOperatorsOp_broadcast_3(TestElementwiseAddOp): - def init_input(self): - self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype) - self.y = np.random.rand(3, 4).astype(self.dtype) - - def init_axis(self): - self.axis = 1 - - def init_output(self): - self.scale = 0.1 - self.out = (self.x + self.y.reshape(1, 3, 4, 1)) * self.scale - - -class TestFusedOperatorsOp_broadcast_4(TestElementwiseAddOp): - def init_input(self): - self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype) - self.y = np.random.rand(2, 1).astype(self.dtype) - - def init_axis(self): - self.axis = 0 - - def init_output(self): - self.scale = 0.1 - self.out = (self.x + self.y.reshape(2, 1, 1, 1)) * self.scale - - -class TestFusedOperatorsOp_rowwise_add_0(TestElementwiseAddOp): - def init_input(self): - self.x = np.random.rand(2, 3, 4).astype(self.dtype) - self.y = np.random.rand(3, 4).astype(self.dtype) - - def init_axis(self): - self.axis = 1 - - def init_output(self): - self.scale = 0.1 - self.out = (self.x + self.y.reshape(1, 3, 4)) * self.scale - - -class TestFusedOperatorsOp_rowwise_add_1(TestElementwiseAddOp): - def init_input(self): - self.x = np.random.rand(2, 1).astype(self.dtype) - self.y = np.random.rand(1).astype(self.dtype) - - def init_axis(self): - self.axis = 1 - - def init_output(self): - self.scale = 0.1 - self.out = (self.x + self.y.reshape(1, 1)) * self.scale - - -class TestFusedOperatorsOp_channelwise_add(TestElementwiseAddOp): - def init_input(self): - self.x = np.random.rand(3, 20, 20).astype(self.dtype) - self.y = np.random.rand(3, 1, 1).astype(self.dtype) - - def init_axis(self): - self.axis = -1 - - def init_output(self): - self.scale = 0.1 - self.out = (self.x + self.y) * self.scale - - -# add + scale -# TestElementwiseAddOp_f_add_scale -# TestFusedOperatorsOp_scalar_f_add_scale -# TestFusedOperatorsOp_scalar2_f_add_scale -# TestFusedOperatorsOp_Vector_f_add_scale -# TestFusedOperatorsOp_broadcast_0_f_add_scale -# TestFusedOperatorsOp_broadcast_1_f_add_scale -# TestFusedOperatorsOp_broadcast_2_f_add_scale -# TestFusedOperatorsOp_broadcast_3_f_add_scale -# TestFusedOperatorsOp_broadcast_4_f_add_scale -# TestFusedOperatorsOp_rowwise_add_0_f_add_scale -# TestFusedOperatorsOp_rowwise_add_1_f_add_scale -# TestFusedOperatorsOp_channelwise_add_f_add_scale - - -class TestFusedOperatorsOp_f_add_scale(TestElementwiseAddOp): - def init_output(self): - self.scale = 0.1 - self.out = self.x + self.y * self.scale - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'scale': self.scale, - 'functor_list': ["elementwise_add", "scale"] - } - - -class TestFusedOperatorsOp_scalar_f_add_scale(TestFusedOperatorsOp_scalar): - def init_output(self): - self.scale = 0.1 - self.out = self.x + self.y * self.scale - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'scale': self.scale, - 'functor_list': ["elementwise_add", "scale"] - } - - -class TestFusedOperatorsOp_scalar2_f_add_scale(TestFusedOperatorsOp_scalar2): - def init_output(self): - self.scale = 0.1 - self.out = self.x + self.y * self.scale - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'scale': self.scale, - 'functor_list': ["elementwise_add", "scale"] - } - - -class TestFusedOperatorsOp_Vector_f_add_scale(TestFusedOperatorsOp_Vector): - def init_output(self): - self.scale = 0.1 - self.out = self.x + self.y * self.scale - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'scale': self.scale, - 'functor_list': ["elementwise_add", "scale"] - } - - -class TestFusedOperatorsOp_broadcast_0_f_add_scale( - TestFusedOperatorsOp_broadcast_0): - def init_axis(self): - self.axis = 0 - - def init_output(self): - self.scale = 0.1 - self.out = self.x + self.y.reshape(2, 1, 1) * self.scale - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'scale': self.scale, - 'functor_list': ["elementwise_add", "scale"] - } - - -class TestFusedOperatorsOp_broadcast_1_f_add_scale( - TestFusedOperatorsOp_broadcast_1): - def init_axis(self): - self.axis = 1 - - def init_output(self): - self.scale = 0.1 - self.out = self.x + self.y.reshape(1, 3, 1) * self.scale - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'scale': self.scale, - 'functor_list': ["elementwise_add", "scale"] - } - - -class TestFusedOperatorsOp_broadcast_2_f_add_scale( - TestFusedOperatorsOp_broadcast_2): - def init_output(self): - self.scale = 0.1 - self.out = self.x + self.y.reshape(1, 1, 4) * self.scale - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'scale': self.scale, - 'functor_list': ["elementwise_add", "scale"] - } - - -class TestFusedOperatorsOp_broadcast_3_f_add_scale( - TestFusedOperatorsOp_broadcast_3): - def init_axis(self): - self.axis = 1 - - def init_output(self): - self.scale = 0.1 - self.out = self.x + self.y.reshape(1, 3, 4, 1) * self.scale - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'scale': self.scale, - 'functor_list': ["elementwise_add", "scale"] - } - - -class TestFusedOperatorsOp_broadcast_4_f_add_scale( - TestFusedOperatorsOp_broadcast_4): - def init_axis(self): - self.axis = 0 - - def init_output(self): - self.scale = 0.2 - self.out = self.x + self.y.reshape(2, 1, 1, 1) * self.scale - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'scale': self.scale, - 'functor_list': ["elementwise_add", "scale"] - } - - -class TestFusedOperatorsOp_rowwise_add_0_f_add_scale( - TestFusedOperatorsOp_rowwise_add_0): - def init_axis(self): - self.axis = 1 - - def init_output(self): - self.scale = 0.1 - self.out = self.x + self.y.reshape(1, 3, 4) * self.scale - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'scale': self.scale, - 'functor_list': ["elementwise_add", "scale"] - } - - -class TestFusedOperatorsOp_rowwise_add_1_f_add_scale( - TestFusedOperatorsOp_rowwise_add_1): - def init_axis(self): - self.axis = 1 - - def init_output(self): - self.scale = 0.2 - self.out = self.x + self.y.reshape(1, 1) * self.scale - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'scale': self.scale, - 'functor_list': ["elementwise_add", "scale"] - } - - -class TestFusedOperatorsOp_channelwise_add_f_add_scale( - TestFusedOperatorsOp_channelwise_add): - def init_axis(self): - self.axis = -1 - - def init_output(self): - self.scale = 0.2 - self.out = self.x + self.y * self.scale - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'scale': self.scale, - 'functor_list': ["elementwise_add", "scale"] - } - - -# add + relu -# TestElementwiseAddOp_f_add_relu -# TestFusedOperatorsOp_scalar_f_add_relu -# TestFusedOperatorsOp_scalar2_f_add_relu -# TestFusedOperatorsOp_Vector_f_add_relu -# TestFusedOperatorsOp_broadcast_0_f_add_relu -# TestFusedOperatorsOp_broadcast_1_f_add_relu -# TestFusedOperatorsOp_broadcast_2_f_add_relu -# TestFusedOperatorsOp_broadcast_3_f_add_relu -# TestFusedOperatorsOp_broadcast_4_f_add_relu -# TestFusedOperatorsOp_rowwise_add_0_f_add_relu -# TestFusedOperatorsOp_rowwise_add_1_f_add_relu -# TestFusedOperatorsOp_channelwise_add_f_add_relu - - -class TestFusedOperatorsOp_f_add_relu(TestElementwiseAddOp): - def init_output(self): - # Copy from test_activation_op.py - # Because we set delta = 0.005 in calculating numeric gradient, - # if x is too small, such as 0.002, x_neg will be -0.003 - # x_pos will be 0.007, so the numeric gradient is inaccurate. - # we should avoid this - self.y[np.abs(self.y) < 0.005] = 0.02 - self.out = self.x + np.maximum(self.y, 0) - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'functor_list': ["elementwise_add", "relu"] - } - - -class TestFusedOperatorsOp_scalar_f_add_relu(TestFusedOperatorsOp_scalar): - def init_output(self): - self.y[np.abs(self.y) < 0.005] = 0.02 - self.out = self.x + np.maximum(self.y, 0) - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'functor_list': ["elementwise_add", "relu"] - } - - -class TestFusedOperatorsOp_scalar2_f_add_relu(TestFusedOperatorsOp_scalar2): - def init_output(self): - self.y[np.abs(self.y) < 0.005] = 0.02 - self.out = self.x + np.maximum(self.y, 0) - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'functor_list': ["elementwise_add", "relu"] - } - - -class TestFusedOperatorsOp_Vector_f_add_relu(TestFusedOperatorsOp_Vector): - def init_output(self): - self.y[np.abs(self.y) < 0.005] = 0.02 - self.out = self.x + np.maximum(self.y, 0) - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'functor_list': ["elementwise_add", "relu"] - } - - -class TestFusedOperatorsOp_broadcast_0_f_add_relu( - TestFusedOperatorsOp_broadcast_0): - def init_axis(self): - self.axis = 0 - - def init_output(self): - self.y[np.abs(self.y) < 0.005] = 0.02 - self.out = self.x + np.maximum(self.y.reshape(2, 1, 1), 0) - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'functor_list': ["elementwise_add", "relu"] - } - - -class TestFusedOperatorsOp_broadcast_1_f_add_relu( - TestFusedOperatorsOp_broadcast_1): - def init_axis(self): - self.axis = 1 - - def init_output(self): - self.y[np.abs(self.y) < 0.005] = 0.02 - self.out = self.x + np.maximum(self.y.reshape(1, 3, 1), 0) - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'functor_list': ["elementwise_add", "relu"] - } - - -class TestFusedOperatorsOp_broadcast_2_f_add_relu( - TestFusedOperatorsOp_broadcast_2): - def init_output(self): - self.y[np.abs(self.y) < 0.005] = 0.02 - self.out = self.x + np.maximum(self.y.reshape(1, 1, 4), 0) - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'functor_list': ["elementwise_add", "relu"] - } - - -class TestFusedOperatorsOp_broadcast_3_f_add_relu( - TestFusedOperatorsOp_broadcast_3): - def init_axis(self): - self.axis = 1 - - def init_output(self): - self.y[np.abs(self.y) < 0.005] = 0.02 - self.out = self.x + np.maximum(self.y.reshape(1, 3, 4, 1), 0) - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'functor_list': ["elementwise_add", "relu"] - } - - -class TestFusedOperatorsOp_broadcast_4_f_add_relu( - TestFusedOperatorsOp_broadcast_4): - def init_axis(self): - self.axis = 0 - - def init_output(self): - self.y[np.abs(self.y) < 0.005] = 0.02 - self.out = self.x + np.maximum(self.y.reshape(2, 1, 1, 1), 0) - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'functor_list': ["elementwise_add", "relu"] - } - - -class TestFusedOperatorsOp_rowwise_add_0_f_add_relu( - TestFusedOperatorsOp_rowwise_add_0): - def init_axis(self): - self.axis = 1 - - def init_output(self): - self.y[np.abs(self.y) < 0.005] = 0.02 - self.out = self.x + np.maximum(self.y.reshape(1, 3, 4), 0) - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'functor_list': ["elementwise_add", "relu"] - } - - -class TestFusedOperatorsOp_rowwise_add_1_f_add_relu( - TestFusedOperatorsOp_rowwise_add_1): - def init_axis(self): - self.axis = 1 - - def init_output(self): - self.y[np.abs(self.y) < 0.005] = 0.02 - self.out = self.x + np.maximum(self.y.reshape(1, 1), 0) - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'functor_list': ["elementwise_add", "relu"] - } - - -class TestFusedOperatorsOp_channelwise_add_f_add_relu( - TestFusedOperatorsOp_channelwise_add): - def init_axis(self): - self.axis = -1 - - def init_output(self): - self.y[np.abs(self.y) < 0.005] = 0.02 - self.out = self.x + np.maximum(self.y, 0) - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'functor_list': ["elementwise_add", "relu"] - } - - -# relu + add -# TestElementwiseAddOp_f_relu_add -# TestFusedOperatorsOp_scalar_f_relu_add -# TestFusedOperatorsOp_scalar2_f_relu_add -# TestFusedOperatorsOp_Vector_f_relu_add -# TestFusedOperatorsOp_broadcast_0_f_relu_add -# TestFusedOperatorsOp_broadcast_1_f_relu_add -# TestFusedOperatorsOp_broadcast_2_f_relu_add -# TestFusedOperatorsOp_broadcast_3_f_relu_add -# TestFusedOperatorsOp_broadcast_4_f_relu_add -# TestFusedOperatorsOp_rowwise_add_0_f_relu_add -# TestFusedOperatorsOp_rowwise_add_1_f_relu_add -# TestFusedOperatorsOp_channelwise_add_f_relu_add - - -class TestFusedOperatorsOp_f_relu_add(TestElementwiseAddOp): - def init_output(self): - # Copy from test_activation_op.py - # Because we set delta = 0.005 in calculating numeric gradient, - # if x is too small, such as 0.002, x_neg will be -0.003 - # x_pos will be 0.007, so the numeric gradient is inaccurate. - # we should avoid this - self.out = self.x + self.y - self.out = np.maximum(self.out, 0) - self.out[np.abs(self.out) < 0.005] = 0.02 - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'functor_list': ["relu", "elementwise_add"] - } - - -class TestFusedOperatorsOp_scalar_f_relu_add(TestFusedOperatorsOp_scalar): - def init_output(self): - self.out = self.x + self.y - self.out = np.maximum(self.out, 0) - self.out[np.abs(self.out) < 0.005] = 0.02 - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'functor_list': ["relu", "elementwise_add"] - } - - -class TestFusedOperatorsOp_scalar2_f_relu_add(TestFusedOperatorsOp_scalar2): - def init_output(self): - self.out = self.x + self.y - self.out = np.maximum(self.out, 0) - self.out[np.abs(self.out) < 0.005] = 0.02 - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'functor_list': ["relu", "elementwise_add"] - } - - -class TestFusedOperatorsOp_Vector_f_relu_add(TestFusedOperatorsOp_Vector): - def init_output(self): - self.out = self.x + self.y - self.out = np.maximum(self.out, 0) - self.out[np.abs(self.out) < 0.005] = 0.02 - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'functor_list': ["relu", "elementwise_add"] - } - - -class TestFusedOperatorsOp_broadcast_0_f_relu_add( - TestFusedOperatorsOp_broadcast_0): - def init_axis(self): - self.axis = 0 - - def init_output(self): - self.out = self.x + self.y.reshape(2, 1, 1) - self.out = np.maximum(self.out, 0) - self.out[np.abs(self.out) < 0.005] = 0.02 - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'functor_list': ["relu", "elementwise_add"] - } - - -class TestFusedOperatorsOp_broadcast_1_f_relu_add( - TestFusedOperatorsOp_broadcast_1): - def init_axis(self): - self.axis = 1 - - def init_output(self): - self.out = self.x + self.y.reshape(1, 3, 1) - self.out = np.maximum(self.out, 0) - self.out[np.abs(self.out) < 0.005] = 0.02 - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'functor_list': ["relu", "elementwise_add"] - } - - -class TestFusedOperatorsOp_broadcast_2_f_relu_add( - TestFusedOperatorsOp_broadcast_2): - def init_output(self): - self.out = self.x + self.y.reshape(1, 1, 4) - self.out = np.maximum(self.out, 0) - self.out[np.abs(self.out) < 0.005] = 0.02 - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'functor_list': ["relu", "elementwise_add"] - } - - -class TestFusedOperatorsOp_broadcast_3_f_relu_add( - TestFusedOperatorsOp_broadcast_3): - def init_axis(self): - self.axis = 1 - - def init_output(self): - self.out = self.x + self.y.reshape(1, 3, 4, 1) - self.out = np.maximum(self.out, 0) - self.out[np.abs(self.out) < 0.005] = 0.02 - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'functor_list': ["relu", "elementwise_add"] - } - - -class TestFusedOperatorsOp_broadcast_4_f_relu_add( - TestFusedOperatorsOp_broadcast_4): - def init_axis(self): - self.axis = 0 - - def init_output(self): - self.out = self.x + self.y.reshape(2, 1, 1, 1) - self.out = np.maximum(self.out, 0) - self.out[np.abs(self.out) < 0.005] = 0.02 - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'functor_list': ["relu", "elementwise_add"] - } - - -class TestFusedOperatorsOp_rowwise_add_0_f_relu_add( - TestFusedOperatorsOp_rowwise_add_0): - def init_axis(self): - self.axis = 1 - - def init_output(self): - self.out = self.x + self.y.reshape(1, 3, 4) - self.out = np.maximum(self.out, 0) - self.out[np.abs(self.out) < 0.005] = 0.02 - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'functor_list': ["relu", "elementwise_add"] - } - - -class TestFusedOperatorsOp_rowwise_add_1_f_relu_add( - TestFusedOperatorsOp_rowwise_add_1): - def init_axis(self): - self.axis = 1 - - def init_output(self): - self.out = self.x + self.y.reshape(1, 1) - self.out = np.maximum(self.out, 0) - self.out[np.abs(self.out) < 0.005] = 0.02 - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'functor_list': ["relu", "elementwise_add"] - } - - -class TestFusedOperatorsOp_channelwise_add_f_relu_add( - TestFusedOperatorsOp_channelwise_add): - def init_axis(self): - self.axis = -1 - - def init_output(self): - self.out = self.x + self.y - self.out = np.maximum(self.out, 0) - self.out[np.abs(self.out) < 0.005] = 0.02 - - def init_attr(self): - self.attrs = { - 'axis': self.axis, - 'functor_list': ["relu", "elementwise_add"] - } - +# TestFusedElementwiseActivationOp +# TestFusedElementwiseActivationOp_scalar +# TestFusedElementwiseActivationOp_scalar2 +# TestFusedElementwiseActivationOp_Vector +# TestFusedElementwiseActivationOp_broadcast_0 +# TestFusedElementwiseActivationOp_broadcast_1 +# TestFusedElementwiseActivationOp_broadcast_2 +# TestFusedElementwiseActivationOp_broadcast_3 +# TestFusedElementwiseActivationOp_broadcast_4 +# TestFusedElementwiseActivationOp_rowwise_add_0 +# TestFusedElementwiseActivationOp_rowwise_add_1 +# TestFusedElementwiseActivationOp_channelwise_add + + +def create_test_class(test_case, callback, attrs): + class TestFusedElementwiseActivationOp_base(OpTest): + def setUp(self): + self.op_type = "fused_elemwise_activation" + self.dtype = np.float32 + self.axis = -1 + + self.init_input() + self.init_output() + self.init_attr() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + if self.attrs["keep_intermediate_value"]: + self.outputs = { + 'Out': self.out, + "IntermediateOut": self.intermediate_out + } + else: + self.outputs = {'Out': self.out} + + def init_input(self): + self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.axis = -1 + + def init_output(self): + self.x, self.y, self.intermediate_out, self.out = \ + callback(self.x, self.y, self.x, self.y) + + def init_attr(self): + self.attrs = {'axis': self.axis, } + for key in attrs.keys(): + self.attrs[key] = attrs[key] + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + if self.attrs["keep_intermediate_value"]: + self.check_grad( + ['X', 'Y'], ['Out', 'IntermediateOut'], + max_relative_error=0.005, + sum_outputs=['Out']) + else: + self.check_grad(['X', 'Y'], ['Out'], max_relative_error=0.005) + + def test_check_grad_ingore_x(self): + if self.attrs["keep_intermediate_value"]: + self.check_grad( + ['Y'], ['Out', 'IntermediateOut'], + max_relative_error=0.005, + no_grad_set=set("X"), + sum_outputs=['Out']) + else: + self.check_grad( + ['Y'], ['Out'], + max_relative_error=0.005, + no_grad_set=set("X")) + + def test_check_grad_ingore_y(self): + if self.attrs["keep_intermediate_value"]: + self.check_grad( + ['X'], ['Out', 'IntermediateOut'], + max_relative_error=0.005, + no_grad_set=set("Y"), + sum_outputs=['Out']) + else: + self.check_grad( + ['X'], ['Out'], + max_relative_error=0.005, + no_grad_set=set("Y")) + + class TestFusedElementwiseActivationOp_scalar( + TestFusedElementwiseActivationOp_base): + def init_input(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(1).astype(self.dtype) + + class TestFusedElementwiseActivationOp_scalar2( + TestFusedElementwiseActivationOp_base): + def init_input(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(1, 1).astype(self.dtype) + + class TestFusedElementwiseActivationOp_Vector( + TestFusedElementwiseActivationOp_base): + def init_input(self): + self.x = np.random.random((32, )).astype(self.dtype) + self.y = np.random.random((32, )).astype(self.dtype) + + class TestFusedElementwiseActivationOp_broadcast_0( + TestFusedElementwiseActivationOp_base): + def init_input(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(2).astype(self.dtype) + self.axis = 0 + + def init_output(self): + self.x, self.y, self.intermediate_out, self.out = \ + callback(self.x, self.y, self.x, self.y.reshape(2, 1, 1)) + + class TestFusedElementwiseActivationOp_broadcast_1( + TestFusedElementwiseActivationOp_base): + def init_input(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(3).astype(self.dtype) + self.axis = 1 + + def init_output(self): + self.x, self.y, self.intermediate_out, self.out = \ + callback(self.x, self.y, self.x, self.y.reshape(1, 3, 1)) + + class TestFusedElementwiseActivationOp_broadcast_2( + TestFusedElementwiseActivationOp_base): + def init_input(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(4).astype(self.dtype) + + def init_output(self): + self.x, self.y, self.intermediate_out, self.out = \ + callback(self.x, self.y, self.x, self.y.reshape(1, 1, 4)) + + class TestFusedElementwiseActivationOp_broadcast_3( + TestFusedElementwiseActivationOp_base): + def init_input(self): + self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype) + self.y = np.random.rand(3, 4).astype(self.dtype) + self.axis = 1 + + def init_output(self): + self.x, self.y, self.intermediate_out, self.out = \ + callback(self.x, self.y, self.x, self.y.reshape(1, 3, 4, 1)) + + class TestFusedElementwiseActivationOp_broadcast_4( + TestFusedElementwiseActivationOp_base): + def init_input(self): + self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype) + self.y = np.random.rand(2, 1).astype(self.dtype) + self.axis = 0 + + def init_output(self): + self.x, self.y, self.intermediate_out, self.out = \ + callback(self.x, self.y, self.x, self.y.reshape(2, 1, 1, 1)) + + class TestFusedElementwiseActivationOp_rowwise_add_0( + TestFusedElementwiseActivationOp_base): + def init_input(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(3, 4).astype(self.dtype) + self.axis = 1 + + def init_output(self): + self.x, self.y, self.intermediate_out, self.out = \ + callback(self.x, self.y, self.x, self.y.reshape(1, 3, 4)) + + class TestFusedElementwiseActivationOp_rowwise_add_1( + TestFusedElementwiseActivationOp_base): + def init_input(self): + self.x = np.random.rand(2, 1).astype(self.dtype) + self.y = np.random.rand(1).astype(self.dtype) + self.axis = 1 + + def init_output(self): + self.x, self.y, self.intermediate_out, self.out = \ + callback(self.x, self.y, self.x, self.y.reshape(1, 1)) + + class TestFusedElementwiseActivationOp_channelwise_add( + TestFusedElementwiseActivationOp_base): + def init_input(self): + self.x = np.random.rand(3, 20, 20).astype(self.dtype) + self.y = np.random.rand(3, 1, 1).astype(self.dtype) + + TestFusedElementwiseActivationOp_base.__name__ = test_case + "_base" + TestFusedElementwiseActivationOp_scalar.__name__ = test_case + "_scalar" + TestFusedElementwiseActivationOp_scalar2.__name__ = test_case + "_scalar2" + TestFusedElementwiseActivationOp_Vector.__name__ = test_case + "_Vector" + TestFusedElementwiseActivationOp_broadcast_0.__name__ = test_case + "_broadcast_0" + TestFusedElementwiseActivationOp_broadcast_1.__name__ = test_case + "_broadcast_1" + TestFusedElementwiseActivationOp_broadcast_2.__name__ = test_case + "_broadcast_2" + TestFusedElementwiseActivationOp_broadcast_3.__name__ = test_case + "_broadcast_3" + TestFusedElementwiseActivationOp_broadcast_4.__name__ = test_case + "_broadcast_4" + TestFusedElementwiseActivationOp_rowwise_add_0.__name__ = test_case + "_rowwise_add_0" + TestFusedElementwiseActivationOp_rowwise_add_1.__name__ = test_case + "_rowwise_add_1" + TestFusedElementwiseActivationOp_channelwise_add.__name__ = test_case + "_channelwise_add" + + globals()[test_case + "_base"] = TestFusedElementwiseActivationOp_base + globals()[test_case + "_scalar"] = TestFusedElementwiseActivationOp_scalar + globals()[test_case + "_scalar2"] = TestFusedElementwiseActivationOp_scalar2 + globals()[test_case + "_Vector"] = TestFusedElementwiseActivationOp_Vector + globals()[test_case + + "_broadcast_0"] = TestFusedElementwiseActivationOp_broadcast_0 + globals()[test_case + + "_broadcast_1"] = TestFusedElementwiseActivationOp_broadcast_1 + globals()[test_case + + "_broadcast_2"] = TestFusedElementwiseActivationOp_broadcast_2 + globals()[test_case + + "_broadcast_3"] = TestFusedElementwiseActivationOp_broadcast_3 + globals()[test_case + + "_broadcast_4"] = TestFusedElementwiseActivationOp_broadcast_4 + globals()[test_case + + "_rowwise_add_0"] = TestFusedElementwiseActivationOp_rowwise_add_0 + globals()[test_case + + "_rowwise_add_1"] = TestFusedElementwiseActivationOp_rowwise_add_1 + globals( + )[test_case + + "_channelwise_add"] = TestFusedElementwiseActivationOp_channelwise_add + + +def scale_add_func(x, y, x_bcast, y_bcast, scale, mode=0): + if mode == 0: + return x, y, (x_bcast + y_bcast), (x_bcast + y_bcast) * scale + else: + return y, x, (x_bcast + y_bcast), (x_bcast + y_bcast) * scale + + +def add_scale_func(x, y, x_bcast, y_bcast, scale, mode=0): + if mode == 0: + return x, y, y * scale, x_bcast + y_bcast * scale + else: + return y, x, x * scale, y_bcast + x_bcast * scale + + +def add_relu_func(x, y, x_bcast, y_bcast, mode=0): + # Copy from test_activation_op.py + # Because we set delta = 0.005 in calculating numeric gradient, + # if x is too small, such as 0.002, x_neg will be -0.003 + # x_pos will be 0.007, so the numeric gradient is inaccurate. + # we should avoid this + if mode == 0: + y[np.abs(y) < 0.005] = 0.02 + y_bcast[np.abs(y_bcast) < 0.005] = 0.02 + return x, y, np.maximum(y, 0), x_bcast + np.maximum(y_bcast, 0) + else: + x[np.abs(x) < 0.005] = 0.02 + x_bcast[np.abs(x_bcast) < 0.005] = 0.02 + return y, x, np.maximum(x, 0), y_bcast + np.maximum(x_bcast, 0) + + +def relu_add_func(x, y, x_bcast, y_bcast, mode=0): + intermediate_out = x_bcast + y_bcast + out = np.maximum(intermediate_out, 0) + out[np.abs(out) < 0.005] = 0.02 + if mode == 0: + return x, y, intermediate_out, out + else: + return y, x, intermediate_out, out + + +def mul_scale_func(x, y, x_bcast, y_bcast, scale, mode=0): + if mode == 0: + return x, y, y * scale, x_bcast * (y_bcast * scale) + else: + return y, x, x * scale, y_bcast * (x_bcast * scale) + + +scale = 0.1 +scale_add_func = partial(scale_add_func, scale=scale) +add_scale_func = partial(add_scale_func, scale=scale) +mul_scale_func = partial(mul_scale_func, scale=scale) + +for mode in {0, 1}: + scale_add_func = partial(scale_add_func, mode=mode) + add_scale_func = partial(add_scale_func, mode=mode) + mul_scale_func = partial(mul_scale_func, mode=mode) + relu_add_func = partial(relu_add_func, mode=mode) + add_relu_func = partial(add_relu_func, mode=mode) + + for recomputation in {True, False}: + for keep_intermediate_value in {True, False}: + suffix = ("_keep_intermediate_value" if keep_intermediate_value else "") \ + + ("_recomputation" if recomputation else "") \ + + ("_mode_"+ str(mode)) + create_test_class('scale_add' + suffix, scale_add_func, { + 'scale': scale, + 'functor_list': ["scale", "elementwise_add"], + 'keep_intermediate_value': keep_intermediate_value, + 'recomputation': recomputation + }) + create_test_class('add_scale' + suffix, add_scale_func, { + 'scale': scale, + 'functor_list': ["elementwise_add", "scale"], + 'keep_intermediate_value': keep_intermediate_value, + 'recomputation': recomputation + }) + create_test_class('add_relu' + suffix, add_relu_func, { + 'functor_list': ["elementwise_add", "relu"], + 'keep_intermediate_value': keep_intermediate_value, + 'recomputation': recomputation + }) + create_test_class('relu_add' + suffix, relu_add_func, { + 'functor_list': ["relu", "elementwise_add"], + 'keep_intermediate_value': keep_intermediate_value, + 'recomputation': recomputation + }) + create_test_class('mul_scale' + suffix, mul_scale_func, { + 'scale': scale, + 'functor_list': ["elementwise_mul", "scale"], + 'keep_intermediate_value': keep_intermediate_value, + 'recomputation': recomputation + }) if __name__ == '__main__': unittest.main()