From 88c22e9d1a809778a7bd83de71c370688cece0b2 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 23 Feb 2018 07:22:07 +0800 Subject: [PATCH] Speed up elemwise grad (#8402) * Speed up elemwise grad * Fix bug * Add macro for MAX_BLOCK_DIM --- paddle/fluid/operators/elementwise_add_op.h | 62 +---- .../fluid/operators/elementwise_op_function.h | 254 ++++++++++++++++++ 2 files changed, 259 insertions(+), 57 deletions(-) diff --git a/paddle/fluid/operators/elementwise_add_op.h b/paddle/fluid/operators/elementwise_add_op.h index 3c546bf3e4..253964562c 100644 --- a/paddle/fluid/operators/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise_add_op.h @@ -41,59 +41,8 @@ class ElementwiseAddKernel : public framework::OpKernel { }; template -struct ElementwiseAddGradFunctor { - template - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) { - auto dz_e = framework::EigenVector::Flatten(*dz); - if (dx) { - auto dx_e = framework::EigenVector::Flatten(*dx); - dx_e.device(d) = dz_e; - } - if (dy) { - auto dy_e = framework::EigenVector::Flatten(*dy); - dy_e.device(d) = dz_e; - } - } -}; - -template -struct ElementwiseAddBroadCastGradFunctor { - template - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n) { - auto dz_e = framework::EigenVector::Flatten(*dz); - if (dx) { - auto dx_e = framework::EigenVector::Flatten(*dx); - dx_e.device(d) = dz_e; - } - - if (dy) { - auto dy_e = framework::EigenVector::Flatten(*dy); - dy_e.device(d) = dz_e.reshape(Eigen::DSizes(pre, n)) - .sum(Eigen::array{{0}}); - } - } -}; - -template -struct ElementwiseAddBroadCast2GradFunctor { - template - void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz, Pre pre, N n, - Post post) { - auto dz_e = framework::EigenVector::Flatten(*dz); - if (dx) { - auto dx_e = framework::EigenVector::Flatten(*dx); - dx_e.device(d) = dz_e; - } - - if (dy) { - auto dy_e = framework::EigenVector::Flatten(*dy); - dy_e.device(d) = dz_e.reshape(Eigen::DSizes(pre, n, post)) - .sum(Eigen::array{{0, 2}}); - } - } +struct IdentityGrad { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; } }; template @@ -109,10 +58,9 @@ class ElementwiseAddGradKernel : public framework::OpKernel { auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); int axis = ctx.Attr("axis"); - ElementwiseGradCompute, - ElementwiseAddBroadCastGradFunctor, - ElementwiseAddBroadCast2GradFunctor>( - ctx, x, y, out, dout, axis, dx, dy); + ElemwiseGradCompute, IdentityGrad>( + ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad(), + IdentityGrad()); } }; diff --git a/paddle/fluid/operators/elementwise_op_function.h b/paddle/fluid/operators/elementwise_op_function.h index 2a4a611511..2da8c10322 100644 --- a/paddle/fluid/operators/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise_op_function.h @@ -20,9 +20,11 @@ limitations under the License. */ #ifdef __NVCC__ #include +constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024; #endif #include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/for_range.h" namespace paddle { namespace operators { @@ -311,6 +313,258 @@ EIGEN_FUNCTOR(Mul, EIGEN_MUL); #define EIGEN_DIV(x, y) ((x) / (y)) EIGEN_FUNCTOR(Div, EIGEN_DIV); +template +struct ElemwiseGradNoBroadcast { + const T* x_; + const T* y_; + const T* out_; + const T* dout_; + + HOSTDEVICE void operator()(size_t i) { + if (dx_ != nullptr) { + dx_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]); + } + if (dy_ != nullptr) { + dy_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]); + } + } + + DX_OP dx_op_; + DY_OP dy_op_; + T* dx_; + T* dy_; +}; + +template +static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out, + const T* dout, int h, int w, DX_OP dx_op, + DY_OP dy_op, T* dx, T* dy) { + for (int i = 0; i < h; ++i) { + for (int j = 0; j < w; ++j) { + int x_offset = i * w + j; + if (dx != nullptr) { + dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); + } + if (dy != nullptr) { + T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); + if (i == 0) { + dy[j] = tmp; + } else { + dy[j] += tmp; + } + } + } + } +} +#ifdef __NVCC__ +template +static __global__ void ElemwiseGradBroadcast1CUDAKernel( + const T* x, const T* y, const T* out, const T* dout, int h, int w, + DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) { + extern __shared__ char shm_buffer[]; + T* shm = reinterpret_cast(shm_buffer); + + int j = blockIdx.x; + int i = threadIdx.x; + int tid = threadIdx.x; + shm[tid] = 0; + + do { + int x_offset = i * w + j; + if (dx) { + dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); + } + if (dy) { + shm[tid] += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); + } + i += ELEMWISE_MAX_BLOCK_DIM; + } while (i < h); + + if (dy) { + __syncthreads(); + + h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; + + // Sum, could be optimized + if (threadIdx.x == 0) { + for (int k = 1; k < h; ++k) { + shm[0] += shm[k]; + } + dy[j] = shm[0]; + } + } +} + +template +static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T* x, + const T* y, const T* out, const T* dout, + int h, int w, DX_OP dx_op, DY_OP dy_op, + T* dx, T* dy) { + int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); + int gird_size = w; + int shared_mem_size = block_size * sizeof(T); + ElemwiseGradBroadcast1CUDAKernel<<>>(x, y, out, dout, h, w, dx_op, + dy_op, dx, dy); +} + +#endif + +template +static void ElemwiseGradBroadcast2CPU(const T* x, const T* y, const T* out, + const T* dout, int pre, int n, int post, + DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) { + for (int i = 0; i < pre; ++i) { + for (int j = 0; j < n; ++j) { + for (int k = 0; k < post; ++k) { + int x_offset = i * n * post + j * post + k; + if (dx != nullptr) { + dx[x_offset] = + dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); + } + if (dy != nullptr) { + T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); + if (i == 0 && k == 0) { + dy[j] = tmp; + } else { + dy[j] += tmp; + } + } + } + } + } +} + +#ifdef __NVCC__ + +template +static __global__ void ElemwiseGradBroadcast2CUDAKernel( + const T* x, const T* y, const T* out, const T* dout, int pre, int n, + int post, DX_OP dx_op, DY_OP dy_op, T* dx, T* dy) { + int tid = threadIdx.x; + int j = blockIdx.x; + + extern __shared__ char shm_buffer[]; + T* shm = reinterpret_cast(shm_buffer); + shm[tid] = 0; + int ttid = tid; + + while (true) { + int i = ttid / post; + int k = ttid % post; + if (i >= pre) break; + + int x_offset = i * n * post + j * post + k; + + if (dx != nullptr) { + dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); + } + + if (dy != nullptr) { + shm[tid] += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); + } + + ttid += ELEMWISE_MAX_BLOCK_DIM; + } + + if (dy) { + __syncthreads(); + int h = pre * post; + h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; + + // Sum, could be optimized + if (tid == 0) { + for (int i = 1; i < h; ++i) { + shm[0] += shm[i]; + } + dy[j] = shm[0]; + } + } +} + +template +static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T* x, + const T* y, const T* out, const T* dout, + int pre, int n, int post, DX_OP dx_op, + DY_OP dy_op, T* dx, T* dy) { + int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post); + int gird_size = n; + int shared_mem_size = block_size * sizeof(T); + ElemwiseGradBroadcast2CUDAKernel<<>>(x, y, out, dout, pre, n, post, + dx_op, dy_op, dx, dy); +} + +#endif + +template +void ElemwiseGradCompute(const framework::ExecutionContext& ctx, + const framework::Tensor& x, const framework::Tensor& y, + const framework::Tensor& out, + const framework::Tensor& dout, int axis, + framework::Tensor* dx, framework::Tensor* dy, + DX_OP dx_op, DY_OP dy_op) { + if (x.dims() == y.dims()) { + size_t N = static_cast(framework::product(x.dims())); + platform::ForRange for_range( + ctx.template device_context(), N); + for_range(ElemwiseGradNoBroadcast{ + x.data(), y.data(), out.data(), dout.data(), dx_op, dy_op, + dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), + dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())}); + } else { // Y is a scalar + auto x_dim = x.dims(); + auto y_dim = y.dims(); + + if (y_dim.size() == 1 && y_dim[0] == 1) { + // y is a scalar + auto extended_dims = framework::vectorize(x_dim); + extended_dims.push_back(1); + x_dim = framework::make_ddim(extended_dims); + } + + axis = (axis == -1 ? x_dim.size() - y_dim.size() : axis); + int pre, n, post; + get_mid_dims(x_dim, y_dim, axis, pre, n, post); + if (post == 1) { + int h = pre; + int w = n; + if (platform::is_gpu_place(ctx.GetPlace())) { +#ifdef __NVCC__ + ElemwiseGradBroadcast1CUDA( + ctx.template device_context().stream(), x.data(), + y.data(), out.data(), dout.data(), h, w, dx_op, dy_op, + dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), + dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); +#endif + } else { + ElemwiseGradBroadcast1CPU( + x.data(), y.data(), out.data(), dout.data(), h, w, + dx_op, dy_op, + dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), + dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); + } + } else { + if (platform::is_gpu_place(ctx.GetPlace())) { +#ifdef __NVCC__ + ElemwiseGradBroadcast2CUDA( + ctx.template device_context().stream(), x.data(), + y.data(), out.data(), dout.data(), pre, n, post, dx_op, + dy_op, + dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), + dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); +#endif + } else { + ElemwiseGradBroadcast2CPU( + x.data(), y.data(), out.data(), dout.data(), pre, n, + post, dx_op, dy_op, + dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), + dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); + } + } + } +}; + template void ElementwiseGradCompute(const framework::ExecutionContext& ctx, -- GitLab