未验证 提交 eab44e1f 编写于 作者: W wangchaochaohu 提交者: GitHub

refine (#29622)

上级 d0b789d2
......@@ -148,6 +148,8 @@ __global__ void MatrixColReduce(const T *__restrict__ in, T *__restrict__ out,
size_t width_stride = gridDim.x * blockDim.x;
size_t full_width = (width & (~((uint64_t)(BLOCK_W - 1)))) +
((width & (BLOCK_W - 1)) ? BLOCK_W : 0);
size_t full_height = (height & (~((uint64_t)(BLOCK_H - 1)))) +
((height & (BLOCK_H - 1)) ? BLOCK_H : 0);
#pragma unroll
for (size_t w = idx; w < full_width; w += width_stride) {
......@@ -155,10 +157,10 @@ __global__ void MatrixColReduce(const T *__restrict__ in, T *__restrict__ out,
__syncthreads();
size_t offset = w + threadIdx.y * width;
#pragma unroll
for (size_t h = threadIdx.y; h < height;
for (size_t h = threadIdx.y; h < full_height;
h += BLOCK_H) { // block-stride loop across matrix height
sdata[threadIdx.y][threadIdx.x] +=
(w < width) ? in[offset] : (static_cast<T>(0));
(w < width && h < height) ? in[offset] : (static_cast<T>(0));
offset += width * BLOCK_H;
}
__syncthreads();
......@@ -184,20 +186,23 @@ __global__ void FP16MatrixColReduce(
size_t width_stride = gridDim.x * blockDim.x;
size_t full_width = (width & (~((uint64_t)(BLOCK_W - 1)))) +
((width & (BLOCK_W - 1)) ? BLOCK_W : 0);
size_t full_height = (height & (~((uint64_t)(BLOCK_H - 1)))) +
((height & (BLOCK_H - 1)) ? BLOCK_H : 0);
#pragma unroll
for (size_t w = idx; w < full_width; w += width_stride) {
for (int r = 0; r < repeats; r++) {
sdata[threadIdx.y + r * BLOCK_W][threadIdx.x] = 0;
}
__syncthreads();
#pragma unroll
for (int r = 0; r < repeats; r++) {
size_t offset = w + (r * BLOCK_W + threadIdx.y) * width;
#pragma unroll
for (size_t h = r * BLOCK_H + threadIdx.y; h < height;
for (size_t h = threadIdx.y + r * BLOCK_W; h < full_height;
h += BLOCK_H) { // block-stride loop across matrix height
sdata[r * BLOCK_W + threadIdx.y][threadIdx.x] +=
(w < width) ? in[offset + r * BLOCK_W * width]
(w < width && h < height)
? in[offset]
: (static_cast<paddle::platform::float16>(0));
offset += width * BLOCK_H;
}
......@@ -373,6 +378,7 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
err = cub::DeviceReduce::Sum(temp_storage, temp_storage_bytes,
dout_data, out_data, nums, stream);
PADDLE_ENFORCE_CUDA_SUCCESS(err);
return;
}
constexpr int block_x = 32;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册