From 58615a62723c81be54ef051a32e7e192e695b951 Mon Sep 17 00:00:00 2001 From: zhaoyuchen2018 <45989343+zhaoyuchen2018@users.noreply.github.com> Date: Sun, 29 Mar 2020 07:42:50 +0800 Subject: [PATCH] Improve elementwise performance. (#23001) * Improve elementwise performance. Elementwise performace is poor as walk into CommonGradBroadcastCUDA, add some new kernels for different data pattern. * Add some cuda kernel to speedup common broadcast cases. test=develop * Add more test cases and fix cuda kernel bug. test=develop * Remove tests as cpu percision fails.test=develop * Refine SplitDims, test=develop * Change file mode, test=develop --- .../elementwise/elementwise_op_function.h | 407 +++++++++++++++++- 1 file changed, 398 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index 23afa752796..b243a9a18e4 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -33,6 +33,8 @@ limitations under the License. */ #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_primitives.h" constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024; +#define BLOCK_X 32 +#define BLOCK_Y 32 #endif #include "paddle/fluid/operators/math/math_function.h" @@ -141,8 +143,8 @@ inline void GetBroadcastDimsArrays(const framework::DDim &x_dims, "ShapeError: broadcast dimension mismatch. Operands could " "not be broadcast together with the shape of X = [%s] and " "the shape of Y = [%s]. Received [%d] in X is not equal to " - "[%d] in Y", - x_dims, y_dims, x_dims_array[i], y_dims_array[i]); + "[%d] in Y at i:%d", + x_dims, y_dims, x_dims_array[i], y_dims_array[i], i); if ((x_dims_array[i] > 1 || y_dims_array[i] > 1) || (x_dims_array[i] == 1 && y_dims_array[i] == 1)) { out_dims_array[i] = std::max(x_dims_array[i], y_dims_array[i]); @@ -373,6 +375,199 @@ __global__ void CommonGradBroadcastCUDAKernel( } } +template +static __global__ void CommonGradBroadcast1CUDAKernelHeight( + const T *x, const T *y, const T *out, const T *dout, int h, int w, + DY_OP dy_op, T *dy, int x_h, int x_w, bool is_y) { + int j = blockIdx.x; + int i = threadIdx.x; + int tid = threadIdx.x; + T val(0); + + if (is_y) { + do { + int out_offset = i * w + j; + int x_offset = (i % x_h) * x_w + j % x_w; + if (dy) { + val += dy_op(x[x_offset], y[j], out[out_offset], dout[out_offset]); + } + i += ELEMWISE_MAX_BLOCK_DIM; + } while (i < h); + + if (dy) { + h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; + val = paddle::platform::reduceSum(val, tid, h); + if (threadIdx.x == 0) { + dy[j] = val; + } + } + } else { + do { + int out_offset = i * w + j; + int y_offset = (i % x_h) * x_w + j % x_w; + if (dy) { + val += dy_op(x[j], y[y_offset], out[out_offset], dout[out_offset]); + } + i += ELEMWISE_MAX_BLOCK_DIM; + } while (i < h); + + if (dy) { + h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; + val = paddle::platform::reduceSum(val, tid, h); + if (threadIdx.x == 0) { + dy[j] = val; + } + } + } +} + +template +static __global__ void FastCommonGradBroadcastCUDAKernelHeight( + const T *x, const T *y, const T *out, const T *dout, int h, int w, + DY_OP dy_op, T *dy, int x_h, int x_w, bool is_y) { + __shared__ T sdata[BLOCK_Y][BLOCK_X + 1]; + + T val(0); + size_t width_stride = gridDim.x * blockDim.x; + size_t idx = threadIdx.x + blockDim.x * blockIdx.x; + size_t full_width = + (w & (~((uint64_t)(BLOCK_X - 1)))) + ((w & (BLOCK_X - 1)) ? BLOCK_X : 0); + size_t full_height = + (h & (~((uint64_t)(BLOCK_Y - 1)))) + ((h & (BLOCK_Y - 1)) ? BLOCK_Y : 0); + if (is_y) { + for (int m = idx; m < full_width; m += width_stride) { + sdata[threadIdx.y][threadIdx.x] = 0; + for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) { + int out_offset = n * w + m; + int x_offset = (n % x_h) * x_w + m % x_w; + if (dy) { + if (m < w && n < h) { + T val = dy_op(x[x_offset], y[m], out[out_offset], dout[out_offset]); + sdata[threadIdx.y][threadIdx.x] += val; + } + __syncthreads(); + } + } + if (dy) { + T my_val = sdata[threadIdx.x][threadIdx.y]; + for (int i = warpSize >> 1; i > 0; i >>= 1) { + my_val += platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i); + } + __syncthreads(); + if ((threadIdx.x == 0)) { + sdata[0][threadIdx.y] = my_val; + } + __syncthreads(); + if (threadIdx.y == 0 && m < w) { + dy[m] = sdata[0][threadIdx.x]; + } + } + } + } else { + for (int m = idx; m < full_width; m += width_stride) { + sdata[threadIdx.y][threadIdx.x] = 0; + for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) { + int out_offset = n * w + m; + int y_offset = (n % x_h) * x_w + m % x_w; + if (dy) { + if (m < w && n < h) { + T val = dy_op(x[m], y[y_offset], out[out_offset], dout[out_offset]); + sdata[threadIdx.y][threadIdx.x] += val; + } + __syncthreads(); + } + } + if (dy) { + T my_val = sdata[threadIdx.x][threadIdx.y]; + for (int i = warpSize >> 1; i > 0; i >>= 1) { + my_val += platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i); + } + __syncthreads(); + if ((threadIdx.x == 0)) { + sdata[0][threadIdx.y] = my_val; + } + __syncthreads(); + if (threadIdx.y == 0 && m < w) { + dy[m] = sdata[0][threadIdx.x]; + } + } + } + } +} + +template +static __global__ void FastCommonGradBroadcastAllCUDAKernel( + const T *x, const T *y, const T *out, const T *dout, int pre, int n, + int post, bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) { + int tid = threadIdx.x; + int bid = blockIdx.x; + + T val(0); + if (is_xsize_larger) { + for (int i = tid; i < n; i += ELEMWISE_MAX_BLOCK_DIM) { + int b_i = bid / post; + int b_j = bid % post; + int x_offset = b_i * n * post + i * post + b_j; + int y_offset = b_i * post + b_j; + if (dx) { + dx[x_offset] = + dx_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]); + } + if (dy) { + val += dy_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]); + } + } + if (dy) { + int h = n > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : n; + val = paddle::platform::reduceSum(val, tid, h); + if (tid == 0) { + dy[bid] = val; + } + } + } else { + for (int i = tid; i < n; i += ELEMWISE_MAX_BLOCK_DIM) { + int b_i = bid / post; + int b_j = bid % post; + int y_offset = b_i * n * post + i * post + b_j; + int x_offset = b_i * post + b_j; + if (dy) { + dy[y_offset] = + dy_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]); + } + if (dx) { + val += dx_op(x[x_offset], y[y_offset], out[x_offset], dout[x_offset]); + } + } + if (dx) { + int h = n > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : n; + val = paddle::platform::reduceSum(val, tid, h); + if (tid == 0) { + dx[bid] = val; + } + } + } +} + +// Check input can be split into 2 parts +static inline bool SplitDims(const std::vector &y_broadcast_pos, + int max_dim) { + bool can_split_dim2 = true; + // must at start or end. + if (y_broadcast_pos[0] != 0 && + y_broadcast_pos[y_broadcast_pos.size() - 1] != max_dim - 1) { + can_split_dim2 = false; + } else { + for (int i = 1; i < y_broadcast_pos.size(); ++i) { + // dim must be continue + if (y_broadcast_pos[i] != y_broadcast_pos[i - 1] + 1) { + can_split_dim2 = false; + break; + } + } + } + return can_split_dim2; +} + template void CommonGradBroadcastCUDA( const framework::Tensor &x, const framework::Tensor &y, @@ -436,7 +631,203 @@ void CommonGradBroadcastCUDA( x_dims_order[i] = out_dims_array[x_trans_indexs[i]]; y_dims_order[i] = out_dims_array[y_trans_indexs[i]]; } + std::vector x_broadcast_pos; + std::vector y_broadcast_pos; + + int bytes = max_dim * sizeof(int); + + for (int i = 0; i < max_dim; ++i) { + if (x_dims_array[i] != out_dims_array[i] && x_dims_array[i] == 1) { + x_broadcast_pos.emplace_back(i); + } + if (y_dims_array[i] != out_dims_array[i] && y_dims_array[i] == 1) { + y_broadcast_pos.emplace_back(i); + } + } + auto stream = ctx.stream(); + bool can_split_x = false; + bool can_split_y = false; + + auto FastCommonCUDAF = [&](const std::vector &broadcast_pos, bool is_y) { + int h = + std::accumulate(out_dims_array, out_dims_array + broadcast_pos.size(), + 1, std::multiplies()); + int w = + std::accumulate(out_dims_array + broadcast_pos.size(), + out_dims_array + max_dim, 1, std::multiplies()); + + VLOG(3) << "FastCommonCUDAF elementwise w:" << w << " h:" << h + << " is_y:" << is_y; + + int split_h; + int split_w; + int kh = h; + int kw = w; + + if (is_y) { + split_h = + std::accumulate(x_dims_array, x_dims_array + broadcast_pos.size(), 1, + std::multiplies()); + split_w = + std::accumulate(x_dims_array + broadcast_pos.size(), + x_dims_array + max_dim, 1, std::multiplies()); + + } else { + split_h = + std::accumulate(y_dims_array, y_dims_array + broadcast_pos.size(), 1, + std::multiplies()); + split_w = + std::accumulate(y_dims_array + broadcast_pos.size(), + y_dims_array + max_dim, 1, std::multiplies()); + } + + if (h > split_h) kh = split_h; + if (w > split_w) kw = split_w; + + if (is_y) { + if (w < 16 || h < 16) { + int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); + int grid_size = w; + CommonGradBroadcast1CUDAKernelHeight<<>>( + x_data, y_data, out_data, dout_data, h, w, dy_op, dy_data, kh, kw, + is_y); + } else { + dim3 block_size = dim3(BLOCK_X, BLOCK_Y); + int grid_size = (w + BLOCK_X - 1) / BLOCK_X; + FastCommonGradBroadcastCUDAKernelHeight<<>>( + x_data, y_data, out_data, dout_data, h, w, dy_op, dy_data, kh, kw, + is_y); + } + } else { + if (w < 16 || h < 16) { + int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); + int grid_size = w; + CommonGradBroadcast1CUDAKernelHeight<<>>( + x_data, y_data, out_data, dout_data, h, w, dx_op, dx_data, kh, kw, + is_y); + } else { + dim3 block_size = dim3(BLOCK_X, BLOCK_Y); + int grid_size = (w + BLOCK_X - 1) / BLOCK_X; + FastCommonGradBroadcastCUDAKernelHeight<<>>( + x_data, y_data, out_data, dout_data, h, w, dx_op, dx_data, kh, kw, + is_y); + } + } + }; + + auto FastBroadCastHeightCUDAF = [&](const std::vector &broadcast_pos, + bool x_large) { + int h = + std::accumulate(out_dims_array, out_dims_array + broadcast_pos.size(), + 1, std::multiplies()); + int w = + std::accumulate(out_dims_array + broadcast_pos.size(), + out_dims_array + max_dim, 1, std::multiplies()); + + VLOG(3) << "FastBroadCastHeightCUDAF w:" << w << " h:" << h; + + if (w < 16 || h < 16) { + int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); + int grid_size = w; + ElemwiseGradBroadcast1CUDAKernel<<>>( + x_data, y_data, out_data, dout_data, h, w, x_large, dx_op, dy_op, + dx_data, dy_data); + } else { + dim3 block_size = dim3(BLOCK_X, BLOCK_Y); + int grid_size = (w + BLOCK_X - 1) / BLOCK_X; + FastElemwiseGradBroadcast1CUDAKernel<<>>( + x_data, y_data, out_data, dout_data, h, w, x_large, dx_op, dy_op, + dx_data, dy_data); + } + }; + + auto FastBroadCastAllCUDAF = [&](const std::vector &broadcast_pos, + int max_dim, bool is_x_large) { + int axis = broadcast_pos[0]; + int pre = std::accumulate(out_dims_array, out_dims_array + axis, 1, + std::multiplies()); + int mid = out_dims_array[axis]; + int post = + std::accumulate(out_dims_array + axis + 1, out_dims_array + max_dim, 1, + std::multiplies()); + + VLOG(3) << "FastBroadCastAllCUDAF pre:" << pre << " mid:" << mid + << " post:" << post; + + int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid); + int grid_size = pre * post; + + FastCommonGradBroadcastAllCUDAKernel<<>>( + x_data, y_data, out_data, dout_data, pre, mid, post, is_x_large, dx_op, + dy_op, dx_data, dy_data); + }; + + // do fast elementwise if: 1. only one input need to do broadcast, we can + // fallback + // to old fast path. + // 2. if both x and y need broadcast, then do it one by one. + if (x_broadcast_pos.empty() && !y_broadcast_pos.empty()) { + can_split_y = SplitDims(y_broadcast_pos, max_dim); + if (can_split_y) { + // only y need to do broadcast on h + if (y_broadcast_pos[0] == 0) { + FastBroadCastHeightCUDAF(y_broadcast_pos, true); + } else { + LOG(ERROR) << "Error, broadcast should not into w broadcast"; + } + return; + } else if (y_broadcast_pos.size() == 1) { // for only one dim broadcast. + // If cannot split, which means input has 3 parts + FastBroadCastAllCUDAF(y_broadcast_pos, max_dim, true); + return; + } + } else if (y_broadcast_pos.empty() && !x_broadcast_pos.empty()) { + // only x need broadcast + can_split_x = SplitDims(x_broadcast_pos, max_dim); + if (can_split_x) { + if (x_broadcast_pos[0] == 0) { + FastBroadCastHeightCUDAF(x_broadcast_pos, false); + } else { + // x need to do broadcast on w + LOG(ERROR) << "Error, broadcast should not into w broadcast"; + } + return; + } else if (x_broadcast_pos.size() == 1) { + FastBroadCastAllCUDAF(x_broadcast_pos, max_dim, false); + return; + } + } else if (!x_broadcast_pos.empty() && !y_broadcast_pos.empty()) { + // do x and y broadcast each. + can_split_y = SplitDims(y_broadcast_pos, max_dim); + if (can_split_y) { + // begin at start. + if (y_broadcast_pos[0] == 0) { + FastCommonCUDAF(y_broadcast_pos, true); + } else { + // finish at end + LOG(ERROR) << "Error, broadcast should not into w broadcast"; + } + } + can_split_x = SplitDims(x_broadcast_pos, max_dim); + if (can_split_x) { + if (x_broadcast_pos[0] == 0) { + FastCommonCUDAF(x_broadcast_pos, false); + } else { + LOG(ERROR) << "Error, broadcast should not into w broadcast"; + } + } + VLOG(3) << "CommonBroadcast can_split_y:" << can_split_y + << " can_split_x:" << can_split_x; + // if both x and y into fast path then return + if (can_split_y && can_split_x) return; + } + // Should remove memory copy, use reg instead. int x_blocks = 0; int x_threads = 0; ComputeBroadcastKernelSize(x_dims_array, out_dims_array, &x_blocks, @@ -446,7 +837,6 @@ void CommonGradBroadcastCUDA( ComputeBroadcastKernelSize(y_dims_array, out_dims_array, &y_blocks, &y_threads, max_dim); - int bytes = max_dim * sizeof(int); auto x_strides_array_tmp = memory::Alloc(ctx, bytes); int *x_strides_array_gpu = reinterpret_cast(x_strides_array_tmp->ptr()); @@ -468,7 +858,7 @@ void CommonGradBroadcastCUDA( 1, std::multiplies()); int x_block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, x_threads); int y_block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, y_threads); - if (dx) { + if (dx && !can_split_x) { auto x_strides_order_tmp = memory::Alloc(ctx, bytes); int *x_strides_order_gpu = reinterpret_cast(x_strides_order_tmp->ptr()); @@ -485,7 +875,7 @@ void CommonGradBroadcastCUDA( x_strides_order_gpu, x_dims_order_gpu, x_data, y_data, out_data, dout_data, dx_data, out_size, max_dim, x_threads, dx_op); } - if (dy) { + if (dy && !can_split_y) { auto y_strides_order_tmp = memory::Alloc(ctx, bytes); int *y_strides_order_gpu = reinterpret_cast(y_strides_order_tmp->ptr()); @@ -846,9 +1236,6 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel( } } -#define BLOCK_X 32 -#define BLOCK_Y 32 - // suppose use 2D block is fast because more parallel // and memory coalesced template @@ -906,7 +1293,7 @@ static __global__ void FastElemwiseGradBroadcast1CUDAKernel( } if (dx) { if (m < w && n < h) { - T val = dy_op(x[m], y[y_offset], out[y_offset], dout[y_offset]); + T val = dx_op(x[m], y[y_offset], out[y_offset], dout[y_offset]); sdata[threadIdx.y][threadIdx.x] += val; } __syncthreads(); @@ -1151,6 +1538,7 @@ void ElemwiseGradComputeWithBroadcast( const framework::Tensor &dout, int axis, framework::Tensor *dx, framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) { bool is_xsize_larger = true; + int max_dim = x_dims.size(); if (x_dims.size() < y_dims.size()) { is_xsize_larger = false; @@ -1173,6 +1561,7 @@ void ElemwiseGradComputeWithBroadcast( get_mid_dims(y_dims, x_dims_trimed, axis_trim, &pre, &n, &post, &is_run_common_broadcast); } + // special case for common backward implementation. if (is_run_common_broadcast) { CommonElementwiseBroadcastBackward( -- GitLab