From 01d7ccd4b65f5c1aa822c570c1d985804022e94a Mon Sep 17 00:00:00 2001 From: zhaoyuchen2018 <45989343+zhaoyuchen2018@users.noreply.github.com> Date: Fri, 3 Apr 2020 20:25:45 +0800 Subject: [PATCH] Fix elementwise compile error, test=develop (#23381) elementwise function used before definition then failed in cuda 8, move it ahead. --- .../elementwise/elementwise_op_function.h | 257 +++++++++--------- 1 file changed, 129 insertions(+), 128 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index f6969611820..d0b8e97c71f 100755 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -336,6 +336,135 @@ inline void ComputeBroadcastTranspositionArray(const int *x_one_indexs, } #ifdef __NVCC__ +template +static __global__ void ElemwiseGradBroadcast1CUDAKernel( + const T *x, const T *y, const T *out, const T *dout, int h, int w, + bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) { + int j = blockIdx.x; + int i = threadIdx.x; + int tid = threadIdx.x; + T val(0); + if (is_xsize_larger) { + do { + int x_offset = i * w + j; + if (dx) { + dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); + } + if (dy) { + val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); + } + i += ELEMWISE_MAX_BLOCK_DIM; + } while (i < h); + + if (dy) { + h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; + val = paddle::platform::reduceSum(val, tid, h); + if (threadIdx.x == 0) { + dy[j] = val; + } + } + } else { // x.dims < y.dims, broadcast for x. + do { + int y_offset = i * w + j; + if (dy) { + dy[y_offset] = dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); + } + if (dx) { + val += dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); + } + i += ELEMWISE_MAX_BLOCK_DIM; + } while (i < h); + + if (dx) { + h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; + val = paddle::platform::reduceSum(val, tid, h); + if (threadIdx.x == 0) { + dx[j] = val; + } + } + } +} + +// suppose use 2D block is fast because more parallel +// and memory coalesced +template +static __global__ void FastElemwiseGradBroadcast1CUDAKernel( + const T *x, const T *y, const T *out, const T *dout, int h, int w, + bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) { + __shared__ T sdata[BLOCK_Y][BLOCK_X + 1]; + + T val(0); + size_t width_stride = gridDim.x * blockDim.x; + size_t idx = threadIdx.x + blockDim.x * blockIdx.x; + size_t full_width = + (w & (~((uint64_t)(BLOCK_X - 1)))) + ((w & (BLOCK_X - 1)) ? BLOCK_X : 0); + size_t full_height = + (h & (~((uint64_t)(BLOCK_Y - 1)))) + ((h & (BLOCK_Y - 1)) ? BLOCK_Y : 0); + if (is_xsize_larger) { + for (int m = idx; m < full_width; m += width_stride) { + sdata[threadIdx.y][threadIdx.x] = 0; + for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) { + int x_offset = n * w + m; + if (dx && m < w && n < h) { + dx[x_offset] = + dx_op(x[x_offset], y[m], out[x_offset], dout[x_offset]); + } + if (dy) { + if (m < w && n < h) { + T val = dy_op(x[x_offset], y[m], out[x_offset], dout[x_offset]); + sdata[threadIdx.y][threadIdx.x] += val; + } + __syncthreads(); + } + } + if (dy) { + T my_val = sdata[threadIdx.x][threadIdx.y]; + for (int i = warpSize >> 1; i > 0; i >>= 1) + my_val += platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i); + __syncthreads(); + if ((threadIdx.x == 0)) { + sdata[0][threadIdx.y] = my_val; + } + __syncthreads(); + if (threadIdx.y == 0 && m < w) { + dy[m] = sdata[0][threadIdx.x]; + } + } + } + } else { // x.dims < y.dims, broadcast for x. + for (int m = idx; m < full_width; m += width_stride) { + sdata[threadIdx.y][threadIdx.x] = 0; + for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) { + int y_offset = n * w + m; + if (dy && m < w && n < h) { + dy[y_offset] = + dy_op(x[m], y[y_offset], out[y_offset], dout[y_offset]); + } + if (dx) { + if (m < w && n < h) { + T val = dx_op(x[m], y[y_offset], out[y_offset], dout[y_offset]); + sdata[threadIdx.y][threadIdx.x] += val; + } + __syncthreads(); + } + } + if (dx) { + T my_val = sdata[threadIdx.x][threadIdx.y]; + for (int i = warpSize >> 1; i > 0; i >>= 1) + my_val += platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i); + __syncthreads(); + if ((threadIdx.x == 0)) { + sdata[0][threadIdx.y] = my_val; + } + __syncthreads(); + if (threadIdx.y == 0 && m < w) { + dx[m] = sdata[0][threadIdx.x]; + } + } + } + } +} + template __global__ void CommonGradBroadcastCUDAKernel( const int *x_strides_array, const int *y_strides_array, @@ -1326,134 +1455,6 @@ static void ElemwiseGradBroadcast1CPU(const T *x, const T *y, const T *out, } #ifdef __NVCC__ -template -static __global__ void ElemwiseGradBroadcast1CUDAKernel( - const T *x, const T *y, const T *out, const T *dout, int h, int w, - bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) { - int j = blockIdx.x; - int i = threadIdx.x; - int tid = threadIdx.x; - T val(0); - if (is_xsize_larger) { - do { - int x_offset = i * w + j; - if (dx) { - dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); - } - if (dy) { - val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]); - } - i += ELEMWISE_MAX_BLOCK_DIM; - } while (i < h); - - if (dy) { - h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; - val = paddle::platform::reduceSum(val, tid, h); - if (threadIdx.x == 0) { - dy[j] = val; - } - } - } else { // x.dims < y.dims, broadcast for x. - do { - int y_offset = i * w + j; - if (dy) { - dy[y_offset] = dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); - } - if (dx) { - val += dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]); - } - i += ELEMWISE_MAX_BLOCK_DIM; - } while (i < h); - - if (dx) { - h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h; - val = paddle::platform::reduceSum(val, tid, h); - if (threadIdx.x == 0) { - dx[j] = val; - } - } - } -} - -// suppose use 2D block is fast because more parallel -// and memory coalesced -template -static __global__ void FastElemwiseGradBroadcast1CUDAKernel( - const T *x, const T *y, const T *out, const T *dout, int h, int w, - bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) { - __shared__ T sdata[BLOCK_Y][BLOCK_X + 1]; - - T val(0); - size_t width_stride = gridDim.x * blockDim.x; - size_t idx = threadIdx.x + blockDim.x * blockIdx.x; - size_t full_width = - (w & (~((uint64_t)(BLOCK_X - 1)))) + ((w & (BLOCK_X - 1)) ? BLOCK_X : 0); - size_t full_height = - (h & (~((uint64_t)(BLOCK_Y - 1)))) + ((h & (BLOCK_Y - 1)) ? BLOCK_Y : 0); - if (is_xsize_larger) { - for (int m = idx; m < full_width; m += width_stride) { - sdata[threadIdx.y][threadIdx.x] = 0; - for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) { - int x_offset = n * w + m; - if (dx && m < w && n < h) { - dx[x_offset] = - dx_op(x[x_offset], y[m], out[x_offset], dout[x_offset]); - } - if (dy) { - if (m < w && n < h) { - T val = dy_op(x[x_offset], y[m], out[x_offset], dout[x_offset]); - sdata[threadIdx.y][threadIdx.x] += val; - } - __syncthreads(); - } - } - if (dy) { - T my_val = sdata[threadIdx.x][threadIdx.y]; - for (int i = warpSize >> 1; i > 0; i >>= 1) - my_val += platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i); - __syncthreads(); - if ((threadIdx.x == 0)) { - sdata[0][threadIdx.y] = my_val; - } - __syncthreads(); - if (threadIdx.y == 0 && m < w) { - dy[m] = sdata[0][threadIdx.x]; - } - } - } - } else { // x.dims < y.dims, broadcast for x. - for (int m = idx; m < full_width; m += width_stride) { - sdata[threadIdx.y][threadIdx.x] = 0; - for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) { - int y_offset = n * w + m; - if (dy && m < w && n < h) { - dy[y_offset] = - dy_op(x[m], y[y_offset], out[y_offset], dout[y_offset]); - } - if (dx) { - if (m < w && n < h) { - T val = dx_op(x[m], y[y_offset], out[y_offset], dout[y_offset]); - sdata[threadIdx.y][threadIdx.x] += val; - } - __syncthreads(); - } - } - if (dx) { - T my_val = sdata[threadIdx.x][threadIdx.y]; - for (int i = warpSize >> 1; i > 0; i >>= 1) - my_val += platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i); - __syncthreads(); - if ((threadIdx.x == 0)) { - sdata[0][threadIdx.y] = my_val; - } - __syncthreads(); - if (threadIdx.y == 0 && m < w) { - dx[m] = sdata[0][threadIdx.x]; - } - } - } - } -} template static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T *x, -- GitLab