未验证 提交 01d7ccd4 编写于 作者: Z zhaoyuchen2018 提交者: GitHub

Fix elementwise compile error, test=develop (#23381)

elementwise function used before definition then failed in cuda 8, move it ahead.
上级 24a063f6
......@@ -336,6 +336,135 @@ inline void ComputeBroadcastTranspositionArray(const int *x_one_indexs,
}
#ifdef __NVCC__
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast1CUDAKernel(
const T *x, const T *y, const T *out, const T *dout, int h, int w,
bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
int j = blockIdx.x;
int i = threadIdx.x;
int tid = threadIdx.x;
T val(0);
if (is_xsize_larger) {
do {
int x_offset = i * w + j;
if (dx) {
dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
}
if (dy) {
val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
}
i += ELEMWISE_MAX_BLOCK_DIM;
} while (i < h);
if (dy) {
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dy[j] = val;
}
}
} else { // x.dims < y.dims, broadcast for x.
do {
int y_offset = i * w + j;
if (dy) {
dy[y_offset] = dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
}
if (dx) {
val += dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
}
i += ELEMWISE_MAX_BLOCK_DIM;
} while (i < h);
if (dx) {
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dx[j] = val;
}
}
}
}
// suppose use 2D block is fast because more parallel
// and memory coalesced
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void FastElemwiseGradBroadcast1CUDAKernel(
const T *x, const T *y, const T *out, const T *dout, int h, int w,
bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
__shared__ T sdata[BLOCK_Y][BLOCK_X + 1];
T val(0);
size_t width_stride = gridDim.x * blockDim.x;
size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
size_t full_width =
(w & (~((uint64_t)(BLOCK_X - 1)))) + ((w & (BLOCK_X - 1)) ? BLOCK_X : 0);
size_t full_height =
(h & (~((uint64_t)(BLOCK_Y - 1)))) + ((h & (BLOCK_Y - 1)) ? BLOCK_Y : 0);
if (is_xsize_larger) {
for (int m = idx; m < full_width; m += width_stride) {
sdata[threadIdx.y][threadIdx.x] = 0;
for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) {
int x_offset = n * w + m;
if (dx && m < w && n < h) {
dx[x_offset] =
dx_op(x[x_offset], y[m], out[x_offset], dout[x_offset]);
}
if (dy) {
if (m < w && n < h) {
T val = dy_op(x[x_offset], y[m], out[x_offset], dout[x_offset]);
sdata[threadIdx.y][threadIdx.x] += val;
}
__syncthreads();
}
}
if (dy) {
T my_val = sdata[threadIdx.x][threadIdx.y];
for (int i = warpSize >> 1; i > 0; i >>= 1)
my_val += platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i);
__syncthreads();
if ((threadIdx.x == 0)) {
sdata[0][threadIdx.y] = my_val;
}
__syncthreads();
if (threadIdx.y == 0 && m < w) {
dy[m] = sdata[0][threadIdx.x];
}
}
}
} else { // x.dims < y.dims, broadcast for x.
for (int m = idx; m < full_width; m += width_stride) {
sdata[threadIdx.y][threadIdx.x] = 0;
for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) {
int y_offset = n * w + m;
if (dy && m < w && n < h) {
dy[y_offset] =
dy_op(x[m], y[y_offset], out[y_offset], dout[y_offset]);
}
if (dx) {
if (m < w && n < h) {
T val = dx_op(x[m], y[y_offset], out[y_offset], dout[y_offset]);
sdata[threadIdx.y][threadIdx.x] += val;
}
__syncthreads();
}
}
if (dx) {
T my_val = sdata[threadIdx.x][threadIdx.y];
for (int i = warpSize >> 1; i > 0; i >>= 1)
my_val += platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i);
__syncthreads();
if ((threadIdx.x == 0)) {
sdata[0][threadIdx.y] = my_val;
}
__syncthreads();
if (threadIdx.y == 0 && m < w) {
dx[m] = sdata[0][threadIdx.x];
}
}
}
}
}
template <typename T, typename DX_OP>
__global__ void CommonGradBroadcastCUDAKernel(
const int *x_strides_array, const int *y_strides_array,
......@@ -1326,134 +1455,6 @@ static void ElemwiseGradBroadcast1CPU(const T *x, const T *y, const T *out,
}
#ifdef __NVCC__
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast1CUDAKernel(
const T *x, const T *y, const T *out, const T *dout, int h, int w,
bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
int j = blockIdx.x;
int i = threadIdx.x;
int tid = threadIdx.x;
T val(0);
if (is_xsize_larger) {
do {
int x_offset = i * w + j;
if (dx) {
dx[x_offset] = dx_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
}
if (dy) {
val += dy_op(x[x_offset], y[j], out[x_offset], dout[x_offset]);
}
i += ELEMWISE_MAX_BLOCK_DIM;
} while (i < h);
if (dy) {
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dy[j] = val;
}
}
} else { // x.dims < y.dims, broadcast for x.
do {
int y_offset = i * w + j;
if (dy) {
dy[y_offset] = dy_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
}
if (dx) {
val += dx_op(x[j], y[y_offset], out[y_offset], dout[y_offset]);
}
i += ELEMWISE_MAX_BLOCK_DIM;
} while (i < h);
if (dx) {
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dx[j] = val;
}
}
}
}
// suppose use 2D block is fast because more parallel
// and memory coalesced
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void FastElemwiseGradBroadcast1CUDAKernel(
const T *x, const T *y, const T *out, const T *dout, int h, int w,
bool is_xsize_larger, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
__shared__ T sdata[BLOCK_Y][BLOCK_X + 1];
T val(0);
size_t width_stride = gridDim.x * blockDim.x;
size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
size_t full_width =
(w & (~((uint64_t)(BLOCK_X - 1)))) + ((w & (BLOCK_X - 1)) ? BLOCK_X : 0);
size_t full_height =
(h & (~((uint64_t)(BLOCK_Y - 1)))) + ((h & (BLOCK_Y - 1)) ? BLOCK_Y : 0);
if (is_xsize_larger) {
for (int m = idx; m < full_width; m += width_stride) {
sdata[threadIdx.y][threadIdx.x] = 0;
for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) {
int x_offset = n * w + m;
if (dx && m < w && n < h) {
dx[x_offset] =
dx_op(x[x_offset], y[m], out[x_offset], dout[x_offset]);
}
if (dy) {
if (m < w && n < h) {
T val = dy_op(x[x_offset], y[m], out[x_offset], dout[x_offset]);
sdata[threadIdx.y][threadIdx.x] += val;
}
__syncthreads();
}
}
if (dy) {
T my_val = sdata[threadIdx.x][threadIdx.y];
for (int i = warpSize >> 1; i > 0; i >>= 1)
my_val += platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i);
__syncthreads();
if ((threadIdx.x == 0)) {
sdata[0][threadIdx.y] = my_val;
}
__syncthreads();
if (threadIdx.y == 0 && m < w) {
dy[m] = sdata[0][threadIdx.x];
}
}
}
} else { // x.dims < y.dims, broadcast for x.
for (int m = idx; m < full_width; m += width_stride) {
sdata[threadIdx.y][threadIdx.x] = 0;
for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) {
int y_offset = n * w + m;
if (dy && m < w && n < h) {
dy[y_offset] =
dy_op(x[m], y[y_offset], out[y_offset], dout[y_offset]);
}
if (dx) {
if (m < w && n < h) {
T val = dx_op(x[m], y[y_offset], out[y_offset], dout[y_offset]);
sdata[threadIdx.y][threadIdx.x] += val;
}
__syncthreads();
}
}
if (dx) {
T my_val = sdata[threadIdx.x][threadIdx.y];
for (int i = warpSize >> 1; i > 0; i >>= 1)
my_val += platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i);
__syncthreads();
if ((threadIdx.x == 0)) {
sdata[0][threadIdx.y] = my_val;
}
__syncthreads();
if (threadIdx.y == 0 && m < w) {
dx[m] = sdata[0][threadIdx.x];
}
}
}
}
}
template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T *x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册