From 676903d558f7e15d61bba6e9fb0cba7c80d1cff3 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Wed, 12 Jan 2022 14:51:31 +0800 Subject: [PATCH] [PTen]Refactor impl of elementwise op grad_kernel (Part1) (#38873) * refactor the impl of elementwise grad kernel * refactor impl of elementwise grad kernel(cuda) * fix compile bugs --- .../elementwise/elementwise_op_function.h | 807 +----------------- paddle/fluid/operators/viterbi_decode_op.h | 9 +- paddle/pten/kernels/cpu/elementwise.h | 192 ++++- paddle/pten/kernels/funcs/elementwise_base.h | 54 ++ paddle/pten/kernels/gpu/elementwise.h | 612 +++++++++++++ 5 files changed, 895 insertions(+), 779 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index 41cb2696f5..37d29ed91b 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -46,13 +46,6 @@ limitations under the License. */ #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" -#ifdef __HIPCC__ -constexpr int ELEMWISE_MAX_BLOCK_DIM = 256; -#else -constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024; -#endif -#define BLOCK_X 32 -#define BLOCK_Y 32 #endif #include "paddle/fluid/operators/math/math_function.h" @@ -136,16 +129,6 @@ int PackTensorsIntoVector(const framework::ExecutionContext &ctx, return axis; } -inline int GetElementwiseIndex(const int *x_dims_array, const int max_dim, - const int *index_array) { - return pten::GetElementwiseIndex(x_dims_array, max_dim, index_array); -} - -inline void UpdateElementwiseIndexArray(const int *out_dims_array, - const int max_dim, int *index_array) { - pten::UpdateElementwiseIndexArray(out_dims_array, max_dim, index_array); -} - inline void GetBroadcastDimsArrays(const framework::DDim &x_dims, const framework::DDim &y_dims, int *x_dims_array, int *y_dims_array, @@ -169,205 +152,7 @@ void CommonForwardBroadcastCPU(const framework::Tensor *x, is_xsize_larger); } -template -void CommonGradBroadcastCPU( - const framework::Tensor &x, const framework::Tensor &y, - const framework::Tensor &out, const framework::Tensor &dout, - framework::Tensor *dx, framework::Tensor *dy, int *x_dims_array, - int *y_dims_array, int *out_dims_array, int max_dim, - const platform::CPUDeviceContext &ctx, DX_OP dx_op, DY_OP dy_op) { - std::vector index_array(max_dim, 0); - const T *x_data = x.data(); - const T *y_data = y.data(); - const Tout *out_data = out.data(); - const Tout *dout_data = dout.data(); - T *dx_data = dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()); - T *dy_data = dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace()); - if (dx_data != nullptr) { - memset(dx_data, 0, dx->numel() * sizeof(T)); - } - if (dy_data != nullptr) { - memset(dy_data, 0, dy->numel() * sizeof(T)); - } - const int out_size = std::accumulate(out_dims_array, out_dims_array + max_dim, - 1, std::multiplies()); - int x_index, y_index; - for (int out_index = 0; out_index < out_size; ++out_index) { - x_index = GetElementwiseIndex(x_dims_array, max_dim, index_array.data()); - y_index = GetElementwiseIndex(y_dims_array, max_dim, index_array.data()); - if (dx_data != nullptr) { - dx_data[x_index] += dx_op(x_data[x_index], y_data[y_index], - out_data[out_index], dout_data[out_index]); - } - if (dy_data != nullptr) { - dy_data[y_index] += dy_op(x_data[x_index], y_data[y_index], - out_data[out_index], dout_data[out_index]); - } - - UpdateElementwiseIndexArray(out_dims_array, max_dim, index_array.data()); - } -} - -inline void ComputeBroadcastKernelSize(int *x_dims_array, int *out_dims_array, - int *x_blocks, int *x_threads, - int max_dim) { - *x_blocks = 1; - *x_threads = 1; - for (int i = 0; i < max_dim; i++) { - if (x_dims_array[i] == out_dims_array[i]) { - *x_blocks *= x_dims_array[i]; - } else { - *x_threads *= out_dims_array[i]; - } - } -} - -inline void ComputeBroadcastTranspositionArray(const int *x_one_indexs, - int *x_trans_indexs, - const int max_dim, - const int x_one_size) { - int diff = max_dim - x_one_size; - std::copy_n(x_one_indexs, x_one_size, x_trans_indexs + diff); - int p = 0; - int q = diff; - for (int i = 0; i < max_dim; ++i) { - if (q < max_dim && i == x_trans_indexs[q]) { - ++q; - } else { - x_trans_indexs[p++] = i; - } - } -} - #if defined(__NVCC__) || defined(__HIPCC__) -template -static __global__ void ElemwiseGradBroadcast1CUDAKernel( - const T *x, const T *y, const Tout *out, const Tout *dout, int h, int w, - bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) { - int j = blockIdx.x; - int i = threadIdx.x; - int tid = threadIdx.x; - T val(0); - if (is_xsize_larger) { - do { - int x_offset = i * w + j; - if (dx) { - dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); - } - if (dy) { - val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); - } - i += ELEMWISE_MAX_BLOCK_DIM; - } while (i < h); - - if (dy) { - h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; - val = paddle::platform::reduceSum(val, tid, h); - if (threadIdx.x == 0) { - dy[j] = val; - } - } - } else { // x.dims < y.dims, broadcast for x. - do { - int y_offset = i * w + j; - if (dy) { - dy[y_offset] = dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); - } - if (dx) { - val += dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); - } - i += ELEMWISE_MAX_BLOCK_DIM; - } while (i < h); - - if (dx) { - h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; - val = paddle::platform::reduceSum(val, tid, h); - if (threadIdx.x == 0) { - dx[j] = val; - } - } - } -} - -// suppose use 2D block is fast because more parallel -// and memory coalesced -template -static __global__ void FastElemwiseGradBroadcast1CUDAKernel( - const T *x, const T *y, const Tout *out, const Tout *dout, int h, int w, - bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) { - __shared__ T sdata[BLOCK_Y][BLOCK_X + 1]; - - T val(0); - size_t width_stride = gridDim.x * blockDim.x; - size_t idx = threadIdx.x + blockDim.x * blockIdx.x; - size_t full_width = - (w & (~((uint64_t)(BLOCK_X - 1)))) + ((w & (BLOCK_X - 1)) ? BLOCK_X : 0); - size_t full_height = - (h & (~((uint64_t)(BLOCK_Y - 1)))) + ((h & (BLOCK_Y - 1)) ? BLOCK_Y : 0); - if (is_xsize_larger) { - for (int m = idx; m < full_width; m += width_stride) { - sdata[threadIdx.y][threadIdx.x] = 0; - for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) { - int x_offset = n * w + m; - if (dx && m < w && n < h) { - dx[x_offset] = - dx_op(x[x_offset], y[m], out[x_offset], dout[x_offset]); - } - if (dy) { - if (m < w && n < h) { - T val = dy_op(x[x_offset], y[m], out[x_offset], dout[x_offset]); - sdata[threadIdx.y][threadIdx.x] += val; - } - __syncthreads(); - } - } - if (dy) { - T my_val = sdata[threadIdx.x][threadIdx.y]; - for (int i = warpSize >> 1; i > 0; i >>= 1) - my_val += platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i); - __syncthreads(); - if ((threadIdx.x == 0)) { - sdata[0][threadIdx.y] = my_val; - } - __syncthreads(); - if (threadIdx.y == 0 && m < w) { - dy[m] = sdata[0][threadIdx.x]; - } - } - } - } else { // x.dims < y.dims, broadcast for x. - for (int m = idx; m < full_width; m += width_stride) { - sdata[threadIdx.y][threadIdx.x] = 0; - for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) { - int y_offset = n * w + m; - if (dy && m < w && n < h) { - dy[y_offset] = - dy_op(x[m], y[y_offset], out[y_offset], dout[y_offset]); - } - if (dx) { - if (m < w && n < h) { - T val = dx_op(x[m], y[y_offset], out[y_offset], dout[y_offset]); - sdata[threadIdx.y][threadIdx.x] += val; - } - __syncthreads(); - } - } - if (dx) { - T my_val = sdata[threadIdx.x][threadIdx.y]; - for (int i = warpSize >> 1; i > 0; i >>= 1) - my_val += platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i); - __syncthreads(); - if ((threadIdx.x == 0)) { - sdata[0][threadIdx.y] = my_val; - } - __syncthreads(); - if (threadIdx.y == 0 && m < w) { - dx[m] = sdata[0][threadIdx.x]; - } - } - } - } -} template __global__ void CommonGradBroadcastCUDAKernel( @@ -408,267 +193,6 @@ __global__ void CommonGradBroadcastCUDAKernel( } } -template -static __global__ void CommonGradBroadcast1CUDAKernelHeight( - const T *x, const T *y, const Tout *out, const Tout *dout, int h, int w, - DY_OP dy_op, T *dy, int x_h, int x_w, bool is_y) { - int j = blockIdx.x; - int i = threadIdx.x; - int tid = threadIdx.x; - T val(0); - - if (is_y) { - do { - int out_offset = i * w + j; - int x_offset = (i % x_h) * x_w + j % x_w; - if (dy) { - val += dy_op(x[x_offset], y[j], out[out_offset], dout[out_offset]); - } - i += ELEMWISE_MAX_BLOCK_DIM; - } while (i < h); - - if (dy) { - h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; - val = paddle::platform::reduceSum(val, tid, h); - if (threadIdx.x == 0) { - dy[j] = val; - } - } - } else { - do { - int out_offset = i * w + j; - int y_offset = (i % x_h) * x_w + j % x_w; - if (dy) { - val += dy_op(x[j], y[y_offset], out[out_offset], dout[out_offset]); - } - i += ELEMWISE_MAX_BLOCK_DIM; - } while (i < h); - - if (dy) { - h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; - val = paddle::platform::reduceSum(val, tid, h); - if (threadIdx.x == 0) { - dy[j] = val; - } - } - } -} - -template -static __global__ void FastCommonGradBroadcastCUDAKernelHeight( - const T *x, const T *y, const Tout *out, const Tout *dout, int h, int w, - DY_OP dy_op, T *dy, int x_h, int x_w, bool is_y) { - __shared__ T sdata[BLOCK_Y][BLOCK_X + 1]; - - T val(0); - size_t width_stride = gridDim.x * blockDim.x; - size_t idx = threadIdx.x + blockDim.x * blockIdx.x; - size_t full_width = - (w & (~((uint64_t)(BLOCK_X - 1)))) + ((w & (BLOCK_X - 1)) ? BLOCK_X : 0); - size_t full_height = - (h & (~((uint64_t)(BLOCK_Y - 1)))) + ((h & (BLOCK_Y - 1)) ? BLOCK_Y : 0); - if (is_y) { - for (int m = idx; m < full_width; m += width_stride) { - sdata[threadIdx.y][threadIdx.x] = 0; - for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) { - int out_offset = n * w + m; - int x_offset = (n % x_h) * x_w + m % x_w; - if (dy) { - if (m < w && n < h) { - T val = dy_op(x[x_offset], y[m], out[out_offset], dout[out_offset]); - sdata[threadIdx.y][threadIdx.x] += val; - } - __syncthreads(); - } - } - if (dy) { - T my_val = sdata[threadIdx.x][threadIdx.y]; - for (int i = warpSize >> 1; i > 0; i >>= 1) { - my_val += platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i); - } - __syncthreads(); - if ((threadIdx.x == 0)) { - sdata[0][threadIdx.y] = my_val; - } - __syncthreads(); - if (threadIdx.y == 0 && m < w) { - dy[m] = sdata[0][threadIdx.x]; - } - } - } - } else { - for (int m = idx; m < full_width; m += width_stride) { - sdata[threadIdx.y][threadIdx.x] = 0; - for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) { - int out_offset = n * w + m; - int y_offset = (n % x_h) * x_w + m % x_w; - if (dy) { - if (m < w && n < h) { - T val = dy_op(x[m], y[y_offset], out[out_offset], dout[out_offset]); - sdata[threadIdx.y][threadIdx.x] += val; - } - __syncthreads(); - } - } - if (dy) { - T my_val = sdata[threadIdx.x][threadIdx.y]; - for (int i = warpSize >> 1; i > 0; i >>= 1) { - my_val += platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i); - } - __syncthreads(); - if ((threadIdx.x == 0)) { - sdata[0][threadIdx.y] = my_val; - } - __syncthreads(); - if (threadIdx.y == 0 && m < w) { - dy[m] = sdata[0][threadIdx.x]; - } - } - } - } -} - -template -static __global__ void FastCommonGradBroadcastAllCUDAKernel( - const T *x, const T *y, const Tout *out, const Tout *dout, int pre, int n, - int post, bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) { - int tid = threadIdx.x; - int bid = blockIdx.x; - - T val(0); - if (is_xsize_larger) { - for (int i = tid; i < n; i += ELEMWISE_MAX_BLOCK_DIM) { - int b_i = bid / post; - int b_j = bid % post; - int x_offset = b_i * n * post + i * post + b_j; - int y_offset = b_i * post + b_j; - if (dx) { - dx[x_offset] = - dx_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]); - } - if (dy) { - val += dy_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]); - } - } - if (dy) { - int h = n > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : n; - val = paddle::platform::reduceSum(val, tid, h); - if (tid == 0) { - dy[bid] = val; - } - } - } else { - for (int i = tid; i < n; i += ELEMWISE_MAX_BLOCK_DIM) { - int b_i = bid / post; - int b_j = bid % post; - int y_offset = b_i * n * post + i * post + b_j; - int x_offset = b_i * post + b_j; - if (dy) { - dy[y_offset] = - dy_op(x[x_offset], y[y_offset], out[y_offset], dout[y_offset]); - } - if (dx) { - val += dx_op(x[x_offset], y[y_offset], out[y_offset], dout[y_offset]); - } - } - if (dx) { - int h = n > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : n; - val = paddle::platform::reduceSum(val, tid, h); - if (tid == 0) { - dx[bid] = val; - } - } - } -} - -template -static __global__ void FastCommonGradBroadcastOneCUDAKernel( - const T *x, const T *y, const Tout *out, const Tout *dout, int pre, int n, - int post, int y_pre, int y_n, int y_post, bool is_xsize, OP op, T *dd) { - int tid = threadIdx.x; - int bid = blockIdx.x; - - T val(0); - if (is_xsize) { - // do reduce for x - for (int i = tid; i < n; i += ELEMWISE_MAX_BLOCK_DIM) { - int b_i = bid / post; - int b_j = bid % post; - int x_offset = b_i * n * post + b_j; - int out_offset = b_i * n * post + i * post + b_j; - - // Get y pre rows id with x post and y_pre. - int b_yi = bid / (post * y_pre); - int b_yj = bid % y_post; - int y_offset = b_yi * y_n + i * y_post + b_yj; - - if (dd) { - val += op(x[x_offset], y[y_offset], out[out_offset], dout[out_offset]); - } - } - if (dd) { - int h = n > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : n; - val = paddle::platform::reduceSum(val, tid, h); - if (tid == 0) { - dd[bid] = val; - } - } - } else { - // do reduce for y - for (int i = tid; i < n; i += ELEMWISE_MAX_BLOCK_DIM) { - int b_i = bid / post; - int b_j = bid % post; - int y_offset = b_i * n * post + b_j; - int out_offset = b_i * n * post + i * post + b_j; - - int b_yi = bid / (post * y_pre); - int b_yj = bid % y_post; - int x_offset = b_yi * y_n + i * y_post + b_yj; - - if (dd) { - val += op(x[x_offset], y[y_offset], out[out_offset], dout[out_offset]); - } - } - if (dd) { - int h = n > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : n; - val = paddle::platform::reduceSum(val, tid, h); - if (tid == 0) { - dd[bid] = val; - } - } - } -} - -// Check input can be split into 2 parts -static inline bool SplitDims(const std::vector &y_broadcast_pos, - int max_dim) { - bool can_split_dim2 = true; - // must at start or end. - if (y_broadcast_pos[0] != 0 && - y_broadcast_pos[y_broadcast_pos.size() - 1] != max_dim - 1) { - can_split_dim2 = false; - } else { - for (int i = 1; i < y_broadcast_pos.size(); ++i) { - // dim must be continue - if (y_broadcast_pos[i] != y_broadcast_pos[i - 1] + 1) { - can_split_dim2 = false; - break; - } - } - } - return can_split_dim2; -} - -// Suppose only has contiguous dims -static inline bool CheckContiguousDims(const std::vector &broadcast_pos) { - for (int i = 1; i < broadcast_pos.size(); ++i) { - if (broadcast_pos[i] != broadcast_pos[i - 1] + 1) { - return false; - } - } - return true; -} - template void CommonGradBroadcastCUDA( const framework::Tensor &x, const framework::Tensor &y, @@ -700,10 +224,10 @@ void CommonGradBroadcastCUDA( std::vector x_trans_indexs(max_dim); std::vector y_trans_indexs(max_dim); - ComputeBroadcastTranspositionArray(x_one_indexs.data(), x_trans_indexs.data(), - max_dim, x_one_indexs.size()); - ComputeBroadcastTranspositionArray(y_one_indexs.data(), y_trans_indexs.data(), - max_dim, y_one_indexs.size()); + pten::ComputeBroadcastTranspositionArray( + x_one_indexs.data(), x_trans_indexs.data(), max_dim, x_one_indexs.size()); + pten::ComputeBroadcastTranspositionArray( + y_one_indexs.data(), y_trans_indexs.data(), max_dim, y_one_indexs.size()); // compute array stride for cuda kernel; // e.g. x.dims=[2,3,4], x_stride=[12,4,1] @@ -790,15 +314,15 @@ void CommonGradBroadcastCUDA( if (w < 16 || h < 16) { int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); int grid_size = w; - CommonGradBroadcast1CUDAKernelHeight<<>>( + pten::CommonGradBroadcast1CUDAKernelHeight<<>>( x_data, y_data, out_data, dout_data, h, w, dy_op, dy_data, kh, kw, is_y); } else { dim3 block_size = dim3(BLOCK_X, BLOCK_Y); int grid_size = (w + BLOCK_X - 1) / BLOCK_X; - FastCommonGradBroadcastCUDAKernelHeight<<>>( + pten::FastCommonGradBroadcastCUDAKernelHeight<<>>( x_data, y_data, out_data, dout_data, h, w, dy_op, dy_data, kh, kw, is_y); } @@ -806,15 +330,15 @@ void CommonGradBroadcastCUDA( if (w < 16 || h < 16) { int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); int grid_size = w; - CommonGradBroadcast1CUDAKernelHeight<<>>( + pten::CommonGradBroadcast1CUDAKernelHeight<<>>( x_data, y_data, out_data, dout_data, h, w, dx_op, dx_data, kh, kw, is_y); } else { dim3 block_size = dim3(BLOCK_X, BLOCK_Y); int grid_size = (w + BLOCK_X - 1) / BLOCK_X; - FastCommonGradBroadcastCUDAKernelHeight<<>>( + pten::FastCommonGradBroadcastCUDAKernelHeight<<>>( x_data, y_data, out_data, dout_data, h, w, dx_op, dx_data, kh, kw, is_y); } @@ -835,14 +359,15 @@ void CommonGradBroadcastCUDA( if (w < 16 || h < 16) { int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); int grid_size = w; - ElemwiseGradBroadcast1CUDAKernel<<>>( + pten::ElemwiseGradBroadcast1CUDAKernel<<>>( x_data, y_data, out_data, dout_data, h, w, x_large, dx_op, dy_op, dx_data, dy_data); } else { dim3 block_size = dim3(BLOCK_X, BLOCK_Y); int grid_size = (w + BLOCK_X - 1) / BLOCK_X; - FastElemwiseGradBroadcast1CUDAKernel<<>>( + pten::FastElemwiseGradBroadcast1CUDAKernel<<>>( x_data, y_data, out_data, dout_data, h, w, x_large, dx_op, dy_op, dx_data, dy_data); } @@ -876,7 +401,8 @@ void CommonGradBroadcastCUDA( int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid); int grid_size = pre * post; - FastCommonGradBroadcastAllCUDAKernel<<>>( + pten::FastCommonGradBroadcastAllCUDAKernel<<>>( x_data, y_data, out_data, dout_data, pre, mid, post, is_x_large, dx_op, dy_op, dx_data, dy_data); }; @@ -907,8 +433,8 @@ void CommonGradBroadcastCUDA( // size. if (k_pre != pre) k_pre = pre / k_pre; - FastCommonGradBroadcastOneCUDAKernel<<>>( + pten::FastCommonGradBroadcastOneCUDAKernel<<>>( x_data, y_data, out_data, dout_data, pre, mid, post, k_pre, k_mid, k_post, true, dx_op, dx_data); } else { @@ -921,8 +447,8 @@ void CommonGradBroadcastCUDA( int grid_size = pre * post; if (k_pre != pre) k_pre = pre / k_pre; - FastCommonGradBroadcastOneCUDAKernel<<>>( + pten::FastCommonGradBroadcastOneCUDAKernel<<>>( x_data, y_data, out_data, dout_data, pre, mid, post, k_pre, k_mid, k_post, false, dy_op, dy_data); } @@ -936,7 +462,7 @@ void CommonGradBroadcastCUDA( // 2. if both x and y need broadcast, then do it one by one. bool fast_broadcast = false; if (x_broadcast_pos.empty() && !y_broadcast_pos.empty()) { - can_split_y = SplitDims(y_broadcast_pos, max_dim); + can_split_y = pten::SplitDims(y_broadcast_pos, max_dim); if (can_split_y) { // only y need to do broadcast on h if (y_broadcast_pos[0] == 0) { @@ -944,28 +470,29 @@ void CommonGradBroadcastCUDA( fast_broadcast = true; } } else if (y_broadcast_pos.size() == 1 || - CheckContiguousDims(y_broadcast_pos)) { // for only one dim and - // contiguous broadcast. + pten::CheckContiguousDims( + y_broadcast_pos)) { // for only one dim and + // contiguous broadcast. // If cannot split, which means input has 3 parts FastBroadCastAllCUDAF(y_broadcast_pos, max_dim, true); fast_broadcast = true; } } else if (y_broadcast_pos.empty() && !x_broadcast_pos.empty()) { // only x need broadcast - can_split_x = SplitDims(x_broadcast_pos, max_dim); + can_split_x = pten::SplitDims(x_broadcast_pos, max_dim); if (can_split_x) { if (x_broadcast_pos[0] == 0) { FastBroadCastHeightCUDAF(x_broadcast_pos, false); fast_broadcast = true; } } else if (x_broadcast_pos.size() == 1 || - CheckContiguousDims(x_broadcast_pos)) { + pten::CheckContiguousDims(x_broadcast_pos)) { FastBroadCastAllCUDAF(x_broadcast_pos, max_dim, false); fast_broadcast = true; } } else if (!x_broadcast_pos.empty() && !y_broadcast_pos.empty()) { // do x and y broadcast each. - can_split_y = SplitDims(y_broadcast_pos, max_dim); + can_split_y = pten::SplitDims(y_broadcast_pos, max_dim); bool fast_broadcast_x = false; bool fast_broadcast_y = false; if (can_split_y) { @@ -979,7 +506,7 @@ void CommonGradBroadcastCUDA( can_split_y = true; fast_broadcast_y = true; } - can_split_x = SplitDims(x_broadcast_pos, max_dim); + can_split_x = pten::SplitDims(x_broadcast_pos, max_dim); if (can_split_x) { if (x_broadcast_pos[0] == 0) { FastCommonCUDAF(x_broadcast_pos, false); @@ -1005,12 +532,12 @@ void CommonGradBroadcastCUDA( } int x_blocks = 0; int x_threads = 0; - ComputeBroadcastKernelSize(x_dims_array, out_dims_array, &x_blocks, - &x_threads, max_dim); + pten::ComputeBroadcastKernelSize(x_dims_array, out_dims_array, &x_blocks, + &x_threads, max_dim); int y_blocks = 0; int y_threads = 0; - ComputeBroadcastKernelSize(y_dims_array, out_dims_array, &y_blocks, - &y_threads, max_dim); + pten::ComputeBroadcastKernelSize(y_dims_array, out_dims_array, &y_blocks, + &y_threads, max_dim); auto x_strides_array_tmp = memory::Alloc(ctx, bytes); int *x_strides_array_gpu = @@ -1076,228 +603,6 @@ inline framework::DDim trim_trailing_singular_dims( return pten::funcs::trim_trailing_singular_dims(dims); } -template -struct ElemwiseGradNoBroadcast { - const T *x_; - const T *y_; - const Tout *out_; - const Tout *dout_; - - HOSTDEVICE void operator()(size_t i) { - if (dx_ != nullptr) { - dx_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]); - } - if (dy_ != nullptr) { - dy_[i] = dy_op_(x_[i], y_[i], out_[i], dout_[i]); - } - } - - DX_OP dx_op_; - DY_OP dy_op_; - T *dx_; - T *dy_; -}; - -template -static void ElemwiseGradBroadcast1CPU(const T *x, const T *y, const Tout *out, - const Tout *dout, int h, int w, - bool is_xsize_larger, DX_OP dx_op, - DY_OP dy_op, T *dx, T *dy) { - if (is_xsize_larger) { - for (int i = 0; i < h; ++i) { - for (int j = 0; j < w; ++j) { - int x_offset = i * w + j; - if (dx != nullptr) { - dx[x_offset] = - dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); - } - if (dy != nullptr) { - T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); - if (i == 0) { - dy[j] = tmp; - } else { - dy[j] += tmp; - } - } - } - } - } else { // x.dims < y.dims, broadcast for x. - for (int i = 0; i < h; ++i) { - for (int j = 0; j < w; ++j) { - int y_offset = i * w + j; - if (dy != nullptr) { - dy[y_offset] = - dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); - } - if (dx != nullptr) { - T tmp = dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); - if (i == 0) { - dx[j] = tmp; - } else { - dx[j] += tmp; - } - } - } - } - } -} - -#if defined(__NVCC__) || defined(__HIPCC__) - -template -static void ElemwiseGradBroadcast1CUDA(gpuStream_t stream, const T *x, - const T *y, const Tout *out, - const Tout *dout, int h, int w, - bool is_xsize_larger, DX_OP dx_op, - DY_OP dy_op, T *dx, T *dy) { - // For small case use 1D block - constexpr int half_walf = 16; - if (w < half_walf || h < half_walf) { - int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); - int gird_size = w; - ElemwiseGradBroadcast1CUDAKernel<<>>( - x, y, out, dout, h, w, is_xsize_larger, dx_op, dy_op, dx, dy); - } else { - // suppose perfoemance improves with h increased. - dim3 block_size = dim3(BLOCK_X, BLOCK_Y); - int grid_size = (w + BLOCK_X - 1) / BLOCK_X; - FastElemwiseGradBroadcast1CUDAKernel<<>>( - x, y, out, dout, h, w, is_xsize_larger, dx_op, dy_op, dx, dy); - } -} - -#endif - -template -static void ElemwiseGradBroadcast2CPU(const T *x, const T *y, const Tout *out, - const Tout *dout, int pre, int n, - int post, bool is_xsize_larger, - DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) { - if (is_xsize_larger) { - for (int i = 0; i < pre; ++i) { - for (int j = 0; j < n; ++j) { - for (int k = 0; k < post; ++k) { - int x_offset = i * n * post + j * post + k; - if (dx != nullptr) { - dx[x_offset] = - dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); - } - if (dy != nullptr) { - T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); - if (i == 0 && k == 0) { - dy[j] = tmp; - } else { - dy[j] += tmp; - } - } - } - } - } - } else { // x.dims < y.dims, broadcast for x. - for (int i = 0; i < pre; ++i) { - for (int j = 0; j < n; ++j) { - for (int k = 0; k < post; ++k) { - int y_offset = i * n * post + j * post + k; - if (dy != nullptr) { - dy[y_offset] = - dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); - } - if (dx != nullptr) { - T tmp = dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); - if (i == 0 && k == 0) { - dx[j] = tmp; - } else { - dx[j] += tmp; - } - } - } - } - } - } -} - -#if defined(__NVCC__) || defined(__HIPCC__) -template -static __global__ void ElemwiseGradBroadcast2CUDAKernel( - const T *x, const T *y, const Tout *out, const Tout *dout, int pre, int n, - int post, bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) { - int tid = threadIdx.x; - int j = blockIdx.x; - - T val(0); - int ttid = tid; - - if (is_xsize_larger) { - while (true) { - int i = ttid / post; - int k = ttid % post; - if (i >= pre) break; - - int x_offset = i * n * post + j * post + k; - - if (dx != nullptr) { - dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); - } - - if (dy != nullptr) { - val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); - } - - ttid += ELEMWISE_MAX_BLOCK_DIM; - } - - if (dy) { - int h = pre * post; - h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; - val = paddle::platform::reduceSum(val, tid, h); - if (threadIdx.x == 0) { - dy[j] = val; - } - } - } else { // x.dims < y.dims, broadcast for x. - while (true) { - int i = ttid / post; - int k = ttid % post; - if (i >= pre) break; - - int y_offset = i * n * post + j * post + k; - - if (dy != nullptr) { - dy[y_offset] = dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); - } - - if (dx != nullptr) { - val += dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); - } - - ttid += ELEMWISE_MAX_BLOCK_DIM; - } - - if (dx) { - int h = pre * post; - h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; - val = paddle::platform::reduceSum(val, tid, h); - if (threadIdx.x == 0) { - dx[j] = val; - } - } - } -} - -template -static void ElemwiseGradBroadcast2CUDA(gpuStream_t stream, const T *x, - const T *y, const Tout *out, - const Tout *dout, int pre, int n, - int post, bool is_xsize_larger, - DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) { - int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post); - int gird_size = n; - ElemwiseGradBroadcast2CUDAKernel<<>>( - x, y, out, dout, pre, n, post, is_xsize_larger, dx_op, dy_op, dx, dy); -} - -#endif - template void CommonElementwiseBroadcastBackward( @@ -1334,7 +639,7 @@ void CommonElementwiseBroadcastBackward( dy_op); #endif } else { - CommonGradBroadcastCPU( + pten::CommonGradBroadcastCPU( x, y, out, dout, dx, dy, x_dims_array.data(), y_dims_array.data(), out_dims_array.data(), max_dim, ctx.template device_context(), dx_op, @@ -1342,28 +647,6 @@ void CommonElementwiseBroadcastBackward( } } -template -void ElemwiseGradComputeNoBroadcast( - const framework::ExecutionContext &ctx, const framework::DDim &x_dim, - const framework::DDim &y_dim, const framework::Tensor &x, - const framework::Tensor &y, const framework::Tensor &out, - const framework::Tensor &dout, int axis, framework::Tensor *dx, - framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) { - size_t N = static_cast(framework::product(x_dim)); -#if !defined(_WIN32) - platform::ForRange for_range( - ctx.template device_context(), N); -#else - platform::ForRange for_range( - ctx.device_context(), N); -#endif // !_WIN32 - for_range(ElemwiseGradNoBroadcast{ - x.data(), y.data(), out.data(), dout.data(), dx_op, - dy_op, dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), - dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())}); -} - template void ElemwiseGradComputeWithBroadcast( @@ -1412,7 +695,7 @@ void ElemwiseGradComputeWithBroadcast( if (post == 1) { if (platform::is_gpu_place(ctx.GetPlace())) { #if defined(__NVCC__) || defined(__HIPCC__) - ElemwiseGradBroadcast1CUDA( + pten::ElemwiseGradBroadcast1CUDA( ctx.template device_context().stream(), x.data(), y.data(), out.data(), dout.data(), pre, n, is_xsize_larger, dx_op, dy_op, @@ -1420,7 +703,7 @@ void ElemwiseGradComputeWithBroadcast( dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); #endif } else { - ElemwiseGradBroadcast1CPU( + pten::ElemwiseGradBroadcast1CPU( x.data(), y.data(), out.data(), dout.data(), pre, n, is_xsize_larger, dx_op, dy_op, dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), @@ -1429,7 +712,7 @@ void ElemwiseGradComputeWithBroadcast( } else { if (platform::is_gpu_place(ctx.GetPlace())) { #if defined(__NVCC__) || defined(__HIPCC__) - ElemwiseGradBroadcast2CUDA( + pten::ElemwiseGradBroadcast2CUDA( ctx.template device_context().stream(), x.data(), y.data(), out.data(), dout.data(), pre, n, post, is_xsize_larger, dx_op, dy_op, @@ -1437,7 +720,7 @@ void ElemwiseGradComputeWithBroadcast( dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace())); #endif } else { - ElemwiseGradBroadcast2CPU( + pten::ElemwiseGradBroadcast2CPU( x.data(), y.data(), out.data(), dout.data(), pre, n, post, is_xsize_larger, dx_op, dy_op, dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), @@ -1474,8 +757,10 @@ void ElemwiseGradCompute(const framework::ExecutionContext &ctx, const framework::DDim &x_dim = x.dims(); const framework::DDim &y_dim = y.dims(); if (x.dims() == y.dims()) { - ElemwiseGradComputeNoBroadcast( - ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op); + const auto &dev_ctx = ctx.template device_context(); + pten::funcs::ElemwiseGradComputeNoBroadcast( + dev_ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op); } else { ElemwiseGradComputeWithBroadcast( ctx, x_dim, y_dim, x, y, out, dout, axis, dx, dy, dx_op, dy_op); @@ -1497,8 +782,10 @@ void ElemwiseExplicitGradCompute(const framework::ExecutionContext &ctx, const framework::DDim &x_dim = x.dims(); const framework::DDim &y_dim = y.dims(); if (x.dims() == y.dims()) { - ElemwiseGradComputeNoBroadcast( - ctx, x_dim, y_dim, dout, dout, out, dout, axis, dx, dy, dx_op, dy_op); + const auto &dev_ctx = ctx.template device_context(); + pten::funcs::ElemwiseGradComputeNoBroadcast( + dev_ctx, x_dim, y_dim, dout, dout, out, dout, axis, dx, dy, dx_op, + dy_op); } else { ElemwiseGradComputeWithBroadcast( ctx, x_dim, y_dim, dout, dout, out, dout, axis, dx, dy, dx_op, dy_op); diff --git a/paddle/fluid/operators/viterbi_decode_op.h b/paddle/fluid/operators/viterbi_decode_op.h index 4da137f774..2b392ae74c 100644 --- a/paddle/fluid/operators/viterbi_decode_op.h +++ b/paddle/fluid/operators/viterbi_decode_op.h @@ -150,9 +150,12 @@ struct GetInputIndex { const std::vector& output_strides, int output_idx, int* index_array, int* lhs_idx, int* rhs_idx) { int out_dims_size = output_strides.size(); - *lhs_idx = GetElementwiseIndex(lhs_dims.data(), out_dims_size, index_array); - *rhs_idx = GetElementwiseIndex(rhs_dims.data(), out_dims_size, index_array); - UpdateElementwiseIndexArray(output_dims.data(), out_dims_size, index_array); + *lhs_idx = + pten::GetElementwiseIndex(lhs_dims.data(), out_dims_size, index_array); + *rhs_idx = + pten::GetElementwiseIndex(rhs_dims.data(), out_dims_size, index_array); + pten::UpdateElementwiseIndexArray(output_dims.data(), out_dims_size, + index_array); } }; diff --git a/paddle/pten/kernels/cpu/elementwise.h b/paddle/pten/kernels/cpu/elementwise.h index d3687b22fb..5a421de117 100644 --- a/paddle/pten/kernels/cpu/elementwise.h +++ b/paddle/pten/kernels/cpu/elementwise.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/pten/backends/cpu/cpu_context.h" #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/kernels/funcs/elementwise_base.h" @@ -22,6 +23,8 @@ limitations under the License. */ namespace pten { +// FORWARD CODE + // Add template struct SameDimsAddFunctor { @@ -206,6 +209,56 @@ inline int GetElementwiseIndex(const int* x_dims_array, return index_; } +template +void CommonGradBroadcastCPU(const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy, + int* x_dims_array, + int* y_dims_array, + int* out_dims_array, + int max_dim, + const CPUContext& ctx, + DX_OP dx_op, + DY_OP dy_op) { + std::vector index_array(max_dim, 0); + const T* x_data = x.data(); + const T* y_data = y.data(); + const Tout* out_data = out.data(); + const Tout* dout_data = dout.data(); + T* dx_data = dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()); + T* dy_data = dy == nullptr ? nullptr : dy->mutable_data(ctx.GetPlace()); + if (dx_data != nullptr) { + memset(dx_data, 0, dx->numel() * sizeof(T)); + } + if (dy_data != nullptr) { + memset(dy_data, 0, dy->numel() * sizeof(T)); + } + const int out_size = std::accumulate( + out_dims_array, out_dims_array + max_dim, 1, std::multiplies()); + int x_index, y_index; + for (int out_index = 0; out_index < out_size; ++out_index) { + x_index = GetElementwiseIndex(x_dims_array, max_dim, index_array.data()); + y_index = GetElementwiseIndex(y_dims_array, max_dim, index_array.data()); + if (dx_data != nullptr) { + dx_data[x_index] += dx_op(x_data[x_index], + y_data[y_index], + out_data[out_index], + dout_data[out_index]); + } + if (dy_data != nullptr) { + dy_data[y_index] += dy_op(x_data[x_index], + y_data[y_index], + out_data[out_index], + dout_data[out_index]); + } + + UpdateElementwiseIndexArray(out_dims_array, max_dim, index_array.data()); + } +} + template void CommonForwardBroadcastCPU(const DenseTensor& x, const DenseTensor& y, @@ -214,7 +267,7 @@ void CommonForwardBroadcastCPU(const DenseTensor& x, int* y_dims_array, int* out_dims_array, int max_dim, - const paddle::platform::CPUDeviceContext& ctx, + const CPUContext& ctx, Functor func, const bool is_xsize_larger = true) { std::vector index_array(max_dim, 0); @@ -245,16 +298,15 @@ void CommonForwardBroadcastCPU(const DenseTensor& x, } template -void CommonElementwiseBroadcastForward( - const paddle::platform::CPUDeviceContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - DenseTensor* z, - const DDim& x_dims, - const DDim& y_dims, - Functor func, - int axis, - const bool is_xsize_larger = true) { +void CommonElementwiseBroadcastForward(const CPUContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* z, + const DDim& x_dims, + const DDim& y_dims, + Functor func, + int axis, + const bool is_xsize_larger = true) { int max_dim = (std::max)(x_dims.size(), y_dims.size()); axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); PADDLE_ENFORCE_GE( @@ -302,7 +354,7 @@ void CommonElementwiseBroadcastForward( // TODO(liuyiqun): optimize the CPU implementation to support all broadcast // cases and avoid the need of XxxInverseFunctor. template -void ElementwiseCompute(const paddle::platform::CPUDeviceContext& dev_ctx, +void ElementwiseCompute(const CPUContext& dev_ctx, const DenseTensor& x, const DenseTensor& y, int axis, @@ -317,9 +369,8 @@ void ElementwiseCompute(const paddle::platform::CPUDeviceContext& dev_ctx, is_xsize_larger = false; max_dim = y_dims.size(); } - funcs:: - TransformFunctor - functor(x, y, z, dev_ctx, func, is_xsize_larger); + funcs::TransformFunctor functor( + x, y, z, dev_ctx, func, is_xsize_larger); if (x_dims == y_dims) { functor.Run(); return; @@ -381,7 +432,7 @@ void ElementwiseCompute(const paddle::platform::CPUDeviceContext& dev_ctx, template struct SameDimsElementwiseCompute { - void operator()(const paddle::platform::CPUDeviceContext& dev_ctx, + void operator()(const CPUContext& dev_ctx, const DenseTensor& x, const DenseTensor& y, DenseTensor* z) { @@ -389,4 +440,113 @@ struct SameDimsElementwiseCompute { } }; +// BACKWARD CODE + +template +static void ElemwiseGradBroadcast1CPU(const T* x, + const T* y, + const Tout* out, + const Tout* dout, + int h, + int w, + bool is_xsize_larger, + DX_OP dx_op, + DY_OP dy_op, + T* dx, + T* dy) { + if (is_xsize_larger) { + for (int i = 0; i < h; ++i) { + for (int j = 0; j < w; ++j) { + int x_offset = i * w + j; + if (dx != nullptr) { + dx[x_offset] = + dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); + } + if (dy != nullptr) { + T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); + if (i == 0) { + dy[j] = tmp; + } else { + dy[j] += tmp; + } + } + } + } + } else { // x.dims < y.dims, broadcast for x. + for (int i = 0; i < h; ++i) { + for (int j = 0; j < w; ++j) { + int y_offset = i * w + j; + if (dy != nullptr) { + dy[y_offset] = + dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); + } + if (dx != nullptr) { + T tmp = dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); + if (i == 0) { + dx[j] = tmp; + } else { + dx[j] += tmp; + } + } + } + } + } +} + +template +static void ElemwiseGradBroadcast2CPU(const T* x, + const T* y, + const Tout* out, + const Tout* dout, + int pre, + int n, + int post, + bool is_xsize_larger, + DX_OP dx_op, + DY_OP dy_op, + T* dx, + T* dy) { + if (is_xsize_larger) { + for (int i = 0; i < pre; ++i) { + for (int j = 0; j < n; ++j) { + for (int k = 0; k < post; ++k) { + int x_offset = i * n * post + j * post + k; + if (dx != nullptr) { + dx[x_offset] = + dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); + } + if (dy != nullptr) { + T tmp = dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); + if (i == 0 && k == 0) { + dy[j] = tmp; + } else { + dy[j] += tmp; + } + } + } + } + } + } else { // x.dims < y.dims, broadcast for x. + for (int i = 0; i < pre; ++i) { + for (int j = 0; j < n; ++j) { + for (int k = 0; k < post; ++k) { + int y_offset = i * n * post + j * post + k; + if (dy != nullptr) { + dy[y_offset] = + dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); + } + if (dx != nullptr) { + T tmp = dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); + if (i == 0 && k == 0) { + dx[j] = tmp; + } else { + dx[j] += tmp; + } + } + } + } + } + } +} + } // namespace pten diff --git a/paddle/pten/kernels/funcs/elementwise_base.h b/paddle/pten/kernels/funcs/elementwise_base.h index a0c6d5ba57..be355557d5 100644 --- a/paddle/pten/kernels/funcs/elementwise_base.h +++ b/paddle/pten/kernels/funcs/elementwise_base.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/transform.h" #include "paddle/pten/backends/all_context.h" #include "paddle/pten/core/dense_tensor.h" @@ -23,6 +24,28 @@ namespace funcs { using DDim = paddle::framework::DDim; +template +struct ElemwiseGradNoBroadcast { + const T *x_; + const T *y_; + const Tout *out_; + const Tout *dout_; + + HOSTDEVICE void operator()(size_t i) { + if (dx_ != nullptr) { + dx_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]); + } + if (dy_ != nullptr) { + dy_[i] = dy_op_(x_[i], y_[i], out_[i], dout_[i]); + } + } + + DX_OP dx_op_; + DY_OP dy_op_; + T *dx_; + T *dy_; +}; + template class RowwiseTransformIterator; @@ -378,5 +401,36 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims, } } } + +template +void ElemwiseGradComputeNoBroadcast(const DeviceContext &dev_ctx, + const DDim &x_dim, + const DDim &y_dim, + const DenseTensor &x, + const DenseTensor &y, + const DenseTensor &out, + const DenseTensor &dout, + int axis, + DenseTensor *dx, + DenseTensor *dy, + DX_OP dx_op, + DY_OP dy_op) { + size_t N = static_cast(paddle::framework::product(x_dim)); + paddle::platform::ForRange for_range(dev_ctx, N); + for_range(ElemwiseGradNoBroadcast{ + x.data(), + y.data(), + out.data(), + dout.data(), + dx_op, + dy_op, + dx == nullptr ? nullptr : dx->mutable_data(dev_ctx.GetPlace()), + dy == nullptr ? nullptr : dy->mutable_data(dev_ctx.GetPlace())}); +} + } // namespace funcs } // namespace pten diff --git a/paddle/pten/kernels/gpu/elementwise.h b/paddle/pten/kernels/gpu/elementwise.h index 049e430154..4dfcd7a215 100644 --- a/paddle/pten/kernels/gpu/elementwise.h +++ b/paddle/pten/kernels/gpu/elementwise.h @@ -20,6 +20,14 @@ limitations under the License. */ #include "paddle/fluid/platform/function_traits.h" #include "paddle/pten/core/dense_tensor.h" +#ifdef __HIPCC__ +constexpr int ELEMWISE_MAX_BLOCK_DIM = 256; +#else +constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024; +#endif +#define BLOCK_X 32 +#define BLOCK_Y 32 + namespace pten { namespace kps = paddle::operators::kernel_primitives; @@ -31,6 +39,7 @@ template using ConditionalT = typename std::conditional_t>; +// FORWARD CODE template &broadcast_pos) { + for (int i = 1; i < broadcast_pos.size(); ++i) { + if (broadcast_pos[i] != broadcast_pos[i - 1] + 1) { + return false; + } + } + return true; +} + +inline void ComputeBroadcastTranspositionArray(const int *x_one_indexs, + int *x_trans_indexs, + const int max_dim, + const int x_one_size) { + int diff = max_dim - x_one_size; + std::copy_n(x_one_indexs, x_one_size, x_trans_indexs + diff); + int p = 0; + int q = diff; + for (int i = 0; i < max_dim; ++i) { + if (q < max_dim && i == x_trans_indexs[q]) { + ++q; + } else { + x_trans_indexs[p++] = i; + } + } +} + +// Check input can be split into 2 parts +static inline bool SplitDims(const std::vector &y_broadcast_pos, + int max_dim) { + bool can_split_dim2 = true; + // must at start or end. + if (y_broadcast_pos[0] != 0 && + y_broadcast_pos[y_broadcast_pos.size() - 1] != max_dim - 1) { + can_split_dim2 = false; + } else { + for (int i = 1; i < y_broadcast_pos.size(); ++i) { + // dim must be continue + if (y_broadcast_pos[i] != y_broadcast_pos[i - 1] + 1) { + can_split_dim2 = false; + break; + } + } + } + return can_split_dim2; +} + +inline void ComputeBroadcastKernelSize(int *x_dims_array, + int *out_dims_array, + int *x_blocks, + int *x_threads, + int max_dim) { + *x_blocks = 1; + *x_threads = 1; + for (int i = 0; i < max_dim; i++) { + if (x_dims_array[i] == out_dims_array[i]) { + *x_blocks *= x_dims_array[i]; + } else { + *x_threads *= out_dims_array[i]; + } + } +} + +template +static __global__ void FastCommonGradBroadcastOneCUDAKernel(const T *x, + const T *y, + const Tout *out, + const Tout *dout, + int pre, + int n, + int post, + int y_pre, + int y_n, + int y_post, + bool is_xsize, + OP op, + T *dd) { + int tid = threadIdx.x; + int bid = blockIdx.x; + + T val(0); + if (is_xsize) { + // do reduce for x + for (int i = tid; i < n; i += ELEMWISE_MAX_BLOCK_DIM) { + int b_i = bid / post; + int b_j = bid % post; + int x_offset = b_i * n * post + b_j; + int out_offset = b_i * n * post + i * post + b_j; + + // Get y pre rows id with x post and y_pre. + int b_yi = bid / (post * y_pre); + int b_yj = bid % y_post; + int y_offset = b_yi * y_n + i * y_post + b_yj; + + if (dd) { + val += op(x[x_offset], y[y_offset], out[out_offset], dout[out_offset]); + } + } + if (dd) { + int h = n > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : n; + val = paddle::platform::reduceSum(val, tid, h); + if (tid == 0) { + dd[bid] = val; + } + } + } else { + // do reduce for y + for (int i = tid; i < n; i += ELEMWISE_MAX_BLOCK_DIM) { + int b_i = bid / post; + int b_j = bid % post; + int y_offset = b_i * n * post + b_j; + int out_offset = b_i * n * post + i * post + b_j; + + int b_yi = bid / (post * y_pre); + int b_yj = bid % y_post; + int x_offset = b_yi * y_n + i * y_post + b_yj; + + if (dd) { + val += op(x[x_offset], y[y_offset], out[out_offset], dout[out_offset]); + } + } + if (dd) { + int h = n > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : n; + val = paddle::platform::reduceSum(val, tid, h); + if (tid == 0) { + dd[bid] = val; + } + } + } +} + +template +static __global__ void FastCommonGradBroadcastAllCUDAKernel( + const T *x, + const T *y, + const Tout *out, + const Tout *dout, + int pre, + int n, + int post, + bool is_xsize_larger, + DX_OP dx_op, + DY_OP dy_op, + T *dx, + T *dy) { + int tid = threadIdx.x; + int bid = blockIdx.x; + + T val(0); + if (is_xsize_larger) { + for (int i = tid; i < n; i += ELEMWISE_MAX_BLOCK_DIM) { + int b_i = bid / post; + int b_j = bid % post; + int x_offset = b_i * n * post + i * post + b_j; + int y_offset = b_i * post + b_j; + if (dx) { + dx[x_offset] = + dx_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]); + } + if (dy) { + val += dy_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]); + } + } + if (dy) { + int h = n > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : n; + val = paddle::platform::reduceSum(val, tid, h); + if (tid == 0) { + dy[bid] = val; + } + } + } else { + for (int i = tid; i < n; i += ELEMWISE_MAX_BLOCK_DIM) { + int b_i = bid / post; + int b_j = bid % post; + int y_offset = b_i * n * post + i * post + b_j; + int x_offset = b_i * post + b_j; + if (dy) { + dy[y_offset] = + dy_op(x[x_offset], y[y_offset], out[y_offset], dout[y_offset]); + } + if (dx) { + val += dx_op(x[x_offset], y[y_offset], out[y_offset], dout[y_offset]); + } + } + if (dx) { + int h = n > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : n; + val = paddle::platform::reduceSum(val, tid, h); + if (tid == 0) { + dx[bid] = val; + } + } + } +} + +template +static __global__ void FastCommonGradBroadcastCUDAKernelHeight(const T *x, + const T *y, + const Tout *out, + const Tout *dout, + int h, + int w, + DY_OP dy_op, + T *dy, + int x_h, + int x_w, + bool is_y) { + __shared__ T sdata[BLOCK_Y][BLOCK_X + 1]; + + T val(0); + size_t width_stride = gridDim.x * blockDim.x; + size_t idx = threadIdx.x + blockDim.x * blockIdx.x; + size_t full_width = + (w & (~((uint64_t)(BLOCK_X - 1)))) + ((w & (BLOCK_X - 1)) ? BLOCK_X : 0); + size_t full_height = + (h & (~((uint64_t)(BLOCK_Y - 1)))) + ((h & (BLOCK_Y - 1)) ? BLOCK_Y : 0); + if (is_y) { + for (int m = idx; m < full_width; m += width_stride) { + sdata[threadIdx.y][threadIdx.x] = 0; + for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) { + int out_offset = n * w + m; + int x_offset = (n % x_h) * x_w + m % x_w; + if (dy) { + if (m < w && n < h) { + T val = dy_op(x[x_offset], y[m], out[out_offset], dout[out_offset]); + sdata[threadIdx.y][threadIdx.x] += val; + } + __syncthreads(); + } + } + if (dy) { + T my_val = sdata[threadIdx.x][threadIdx.y]; + for (int i = warpSize >> 1; i > 0; i >>= 1) { + my_val += paddle::platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i); + } + __syncthreads(); + if ((threadIdx.x == 0)) { + sdata[0][threadIdx.y] = my_val; + } + __syncthreads(); + if (threadIdx.y == 0 && m < w) { + dy[m] = sdata[0][threadIdx.x]; + } + } + } + } else { + for (int m = idx; m < full_width; m += width_stride) { + sdata[threadIdx.y][threadIdx.x] = 0; + for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) { + int out_offset = n * w + m; + int y_offset = (n % x_h) * x_w + m % x_w; + if (dy) { + if (m < w && n < h) { + T val = dy_op(x[m], y[y_offset], out[out_offset], dout[out_offset]); + sdata[threadIdx.y][threadIdx.x] += val; + } + __syncthreads(); + } + } + if (dy) { + T my_val = sdata[threadIdx.x][threadIdx.y]; + for (int i = warpSize >> 1; i > 0; i >>= 1) { + my_val += paddle::platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i); + } + __syncthreads(); + if ((threadIdx.x == 0)) { + sdata[0][threadIdx.y] = my_val; + } + __syncthreads(); + if (threadIdx.y == 0 && m < w) { + dy[m] = sdata[0][threadIdx.x]; + } + } + } + } +} + +template +static __global__ void CommonGradBroadcast1CUDAKernelHeight(const T *x, + const T *y, + const Tout *out, + const Tout *dout, + int h, + int w, + DY_OP dy_op, + T *dy, + int x_h, + int x_w, + bool is_y) { + int j = blockIdx.x; + int i = threadIdx.x; + int tid = threadIdx.x; + T val(0); + + if (is_y) { + do { + int out_offset = i * w + j; + int x_offset = (i % x_h) * x_w + j % x_w; + if (dy) { + val += dy_op(x[x_offset], y[j], out[out_offset], dout[out_offset]); + } + i += ELEMWISE_MAX_BLOCK_DIM; + } while (i < h); + + if (dy) { + h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; + val = paddle::platform::reduceSum(val, tid, h); + if (threadIdx.x == 0) { + dy[j] = val; + } + } + } else { + do { + int out_offset = i * w + j; + int y_offset = (i % x_h) * x_w + j % x_w; + if (dy) { + val += dy_op(x[j], y[y_offset], out[out_offset], dout[out_offset]); + } + i += ELEMWISE_MAX_BLOCK_DIM; + } while (i < h); + + if (dy) { + h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; + val = paddle::platform::reduceSum(val, tid, h); + if (threadIdx.x == 0) { + dy[j] = val; + } + } + } +} + +template +static __global__ void ElemwiseGradBroadcast1CUDAKernel(const T *x, + const T *y, + const Tout *out, + const Tout *dout, + int h, + int w, + bool is_xsize_larger, + DX_OP dx_op, + DY_OP dy_op, + T *dx, + T *dy) { + int j = blockIdx.x; + int i = threadIdx.x; + int tid = threadIdx.x; + T val(0); + if (is_xsize_larger) { + do { + int x_offset = i * w + j; + if (dx) { + dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); + } + if (dy) { + val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); + } + i += ELEMWISE_MAX_BLOCK_DIM; + } while (i < h); + + if (dy) { + h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; + val = paddle::platform::reduceSum(val, tid, h); + if (threadIdx.x == 0) { + dy[j] = val; + } + } + } else { // x.dims < y.dims, broadcast for x. + do { + int y_offset = i * w + j; + if (dy) { + dy[y_offset] = dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); + } + if (dx) { + val += dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); + } + i += ELEMWISE_MAX_BLOCK_DIM; + } while (i < h); + + if (dx) { + h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; + val = paddle::platform::reduceSum(val, tid, h); + if (threadIdx.x == 0) { + dx[j] = val; + } + } + } +} + +// suppose use 2D block is fast because more parallel +// and memory coalesced +template +static __global__ void FastElemwiseGradBroadcast1CUDAKernel( + const T *x, + const T *y, + const Tout *out, + const Tout *dout, + int h, + int w, + bool is_xsize_larger, + DX_OP dx_op, + DY_OP dy_op, + T *dx, + T *dy) { + __shared__ T sdata[BLOCK_Y][BLOCK_X + 1]; + + T val(0); + size_t width_stride = gridDim.x * blockDim.x; + size_t idx = threadIdx.x + blockDim.x * blockIdx.x; + size_t full_width = + (w & (~((uint64_t)(BLOCK_X - 1)))) + ((w & (BLOCK_X - 1)) ? BLOCK_X : 0); + size_t full_height = + (h & (~((uint64_t)(BLOCK_Y - 1)))) + ((h & (BLOCK_Y - 1)) ? BLOCK_Y : 0); + if (is_xsize_larger) { + for (int m = idx; m < full_width; m += width_stride) { + sdata[threadIdx.y][threadIdx.x] = 0; + for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) { + int x_offset = n * w + m; + if (dx && m < w && n < h) { + dx[x_offset] = + dx_op(x[x_offset], y[m], out[x_offset], dout[x_offset]); + } + if (dy) { + if (m < w && n < h) { + T val = dy_op(x[x_offset], y[m], out[x_offset], dout[x_offset]); + sdata[threadIdx.y][threadIdx.x] += val; + } + __syncthreads(); + } + } + if (dy) { + T my_val = sdata[threadIdx.x][threadIdx.y]; + for (int i = warpSize >> 1; i > 0; i >>= 1) + my_val += paddle::platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i); + __syncthreads(); + if ((threadIdx.x == 0)) { + sdata[0][threadIdx.y] = my_val; + } + __syncthreads(); + if (threadIdx.y == 0 && m < w) { + dy[m] = sdata[0][threadIdx.x]; + } + } + } + } else { // x.dims < y.dims, broadcast for x. + for (int m = idx; m < full_width; m += width_stride) { + sdata[threadIdx.y][threadIdx.x] = 0; + for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) { + int y_offset = n * w + m; + if (dy && m < w && n < h) { + dy[y_offset] = + dy_op(x[m], y[y_offset], out[y_offset], dout[y_offset]); + } + if (dx) { + if (m < w && n < h) { + T val = dx_op(x[m], y[y_offset], out[y_offset], dout[y_offset]); + sdata[threadIdx.y][threadIdx.x] += val; + } + __syncthreads(); + } + } + if (dx) { + T my_val = sdata[threadIdx.x][threadIdx.y]; + for (int i = warpSize >> 1; i > 0; i >>= 1) + my_val += paddle::platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i); + __syncthreads(); + if ((threadIdx.x == 0)) { + sdata[0][threadIdx.y] = my_val; + } + __syncthreads(); + if (threadIdx.y == 0 && m < w) { + dx[m] = sdata[0][threadIdx.x]; + } + } + } + } +} + +template +static __global__ void ElemwiseGradBroadcast2CUDAKernel(const T *x, + const T *y, + const Tout *out, + const Tout *dout, + int pre, + int n, + int post, + bool is_xsize_larger, + DX_OP dx_op, + DY_OP dy_op, + T *dx, + T *dy) { + int tid = threadIdx.x; + int j = blockIdx.x; + + T val(0); + int ttid = tid; + + if (is_xsize_larger) { + while (true) { + int i = ttid / post; + int k = ttid % post; + if (i >= pre) break; + + int x_offset = i * n * post + j * post + k; + + if (dx != nullptr) { + dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); + } + + if (dy != nullptr) { + val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); + } + + ttid += ELEMWISE_MAX_BLOCK_DIM; + } + + if (dy) { + int h = pre * post; + h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; + val = paddle::platform::reduceSum(val, tid, h); + if (threadIdx.x == 0) { + dy[j] = val; + } + } + } else { // x.dims < y.dims, broadcast for x. + while (true) { + int i = ttid / post; + int k = ttid % post; + if (i >= pre) break; + + int y_offset = i * n * post + j * post + k; + + if (dy != nullptr) { + dy[y_offset] = dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); + } + + if (dx != nullptr) { + val += dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); + } + + ttid += ELEMWISE_MAX_BLOCK_DIM; + } + + if (dx) { + int h = pre * post; + h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; + val = paddle::platform::reduceSum(val, tid, h); + if (threadIdx.x == 0) { + dx[j] = val; + } + } + } +} + +template +static void ElemwiseGradBroadcast1CUDA(gpuStream_t stream, + const T *x, + const T *y, + const Tout *out, + const Tout *dout, + int h, + int w, + bool is_xsize_larger, + DX_OP dx_op, + DY_OP dy_op, + T *dx, + T *dy) { + // For small case use 1D block + constexpr int half_walf = 16; + if (w < half_walf || h < half_walf) { + int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); + int gird_size = w; + ElemwiseGradBroadcast1CUDAKernel<<>>( + x, y, out, dout, h, w, is_xsize_larger, dx_op, dy_op, dx, dy); + } else { + // suppose perfoemance improves with h increased. + dim3 block_size = dim3(BLOCK_X, BLOCK_Y); + int grid_size = (w + BLOCK_X - 1) / BLOCK_X; + FastElemwiseGradBroadcast1CUDAKernel<<>>( + x, y, out, dout, h, w, is_xsize_larger, dx_op, dy_op, dx, dy); + } +} + +template +static void ElemwiseGradBroadcast2CUDA(gpuStream_t stream, + const T *x, + const T *y, + const Tout *out, + const Tout *dout, + int pre, + int n, + int post, + bool is_xsize_larger, + DX_OP dx_op, + DY_OP dy_op, + T *dx, + T *dy) { + int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post); + int gird_size = n; + ElemwiseGradBroadcast2CUDAKernel<<>>( + x, y, out, dout, pre, n, post, is_xsize_larger, dx_op, dy_op, dx, dy); +} + } // namespace pten -- GitLab