未验证 提交 eab44e1f 编写于 作者: W wangchaochaohu 提交者: GitHub

refine (#29622)

上级 d0b789d2
...@@ -148,6 +148,8 @@ __global__ void MatrixColReduce(const T *__restrict__ in, T *__restrict__ out, ...@@ -148,6 +148,8 @@ __global__ void MatrixColReduce(const T *__restrict__ in, T *__restrict__ out,
size_t width_stride = gridDim.x * blockDim.x; size_t width_stride = gridDim.x * blockDim.x;
size_t full_width = (width & (~((uint64_t)(BLOCK_W - 1)))) + size_t full_width = (width & (~((uint64_t)(BLOCK_W - 1)))) +
((width & (BLOCK_W - 1)) ? BLOCK_W : 0); ((width & (BLOCK_W - 1)) ? BLOCK_W : 0);
size_t full_height = (height & (~((uint64_t)(BLOCK_H - 1)))) +
((height & (BLOCK_H - 1)) ? BLOCK_H : 0);
#pragma unroll #pragma unroll
for (size_t w = idx; w < full_width; w += width_stride) { for (size_t w = idx; w < full_width; w += width_stride) {
...@@ -155,10 +157,10 @@ __global__ void MatrixColReduce(const T *__restrict__ in, T *__restrict__ out, ...@@ -155,10 +157,10 @@ __global__ void MatrixColReduce(const T *__restrict__ in, T *__restrict__ out,
__syncthreads(); __syncthreads();
size_t offset = w + threadIdx.y * width; size_t offset = w + threadIdx.y * width;
#pragma unroll #pragma unroll
for (size_t h = threadIdx.y; h < height; for (size_t h = threadIdx.y; h < full_height;
h += BLOCK_H) { // block-stride loop across matrix height h += BLOCK_H) { // block-stride loop across matrix height
sdata[threadIdx.y][threadIdx.x] += sdata[threadIdx.y][threadIdx.x] +=
(w < width) ? in[offset] : (static_cast<T>(0)); (w < width && h < height) ? in[offset] : (static_cast<T>(0));
offset += width * BLOCK_H; offset += width * BLOCK_H;
} }
__syncthreads(); __syncthreads();
...@@ -184,21 +186,24 @@ __global__ void FP16MatrixColReduce( ...@@ -184,21 +186,24 @@ __global__ void FP16MatrixColReduce(
size_t width_stride = gridDim.x * blockDim.x; size_t width_stride = gridDim.x * blockDim.x;
size_t full_width = (width & (~((uint64_t)(BLOCK_W - 1)))) + size_t full_width = (width & (~((uint64_t)(BLOCK_W - 1)))) +
((width & (BLOCK_W - 1)) ? BLOCK_W : 0); ((width & (BLOCK_W - 1)) ? BLOCK_W : 0);
size_t full_height = (height & (~((uint64_t)(BLOCK_H - 1)))) +
((height & (BLOCK_H - 1)) ? BLOCK_H : 0);
#pragma unroll #pragma unroll
for (size_t w = idx; w < full_width; w += width_stride) { for (size_t w = idx; w < full_width; w += width_stride) {
for (int r = 0; r < repeats; r++) { for (int r = 0; r < repeats; r++) {
sdata[threadIdx.y + r * BLOCK_W][threadIdx.x] = 0; sdata[threadIdx.y + r * BLOCK_W][threadIdx.x] = 0;
} }
__syncthreads(); __syncthreads();
#pragma unroll
for (int r = 0; r < repeats; r++) { for (int r = 0; r < repeats; r++) {
size_t offset = w + (r * BLOCK_W + threadIdx.y) * width; size_t offset = w + (r * BLOCK_W + threadIdx.y) * width;
#pragma unroll #pragma unroll
for (size_t h = r * BLOCK_H + threadIdx.y; h < height; for (size_t h = threadIdx.y + r * BLOCK_W; h < full_height;
h += BLOCK_H) { // block-stride loop across matrix height h += BLOCK_H) { // block-stride loop across matrix height
sdata[r * BLOCK_W + threadIdx.y][threadIdx.x] += sdata[r * BLOCK_W + threadIdx.y][threadIdx.x] +=
(w < width) ? in[offset + r * BLOCK_W * width] (w < width && h < height)
: (static_cast<paddle::platform::float16>(0)); ? in[offset]
: (static_cast<paddle::platform::float16>(0));
offset += width * BLOCK_H; offset += width * BLOCK_H;
} }
} }
...@@ -373,6 +378,7 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> { ...@@ -373,6 +378,7 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
err = cub::DeviceReduce::Sum(temp_storage, temp_storage_bytes, err = cub::DeviceReduce::Sum(temp_storage, temp_storage_bytes,
dout_data, out_data, nums, stream); dout_data, out_data, nums, stream);
PADDLE_ENFORCE_CUDA_SUCCESS(err); PADDLE_ENFORCE_CUDA_SUCCESS(err);
return;
} }
constexpr int block_x = 32; constexpr int block_x = 32;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册