未验证 提交 3bd1d22a 编写于 作者: C chengduo 提交者: GitHub

Enhance fused_elementwise_activation_op (#12837)

* Enhance the function of fused_elementwise_activation_op

* enhance unit test

* Clean Code And Add Doc

* Add compound functors

* Fix doc and enhance unit test

* define Dx and Dy for d_binary_func

* add mul_scale

* add mul_scale

* add elementwise_mul

* code refine

* code refine

* add doc

* add  AsIntermediate
上级 a615ad46
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <glog/logging.h>
#include <algorithm>
#include <vector>
......@@ -46,9 +47,9 @@ namespace operators {
* pre=2*3, n=4*5, post=1
* x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1)
*/
inline void get_mid_dims(const framework::DDim& x_dims,
const framework::DDim& y_dims, const int axis,
int* pre, int* n, int* post) {
inline void get_mid_dims(const framework::DDim &x_dims,
const framework::DDim &y_dims, const int axis,
int *pre, int *n, int *post) {
*pre = 1;
*n = 1;
*post = 1;
......@@ -68,7 +69,7 @@ inline void get_mid_dims(const framework::DDim& x_dims,
}
inline framework::DDim trim_trailing_singular_dims(
const framework::DDim& dims) {
const framework::DDim &dims) {
// Remove trailing dimensions of size 1 for y
auto actual_dims_size = dims.size();
for (; actual_dims_size != 0; --actual_dims_size) {
......@@ -89,15 +90,16 @@ inline framework::DDim trim_trailing_singular_dims(
template <typename T, typename DeviceContext>
class RowwiseTransformIterator;
template <typename T, typename DeviceContext>
class MidWiseTransformIterator;
template <typename T>
class RowwiseTransformIterator<T, platform::CPUDeviceContext> {
public:
RowwiseTransformIterator(const T* ptr, int n) : ptr_(ptr), i_(0), n_(n) {}
RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {}
RowwiseTransformIterator<T, platform::CPUDeviceContext>& operator++() {
RowwiseTransformIterator<T, platform::CPUDeviceContext> &operator++() {
++i_;
if (UNLIKELY(i_ == n_)) {
i_ = 0;
......@@ -105,20 +107,20 @@ class RowwiseTransformIterator<T, platform::CPUDeviceContext> {
return *this;
}
bool operator==(const RowwiseTransformIterator<T, platform::CPUDeviceContext>&
rhs) const {
bool operator==(const RowwiseTransformIterator<T, platform::CPUDeviceContext>
&rhs) const {
return (ptr_ + i_) == &(*rhs);
}
bool operator!=(const RowwiseTransformIterator<T, platform::CPUDeviceContext>&
rhs) const {
bool operator!=(const RowwiseTransformIterator<T, platform::CPUDeviceContext>
&rhs) const {
return (ptr_ + i_) != &(*rhs);
}
const T& operator*() { return ptr_[i_]; }
const T &operator*() { return ptr_[i_]; }
private:
const T* ptr_;
const T *ptr_;
int i_;
int64_t n_;
};
......@@ -126,10 +128,10 @@ class RowwiseTransformIterator<T, platform::CPUDeviceContext> {
template <typename T>
class MidWiseTransformIterator<T, platform::CPUDeviceContext> {
public:
MidWiseTransformIterator(const T* ptr, int n, int post)
MidWiseTransformIterator(const T *ptr, int n, int post)
: ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}
MidWiseTransformIterator<T, platform::CPUDeviceContext>& operator++() {
MidWiseTransformIterator<T, platform::CPUDeviceContext> &operator++() {
++j_;
if (UNLIKELY(j_ == post_)) {
++i_;
......@@ -141,20 +143,20 @@ class MidWiseTransformIterator<T, platform::CPUDeviceContext> {
return *this;
}
bool operator==(const MidWiseTransformIterator<T, platform::CPUDeviceContext>&
rhs) const {
bool operator==(const MidWiseTransformIterator<T, platform::CPUDeviceContext>
&rhs) const {
return (ptr_ + i_) == &(*rhs);
}
bool operator!=(const MidWiseTransformIterator<T, platform::CPUDeviceContext>&
rhs) const {
bool operator!=(const MidWiseTransformIterator<T, platform::CPUDeviceContext>
&rhs) const {
return (ptr_ + i_) != &(*rhs);
}
const T& operator*() { return ptr_[i_]; }
const T &operator*() { return ptr_[i_]; }
private:
const T* ptr_;
const T *ptr_;
int64_t i_;
int64_t j_;
int64_t n_;
......@@ -165,18 +167,18 @@ class MidWiseTransformIterator<T, platform::CPUDeviceContext> {
template <typename T>
class RowwiseTransformIterator<T, platform::CUDADeviceContext>
: public thrust::iterator_adaptor<
RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T*> {
RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T *> {
public:
typedef thrust::iterator_adaptor<
RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T*>
RowwiseTransformIterator<T, platform::CUDADeviceContext>, const T *>
super_t;
HOSTDEVICE RowwiseTransformIterator(const T* x, int n)
HOSTDEVICE RowwiseTransformIterator(const T *x, int n)
: super_t(x), begin_(x), n_(n) {}
friend class thrust::iterator_core_access;
private:
unsigned int n_;
const T* begin_;
const T *begin_;
HOSTDEVICE typename super_t::reference dereference() const {
return *(begin_ + (this->base() - begin_) % n_);
}
......@@ -185,19 +187,19 @@ class RowwiseTransformIterator<T, platform::CUDADeviceContext>
template <typename T>
class MidWiseTransformIterator<T, platform::CUDADeviceContext>
: public thrust::iterator_adaptor<
MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T*> {
MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T *> {
public:
typedef thrust::iterator_adaptor<
MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T*>
MidWiseTransformIterator<T, platform::CUDADeviceContext>, const T *>
super_t;
HOSTDEVICE MidWiseTransformIterator(const T* x, int n, int post)
HOSTDEVICE MidWiseTransformIterator(const T *x, int n, int post)
: super_t(x), begin_(x), n_(n), post_(post) {}
friend class thrust::iterator_core_access;
private:
unsigned int post_;
unsigned int n_;
const T* begin_;
const T *begin_;
HOSTDEVICE typename super_t::reference dereference() const {
return *(begin_ + (((this->base() - begin_) / post_) % n_));
}
......@@ -208,8 +210,8 @@ template <typename Functor, typename T, typename DeviceContext,
typename OutType = T>
class TransformFunctor {
public:
TransformFunctor(const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z, const DeviceContext& ctx, Functor func)
TransformFunctor(const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z, const DeviceContext &ctx, Functor func)
: x_(x->data<T>()),
y_(y->data<T>()),
z_(z->mutable_data<OutType>(ctx.GetPlace())),
......@@ -235,20 +237,20 @@ class TransformFunctor {
}
private:
const T* x_;
const T* y_;
OutType* z_;
const T *x_;
const T *y_;
OutType *z_;
int64_t nx_;
const DeviceContext& ctx_;
const DeviceContext &ctx_;
Functor func_;
};
#define EIGEN_FUNCTOR(name, eigen_op) \
struct Eigen##name##Functor { \
template <typename DeviceContext, typename T> \
inline void Run(const framework::Tensor* x, const framework::Tensor* y, \
framework::Tensor* z, \
const framework::ExecutionContext& ctx) { \
inline void Run(const framework::Tensor *x, const framework::Tensor *y, \
framework::Tensor *z, \
const framework::ExecutionContext &ctx) { \
auto x_e = framework::EigenVector<T>::Flatten(*x); \
auto y_e = framework::EigenVector<T>::Flatten(*y); \
auto z_e = framework::EigenVector<T>::Flatten(*z); \
......@@ -257,9 +259,9 @@ class TransformFunctor {
eigen_op(x_e, y_e); \
} \
template <typename DeviceContext, typename T> \
inline void RunBroadCast(const framework::Tensor* x, \
const framework::Tensor* y, framework::Tensor* z, \
const framework::ExecutionContext& ctx, int pre, \
inline void RunBroadCast(const framework::Tensor *x, \
const framework::Tensor *y, framework::Tensor *z, \
const framework::ExecutionContext &ctx, int pre, \
int n) { \
auto x_e = framework::EigenVector<T>::Flatten(*x); \
auto y_e = framework::EigenVector<T>::Flatten(*y); \
......@@ -272,10 +274,10 @@ class TransformFunctor {
eigen_op(x_e, y_bcast); \
} \
template <typename DeviceContext, typename T> \
inline void RunBroadCast2(const framework::Tensor* x, \
const framework::Tensor* y, \
framework::Tensor* z, \
const framework::ExecutionContext& ctx, int pre, \
inline void RunBroadCast2(const framework::Tensor *x, \
const framework::Tensor *y, \
framework::Tensor *z, \
const framework::ExecutionContext &ctx, int pre, \
int n, int post) { \
auto x_e = framework::EigenVector<T>::Flatten(*x); \
auto y_e = framework::EigenVector<T>::Flatten(*y); \
......@@ -290,23 +292,27 @@ class TransformFunctor {
}
#define EIGEN_ADD(x, y) ((x) + (y))
EIGEN_FUNCTOR(Add, EIGEN_ADD);
#define EIGEN_SUB(x, y) ((x) - (y))
EIGEN_FUNCTOR(Sub, EIGEN_SUB);
#define EIGEN_MUL(x, y) ((x) * (y))
EIGEN_FUNCTOR(Mul, EIGEN_MUL);
#define EIGEN_DIV(x, y) ((x) / (y))
EIGEN_FUNCTOR(Div, EIGEN_DIV);
template <typename T, typename DX_OP, typename DY_OP>
struct ElemwiseGradNoBroadcast {
const T* x_;
const T* y_;
const T* out_;
const T* dout_;
const T *x_;
const T *y_;
const T *out_;
const T *dout_;
HOSTDEVICE void operator()(size_t i) {
if (dx_ != nullptr) {
......@@ -319,14 +325,14 @@ struct ElemwiseGradNoBroadcast {
DX_OP dx_op_;
DY_OP dy_op_;
T* dx_;
T* dy_;
T *dx_;
T *dy_;
};
template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out,
const T* dout, int h, int w, DX_OP dx_op,
DY_OP dy_op, T* dx, T* dy) {
static void ElemwiseGradBroadcast1CPU(const T *x, const T *y, const T *out,
const T *dout, int h, int w, DX_OP dx_op,
DY_OP dy_op, T *dx, T *dy) {
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
int x_offset = i * w + j;
......@@ -348,8 +354,8 @@ static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out,
#ifdef __NVCC__
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast1CUDAKernel(
const T* x, const T* y, const T* out, const T* dout, int h, int w,
DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
const T *x, const T *y, const T *out, const T *dout, int h, int w,
DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
int j = blockIdx.x;
int i = threadIdx.x;
int tid = threadIdx.x;
......@@ -376,10 +382,10 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
}
template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T* x,
const T* y, const T* out, const T* dout,
static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T *x,
const T *y, const T *out, const T *dout,
int h, int w, DX_OP dx_op, DY_OP dy_op,
T* dx, T* dy) {
T *dx, T *dy) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
int gird_size = w;
ElemwiseGradBroadcast1CUDAKernel<<<gird_size, block_size, 0, stream>>>(
......@@ -389,9 +395,9 @@ static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T* x,
#endif
template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast2CPU(const T* x, const T* y, const T* out,
const T* dout, int pre, int n, int post,
DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
static void ElemwiseGradBroadcast2CPU(const T *x, const T *y, const T *out,
const T *dout, int pre, int n, int post,
DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) {
for (int k = 0; k < post; ++k) {
......@@ -416,8 +422,8 @@ static void ElemwiseGradBroadcast2CPU(const T* x, const T* y, const T* out,
#ifdef __NVCC__
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast2CUDAKernel(
const T* x, const T* y, const T* out, const T* dout, int pre, int n,
int post, DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) {
const T *x, const T *y, const T *out, const T *dout, int pre, int n,
int post, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
int tid = threadIdx.x;
int j = blockIdx.x;
......@@ -453,10 +459,10 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
}
template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T* x,
const T* y, const T* out, const T* dout,
static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T *x,
const T *y, const T *out, const T *dout,
int pre, int n, int post, DX_OP dx_op,
DY_OP dy_op, T* dx, T* dy) {
DY_OP dy_op, T *dx, T *dy) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
int gird_size = n;
ElemwiseGradBroadcast2CUDAKernel<<<gird_size, block_size, 0, stream>>>(
......@@ -467,11 +473,11 @@ static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T* x,
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradComputeNoBroadcast(
const framework::ExecutionContext& ctx, const framework::DDim& x_dim,
const framework::DDim& y_dim, const framework::Tensor& x,
const framework::Tensor& y, const framework::Tensor& out,
const framework::Tensor& dout, int axis, framework::Tensor* dx,
framework::Tensor* dy, DX_OP dx_op, DY_OP dy_op) {
const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
const framework::DDim &y_dim, const framework::Tensor &x,
const framework::Tensor &y, const framework::Tensor &out,
const framework::Tensor &dout, int axis, framework::Tensor *dx,
framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
size_t N = static_cast<size_t>(framework::product(x_dim));
platform::ForRange<DeviceContext> for_range(
ctx.template device_context<DeviceContext>(), N);
......@@ -483,11 +489,11 @@ void ElemwiseGradComputeNoBroadcast(
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradComputeWithBroadcast(
const framework::ExecutionContext& ctx, const framework::DDim& x_dim,
const framework::DDim& y_dim_untrimed, const framework::Tensor& x,
const framework::Tensor& y, const framework::Tensor& out,
const framework::Tensor& dout, int axis, framework::Tensor* dx,
framework::Tensor* dy, DX_OP dx_op, DY_OP dy_op) {
const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
const framework::DDim &y_dim_untrimed, const framework::Tensor &x,
const framework::Tensor &y, const framework::Tensor &out,
const framework::Tensor &dout, int axis, framework::Tensor *dx,
framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis);
auto y_dim = trim_trailing_singular_dims(y_dim_untrimed);
axis = (y_dim.size() == 0) ? x_dim.size() : axis;
......@@ -531,14 +537,14 @@ void ElemwiseGradComputeWithBroadcast(
}
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
const framework::Tensor& x, const framework::Tensor& y,
const framework::Tensor& out,
const framework::Tensor& dout, int axis,
framework::Tensor* dx, framework::Tensor* dy,
void ElemwiseGradCompute(const framework::ExecutionContext &ctx,
const framework::Tensor &x, const framework::Tensor &y,
const framework::Tensor &out,
const framework::Tensor &dout, int axis,
framework::Tensor *dx, framework::Tensor *dy,
DX_OP dx_op, DY_OP dy_op) {
const framework::DDim& x_dim = x.dims();
const framework::DDim& y_dim = y.dims();
const framework::DDim &x_dim = x.dims();
const framework::DDim &y_dim = y.dims();
if (x.dims() == y.dims()) {
ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
......@@ -553,27 +559,27 @@ void ElemwiseGradCompute(const framework::ExecutionContext& ctx,
// In elementwise_add, elementwise_sub, we use dout as fake X, Y, Out to reuse
// elementwise code.
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP>
void ElemwiseExplicitGradCompute(const framework::ExecutionContext& ctx,
const framework::Tensor& x,
const framework::Tensor& y,
const framework::Tensor& out,
const framework::Tensor& dout, int axis,
framework::Tensor* dx, framework::Tensor* dy,
void ElemwiseExplicitGradCompute(const framework::ExecutionContext &ctx,
const framework::Tensor &x,
const framework::Tensor &y,
const framework::Tensor &out,
const framework::Tensor &dout, int axis,
framework::Tensor *dx, framework::Tensor *dy,
DX_OP dx_op, DY_OP dy_op) {
if (dy == nullptr) {
const framework::DDim& dx_dims = dout.dims();
const framework::DDim &dx_dims = dout.dims();
auto dy_dims = dx_dims;
ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
} else {
if (dout.dims() == dy->dims()) {
const framework::DDim& dx_dims = dout.dims();
const framework::DDim& dy_dims = dy->dims();
const framework::DDim &dx_dims = dout.dims();
const framework::DDim &dy_dims = dy->dims();
ElemwiseGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP>(
ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
} else { // Y is a scalar
auto dx_dims = dout.dims();
const framework::DDim& dy_dims = dy->dims();
const framework::DDim &dy_dims = dy->dims();
ElemwiseGradComputeWithBroadcast<DeviceContext, T, DX_OP, DY_OP>(
ctx, dx_dims, dy_dims, x, y, out, dout, axis, dx, dy, dx_op, dy_op);
}
......@@ -583,13 +589,13 @@ void ElemwiseExplicitGradCompute(const framework::ExecutionContext& ctx,
// Deprecated
template <typename DeviceContext, typename T, typename functor,
typename broadcastfunctor, typename broadcast2functor>
void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout, int axis,
framework::Tensor* dx, framework::Tensor* dy) {
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
void ElementwiseGradCompute(const framework::ExecutionContext &ctx,
const framework::Tensor *x,
const framework::Tensor *y,
const framework::Tensor *out,
const framework::Tensor *dout, int axis,
framework::Tensor *dx, framework::Tensor *dy) {
auto &place = *ctx.template device_context<DeviceContext>().eigen_device();
auto x_dims = x->dims();
auto y_dims = y->dims();
......@@ -627,10 +633,10 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx,
template <typename Functor, typename DeviceContext, typename T,
typename OutType = T>
void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y, int axis, Functor func,
framework::Tensor* z) {
void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
const framework::Tensor *x,
const framework::Tensor *y, int axis, Functor func,
framework::Tensor *z) {
TransformFunctor<Functor, T, DeviceContext, OutType> functor(
x, y, z, ctx.template device_context<DeviceContext>(), func);
......@@ -661,5 +667,823 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
}
}
// FusedElemwiseAndAct
// --- forward
template <typename T, typename CompoundFunctor, bool KeepIntermediateOut>
struct FusedElemwiseAndActNoBroadcast {
HOSTDEVICE void operator()(size_t i) {
T y_val = y_[i];
T x_val = x_[i];
if (KeepIntermediateOut) {
T intermeidiate_out = compound_functor_.GetIntermediateOut(x_val, y_val);
intermediate_out_[i] = intermeidiate_out;
out_[i] =
compound_functor_.GetOutUseIntermediateOut(x_val, intermeidiate_out);
} else {
out_[i] = compound_functor_.GetOut(x_val, y_val);
}
}
const T *x_;
const T *y_;
CompoundFunctor compound_functor_;
T *out_;
T *intermediate_out_;
};
// FusedElemwiseAndActBroadcast1:
// In this case, X and Y can be reshaped to a matrix.
// For example shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5) and axis = -1 or 2,
// X can be reshaped to (6, 20) and Y can be reshaped to (1, 20)
template <typename T, typename CompoundFunctor, bool BcastY,
bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActBroadcast1CPU(const T *x, const T *y,
CompoundFunctor compound_functor,
int h, int w, T *out,
T *intermediate_out) {
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
int offset = i * w + j;
T y_val = BcastY ? y[j] : y[offset];
T x_val = BcastY ? x[offset] : x[j];
int64_t intermediate_out_offset;
if (KeepIntermediateOut) {
T intermeidiate_out = compound_functor.GetIntermediateOut(x_val, y_val);
if (SameShapeOfIntermediateOutAndOut) {
// for the case of f1(f2(x, y))
intermediate_out_offset = offset;
} else if (BcastY) {
intermediate_out_offset = j;
} else {
intermediate_out_offset = offset;
}
intermediate_out[intermediate_out_offset] = intermeidiate_out;
out[offset] =
compound_functor.GetOutUseIntermediateOut(x_val, intermeidiate_out);
} else {
out[offset] = compound_functor.GetOut(x_val, y_val);
}
}
}
}
// FusedElemwiseAndActBroadcast2
// In this case, X and Y can be reshaped to a matrix.
// For example shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4) and axis = 1,
// X can be reshaped to (2, 12, 5) and Y can be reshaped to (1, 12, 1)
// pre = 2, n = 12, post = 5
template <typename T, typename CompoundFunctor, bool BcastY,
bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActBroadcast2CPU(const T *x, const T *y, int pre,
int n, int post,
CompoundFunctor compound_functor,
T *out, T *intermediate_out) {
for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) {
for (int k = 0; k < post; ++k) {
int offset = i * n * post + j * post + k;
T y_val = BcastY ? y[j] : y[offset];
T x_val = BcastY ? x[offset] : x[j];
int64_t intermediate_out_offset;
if (KeepIntermediateOut) {
T intermeidiate_out =
compound_functor.GetIntermediateOut(x_val, y_val);
if (SameShapeOfIntermediateOutAndOut) {
// for the case of f1(f2(x, y))
intermediate_out_offset = offset;
} else if (BcastY) {
intermediate_out_offset = j;
} else {
intermediate_out_offset = offset;
}
intermediate_out[intermediate_out_offset] = intermeidiate_out;
out[offset] = compound_functor.GetOutUseIntermediateOut(
x_val, intermeidiate_out);
} else {
out[offset] = compound_functor.GetOut(x_val, y_val);
}
}
}
}
}
#ifdef __NVCC__
template <typename T, typename CompoundFunctor, bool BcastY,
bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static __global__ void FusedElemwiseAndActBroadcast1CUDAKernel(
const T *x, const T *y, int h, int w, CompoundFunctor compound_functor,
T *out, T *intermediate_out) {
int j = blockIdx.x;
int i = threadIdx.x;
while (i < h) {
int offset = i * w + j;
T y_val = BcastY ? y[j] : y[offset];
T x_val = BcastY ? x[offset] : x[j];
int64_t intermediate_out_offset;
if (KeepIntermediateOut) {
T intermeidiate_out = compound_functor.GetIntermediateOut(x_val, y_val);
if (SameShapeOfIntermediateOutAndOut) {
// for the case of f1(f2(x, y))
intermediate_out_offset = offset;
} else if (BcastY) {
intermediate_out_offset = j;
} else {
intermediate_out_offset = offset;
}
intermediate_out[intermediate_out_offset] = intermeidiate_out;
out[offset] =
compound_functor.GetOutUseIntermediateOut(x_val, intermeidiate_out);
} else {
out[offset] = compound_functor.GetOut(x_val, y_val);
}
i += ELEMWISE_MAX_BLOCK_DIM;
}
}
template <typename T, typename CompoundFunctor, bool BcastY,
bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActBroadcast1CUDA(cudaStream_t stream, const T *x,
const T *y,
CompoundFunctor compound_functor,
int h, int w, T *out,
T *intermediate_out) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
int gird_size = w;
FusedElemwiseAndActBroadcast1CUDAKernel<
T, CompoundFunctor, BcastY, KeepIntermediateOut,
SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
x, y, h, w, compound_functor, out, intermediate_out);
}
template <typename T, typename CompoundFunctor, bool BcastY,
bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static __global__ void FusedElemwiseAndActBroadcast2CUDAKernel(
const T *x, const T *y, CompoundFunctor compound_functor, int pre, int n,
int post, T *out, T *intermediate_out) {
int tid = threadIdx.x;
int j = blockIdx.x;
while (true) {
int i = tid / post;
int k = tid % post;
if (i >= pre) break;
int offset = i * n * post + j * post + k;
T y_val = BcastY ? y[j] : y[offset];
T x_val = BcastY ? x[offset] : x[j];
int64_t intermediate_out_offset;
if (KeepIntermediateOut) {
T intermeidiate_out = compound_functor.GetIntermediateOut(x_val, y_val);
if (SameShapeOfIntermediateOutAndOut) {
// for the case of f1(f2(x, y))
intermediate_out_offset = offset;
} else if (BcastY) {
intermediate_out_offset = j;
} else {
intermediate_out_offset = offset;
}
intermediate_out[intermediate_out_offset] = intermeidiate_out;
out[offset] =
compound_functor.GetOutUseIntermediateOut(x_val, intermeidiate_out);
} else {
out[offset] = compound_functor.GetOut(x_val, y_val);
}
tid += ELEMWISE_MAX_BLOCK_DIM;
}
}
template <typename T, typename CompoundFunctor, bool BcastY,
bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActBroadcast2CUDA(cudaStream_t stream, const T *x,
const T *y, int pre, int n,
int post,
CompoundFunctor compound_functor,
T *out, T *intermediate_out) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
int gird_size = n;
FusedElemwiseAndActBroadcast2CUDAKernel<
T, CompoundFunctor, BcastY, KeepIntermediateOut,
SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
x, y, compound_functor, pre, n, post, out, intermediate_out);
}
#endif
template <typename DeviceContext, typename T, typename CompoundFunctor,
bool KeepIntermediateOut>
void FusedElemwiseAndActComputeNoBroadcast(
const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
const framework::Tensor &x, const framework::Tensor &y,
CompoundFunctor compound_functor, framework::Tensor *out,
framework::Tensor *intermediate_out) {
size_t N = static_cast<size_t>(framework::product(x_dim));
platform::ForRange<DeviceContext> for_range(
ctx.template device_context<DeviceContext>(), N);
for_range(
FusedElemwiseAndActNoBroadcast<T, CompoundFunctor, KeepIntermediateOut>{
x.data<T>(), y.data<T>(), compound_functor,
out->mutable_data<T>(ctx.GetPlace()),
intermediate_out == nullptr
? nullptr
: intermediate_out->mutable_data<T>(ctx.GetPlace())});
}
template <typename DeviceContext, typename T, typename CompoundFunctor,
bool BcastY, bool KeepIntermediateOut,
bool SameShapeOfIntermediateOutAndOut>
void FusedElemwiseAndActComputeWithBroadcast(
const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
const framework::DDim &y_dim_untrimed, const framework::Tensor &x,
const framework::Tensor &y, CompoundFunctor compound_functor, int axis,
framework::Tensor *out, framework::Tensor *intermediate_out) {
axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis);
auto y_dim = trim_trailing_singular_dims(y_dim_untrimed);
axis = (y_dim.size() == 0) ? x_dim.size() : axis;
int pre, n, post;
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
if (post == 1) {
int h = pre;
int w = n;
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
FusedElemwiseAndActBroadcast1CUDA<T, CompoundFunctor, BcastY,
KeepIntermediateOut,
SameShapeOfIntermediateOutAndOut>(
ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
y.data<T>(), compound_functor, h, w,
out->mutable_data<T>(ctx.GetPlace()),
intermediate_out == nullptr
? nullptr
: intermediate_out->mutable_data<T>(ctx.GetPlace()));
#endif
} else {
FusedElemwiseAndActBroadcast1CPU<T, CompoundFunctor, BcastY,
KeepIntermediateOut,
SameShapeOfIntermediateOutAndOut>(
x.data<T>(), y.data<T>(), compound_functor, h, w,
out->mutable_data<T>(ctx.GetPlace()),
intermediate_out == nullptr
? nullptr
: intermediate_out->mutable_data<T>(ctx.GetPlace()));
}
} else {
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
FusedElemwiseAndActBroadcast2CUDA<T, CompoundFunctor, BcastY,
KeepIntermediateOut,
SameShapeOfIntermediateOutAndOut>(
ctx.template device_context<DeviceContext>().stream(), x.data<T>(),
y.data<T>(), pre, n, post, compound_functor,
out->mutable_data<T>(ctx.GetPlace()),
intermediate_out == nullptr
? nullptr
: intermediate_out->mutable_data<T>(ctx.GetPlace()));
#endif
} else {
FusedElemwiseAndActBroadcast2CPU<T, CompoundFunctor, BcastY,
KeepIntermediateOut,
SameShapeOfIntermediateOutAndOut>(
x.data<T>(), y.data<T>(), pre, n, post, compound_functor,
out->mutable_data<T>(ctx.GetPlace()),
intermediate_out == nullptr
? nullptr
: intermediate_out->mutable_data<T>(ctx.GetPlace()));
}
}
}
// --- backward
template <typename T, typename DX_OP, typename DY_OP, bool UseIntermediateOut>
struct FusedElemwiseAndActGradNoBroadcast {
HOSTDEVICE void operator()(size_t i) {
if (dx_ != nullptr) {
dx_[i] = UseIntermediateOut ? dx_op_(x_[i], y_[i], intermediate_out_[i],
out_[i], dout_[i])
: dx_op_(x_[i], y_[i], out_[i], dout_[i]);
}
if (dy_ != nullptr) {
dy_[i] = UseIntermediateOut ? dy_op_(x_[i], y_[i], intermediate_out_[i],
out_[i], dout_[i])
: dy_op_(x_[i], y_[i], out_[i], dout_[i]);
}
}
const T *x_;
const T *y_;
const T *intermediate_out_;
const T *out_;
const T *dout_;
DX_OP dx_op_;
DY_OP dy_op_;
T *dx_;
T *dy_;
};
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
bool UseIntermediateOut>
void FusedElemwiseAndActGradComputeNoBroadcast(
const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
const framework::DDim &y_dim, const framework::Tensor *x,
const framework::Tensor *y, const framework::Tensor *intermediate_out,
const framework::Tensor *out, const framework::Tensor *dout, int axis,
framework::Tensor *dx, framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
size_t N = static_cast<size_t>(framework::product(x_dim));
platform::ForRange<DeviceContext> for_range(
ctx.template device_context<DeviceContext>(), N);
for_range(
FusedElemwiseAndActGradNoBroadcast<T, DX_OP, DY_OP, UseIntermediateOut>{
x->data<T>(), y->data<T>(),
intermediate_out ? intermediate_out->data<T>() : nullptr,
out->data<T>(), dout->data<T>(), dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace())});
}
template <typename T, typename DX_OP, typename DY_OP, bool UseIntermediateOut,
bool BcastY, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast1CPU(const T *x, const T *y,
const T *intermediate_out,
const T *out, const T *dout,
int h, int w, DX_OP dx_op,
DY_OP dy_op, T *dx, T *dy) {
int64_t tmp_out_idx, x_idx, y_idx;
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
int offset = i * w + j;
tmp_out_idx = BcastY ? j : offset;
y_idx = BcastY ? j : offset;
x_idx = BcastY ? offset : j;
if (SameShapeOfIntermediateOutAndOut) {
tmp_out_idx = offset;
}
if (dx != nullptr) {
T tmp = UseIntermediateOut
? dx_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dx_op(x[x_idx], y[y_idx], out[offset], dout[offset]);
if (BcastY) {
dx[x_idx] = tmp;
} else {
if (i == 0) {
dx[x_idx] = tmp;
} else {
dx[x_idx] += tmp;
}
}
}
if (dy != nullptr) {
T tmp = UseIntermediateOut
? dy_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dy_op(x[x_idx], y[y_idx], out[offset], dout[offset]);
if (BcastY) {
if (i == 0) {
dy[y_idx] = tmp;
} else {
dy[y_idx] += tmp;
}
} else {
dy[y_idx] = tmp;
}
}
}
}
}
template <typename T, typename DX_OP, typename DY_OP, bool UseIntermediateOut,
bool BcastY, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast2CPU(const T *x, const T *y,
const T *intermediate_out,
const T *out, const T *dout,
int pre, int n, int post,
DX_OP dx_op, DY_OP dy_op,
T *dx, T *dy) {
int64_t tmp_out_idx, x_idx, y_idx;
for (int i = 0; i < pre; ++i) {
for (int j = 0; j < n; ++j) {
for (int k = 0; k < post; ++k) {
int offset = i * n * post + j * post + k;
tmp_out_idx = BcastY ? j : offset;
y_idx = BcastY ? j : offset;
x_idx = BcastY ? offset : j;
if (SameShapeOfIntermediateOutAndOut) {
tmp_out_idx = offset;
}
if (dx != nullptr) {
T tmp = UseIntermediateOut
? dx_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dx_op(x[x_idx], y[y_idx], out[offset], dout[offset]);
if (BcastY) {
dx[x_idx] = tmp;
} else {
if (i == 0 && k == 0) {
dx[x_idx] = tmp;
} else {
dx[x_idx] += tmp;
}
}
}
if (dy != nullptr) {
T tmp = UseIntermediateOut
? dy_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dy_op(x[x_idx], y[y_idx], out[offset], dout[offset]);
if (BcastY) {
if (i == 0 && k == 0) {
dy[y_idx] = tmp;
} else {
dy[y_idx] += tmp;
}
} else {
dy[y_idx] = tmp;
}
}
}
}
}
}
#ifdef __NVCC__
template <typename T, typename DX_OP, typename DY_OP, bool UseIntermediateOut,
bool BcastY, bool SameShapeOfIntermediateOutAndOut>
static __global__ void FusedElemwiseAndActGradBroadcast1CUDAKernel(
const T *x, const T *y, const T *intermediate_out, const T *out,
const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
int j = blockIdx.x;
int i = threadIdx.x;
int tid = threadIdx.x;
T val(0);
int64_t tmp_out_idx, x_idx, y_idx;
do {
int offset = i * w + j;
tmp_out_idx = BcastY ? j : offset;
y_idx = BcastY ? j : offset;
x_idx = BcastY ? offset : j;
if (SameShapeOfIntermediateOutAndOut) {
tmp_out_idx = offset;
}
if (dx != nullptr) {
T tmp = UseIntermediateOut
? dx_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dx_op(x[x_idx], y[y_idx], out[offset], dout[offset]);
if (BcastY) {
dx[x_idx] = tmp;
} else {
val += tmp;
}
}
if (dy != nullptr) {
T tmp = UseIntermediateOut
? dy_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dy_op(x[x_idx], y[y_idx], out[offset], dout[offset]);
if (BcastY) {
val += tmp;
} else {
dy[y_idx] = tmp;
}
}
i += ELEMWISE_MAX_BLOCK_DIM;
} while (i < h);
if (BcastY) {
if (dy) {
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dy[j] = val;
}
}
} else {
if (dx) {
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dx[j] = val;
}
}
}
}
template <typename T, typename DX_OP, typename DY_OP, bool UseIntermediateOut,
bool BcastY, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast1CUDA(cudaStream_t stream,
const T *x, const T *y,
const T *intermediate_out,
const T *out, const T *dout,
int h, int w, DX_OP dx_op,
DY_OP dy_op, T *dx, T *dy) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
int gird_size = w;
FusedElemwiseAndActGradBroadcast1CUDAKernel<
T, DX_OP, DY_OP, UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
x, y, intermediate_out, out, dout, h, w, dx_op, dy_op, dx, dy);
}
template <typename T, typename DX_OP, typename DY_OP, bool UseIntermediateOut,
bool BcastY, bool SameShapeOfIntermediateOutAndOut>
static __global__ void FusedElemwiseAndActGradBroadcast2CUDAKernel(
const T *x, const T *y, const T *intermediate_out, const T *out,
const T *dout, int pre, int n, int post, DX_OP dx_op, DY_OP dy_op, T *dx,
T *dy) {
int tid = threadIdx.x;
int j = blockIdx.x;
T val(0);
int ttid = tid;
int64_t tmp_out_idx, x_idx, y_idx;
while (true) {
int i = ttid / post;
int k = ttid % post;
if (i >= pre) break;
int offset = i * n * post + j * post + k;
tmp_out_idx = BcastY ? j : offset;
y_idx = BcastY ? j : offset;
x_idx = BcastY ? offset : j;
if (SameShapeOfIntermediateOutAndOut) {
tmp_out_idx = offset;
}
if (dx != nullptr) {
T tmp = UseIntermediateOut
? dx_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dx_op(x[x_idx], y[y_idx], out[offset], dout[offset]);
if (BcastY) {
dx[x_idx] = tmp;
} else {
val += tmp;
}
}
if (dy != nullptr) {
T tmp = UseIntermediateOut
? dy_op(x[x_idx], y[y_idx], intermediate_out[tmp_out_idx],
out[offset], dout[offset])
: dy_op(x[x_idx], y[y_idx], out[offset], dout[offset]);
if (BcastY) {
val += tmp;
} else {
dy[y_idx] = tmp;
}
}
ttid += ELEMWISE_MAX_BLOCK_DIM;
}
if (BcastY) {
if (dy) {
int h = pre * post;
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dy[j] = val;
}
}
} else {
if (dx) {
int h = pre * post;
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dx[j] = val;
}
}
}
}
template <typename T, typename DX_OP, typename DY_OP, bool UseIntermediateOut,
bool BcastY, bool SameShapeOfIntermediateOutAndOut>
static void FusedElemwiseAndActGradBroadcast2CUDA(
cudaStream_t stream, const T *x, const T *y, const T *intermediate_out,
const T *out, const T *dout, int pre, int n, int post, DX_OP dx_op,
DY_OP dy_op, T *dx, T *dy) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
int gird_size = n;
FusedElemwiseAndActGradBroadcast2CUDAKernel<
T, DX_OP, DY_OP, UseIntermediateOut, BcastY,
SameShapeOfIntermediateOutAndOut><<<gird_size, block_size, 0, stream>>>(
x, y, intermediate_out, out, dout, pre, n, post, dx_op, dy_op, dx, dy);
}
#endif
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
bool UseIntermediateOut, bool BcastY,
bool SameShapeOfIntermediateOutAndOut>
void FusedElemwiseAndActGradComputeWithBroadcast(
const framework::ExecutionContext &ctx, const framework::DDim &x_dim,
const framework::DDim &y_dim_untrimed, const framework::Tensor *x,
const framework::Tensor *y, const framework::Tensor *intermediate_out,
const framework::Tensor *out, const framework::Tensor *dout, int axis,
framework::Tensor *dx, framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
axis = (axis == -1 ? x_dim.size() - y_dim_untrimed.size() : axis);
auto y_dim = trim_trailing_singular_dims(y_dim_untrimed);
axis = (y_dim.size() == 0) ? x_dim.size() : axis;
int pre, n, post;
get_mid_dims(x_dim, y_dim, axis, &pre, &n, &post);
if (post == 1) {
int h = pre;
int w = n;
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
FusedElemwiseAndActGradBroadcast1CUDA<T, DX_OP, DY_OP, UseIntermediateOut,
BcastY,
SameShapeOfIntermediateOutAndOut>(
ctx.template device_context<DeviceContext>().stream(), x->data<T>(),
y->data<T>(),
intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
out->data<T>(), dout->data<T>(), h, w, dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
} else {
FusedElemwiseAndActGradBroadcast1CPU<T, DX_OP, DY_OP, UseIntermediateOut,
BcastY,
SameShapeOfIntermediateOutAndOut>(
x->data<T>(), y->data<T>(),
intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
out->data<T>(), dout->data<T>(), h, w, dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
}
} else {
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
FusedElemwiseAndActGradBroadcast2CUDA<T, DX_OP, DY_OP, UseIntermediateOut,
BcastY,
SameShapeOfIntermediateOutAndOut>(
ctx.template device_context<DeviceContext>().stream(), x->data<T>(),
y->data<T>(),
intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
out->data<T>(), dout->data<T>(), pre, n, post, dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
#endif
} else {
FusedElemwiseAndActGradBroadcast2CPU<T, DX_OP, DY_OP, UseIntermediateOut,
BcastY,
SameShapeOfIntermediateOutAndOut>(
x->data<T>(), y->data<T>(),
intermediate_out == nullptr ? nullptr : intermediate_out->data<T>(),
out->data<T>(), dout->data<T>(), pre, n, post, dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
dy == nullptr ? nullptr : dy->mutable_data<T>(ctx.GetPlace()));
}
}
}
template <typename DeviceContext, typename T, typename DX_OP, typename DY_OP,
bool UseIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
void FusedElemwiseAndActGradComputeEx(
const framework::ExecutionContext &ctx, const framework::Tensor *x,
const framework::Tensor *y, const framework::Tensor *out,
const framework::Tensor *intermediate_out, const framework::Tensor *dout,
int axis, framework::Tensor *dx, framework::Tensor *dy, DX_OP dx_op,
DY_OP dy_op) {
const framework::DDim &x_dim = x->dims();
const framework::DDim &y_dim = y->dims();
if (UseIntermediateOut) {
PADDLE_ENFORCE(intermediate_out, "intermediate_out should not be nullptr");
}
if (x_dim == y_dim) {
FusedElemwiseAndActGradComputeNoBroadcast<DeviceContext, T, DX_OP, DY_OP,
UseIntermediateOut>(
ctx, x_dim, y_dim, x, y, intermediate_out, out, dout, axis, dx, dy,
dx_op, dy_op);
} else { // Y is a scalar
bool bcast_y = x_dim.size() >= y_dim.size();
if (x_dim.size() == y_dim.size()) {
for (int i = 0; i < x_dim.size(); ++i) {
if (x_dim[i] < y_dim[i]) {
bcast_y = false;
break;
}
}
}
// z = f1(x, f2(y))
// z = f1(f2(x, y))
if (bcast_y) { // Y should be broadcast.
FusedElemwiseAndActGradComputeWithBroadcast<
DeviceContext, T, DX_OP, DY_OP, UseIntermediateOut, true /*BcastY*/,
SameShapeOfIntermediateOutAndOut>(ctx, x_dim, y_dim, x, y,
intermediate_out, out, dout, axis,
dx, dy, dx_op, dy_op);
} else {
FusedElemwiseAndActGradComputeWithBroadcast<
DeviceContext, T, DX_OP, DY_OP, UseIntermediateOut, false /*BcastY*/,
SameShapeOfIntermediateOutAndOut>(ctx, y_dim, x_dim, x, y,
intermediate_out, out, dout, axis,
dx, dy, dx_op, dy_op);
}
}
}
template <typename DeviceContext, typename T, typename CompoundFunctor,
bool KeepIntermediateOut, bool SameShapeOfIntermediateOutAndOut>
void FusedElemwiseAndActComputeEx(const framework::ExecutionContext &ctx,
const framework::Tensor &x,
const framework::Tensor &y, int axis,
CompoundFunctor compound_functor,
framework::Tensor *out,
framework::Tensor *intermediate_out) {
if (KeepIntermediateOut) {
PADDLE_ENFORCE(intermediate_out,
"The keep_intermediate_value is opened, "
"intermediate_out should not be nullptr.");
}
const framework::DDim &x_dim = x.dims();
const framework::DDim &y_dim = y.dims();
if (x.dims() == y.dims()) {
FusedElemwiseAndActComputeNoBroadcast<DeviceContext, T, CompoundFunctor,
KeepIntermediateOut>(
ctx, x_dim, x, y, compound_functor, out, intermediate_out);
} else {
// Whether the shape of Y is a continuous subsequence of X,
// For more information please refer to the op's introduction.
bool bcast_y = x.dims().size() >= y.dims().size();
if (x.dims().size() == y.dims().size()) {
for (int i = 0; i < x.dims().size(); ++i) {
if (x.dims()[i] < y.dims()[i]) {
bcast_y = false;
break;
}
}
}
// z = f1(x, f2(y))
// z = f1(f2(x, y))
if (bcast_y) { // Y should be broadcast.
// In this case,
// for 'f2(y)', the shape of intermediate_out should be equal to the shape
// of Y.
// for 'f2(x, y)', the shape of intermediate_out should be equal to the
// shape of Out.
// the shape of Out should be equal to the shape of X.
FusedElemwiseAndActComputeWithBroadcast<
DeviceContext, T, CompoundFunctor, true /*BcastY*/,
KeepIntermediateOut, SameShapeOfIntermediateOutAndOut>(
ctx, x_dim /*OutShape*/, y_dim, x, y, compound_functor, axis, out,
intermediate_out);
} else {
// In this case,
// for 'f2(y)', the shape of intermediate_out should be equal to the shape
// of Out.
// for 'f2(x, y)', the shape of intermediate_out should be equal to the
// shape of Out.
// the shape of Out should be equal to the shape of Y.
FusedElemwiseAndActComputeWithBroadcast<
DeviceContext, T, CompoundFunctor, false /*BcastY*/,
KeepIntermediateOut, SameShapeOfIntermediateOutAndOut>(
ctx, y_dim /*OutShape*/, x_dim, x, y, compound_functor, axis, out,
intermediate_out);
}
}
}
} // namespace operators
} // namespace paddle
......@@ -12,14 +12,60 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fused_elemwise_activation_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/operators/fused_elemwise_activation_op.h"
namespace paddle {
namespace operators {
/*
* Whether the compound function is Unary(Binary(X, Y)).
* For Unary(Binary(X, Y)), the intermediate_out's shape is the same the final
* out.
*/
static bool IsUnaryCompound(const std::vector<std::string> &functor_list) {
PADDLE_ENFORCE_EQ(functor_list.size(), 2);
static std::unordered_set<std::string> binary_fun = {
"elementwise_add", "elementwise_mul", "elementwise_add_grad",
"elementwise_mul_grad"};
return binary_fun.count(functor_list[1]) != 0;
}
/*
* Whether the Input(X) could be absent.
*/
static bool InputXCanBeAbsent(const std::vector<std::string> &functor_list) {
PADDLE_ENFORCE_EQ(functor_list.size(), 2);
static std::unordered_set<std::string> binary_fun = {"elementwise_add_grad"};
return binary_fun.count(functor_list[0]) != 0 ||
binary_fun.count(functor_list[1]) != 0;
}
/*
* Whether the compound function is supported.
* For Unary(Binary(X, Y)), the intermediate_out's shape is the same the final
* out.
*/
static bool IsSupportedCompound(const std::vector<std::string> &functors) {
static std::unordered_set<std::string> unary_fun = {"scale", "relu"};
static std::unordered_set<std::string> binary_fun = {"elementwise_add",
"elementwise_mul"};
std::string unary_fun_str;
if (binary_fun.count(functors[0])) {
unary_fun_str = functors[1];
} else if (binary_fun.count(functors[1])) {
unary_fun_str = functors[0];
} else {
PADDLE_THROW("%s and %s are not included in fused_list.", functors[0],
functors[1]);
}
PADDLE_ENFORCE_EQ(unary_fun.count(unary_fun_str), 1,
"%s is not included in fused_list.", unary_fun_str);
return true;
}
class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -37,11 +83,44 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputDim("Y");
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
"Rank of first input must >= rank of second input.");
ctx->SetOutputDim("Out", x_dim);
ctx->ShareLoD("X", /*->*/ "Out");
// Whether the shape of Y is a continuous subsequence of X,
// For more information please refer to the op's introduction.
bool bcast_y = x_dim.size() >= y_dim.size();
if (x_dim.size() == y_dim.size()) {
for (int i = 0; i < x_dim.size(); ++i) {
if (x_dim[i] < y_dim[i]) {
bcast_y = false;
break;
}
}
}
auto &out_dim = bcast_y ? x_dim : y_dim;
std::string out_lod = bcast_y ? "X" : "Y";
if (ctx->Attrs().Get<bool>("keep_intermediate_value")) {
PADDLE_ENFORCE(ctx->HasOutput("IntermediateOut"),
"Output(IntermediateOut) of FusedElemwiseActivationOp "
"should not be null.");
if (IsUnaryCompound(
ctx->Attrs().Get<std::vector<std::string>>("functor_list"))) {
// for Unary(Binary(X, Y)), the shape and lod of out and
// intermediate_out are the same.
ctx->SetOutputDim("IntermediateOut", out_dim);
// set the lod of intermediate_out
ctx->ShareLoD(out_lod, /*->*/ "IntermediateOut");
} else {
// for Binary(X, Unary(Y)), the shape and lod of Y and
// intermediate_out are the same.
ctx->SetOutputDim("IntermediateOut", y_dim);
// set the lod of intermediate_out
ctx->ShareLoD("Y", /*->*/ "IntermediateOut");
}
}
ctx->SetOutputDim("Out", out_dim);
ctx->ShareLoD(out_lod, /*->*/ "Out");
}
protected:
......@@ -59,29 +138,42 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
class FusedElemwiseActivationMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(vector<Tensor>)");
AddInput("Y", "(vector<Tensor>)");
AddOutput("Out", "vector<Tensor>");
AddInput(
"X",
"(Tensor) The input tensor of fused_elemwise_activation operator.");
AddInput(
"Y",
"(Tensor) The input tensor of fused_elemwise_activation operator.");
AddOutput("Out",
"vector<Tensor> The output tensor of fused_elemwise_activation "
"operator.");
AddOutput("IntermediateOut",
"Tensor The IntermediateOut tensor of fused_elemwise_activation "
"operator.")
.AsIntermediate();
AddAttr<int>("axis",
"axis is used by elementwise_op, the default value is -1.")
.SetDefault(-1);
AddAttr<float>("scale",
"scale is used by scale_op, the default value is 0.0.")
.SetDefault(0.0);
AddAttr<bool>("recomputation",
AddAttr<bool>(
"recomputation",
"Whether to recompute the Out."
"fused_elemwise_activation_grad has two methods to get the "
"dx and dy, one "
"is to use the 'Out', and the other is not to use it. "
"The former method will save the time of recomputing the "
"'Out', but it must occupy the memory to store the 'out'. "
"While, the later method can avoid occupying the memory, "
"but it must recompute the 'Out'. The default value is true.")
"The computation of fused_elemwise_activation_grad has two methods to "
"get the dx and dy, one is to use the 'Out', and the other is not. "
"The former method will save the time of recomputing the 'Out', but it "
"must occupy the memory to store the 'out'. While, the later method "
"can avoid occupying the memory, but it must recompute the 'Out'. "
"It is useful for Unary(Binary(X, Y)). The default value is true.")
.SetDefault(true);
AddAttr<bool>("keep_intermediate_value",
"Whether to save the intermediate_out.")
.SetDefault(false);
AddAttr<std::vector<std::string>>("functor_list",
"The functors that should be fused.")
.AddCustomChecker([&](const std::vector<std::string> &functor_list) {
PADDLE_ENFORCE(ValidCheck(functor_list));
PADDLE_ENFORCE(IsSupportedCompound(functor_list));
});
AddComment(R"DOC(
......@@ -93,30 +185,38 @@ operators (elementwise_op and activation_op):
Z = Binary(X, Unary(Y))
Z = Unary(Binary(X, Y))
The attributions of activation_op can be get from fused_elemwise_activation_op's
attributions. functor_list records the functors to be fused, for example
"scale,elementwise_add".
There are two cases for this operator:
)DOC");
}
1. The shape of $Y$ and $X$ is the same.
2. The shape of $Y$ is a continuous subsequence of $X$ or the shape of $X$ is a continuous subsequence of $Y$.
private:
bool ValidCheck(const std::vector<std::string> &functors) {
std::unordered_set<std::string> unary_fun = {"scale", "relu"};
std::unordered_set<std::string> binary_fun = {"elementwise_add"};
For case 2 (assume that the shape of $Y$ is a continuous subsequence of $X$ ):
std::string unary_fun_str;
if (binary_fun.count(functors[0])) {
unary_fun_str = functors[1];
} else if (binary_fun.count(functors[1])) {
unary_fun_str = functors[0];
} else {
PADDLE_THROW("%s and %s are not included in fused_list.", functors[0],
functors[1]);
}
PADDLE_ENFORCE_EQ(unary_fun.count(unary_fun_str), 1,
"%s is not included in fused_list.", unary_fun_str);
return true;
1. Broadcast $Y$ to match the shape of $X$, where $axis$ is the start dimension index
for broadcasting $Y$ onto $X$.
2. If $axis$ is -1 (default), $axis = rank(X) - rank(Y)$.
3. The trailing dimensions of size 1 for $Y$ will be ignored for the consideration of
subsequence, such as shape(Y) = (2, 1) => (2).
For example:
.. code-block:: python
shape(X) = (2, 3, 4, 5), shape(Y) = (,)
shape(X) = (2, 3, 4, 5), shape(Y) = (5,)
shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5), with axis=-1(default) or axis=2
shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0
shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0
The inputs $X$ and $Y$ can carry the different LoD information.
But the output only shares the LoD information with the one whose shape is the same with Out.
The attributions of activation_op can be get from fused_elemwise_activation_op's.
The functor_list records the functions to be fused, for example
["scale", "elementwise_add"].
)DOC");
}
};
......@@ -141,6 +241,7 @@ class FusedElemwiseActivationGradMaker
op_desc_ptr->SetInput(framework::GradVarName(output_param),
this->OutputGrad(output_param));
}
op_desc_ptr->SetAttrMap(this->Attrs());
std::vector<std::string> functor_names =
......@@ -158,40 +259,59 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input.");
"Input(Out@Grad) should not be null");
if (ctx->Attrs().Get<bool>("keep_intermediate_value")) {
PADDLE_ENFORCE(ctx->HasInput("IntermediateOut"),
"Input(IntermediateOut) should not be null");
} else {
PADDLE_ENFORCE_EQ(ctx->Inputs(framework::GradVarName("Out")).size(), 1);
}
auto funtor_list =
ctx->Attrs().Get<std::vector<std::string>>("functor_list");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
if (ctx->HasInputs("X")) {
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
ctx->ShareLoD("X", x_grad_name);
} else {
// Node: If "X" is absence, the shape of Y should be a continuous
// subsequence of X, if not, we could not infer the shape of dx.
// Currently, only when Binary is elementwise_add or elementwise_sub,
// the "X" could be absent.
PADDLE_ENFORCE(InputXCanBeAbsent(funtor_list),
"Only when BinaryFunctor is elementwise_add, the 'X' "
"could be absent.");
// For Unary(Binary(X, Y)), IntermediateOut should not be empty.
if (IsUnaryCompound(funtor_list)) {
PADDLE_ENFORCE(
ctx->HasInputs("IntermediateOut"),
"If the compound_functor is Unary(Binary(X, Y)) and Binary "
"is elementwise_add, the intermediate_out must be not absent.");
}
ctx->SetOutputDim(x_grad_name,
ctx->GetInputDim(framework::GradVarName("Out")));
ctx->ShareLoD(framework::GradVarName("Out"), x_grad_name);
}
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims);
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
ctx->SetOutputDim(y_grad_name, ctx->GetInputDim("Y"));
ctx->ShareLoD("Y", y_grad_name);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type_index = ctx.Input<framework::Tensor>("X")->type();
PADDLE_ENFORCE_EQ(input_data_type_index,
ctx.Input<framework::Tensor>("Y")->type(),
"The element's type of input should be the same.");
PADDLE_ENFORCE_EQ(
input_data_type_index,
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
"The element's type of input should be the same.");
// PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
auto input_data_type_index = ctx.Input<framework::Tensor>("Y")->type();
auto input_data_type = framework::ToDataType(input_data_type_index);
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......
......@@ -20,208 +20,114 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/elementwise_op_function.h"
#include "paddle/fluid/operators/math/compound_functors.h"
#include "paddle/fluid/operators/math/functors.h"
namespace math = paddle::operators::math;
namespace paddle {
namespace operators {
// CompoundFunctors
// For example: Z = Binary(X, Unary(Y))
template <typename T, typename BinaryFun, typename UnaryFun>
struct BinaryCompoundFunctor {
BinaryCompoundFunctor(const BinaryFun &binary_fun, const UnaryFun &unary_fun)
: binary_fun_(binary_fun), unary_fun_(unary_fun) {}
inline HOSTDEVICE T operator()(T x, T y) {
return binary_fun_(x, unary_fun_(y));
}
private:
BinaryFun binary_fun_;
UnaryFun unary_fun_;
};
// For example: Z = Unary(Binary(X, Y))
template <typename T, typename UnaryFun, typename BinaryFun>
struct UnaryCompoundFunctor {
UnaryCompoundFunctor(const UnaryFun &unary_fun, const BinaryFun &binary_fun)
: unary_fun_(unary_fun), binary_fun_(binary_fun) {}
inline HOSTDEVICE T operator()(T x, T y) {
return unary_fun_(binary_fun_(x, y));
}
private:
UnaryFun unary_fun_;
BinaryFun binary_fun_;
};
// FIXME(zcd): DBinaryFun and DUnaryFun have to method to get
// the dx, one is to use the 'out', and the other is not to use it.
// the former method will save the time of recomputing the
// 'out', but it must occupy the memory to store the 'out'.
// While the later method can avoid occupying this memory,
// but it must recompute the 'out'.
template <typename T, typename DBinaryFun, typename UnaryFun,
bool Recomputation = true>
struct BinaryCompoundGradDxFunctor {
BinaryCompoundGradDxFunctor(const DBinaryFun &d_binary_fun,
const UnaryFun &unary_fun)
: d_binary_fun_(d_binary_fun), unary_fun_(unary_fun) {}
inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
if (Recomputation) {
return dout * d_binary_fun_(x, unary_fun_(y));
} else {
return dout * d_binary_fun_(x, unary_fun_(y), out);
}
}
private:
DBinaryFun d_binary_fun_;
UnaryFun unary_fun_;
};
template <typename T, typename DBinaryFun, typename UnaryFun,
typename DUnaryFun, bool Recomputation = true>
struct BinaryCompoundGradDyFunctor {
BinaryCompoundGradDyFunctor(const DBinaryFun &d_binary_fun,
const UnaryFun &unary_fun,
const DUnaryFun &d_unary_fun)
: d_binary_fun_(d_binary_fun),
unary_fun_(unary_fun),
d_unary_fun_(d_unary_fun) {}
inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
if (Recomputation) {
return dout * d_binary_fun_(unary_fun_(y), x) * d_unary_fun_(y);
} else {
return dout * d_binary_fun_(unary_fun_(y), x, out) * d_unary_fun_(y);
}
}
private:
DBinaryFun d_binary_fun_;
UnaryFun unary_fun_;
DUnaryFun d_unary_fun_;
};
template <typename T, typename DUnaryFun, typename BinaryFun,
typename DBinaryFun, bool Recomputation = true>
struct UnaryCompoundGradDxFunctor {
UnaryCompoundGradDxFunctor(const DUnaryFun &d_unary_fun,
const BinaryFun &binary_fun,
const DBinaryFun &d_binary_fun)
: d_unary_fun_(d_unary_fun),
binary_fun_(binary_fun),
d_binary_fun_(d_binary_fun) {}
inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
T base;
if (Recomputation) {
base = dout * d_unary_fun_(binary_fun_(x, y));
} else {
base = dout * d_unary_fun_(binary_fun_(x, y), out);
}
return base * d_binary_fun_(x, y);
}
private:
DUnaryFun d_unary_fun_;
BinaryFun binary_fun_;
DBinaryFun d_binary_fun_;
};
template <typename T, typename DUnaryFun, typename BinaryFun,
typename DBinaryFun, bool Recomputation = true>
struct UnaryCompoundGradDyFunctor {
UnaryCompoundGradDyFunctor(const DUnaryFun &d_unary_fun,
const BinaryFun &binary_fun,
const DBinaryFun &d_binary_fun)
: d_unary_fun_(d_unary_fun),
binary_fun_(binary_fun),
d_binary_fun_(d_binary_fun) {}
inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
T base;
if (Recomputation) {
base = dout * d_unary_fun_(binary_fun_(x, y));
} else {
base = dout * d_unary_fun_(binary_fun_(x, y), out);
}
return base * d_binary_fun_(y, x);
}
private:
DUnaryFun d_unary_fun_;
BinaryFun binary_fun_;
DBinaryFun d_binary_fun_;
};
template <typename DeviceContext, typename T, typename BinaryFunctor,
typename UnaryFunctor>
static void RunBinaryCompoundFunctor(const framework::ExecutionContext &ctx,
const BinaryFunctor &binary_functor,
const UnaryFunctor &unary_functor,
const framework::Tensor *in_x,
const framework::Tensor *in_y,
framework::Tensor *output) {
static void RunBinaryCompoundFunctor(
const framework::ExecutionContext &ctx, const BinaryFunctor &binary_functor,
const UnaryFunctor &unary_functor, const framework::Tensor &in_x,
const framework::Tensor &in_y, std::vector<framework::Tensor *> *outputs) {
// Z = Binary(X, Unary(Y))
// intermediate_out = Unary(Y)
// out = Binary(X, Unary(Y))
// In this case, the shape of intermediate_out and out are different.
paddle::operators::math::BinaryCompoundFunctor<T, BinaryFunctor, UnaryFunctor>
compound_func(binary_functor, unary_functor);
int axis = ctx.Attr<int>("axis");
using BinaryCompoundFunctor =
BinaryCompoundFunctor<T, BinaryFunctor, UnaryFunctor>;
ElementwiseComputeEx<BinaryCompoundFunctor, DeviceContext, T>(
ctx, in_x, in_y, axis,
BinaryCompoundFunctor(binary_functor, unary_functor), output);
if (ctx.Attr<bool>("keep_intermediate_value")) {
FusedElemwiseAndActComputeEx<DeviceContext, T,
paddle::operators::math::BinaryCompoundFunctor<
T, BinaryFunctor, UnaryFunctor>,
true /*KeepIntermediateValue*/,
false /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
} else {
FusedElemwiseAndActComputeEx<DeviceContext, T,
paddle::operators::math::BinaryCompoundFunctor<
T, BinaryFunctor, UnaryFunctor>,
false /*KeepIntermediateValue*/,
false /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
}
}
template <typename DeviceContext, typename T, typename UnaryFunctor,
typename BinaryFunctor>
static void RunUnaryCompoundFunctors(const framework::ExecutionContext &ctx,
const UnaryFunctor &unary_functor,
const BinaryFunctor &binary_functor,
const framework::Tensor *in_x,
const framework::Tensor *in_y,
framework::Tensor *output) {
static void RunUnaryCompoundFunctors(
const framework::ExecutionContext &ctx, const UnaryFunctor &unary_functor,
const BinaryFunctor &binary_functor, const framework::Tensor &in_x,
const framework::Tensor &in_y, std::vector<framework::Tensor *> *outputs) {
// Z = Unary(Binary(X, Y))
// intermediate_out = Binary(X, Y)
// out = Unary(Binary(X, Y))
// In this case, the shape of intermediate_out and out are the same.
int axis = ctx.Attr<int>("axis");
using UnaryCompoundFunctor =
UnaryCompoundFunctor<T, UnaryFunctor, BinaryFunctor>;
paddle::operators::math::UnaryCompoundFunctor<T, UnaryFunctor, BinaryFunctor>
compound_func(unary_functor, binary_functor);
ElementwiseComputeEx<UnaryCompoundFunctor, DeviceContext, T>(
ctx, in_x, in_y, axis,
UnaryCompoundFunctor(unary_functor, binary_functor), output);
if (ctx.Attr<bool>("keep_intermediate_value")) {
FusedElemwiseAndActComputeEx<DeviceContext, T,
paddle::operators::math::UnaryCompoundFunctor<
T, UnaryFunctor, BinaryFunctor>,
true /*KeepIntermediateValue*/,
true /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
} else {
FusedElemwiseAndActComputeEx<DeviceContext, T,
paddle::operators::math::UnaryCompoundFunctor<
T, UnaryFunctor, BinaryFunctor>,
false /*KeepIntermediateValue*/,
true /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
}
}
template <typename DeviceContext, typename T, typename BinaryGradFunctor,
typename UnaryFunctor, typename UnaryGradFunctor,
bool Recomputation = true>
typename UnaryFunctor, typename UnaryGradFunctor>
static void RunBinaryCompoundGradFunctors(
const framework::ExecutionContext &ctx,
const BinaryGradFunctor &binary_grad_functor,
const UnaryFunctor &unary_functor,
const UnaryGradFunctor &unary_grad_functor, const framework::Tensor *in_x,
const framework::Tensor *in_y, const framework::Tensor *in_out,
const framework::Tensor *in_intermediate_out,
const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
framework::Tensor *y_grad) {
// Z = Binary(X, Unary(Y))
int axis = ctx.Attr<int>("axis");
using BinaryCompoundDxFunctor =
BinaryCompoundGradDxFunctor<T, BinaryGradFunctor, UnaryFunctor,
Recomputation>;
paddle::operators::math::BinaryCompoundGradDxFunctor<T, BinaryGradFunctor,
UnaryFunctor>;
using BinaryCompoundDyFunctor =
BinaryCompoundGradDyFunctor<T, BinaryGradFunctor, UnaryFunctor,
UnaryGradFunctor, Recomputation>;
ElemwiseGradCompute<DeviceContext, T, BinaryCompoundDxFunctor,
BinaryCompoundDyFunctor>(
ctx, *in_x, *in_y, *in_out, *in_out_grad, axis, x_grad, y_grad,
BinaryCompoundDxFunctor(binary_grad_functor, unary_functor),
paddle::operators::math::BinaryCompoundGradDyFunctor<
T, BinaryGradFunctor, UnaryFunctor, UnaryGradFunctor>;
if (in_intermediate_out) {
FusedElemwiseAndActGradComputeEx<
DeviceContext, T, BinaryCompoundDxFunctor, BinaryCompoundDyFunctor,
true /*UseIntermediateOut*/,
false /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
y_grad, BinaryCompoundDxFunctor(binary_grad_functor, unary_functor),
BinaryCompoundDyFunctor(binary_grad_functor, unary_functor,
unary_grad_functor));
} else {
FusedElemwiseAndActGradComputeEx<
DeviceContext, T, BinaryCompoundDxFunctor, BinaryCompoundDyFunctor,
false /*UseIntermediateOut*/,
false /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
y_grad, BinaryCompoundDxFunctor(binary_grad_functor, unary_functor),
BinaryCompoundDyFunctor(binary_grad_functor, unary_functor,
unary_grad_functor));
}
}
template <typename DeviceContext, typename T, typename UnaryGradFunctor,
......@@ -233,143 +139,159 @@ static void RunUnaryCompoundGradFunctors(
const BinaryFunctor &binary_functor,
const BinaryGradFunctor &binary_grad_functor, const framework::Tensor *in_x,
const framework::Tensor *in_y, const framework::Tensor *in_out,
const framework::Tensor *in_intermediate_out,
const framework::Tensor *in_out_grad, framework::Tensor *x_grad,
framework::Tensor *y_grad) {
// Z = Unary(Binary(X, Y))
int axis = ctx.Attr<int>("axis");
using UnaryCompoundDxFunctor =
UnaryCompoundGradDxFunctor<T, UnaryGradFunctor, BinaryFunctor,
BinaryGradFunctor, Recomputation>;
paddle::operators::math::UnaryCompoundGradDxFunctor<
T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, Recomputation>;
using UnaryCompoundDyFunctor =
UnaryCompoundGradDyFunctor<T, UnaryGradFunctor, BinaryFunctor,
BinaryGradFunctor, Recomputation>;
ElemwiseGradCompute<DeviceContext, T, UnaryCompoundDxFunctor,
UnaryCompoundDyFunctor>(
ctx, *in_x, *in_y, *in_out, *in_out_grad, axis, x_grad, y_grad,
UnaryCompoundDxFunctor(unary_grad_functor, binary_functor,
paddle::operators::math::UnaryCompoundGradDyFunctor<
T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, Recomputation>;
if (in_intermediate_out) {
FusedElemwiseAndActGradComputeEx<
DeviceContext, T, UnaryCompoundDxFunctor, UnaryCompoundDyFunctor,
true /*UseIntermediateOut*/, true /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
y_grad, UnaryCompoundDxFunctor(unary_grad_functor, binary_functor,
binary_grad_functor),
UnaryCompoundDyFunctor(unary_grad_functor, binary_functor,
binary_grad_functor));
} else {
FusedElemwiseAndActGradComputeEx<DeviceContext, T, UnaryCompoundDxFunctor,
UnaryCompoundDyFunctor,
false /*UseIntermediateOut*/,
true /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, in_out, in_intermediate_out, in_out_grad, axis, x_grad,
y_grad, UnaryCompoundDxFunctor(unary_grad_functor, binary_functor,
binary_grad_functor),
UnaryCompoundDyFunctor(unary_grad_functor, binary_functor,
binary_grad_functor));
}
}
template <typename DeviceContext, typename T>
static void RunFunctors(const framework::ExecutionContext &ctx,
const framework::Tensor *in_x,
const framework::Tensor *in_y,
framework::Tensor *output) {
const framework::Tensor &in_x,
const framework::Tensor &in_y,
std::vector<framework::Tensor *> *outputs) {
auto &functors = ctx.Attr<std::vector<std::string>>("functor_list");
auto funcs_str = functors[0] + "," + functors[1];
// TODO(zcd): The following code can be refined.
auto funcs_str = functors[0] + "," + functors[1];
if (funcs_str == "elementwise_add,scale") {
// Z = Binary(X, Unary(Y))
T scale = static_cast<T>(ctx.Attr<float>("scale"));
RunBinaryCompoundFunctor<DeviceContext, T, math::AddFunctor<T>,
math::ScaleFunctor<T>>(
ctx, math::AddFunctor<T>(), math::ScaleFunctor<T>(scale), in_x, in_y,
output);
RunBinaryCompoundFunctor<DeviceContext, T,
paddle::operators::math::AddFunctor<T>,
paddle::operators::math::ScaleFunctor<T>>(
ctx, paddle::operators::math::AddFunctor<T>(),
paddle::operators::math::ScaleFunctor<T>(scale), in_x, in_y, outputs);
} else if (funcs_str == "scale,elementwise_add") {
// Z = Unary(Binary(X, Y))
T scale = static_cast<T>(ctx.Attr<float>("scale"));
RunUnaryCompoundFunctors<DeviceContext, T, math::ScaleFunctor<T>,
math::AddFunctor<T>>(
ctx, math::ScaleFunctor<T>(scale), math::AddFunctor<T>(), in_x, in_y,
output);
RunUnaryCompoundFunctors<DeviceContext, T,
paddle::operators::math::ScaleFunctor<T>,
paddle::operators::math::AddFunctor<T>>(
ctx, paddle::operators::math::ScaleFunctor<T>(scale),
paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
} else if (funcs_str == "elementwise_add,relu") {
RunBinaryCompoundFunctor<DeviceContext, T, math::AddFunctor<T>,
math::ReluFunctor<T>>(
ctx, math::AddFunctor<T>(), math::ReluFunctor<T>(), in_x, in_y, output);
// Z = Binary(X, Unary(Y))
RunBinaryCompoundFunctor<DeviceContext, T,
paddle::operators::math::AddFunctor<T>,
paddle::operators::math::ReluFunctor<T>>(
ctx, paddle::operators::math::AddFunctor<T>(),
paddle::operators::math::ReluFunctor<T>(), in_x, in_y, outputs);
} else if (funcs_str == "relu,elementwise_add") {
RunUnaryCompoundFunctors<DeviceContext, T, math::ReluFunctor<T>,
math::AddFunctor<T>>(
ctx, math::ReluFunctor<T>(), math::AddFunctor<T>(), in_x, in_y, output);
// Z = Unary(Binary(X, Y))
RunUnaryCompoundFunctors<DeviceContext, T,
paddle::operators::math::ReluFunctor<T>,
paddle::operators::math::AddFunctor<T>>(
ctx, paddle::operators::math::ReluFunctor<T>(),
paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
} else if (funcs_str == "elementwise_mul,scale") {
// Z = Binary(X, Unary(Y))
T scale = static_cast<T>(ctx.Attr<float>("scale"));
RunBinaryCompoundFunctor<DeviceContext, T,
paddle::operators::math::MulFunctor<T>,
paddle::operators::math::ScaleFunctor<T>>(
ctx, paddle::operators::math::MulFunctor<T>(),
paddle::operators::math::ScaleFunctor<T>(scale), in_x, in_y, outputs);
} else {
PADDLE_THROW("%s has not been implemented.", funcs_str);
}
}
template <typename DeviceContext, typename T>
template <typename DeviceContext, typename T, bool ReComputation>
static void RunGradFunctors(const framework::ExecutionContext &ctx,
const framework::Tensor *in_x,
const framework::Tensor *in_y,
const framework::Tensor *in_out,
const framework::Tensor *in_intermediate_out,
const framework::Tensor *in_out_grad,
framework::Tensor *x_grad,
framework::Tensor *y_grad) {
auto &functors = ctx.Attr<std::vector<std::string>>("functor_list");
auto funcs_str = functors[0] + "," + functors[1];
bool recomputation = ctx.Attr<bool>("recomputation");
// TODO(zcd): The following code can be refined. for example, use registion
// TODO(zcd): The following code can be refined. for example, use registrition
if (funcs_str == "elementwise_add_grad,scale_grad") {
// The backward of Z = Binary(X, Unary(Y))
T scale = static_cast<T>(ctx.Attr<float>("scale"));
if (recomputation) {
RunBinaryCompoundGradFunctors<DeviceContext, T, math::AddGradFunctor<T>,
math::ScaleFunctor<T>,
math::ScaleGradFunctor<T>, true>(
ctx, math::AddGradFunctor<T>(), math::ScaleFunctor<T>(scale),
math::ScaleGradFunctor<T>(scale), in_x, in_y, in_out, in_out_grad,
x_grad, y_grad);
} else {
RunBinaryCompoundGradFunctors<DeviceContext, T, math::AddGradFunctor<T>,
math::ScaleFunctor<T>,
math::ScaleGradFunctor<T>, false>(
ctx, math::AddGradFunctor<T>(), math::ScaleFunctor<T>(scale),
math::ScaleGradFunctor<T>(scale), in_x, in_y, in_out, in_out_grad,
x_grad, y_grad);
}
RunBinaryCompoundGradFunctors<DeviceContext, T,
paddle::operators::math::AddGradFunctor<T>,
paddle::operators::math::ScaleFunctor<T>,
paddle::operators::math::ScaleGradFunctor<T>>(
ctx, paddle::operators::math::AddGradFunctor<T>(),
paddle::operators::math::ScaleFunctor<T>(scale),
paddle::operators::math::ScaleGradFunctor<T>(scale), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad);
} else if (funcs_str == "scale_grad,elementwise_add_grad") {
// The backward of Z = Unary(Binary(X, Y))
T scale = static_cast<T>(ctx.Attr<float>("scale"));
if (recomputation) {
RunUnaryCompoundGradFunctors<DeviceContext, T, math::ScaleGradFunctor<T>,
math::AddFunctor<T>, math::AddGradFunctor<T>,
true>(ctx, math::ScaleGradFunctor<T>(scale),
math::AddFunctor<T>(),
math::AddGradFunctor<T>(), in_x, in_y,
in_out, in_out_grad, x_grad, y_grad);
} else {
RunUnaryCompoundGradFunctors<DeviceContext, T, math::ScaleGradFunctor<T>,
math::AddFunctor<T>, math::AddGradFunctor<T>,
false>(ctx, math::ScaleGradFunctor<T>(scale),
math::AddFunctor<T>(),
math::AddGradFunctor<T>(), in_x, in_y,
in_out, in_out_grad, x_grad, y_grad);
}
RunUnaryCompoundGradFunctors<DeviceContext, T,
paddle::operators::math::ScaleGradFunctor<T>,
paddle::operators::math::AddFunctor<T>,
paddle::operators::math::AddGradFunctor<T>,
ReComputation /*Recomputation*/>(
ctx, paddle::operators::math::ScaleGradFunctor<T>(scale),
paddle::operators::math::AddFunctor<T>(),
paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad);
} else if (funcs_str == "elementwise_add_grad,relu_grad") {
if (recomputation) {
RunBinaryCompoundGradFunctors<DeviceContext, T, math::AddGradFunctor<T>,
math::ReluFunctor<T>,
math::ReluGradFunctor<T>, true>(
ctx, math::AddGradFunctor<T>(), math::ReluFunctor<T>(),
math::ReluGradFunctor<T>(), in_x, in_y, in_out, in_out_grad, x_grad,
y_grad);
} else {
RunBinaryCompoundGradFunctors<DeviceContext, T, math::AddGradFunctor<T>,
math::ReluFunctor<T>,
math::ReluGradFunctor<T>, false>(
ctx, math::AddGradFunctor<T>(), math::ReluFunctor<T>(),
math::ReluGradFunctor<T>(), in_x, in_y, in_out, in_out_grad, x_grad,
y_grad);
}
RunBinaryCompoundGradFunctors<DeviceContext, T,
paddle::operators::math::AddGradFunctor<T>,
paddle::operators::math::ReluFunctor<T>,
paddle::operators::math::ReluGradFunctor<T>>(
ctx, paddle::operators::math::AddGradFunctor<T>(),
paddle::operators::math::ReluFunctor<T>(),
paddle::operators::math::ReluGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad);
} else if (funcs_str == "relu_grad,elementwise_add_grad") {
if (recomputation) {
RunUnaryCompoundGradFunctors<DeviceContext, T, math::ReluGradFunctor<T>,
math::AddFunctor<T>, math::AddGradFunctor<T>,
true>(ctx, math::ReluGradFunctor<T>(),
math::AddFunctor<T>(),
math::AddGradFunctor<T>(), in_x, in_y,
in_out, in_out_grad, x_grad, y_grad);
} else {
RunUnaryCompoundGradFunctors<DeviceContext, T, math::ReluGradFunctor<T>,
math::AddFunctor<T>, math::AddGradFunctor<T>,
false>(ctx, math::ReluGradFunctor<T>(),
math::AddFunctor<T>(),
math::AddGradFunctor<T>(), in_x, in_y,
in_out, in_out_grad, x_grad, y_grad);
}
RunUnaryCompoundGradFunctors<DeviceContext, T,
paddle::operators::math::ReluGradFunctor<T>,
paddle::operators::math::AddFunctor<T>,
paddle::operators::math::AddGradFunctor<T>,
ReComputation /*Recomputation*/>(
ctx, paddle::operators::math::ReluGradFunctor<T>(),
paddle::operators::math::AddFunctor<T>(),
paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad);
} else if (funcs_str == "elementwise_mul_grad,scale_grad") {
// The backward of Z = Binary(X, Unary(Y))
T scale = static_cast<T>(ctx.Attr<float>("scale"));
RunBinaryCompoundGradFunctors<DeviceContext, T,
paddle::operators::math::MulGradFunctor<T>,
paddle::operators::math::ScaleFunctor<T>,
paddle::operators::math::ScaleGradFunctor<T>>(
ctx, paddle::operators::math::MulGradFunctor<T>(),
paddle::operators::math::ScaleFunctor<T>(scale),
paddle::operators::math::ScaleGradFunctor<T>(scale), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad);
} else {
PADDLE_THROW("%s has not been implemented.", funcs_str);
}
......@@ -385,11 +307,23 @@ class FusedElemwiseActivationKernel : public framework::OpKernel<T> {
auto &in_y = detail::Ref(ctx.Input<framework::Tensor>("Y"),
"Cannot get input tensor %s, variable name = %s",
"Y", ctx.op().Input("Y"));
auto &output = detail::Ref(ctx.Output<framework::Tensor>("Out"),
"Cannot get input tensor %s, variable name = %s",
"Out", ctx.op().Output("Out"));
PADDLE_ENFORCE(ctx.HasOutput("Out"), "The output(Out) should not be empty");
auto output = ctx.Output<framework::Tensor>("Out");
std::vector<framework::Tensor *> outputs;
outputs.emplace_back(output);
if (ctx.Attr<bool>("keep_intermediate_value")) {
PADDLE_ENFORCE(ctx.HasOutput("IntermediateOut"),
"The keep_intermediate_value is enable, so the "
"IntermediateOut should not be empty.");
auto intermediate_out = ctx.Output<framework::Tensor>("IntermediateOut");
outputs.emplace_back(intermediate_out);
} else {
outputs.emplace_back(nullptr);
}
RunFunctors<DeviceContext, T>(ctx, &in_x, &in_y, &output);
RunFunctors<DeviceContext, T>(ctx, in_x, in_y, &outputs);
}
};
......@@ -397,28 +331,66 @@ template <typename DeviceContext, typename T>
class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto &in_x = detail::Ref(ctx.Input<framework::Tensor>("X"),
"Cannot get input tensor %s, variable name = %s",
"X", ctx.op().Input("X"));
auto &in_y = detail::Ref(ctx.Input<framework::Tensor>("Y"),
"Cannot get input tensor %s, variable name = %s",
"Y", ctx.op().Input("Y"));
auto &in_out = detail::Ref(ctx.Input<framework::Tensor>("Out"),
"Cannot get input tensor %s, variable name = %s",
"Out", ctx.op().Input("Out"));
auto &in_out_grad =
detail::Ref(ctx.Input<framework::Tensor>(framework::GradVarName("Out")),
"Cannot get input tensor %s, variable name = %s",
framework::GradVarName("Out"),
ctx.op().Input(framework::GradVarName("Out")));
auto x = ctx.Input<framework::Tensor>("X");
auto y = ctx.Input<framework::Tensor>("Y");
auto in_out = ctx.Input<framework::Tensor>("Out");
auto in_out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
framework::Tensor *x_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
framework::Tensor *y_grad =
ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
RunGradFunctors<DeviceContext, T>(ctx, &in_x, &in_y, &in_out, &in_out_grad,
x_grad, y_grad);
PADDLE_ENFORCE(y != nullptr, "Input(Y) should not be nullptr.");
if (ctx.Attr<bool>("recomputation")) {
PADDLE_ENFORCE(
x != nullptr,
"The recomputation is opened, so Input(X) should not be absent.");
} else {
PADDLE_ENFORCE(in_out != nullptr,
"The recomputation is disabled, so the Input('Out') "
"should not be empty.");
}
framework::Tensor *in_x;
auto functor_list = ctx.Attr<std::vector<std::string>>("functor_list");
// If functor_list contains elementwise_add, the backward doesn't use
// in_x, and in_outs.
if (x == nullptr) {
PADDLE_ENFORCE(functor_list[0] == "elementwise_add_grad" ||
functor_list[1] == "elementwise_add_grad",
"Only when the compoundfunctor contains "
"elementwise_add_grad, the 'X' could be absent.");
in_x = const_cast<framework::Tensor *>(in_out_grad);
in_out = const_cast<framework::Tensor *>(in_out_grad);
} else {
in_x = const_cast<framework::Tensor *>(x);
}
framework::Tensor *in_intermediate_out;
if (ctx.Attr<bool>("keep_intermediate_value")) {
in_intermediate_out = const_cast<framework::Tensor *>(
ctx.Input<framework::Tensor>("IntermediateOut"));
PADDLE_ENFORCE(in_intermediate_out != nullptr,
"The option of 'keep_intermediate_value' is opened, "
"so the number of 'Out' should be two.");
} else {
in_intermediate_out = nullptr;
}
if (ctx.Attr<bool>("recomputation")) {
RunGradFunctors<DeviceContext, T, true /*Recomputation*/>(
ctx, in_x, y, in_out, in_intermediate_out, in_out_grad, x_grad,
y_grad);
} else {
RunGradFunctors<DeviceContext, T, false /*Recomputation*/>(
ctx, in_x, y, in_out, in_intermediate_out, in_out_grad, x_grad,
y_grad);
}
}
};
} // namespace operators
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <unordered_set>
#include <vector>
namespace paddle {
namespace operators {
namespace math {
template <typename T, typename BinaryFunctor, typename UnaryFunctor>
struct BinaryCompoundFunctor {
BinaryCompoundFunctor(const BinaryFunctor func1, const UnaryFunctor func2)
: func1_(func1), func2_(func2) {}
// Z = BinaryFunctor(X, UnaryFunctor(Y))
inline HOSTDEVICE T GetOut(T x, T y) { return func1_(x, func2_(y)); }
inline HOSTDEVICE T GetOutUseIntermediateOut(T x, T intermediat_out) {
return func1_(x, intermediat_out);
}
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return func2_(y); }
BinaryFunctor func1_;
UnaryFunctor func2_;
};
template <typename T, typename UnaryFunctor, typename BinaryFunctor>
struct UnaryCompoundFunctor {
UnaryCompoundFunctor(const UnaryFunctor func1, const BinaryFunctor func2)
: func1_(func1), func2_(func2) {}
// Z = UnaryFunctor(BinaryFunctor(X, Y))
inline HOSTDEVICE T GetOut(T x, T y) { return func1_(func2_(x, y)); }
inline HOSTDEVICE T GetOutUseIntermediateOut(T x, T intermediat_out) {
return func1_(intermediat_out);
}
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return func2_(x, y); }
UnaryFunctor func1_;
BinaryFunctor func2_;
};
// FIXME(zcd): DBinaryFun and DUnaryFun have to method to get
// the dx, one is to use the 'out', and the other is not to use it.
// the former method will save the time of recomputing the
// 'out', but it must occupy the memory to store the 'out'.
// While the later method can avoid occupying this memory,
// but it must recompute the 'out'.
template <typename T, typename DBinaryFun, typename UnaryFun>
struct BinaryCompoundGradDxFunctor {
BinaryCompoundGradDxFunctor(const DBinaryFun &d_binary_fun,
const UnaryFun &unary_fun)
: d_binary_fun_(d_binary_fun), unary_fun_(unary_fun) {}
inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
return dout * d_binary_fun_.Dx(x, unary_fun_(y));
}
inline HOSTDEVICE T operator()(T x, T y, T intermediate_out, T out, T dout) {
return dout * d_binary_fun_.Dx(x, intermediate_out);
}
private:
DBinaryFun d_binary_fun_;
UnaryFun unary_fun_;
};
template <typename T, typename DBinaryFun, typename UnaryFun,
typename DUnaryFun>
struct BinaryCompoundGradDyFunctor {
BinaryCompoundGradDyFunctor(const DBinaryFun &d_binary_fun,
const UnaryFun &unary_fun,
const DUnaryFun &d_unary_fun)
: d_binary_fun_(d_binary_fun),
unary_fun_(unary_fun),
d_unary_fun_(d_unary_fun) {}
inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
return dout * d_binary_fun_.Dy(x, unary_fun_(y)) * d_unary_fun_(y);
}
inline HOSTDEVICE T operator()(T x, T y, T intermediate_out, T out, T dout) {
return dout * d_binary_fun_.Dy(x, intermediate_out) *
d_unary_fun_(y, intermediate_out);
}
private:
DBinaryFun d_binary_fun_;
UnaryFun unary_fun_;
DUnaryFun d_unary_fun_;
};
template <typename T, typename DUnaryFun, typename BinaryFun,
typename DBinaryFun, bool Recomputation = true>
struct UnaryCompoundGradDxFunctor {
UnaryCompoundGradDxFunctor(const DUnaryFun &d_unary_fun,
const BinaryFun &binary_fun,
const DBinaryFun &d_binary_fun)
: d_unary_fun_(d_unary_fun),
binary_fun_(binary_fun),
d_binary_fun_(d_binary_fun) {}
inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
T base;
if (Recomputation) {
base = dout * d_unary_fun_(binary_fun_(x, y));
} else {
base = dout * d_unary_fun_(binary_fun_(x, y), out);
}
return base * d_binary_fun_.Dx(x, y);
}
inline HOSTDEVICE T operator()(T x, T y, T intermediate_out, T out, T dout) {
T base;
if (Recomputation) {
base = dout * d_unary_fun_(intermediate_out);
} else {
base = dout * d_unary_fun_(intermediate_out, out);
}
return base * d_binary_fun_.Dx(x, y);
}
private:
DUnaryFun d_unary_fun_;
BinaryFun binary_fun_;
DBinaryFun d_binary_fun_;
};
template <typename T, typename DUnaryFun, typename BinaryFun,
typename DBinaryFun, bool Recomputation = true>
struct UnaryCompoundGradDyFunctor {
UnaryCompoundGradDyFunctor(const DUnaryFun &d_unary_fun,
const BinaryFun &binary_fun,
const DBinaryFun &d_binary_fun)
: d_unary_fun_(d_unary_fun),
binary_fun_(binary_fun),
d_binary_fun_(d_binary_fun) {}
inline HOSTDEVICE T operator()(T x, T y, T out, T dout) {
T base;
if (Recomputation) {
base = dout * d_unary_fun_(binary_fun_(x, y));
} else {
base = dout * d_unary_fun_(binary_fun_(x, y), out);
}
return base * d_binary_fun_.Dy(x, y);
}
inline HOSTDEVICE T operator()(T x, T y, T intermediate_out, T out, T dout) {
T base;
if (Recomputation) {
base = dout * d_unary_fun_(intermediate_out);
} else {
base = dout * d_unary_fun_(intermediate_out, out);
}
return base * d_binary_fun_.Dy(x, y);
}
private:
DUnaryFun d_unary_fun_;
BinaryFun binary_fun_;
DBinaryFun d_binary_fun_;
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -18,6 +18,19 @@ namespace paddle {
namespace operators {
namespace math {
// MulFunctor
template <typename T>
struct MulFunctor {
// out = x * y;
inline HOSTDEVICE T operator()(T x, T y) { return x * y; }
};
template <typename T>
struct MulGradFunctor {
inline HOSTDEVICE T Dx(T x, T y) { return y; }
inline HOSTDEVICE T Dy(T x, T y) { return x; }
};
// AddFunctor
template <typename T>
struct AddFunctor {
......@@ -27,9 +40,8 @@ struct AddFunctor {
template <typename T>
struct AddGradFunctor {
inline HOSTDEVICE T operator()(T x, T y) { return 1; }
inline HOSTDEVICE T operator()(T x, T y, T out) const { return 1; }
inline HOSTDEVICE T Dx(T x, T y) { return 1; }
inline HOSTDEVICE T Dy(T x, T y) { return 1; }
};
template <typename T>
......
......@@ -47,7 +47,8 @@ def get_numeric_gradient(place,
input_to_check,
output_names,
delta=0.005,
in_place=False):
in_place=False,
sum_outputs=None):
# FIXME: change this method by compile time concepts
set_input(scope, op, inputs, place)
......@@ -58,9 +59,11 @@ def get_numeric_gradient(place,
sum = []
op.run(scope, place)
for output_name in output_names:
if sum_outputs and output_name not in sum_outputs:
continue
sum.append(
np.array(scope.find_var(output_name).get_tensor()).mean())
return np.array(sum).mean()
return np.array(sum).sum() / len(output_names)
tensor_to_check = scope.find_var(input_to_check).get_tensor()
tensor_size = product(tensor_to_check.shape())
......@@ -396,13 +399,14 @@ class OpTest(unittest.TestCase):
numeric_grad_delta=0.005,
in_place=False,
max_relative_error=0.005,
user_defined_grads=None):
user_defined_grads=None,
sum_outputs=None):
places = self._get_places()
for place in places:
self.check_grad_with_place(place, inputs_to_check, output_names,
no_grad_set, numeric_grad_delta,
in_place, max_relative_error,
user_defined_grads)
user_defined_grads, sum_outputs)
def check_grad_with_place(self,
place,
......@@ -412,7 +416,8 @@ class OpTest(unittest.TestCase):
numeric_grad_delta=0.005,
in_place=False,
max_relative_error=0.005,
user_defined_grads=None):
user_defined_grads=None,
sum_outputs=None):
self.scope = core.Scope()
op_inputs = self.inputs if hasattr(self, "inputs") else dict()
op_outputs = self.outputs if hasattr(self, "outputs") else dict()
......@@ -435,7 +440,8 @@ class OpTest(unittest.TestCase):
input_to_check,
output_names,
delta=numeric_grad_delta,
in_place=in_place) for input_to_check in inputs_to_check
in_place=in_place,
sum_outputs=sum_outputs) for input_to_check in inputs_to_check
]
analytic_grads = self._get_gradient(inputs_to_check, place,
output_names, no_grad_set)
......
......@@ -15,32 +15,31 @@
from __future__ import print_function
import unittest
import numpy as np
from functools import partial
import paddle.fluid.core as core
from op_test import OpTest
# scale + add
# TestElementwiseAddOp
# TestFusedOperatorsOp_scalar
# TestFusedOperatorsOp_scalar2
# TestFusedOperatorsOp_Vector
# TestFusedOperatorsOp_broadcast_0
# TestFusedOperatorsOp_broadcast_1
# TestFusedOperatorsOp_broadcast_2
# TestFusedOperatorsOp_broadcast_3
# TestFusedOperatorsOp_broadcast_4
# TestFusedOperatorsOp_rowwise_add_0
# TestFusedOperatorsOp_rowwise_add_1
# TestFusedOperatorsOp_channelwise_add
class TestElementwiseAddOp(OpTest):
# TestFusedElementwiseActivationOp
# TestFusedElementwiseActivationOp_scalar
# TestFusedElementwiseActivationOp_scalar2
# TestFusedElementwiseActivationOp_Vector
# TestFusedElementwiseActivationOp_broadcast_0
# TestFusedElementwiseActivationOp_broadcast_1
# TestFusedElementwiseActivationOp_broadcast_2
# TestFusedElementwiseActivationOp_broadcast_3
# TestFusedElementwiseActivationOp_broadcast_4
# TestFusedElementwiseActivationOp_rowwise_add_0
# TestFusedElementwiseActivationOp_rowwise_add_1
# TestFusedElementwiseActivationOp_channelwise_add
def create_test_class(test_case, callback, attrs):
class TestFusedElementwiseActivationOp_base(OpTest):
def setUp(self):
self.op_type = "fused_elemwise_activation"
self.dtype = np.float32
self.axis = -1
self.init_axis()
self.init_dtype()
self.init_input()
self.init_output()
self.init_attr()
......@@ -49,772 +48,294 @@ class TestElementwiseAddOp(OpTest):
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
if self.attrs["keep_intermediate_value"]:
self.outputs = {
'Out': self.out,
"IntermediateOut": self.intermediate_out
}
else:
self.outputs = {'Out': self.out}
def init_input(self):
self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype)
self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype)
self.axis = -1
def init_output(self):
self.scale = 0.1
self.out = (self.x + self.y) * self.scale
self.x, self.y, self.intermediate_out, self.out = \
callback(self.x, self.y, self.x, self.y)
def init_attr(self):
self.attrs = {
'axis': self.axis,
'scale': self.scale,
'functor_list': ["scale", "elementwise_add"]
}
def init_dtype(self):
pass
def init_axis(self):
pass
self.attrs = {'axis': self.axis, }
for key in attrs.keys():
self.attrs[key] = attrs[key]
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.005)
if self.attrs["keep_intermediate_value"]:
self.check_grad(
['X', 'Y'], ['Out', 'IntermediateOut'],
max_relative_error=0.005,
sum_outputs=['Out'])
else:
self.check_grad(['X', 'Y'], ['Out'], max_relative_error=0.005)
def test_check_grad_ingore_x(self):
if self.attrs["keep_intermediate_value"]:
self.check_grad(
['Y'], ['Out', 'IntermediateOut'],
max_relative_error=0.005,
no_grad_set=set("X"),
sum_outputs=['Out'])
else:
self.check_grad(
['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X"))
['Y'], ['Out'],
max_relative_error=0.005,
no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
if self.attrs["keep_intermediate_value"]:
self.check_grad(
['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y'))
['X'], ['Out', 'IntermediateOut'],
max_relative_error=0.005,
no_grad_set=set("Y"),
sum_outputs=['Out'])
else:
self.check_grad(
['X'], ['Out'],
max_relative_error=0.005,
no_grad_set=set("Y"))
class TestFusedOperatorsOp_scalar(TestElementwiseAddOp):
class TestFusedElementwiseActivationOp_scalar(
TestFusedElementwiseActivationOp_base):
def init_input(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(1).astype(self.dtype)
def init_output(self):
self.scale = 0.1
self.out = (self.x + self.y) * self.scale
class TestFusedOperatorsOp_scalar2(TestElementwiseAddOp):
class TestFusedElementwiseActivationOp_scalar2(
TestFusedElementwiseActivationOp_base):
def init_input(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(1, 1).astype(self.dtype)
def init_output(self):
self.scale = 0.1
self.out = (self.x + self.y) * self.scale
class TestFusedOperatorsOp_Vector(TestElementwiseAddOp):
class TestFusedElementwiseActivationOp_Vector(
TestFusedElementwiseActivationOp_base):
def init_input(self):
self.x = np.random.random((32, )).astype(self.dtype)
self.y = np.random.random((32, )).astype(self.dtype)
def init_output(self):
self.scale = 0.1
self.out = (self.x + self.y) * self.scale
class TestFusedOperatorsOp_broadcast_0(TestElementwiseAddOp):
class TestFusedElementwiseActivationOp_broadcast_0(
TestFusedElementwiseActivationOp_base):
def init_input(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(2).astype(self.dtype)
def init_axis(self):
self.axis = 0
def init_output(self):
self.scale = 0.1
self.out = (self.x + self.y.reshape(2, 1, 1)) * self.scale
self.x, self.y, self.intermediate_out, self.out = \
callback(self.x, self.y, self.x, self.y.reshape(2, 1, 1))
class TestFusedOperatorsOp_broadcast_1(TestElementwiseAddOp):
class TestFusedElementwiseActivationOp_broadcast_1(
TestFusedElementwiseActivationOp_base):
def init_input(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(3).astype(self.dtype)
def init_axis(self):
self.axis = 1
def init_output(self):
self.scale = 0.1
self.out = (self.x + self.y.reshape(1, 3, 1)) * self.scale
self.x, self.y, self.intermediate_out, self.out = \
callback(self.x, self.y, self.x, self.y.reshape(1, 3, 1))
class TestFusedOperatorsOp_broadcast_2(TestElementwiseAddOp):
class TestFusedElementwiseActivationOp_broadcast_2(
TestFusedElementwiseActivationOp_base):
def init_input(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(4).astype(self.dtype)
def init_output(self):
self.scale = 0.1
self.out = (self.x + self.y.reshape(1, 1, 4)) * self.scale
self.x, self.y, self.intermediate_out, self.out = \
callback(self.x, self.y, self.x, self.y.reshape(1, 1, 4))
class TestFusedOperatorsOp_broadcast_3(TestElementwiseAddOp):
class TestFusedElementwiseActivationOp_broadcast_3(
TestFusedElementwiseActivationOp_base):
def init_input(self):
self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype)
self.y = np.random.rand(3, 4).astype(self.dtype)
def init_axis(self):
self.axis = 1
def init_output(self):
self.scale = 0.1
self.out = (self.x + self.y.reshape(1, 3, 4, 1)) * self.scale
self.x, self.y, self.intermediate_out, self.out = \
callback(self.x, self.y, self.x, self.y.reshape(1, 3, 4, 1))
class TestFusedOperatorsOp_broadcast_4(TestElementwiseAddOp):
class TestFusedElementwiseActivationOp_broadcast_4(
TestFusedElementwiseActivationOp_base):
def init_input(self):
self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype)
self.y = np.random.rand(2, 1).astype(self.dtype)
def init_axis(self):
self.axis = 0
def init_output(self):
self.scale = 0.1
self.out = (self.x + self.y.reshape(2, 1, 1, 1)) * self.scale
self.x, self.y, self.intermediate_out, self.out = \
callback(self.x, self.y, self.x, self.y.reshape(2, 1, 1, 1))
class TestFusedOperatorsOp_rowwise_add_0(TestElementwiseAddOp):
class TestFusedElementwiseActivationOp_rowwise_add_0(
TestFusedElementwiseActivationOp_base):
def init_input(self):
self.x = np.random.rand(2, 3, 4).astype(self.dtype)
self.y = np.random.rand(3, 4).astype(self.dtype)
def init_axis(self):
self.axis = 1
def init_output(self):
self.scale = 0.1
self.out = (self.x + self.y.reshape(1, 3, 4)) * self.scale
self.x, self.y, self.intermediate_out, self.out = \
callback(self.x, self.y, self.x, self.y.reshape(1, 3, 4))
class TestFusedOperatorsOp_rowwise_add_1(TestElementwiseAddOp):
class TestFusedElementwiseActivationOp_rowwise_add_1(
TestFusedElementwiseActivationOp_base):
def init_input(self):
self.x = np.random.rand(2, 1).astype(self.dtype)
self.y = np.random.rand(1).astype(self.dtype)
def init_axis(self):
self.axis = 1
def init_output(self):
self.scale = 0.1
self.out = (self.x + self.y.reshape(1, 1)) * self.scale
self.x, self.y, self.intermediate_out, self.out = \
callback(self.x, self.y, self.x, self.y.reshape(1, 1))
class TestFusedOperatorsOp_channelwise_add(TestElementwiseAddOp):
class TestFusedElementwiseActivationOp_channelwise_add(
TestFusedElementwiseActivationOp_base):
def init_input(self):
self.x = np.random.rand(3, 20, 20).astype(self.dtype)
self.y = np.random.rand(3, 1, 1).astype(self.dtype)
def init_axis(self):
self.axis = -1
def init_output(self):
self.scale = 0.1
self.out = (self.x + self.y) * self.scale
# add + scale
# TestElementwiseAddOp_f_add_scale
# TestFusedOperatorsOp_scalar_f_add_scale
# TestFusedOperatorsOp_scalar2_f_add_scale
# TestFusedOperatorsOp_Vector_f_add_scale
# TestFusedOperatorsOp_broadcast_0_f_add_scale
# TestFusedOperatorsOp_broadcast_1_f_add_scale
# TestFusedOperatorsOp_broadcast_2_f_add_scale
# TestFusedOperatorsOp_broadcast_3_f_add_scale
# TestFusedOperatorsOp_broadcast_4_f_add_scale
# TestFusedOperatorsOp_rowwise_add_0_f_add_scale
# TestFusedOperatorsOp_rowwise_add_1_f_add_scale
# TestFusedOperatorsOp_channelwise_add_f_add_scale
class TestFusedOperatorsOp_f_add_scale(TestElementwiseAddOp):
def init_output(self):
self.scale = 0.1
self.out = self.x + self.y * self.scale
def init_attr(self):
self.attrs = {
'axis': self.axis,
'scale': self.scale,
'functor_list': ["elementwise_add", "scale"]
}
class TestFusedOperatorsOp_scalar_f_add_scale(TestFusedOperatorsOp_scalar):
def init_output(self):
self.scale = 0.1
self.out = self.x + self.y * self.scale
def init_attr(self):
self.attrs = {
'axis': self.axis,
'scale': self.scale,
'functor_list': ["elementwise_add", "scale"]
}
class TestFusedOperatorsOp_scalar2_f_add_scale(TestFusedOperatorsOp_scalar2):
def init_output(self):
self.scale = 0.1
self.out = self.x + self.y * self.scale
def init_attr(self):
self.attrs = {
'axis': self.axis,
'scale': self.scale,
'functor_list': ["elementwise_add", "scale"]
}
class TestFusedOperatorsOp_Vector_f_add_scale(TestFusedOperatorsOp_Vector):
def init_output(self):
self.scale = 0.1
self.out = self.x + self.y * self.scale
def init_attr(self):
self.attrs = {
'axis': self.axis,
'scale': self.scale,
'functor_list': ["elementwise_add", "scale"]
}
class TestFusedOperatorsOp_broadcast_0_f_add_scale(
TestFusedOperatorsOp_broadcast_0):
def init_axis(self):
self.axis = 0
def init_output(self):
self.scale = 0.1
self.out = self.x + self.y.reshape(2, 1, 1) * self.scale
def init_attr(self):
self.attrs = {
'axis': self.axis,
'scale': self.scale,
'functor_list': ["elementwise_add", "scale"]
}
class TestFusedOperatorsOp_broadcast_1_f_add_scale(
TestFusedOperatorsOp_broadcast_1):
def init_axis(self):
self.axis = 1
def init_output(self):
self.scale = 0.1
self.out = self.x + self.y.reshape(1, 3, 1) * self.scale
def init_attr(self):
self.attrs = {
'axis': self.axis,
'scale': self.scale,
'functor_list': ["elementwise_add", "scale"]
}
class TestFusedOperatorsOp_broadcast_2_f_add_scale(
TestFusedOperatorsOp_broadcast_2):
def init_output(self):
self.scale = 0.1
self.out = self.x + self.y.reshape(1, 1, 4) * self.scale
def init_attr(self):
self.attrs = {
'axis': self.axis,
'scale': self.scale,
'functor_list': ["elementwise_add", "scale"]
}
class TestFusedOperatorsOp_broadcast_3_f_add_scale(
TestFusedOperatorsOp_broadcast_3):
def init_axis(self):
self.axis = 1
def init_output(self):
self.scale = 0.1
self.out = self.x + self.y.reshape(1, 3, 4, 1) * self.scale
def init_attr(self):
self.attrs = {
'axis': self.axis,
'scale': self.scale,
'functor_list': ["elementwise_add", "scale"]
}
class TestFusedOperatorsOp_broadcast_4_f_add_scale(
TestFusedOperatorsOp_broadcast_4):
def init_axis(self):
self.axis = 0
def init_output(self):
self.scale = 0.2
self.out = self.x + self.y.reshape(2, 1, 1, 1) * self.scale
def init_attr(self):
self.attrs = {
'axis': self.axis,
'scale': self.scale,
'functor_list': ["elementwise_add", "scale"]
}
class TestFusedOperatorsOp_rowwise_add_0_f_add_scale(
TestFusedOperatorsOp_rowwise_add_0):
def init_axis(self):
self.axis = 1
def init_output(self):
self.scale = 0.1
self.out = self.x + self.y.reshape(1, 3, 4) * self.scale
def init_attr(self):
self.attrs = {
'axis': self.axis,
'scale': self.scale,
'functor_list': ["elementwise_add", "scale"]
}
class TestFusedOperatorsOp_rowwise_add_1_f_add_scale(
TestFusedOperatorsOp_rowwise_add_1):
def init_axis(self):
self.axis = 1
def init_output(self):
self.scale = 0.2
self.out = self.x + self.y.reshape(1, 1) * self.scale
def init_attr(self):
self.attrs = {
'axis': self.axis,
'scale': self.scale,
'functor_list': ["elementwise_add", "scale"]
}
class TestFusedOperatorsOp_channelwise_add_f_add_scale(
TestFusedOperatorsOp_channelwise_add):
def init_axis(self):
self.axis = -1
def init_output(self):
self.scale = 0.2
self.out = self.x + self.y * self.scale
def init_attr(self):
self.attrs = {
'axis': self.axis,
'scale': self.scale,
'functor_list': ["elementwise_add", "scale"]
}
# add + relu
# TestElementwiseAddOp_f_add_relu
# TestFusedOperatorsOp_scalar_f_add_relu
# TestFusedOperatorsOp_scalar2_f_add_relu
# TestFusedOperatorsOp_Vector_f_add_relu
# TestFusedOperatorsOp_broadcast_0_f_add_relu
# TestFusedOperatorsOp_broadcast_1_f_add_relu
# TestFusedOperatorsOp_broadcast_2_f_add_relu
# TestFusedOperatorsOp_broadcast_3_f_add_relu
# TestFusedOperatorsOp_broadcast_4_f_add_relu
# TestFusedOperatorsOp_rowwise_add_0_f_add_relu
# TestFusedOperatorsOp_rowwise_add_1_f_add_relu
# TestFusedOperatorsOp_channelwise_add_f_add_relu
class TestFusedOperatorsOp_f_add_relu(TestElementwiseAddOp):
def init_output(self):
# Copy from test_activation_op.py
# Because we set delta = 0.005 in calculating numeric gradient,
# if x is too small, such as 0.002, x_neg will be -0.003
# x_pos will be 0.007, so the numeric gradient is inaccurate.
# we should avoid this
self.y[np.abs(self.y) < 0.005] = 0.02
self.out = self.x + np.maximum(self.y, 0)
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["elementwise_add", "relu"]
}
class TestFusedOperatorsOp_scalar_f_add_relu(TestFusedOperatorsOp_scalar):
def init_output(self):
self.y[np.abs(self.y) < 0.005] = 0.02
self.out = self.x + np.maximum(self.y, 0)
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["elementwise_add", "relu"]
}
class TestFusedOperatorsOp_scalar2_f_add_relu(TestFusedOperatorsOp_scalar2):
def init_output(self):
self.y[np.abs(self.y) < 0.005] = 0.02
self.out = self.x + np.maximum(self.y, 0)
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["elementwise_add", "relu"]
}
class TestFusedOperatorsOp_Vector_f_add_relu(TestFusedOperatorsOp_Vector):
def init_output(self):
self.y[np.abs(self.y) < 0.005] = 0.02
self.out = self.x + np.maximum(self.y, 0)
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["elementwise_add", "relu"]
}
class TestFusedOperatorsOp_broadcast_0_f_add_relu(
TestFusedOperatorsOp_broadcast_0):
def init_axis(self):
self.axis = 0
def init_output(self):
self.y[np.abs(self.y) < 0.005] = 0.02
self.out = self.x + np.maximum(self.y.reshape(2, 1, 1), 0)
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["elementwise_add", "relu"]
}
class TestFusedOperatorsOp_broadcast_1_f_add_relu(
TestFusedOperatorsOp_broadcast_1):
def init_axis(self):
self.axis = 1
def init_output(self):
self.y[np.abs(self.y) < 0.005] = 0.02
self.out = self.x + np.maximum(self.y.reshape(1, 3, 1), 0)
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["elementwise_add", "relu"]
}
class TestFusedOperatorsOp_broadcast_2_f_add_relu(
TestFusedOperatorsOp_broadcast_2):
def init_output(self):
self.y[np.abs(self.y) < 0.005] = 0.02
self.out = self.x + np.maximum(self.y.reshape(1, 1, 4), 0)
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["elementwise_add", "relu"]
}
class TestFusedOperatorsOp_broadcast_3_f_add_relu(
TestFusedOperatorsOp_broadcast_3):
def init_axis(self):
self.axis = 1
def init_output(self):
self.y[np.abs(self.y) < 0.005] = 0.02
self.out = self.x + np.maximum(self.y.reshape(1, 3, 4, 1), 0)
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["elementwise_add", "relu"]
}
class TestFusedOperatorsOp_broadcast_4_f_add_relu(
TestFusedOperatorsOp_broadcast_4):
def init_axis(self):
self.axis = 0
def init_output(self):
self.y[np.abs(self.y) < 0.005] = 0.02
self.out = self.x + np.maximum(self.y.reshape(2, 1, 1, 1), 0)
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["elementwise_add", "relu"]
}
class TestFusedOperatorsOp_rowwise_add_0_f_add_relu(
TestFusedOperatorsOp_rowwise_add_0):
def init_axis(self):
self.axis = 1
def init_output(self):
self.y[np.abs(self.y) < 0.005] = 0.02
self.out = self.x + np.maximum(self.y.reshape(1, 3, 4), 0)
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["elementwise_add", "relu"]
}
class TestFusedOperatorsOp_rowwise_add_1_f_add_relu(
TestFusedOperatorsOp_rowwise_add_1):
def init_axis(self):
self.axis = 1
def init_output(self):
self.y[np.abs(self.y) < 0.005] = 0.02
self.out = self.x + np.maximum(self.y.reshape(1, 1), 0)
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["elementwise_add", "relu"]
}
class TestFusedOperatorsOp_channelwise_add_f_add_relu(
TestFusedOperatorsOp_channelwise_add):
def init_axis(self):
self.axis = -1
def init_output(self):
self.y[np.abs(self.y) < 0.005] = 0.02
self.out = self.x + np.maximum(self.y, 0)
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["elementwise_add", "relu"]
}
# relu + add
# TestElementwiseAddOp_f_relu_add
# TestFusedOperatorsOp_scalar_f_relu_add
# TestFusedOperatorsOp_scalar2_f_relu_add
# TestFusedOperatorsOp_Vector_f_relu_add
# TestFusedOperatorsOp_broadcast_0_f_relu_add
# TestFusedOperatorsOp_broadcast_1_f_relu_add
# TestFusedOperatorsOp_broadcast_2_f_relu_add
# TestFusedOperatorsOp_broadcast_3_f_relu_add
# TestFusedOperatorsOp_broadcast_4_f_relu_add
# TestFusedOperatorsOp_rowwise_add_0_f_relu_add
# TestFusedOperatorsOp_rowwise_add_1_f_relu_add
# TestFusedOperatorsOp_channelwise_add_f_relu_add
class TestFusedOperatorsOp_f_relu_add(TestElementwiseAddOp):
def init_output(self):
TestFusedElementwiseActivationOp_base.__name__ = test_case + "_base"
TestFusedElementwiseActivationOp_scalar.__name__ = test_case + "_scalar"
TestFusedElementwiseActivationOp_scalar2.__name__ = test_case + "_scalar2"
TestFusedElementwiseActivationOp_Vector.__name__ = test_case + "_Vector"
TestFusedElementwiseActivationOp_broadcast_0.__name__ = test_case + "_broadcast_0"
TestFusedElementwiseActivationOp_broadcast_1.__name__ = test_case + "_broadcast_1"
TestFusedElementwiseActivationOp_broadcast_2.__name__ = test_case + "_broadcast_2"
TestFusedElementwiseActivationOp_broadcast_3.__name__ = test_case + "_broadcast_3"
TestFusedElementwiseActivationOp_broadcast_4.__name__ = test_case + "_broadcast_4"
TestFusedElementwiseActivationOp_rowwise_add_0.__name__ = test_case + "_rowwise_add_0"
TestFusedElementwiseActivationOp_rowwise_add_1.__name__ = test_case + "_rowwise_add_1"
TestFusedElementwiseActivationOp_channelwise_add.__name__ = test_case + "_channelwise_add"
globals()[test_case + "_base"] = TestFusedElementwiseActivationOp_base
globals()[test_case + "_scalar"] = TestFusedElementwiseActivationOp_scalar
globals()[test_case + "_scalar2"] = TestFusedElementwiseActivationOp_scalar2
globals()[test_case + "_Vector"] = TestFusedElementwiseActivationOp_Vector
globals()[test_case +
"_broadcast_0"] = TestFusedElementwiseActivationOp_broadcast_0
globals()[test_case +
"_broadcast_1"] = TestFusedElementwiseActivationOp_broadcast_1
globals()[test_case +
"_broadcast_2"] = TestFusedElementwiseActivationOp_broadcast_2
globals()[test_case +
"_broadcast_3"] = TestFusedElementwiseActivationOp_broadcast_3
globals()[test_case +
"_broadcast_4"] = TestFusedElementwiseActivationOp_broadcast_4
globals()[test_case +
"_rowwise_add_0"] = TestFusedElementwiseActivationOp_rowwise_add_0
globals()[test_case +
"_rowwise_add_1"] = TestFusedElementwiseActivationOp_rowwise_add_1
globals(
)[test_case +
"_channelwise_add"] = TestFusedElementwiseActivationOp_channelwise_add
def scale_add_func(x, y, x_bcast, y_bcast, scale, mode=0):
if mode == 0:
return x, y, (x_bcast + y_bcast), (x_bcast + y_bcast) * scale
else:
return y, x, (x_bcast + y_bcast), (x_bcast + y_bcast) * scale
def add_scale_func(x, y, x_bcast, y_bcast, scale, mode=0):
if mode == 0:
return x, y, y * scale, x_bcast + y_bcast * scale
else:
return y, x, x * scale, y_bcast + x_bcast * scale
def add_relu_func(x, y, x_bcast, y_bcast, mode=0):
# Copy from test_activation_op.py
# Because we set delta = 0.005 in calculating numeric gradient,
# if x is too small, such as 0.002, x_neg will be -0.003
# x_pos will be 0.007, so the numeric gradient is inaccurate.
# we should avoid this
self.out = self.x + self.y
self.out = np.maximum(self.out, 0)
self.out[np.abs(self.out) < 0.005] = 0.02
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["relu", "elementwise_add"]
}
class TestFusedOperatorsOp_scalar_f_relu_add(TestFusedOperatorsOp_scalar):
def init_output(self):
self.out = self.x + self.y
self.out = np.maximum(self.out, 0)
self.out[np.abs(self.out) < 0.005] = 0.02
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["relu", "elementwise_add"]
}
class TestFusedOperatorsOp_scalar2_f_relu_add(TestFusedOperatorsOp_scalar2):
def init_output(self):
self.out = self.x + self.y
self.out = np.maximum(self.out, 0)
self.out[np.abs(self.out) < 0.005] = 0.02
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["relu", "elementwise_add"]
}
class TestFusedOperatorsOp_Vector_f_relu_add(TestFusedOperatorsOp_Vector):
def init_output(self):
self.out = self.x + self.y
self.out = np.maximum(self.out, 0)
self.out[np.abs(self.out) < 0.005] = 0.02
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["relu", "elementwise_add"]
}
class TestFusedOperatorsOp_broadcast_0_f_relu_add(
TestFusedOperatorsOp_broadcast_0):
def init_axis(self):
self.axis = 0
def init_output(self):
self.out = self.x + self.y.reshape(2, 1, 1)
self.out = np.maximum(self.out, 0)
self.out[np.abs(self.out) < 0.005] = 0.02
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["relu", "elementwise_add"]
}
class TestFusedOperatorsOp_broadcast_1_f_relu_add(
TestFusedOperatorsOp_broadcast_1):
def init_axis(self):
self.axis = 1
def init_output(self):
self.out = self.x + self.y.reshape(1, 3, 1)
self.out = np.maximum(self.out, 0)
self.out[np.abs(self.out) < 0.005] = 0.02
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["relu", "elementwise_add"]
}
class TestFusedOperatorsOp_broadcast_2_f_relu_add(
TestFusedOperatorsOp_broadcast_2):
def init_output(self):
self.out = self.x + self.y.reshape(1, 1, 4)
self.out = np.maximum(self.out, 0)
self.out[np.abs(self.out) < 0.005] = 0.02
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["relu", "elementwise_add"]
}
class TestFusedOperatorsOp_broadcast_3_f_relu_add(
TestFusedOperatorsOp_broadcast_3):
def init_axis(self):
self.axis = 1
def init_output(self):
self.out = self.x + self.y.reshape(1, 3, 4, 1)
self.out = np.maximum(self.out, 0)
self.out[np.abs(self.out) < 0.005] = 0.02
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["relu", "elementwise_add"]
}
class TestFusedOperatorsOp_broadcast_4_f_relu_add(
TestFusedOperatorsOp_broadcast_4):
def init_axis(self):
self.axis = 0
def init_output(self):
self.out = self.x + self.y.reshape(2, 1, 1, 1)
self.out = np.maximum(self.out, 0)
self.out[np.abs(self.out) < 0.005] = 0.02
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["relu", "elementwise_add"]
}
class TestFusedOperatorsOp_rowwise_add_0_f_relu_add(
TestFusedOperatorsOp_rowwise_add_0):
def init_axis(self):
self.axis = 1
def init_output(self):
self.out = self.x + self.y.reshape(1, 3, 4)
self.out = np.maximum(self.out, 0)
self.out[np.abs(self.out) < 0.005] = 0.02
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["relu", "elementwise_add"]
}
class TestFusedOperatorsOp_rowwise_add_1_f_relu_add(
TestFusedOperatorsOp_rowwise_add_1):
def init_axis(self):
self.axis = 1
def init_output(self):
self.out = self.x + self.y.reshape(1, 1)
self.out = np.maximum(self.out, 0)
self.out[np.abs(self.out) < 0.005] = 0.02
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["relu", "elementwise_add"]
}
class TestFusedOperatorsOp_channelwise_add_f_relu_add(
TestFusedOperatorsOp_channelwise_add):
def init_axis(self):
self.axis = -1
def init_output(self):
self.out = self.x + self.y
self.out = np.maximum(self.out, 0)
self.out[np.abs(self.out) < 0.005] = 0.02
def init_attr(self):
self.attrs = {
'axis': self.axis,
'functor_list': ["relu", "elementwise_add"]
}
if mode == 0:
y[np.abs(y) < 0.005] = 0.02
y_bcast[np.abs(y_bcast) < 0.005] = 0.02
return x, y, np.maximum(y, 0), x_bcast + np.maximum(y_bcast, 0)
else:
x[np.abs(x) < 0.005] = 0.02
x_bcast[np.abs(x_bcast) < 0.005] = 0.02
return y, x, np.maximum(x, 0), y_bcast + np.maximum(x_bcast, 0)
def relu_add_func(x, y, x_bcast, y_bcast, mode=0):
intermediate_out = x_bcast + y_bcast
out = np.maximum(intermediate_out, 0)
out[np.abs(out) < 0.005] = 0.02
if mode == 0:
return x, y, intermediate_out, out
else:
return y, x, intermediate_out, out
def mul_scale_func(x, y, x_bcast, y_bcast, scale, mode=0):
if mode == 0:
return x, y, y * scale, x_bcast * (y_bcast * scale)
else:
return y, x, x * scale, y_bcast * (x_bcast * scale)
scale = 0.1
scale_add_func = partial(scale_add_func, scale=scale)
add_scale_func = partial(add_scale_func, scale=scale)
mul_scale_func = partial(mul_scale_func, scale=scale)
for mode in {0, 1}:
scale_add_func = partial(scale_add_func, mode=mode)
add_scale_func = partial(add_scale_func, mode=mode)
mul_scale_func = partial(mul_scale_func, mode=mode)
relu_add_func = partial(relu_add_func, mode=mode)
add_relu_func = partial(add_relu_func, mode=mode)
for recomputation in {True, False}:
for keep_intermediate_value in {True, False}:
suffix = ("_keep_intermediate_value" if keep_intermediate_value else "") \
+ ("_recomputation" if recomputation else "") \
+ ("_mode_"+ str(mode))
create_test_class('scale_add' + suffix, scale_add_func, {
'scale': scale,
'functor_list': ["scale", "elementwise_add"],
'keep_intermediate_value': keep_intermediate_value,
'recomputation': recomputation
})
create_test_class('add_scale' + suffix, add_scale_func, {
'scale': scale,
'functor_list': ["elementwise_add", "scale"],
'keep_intermediate_value': keep_intermediate_value,
'recomputation': recomputation
})
create_test_class('add_relu' + suffix, add_relu_func, {
'functor_list': ["elementwise_add", "relu"],
'keep_intermediate_value': keep_intermediate_value,
'recomputation': recomputation
})
create_test_class('relu_add' + suffix, relu_add_func, {
'functor_list': ["relu", "elementwise_add"],
'keep_intermediate_value': keep_intermediate_value,
'recomputation': recomputation
})
create_test_class('mul_scale' + suffix, mul_scale_func, {
'scale': scale,
'functor_list': ["elementwise_mul", "scale"],
'keep_intermediate_value': keep_intermediate_value,
'recomputation': recomputation
})
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册