未验证 提交 8256f6fa 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix lars (#36431)

上级 3cf57646
......@@ -165,8 +165,10 @@ __global__ void L2NormKernel(
int tid = threadIdx.x + blockDim.x * blockIdx.x;
int grid_stride = LARS_BLOCK_SIZE * gridDim.x;
const MT rescale_pow = rescale_grad * rescale_grad;
s_buffer[0] = static_cast<MT>(0);
s_buffer[1] = static_cast<MT>(0);
if (threadIdx.x == 0) {
s_buffer[0] = static_cast<MT>(0);
s_buffer[1] = static_cast<MT>(0);
}
MT p_tmp = static_cast<MT>(0);
MT g_tmp = static_cast<MT>(0);
......@@ -175,8 +177,12 @@ __global__ void L2NormKernel(
p_tmp = static_cast<MT>(p_data[tid]);
g_tmp = static_cast<MT>(g_data[tid]);
}
s_buffer[0] += math::blockReduceSum<MT>(p_tmp * p_tmp, FINAL_MASK);
s_buffer[1] += math::blockReduceSum<MT>(g_tmp * g_tmp, FINAL_MASK);
MT tmp0 = math::blockReduceSum<MT>(p_tmp * p_tmp, FINAL_MASK);
MT tmp1 = math::blockReduceSum<MT>(g_tmp * g_tmp, FINAL_MASK);
if (threadIdx.x == 0) {
s_buffer[0] += tmp0;
s_buffer[1] += tmp1;
}
} else {
/* Avoid occupy too much temp buffer. Slice the whole data into 2 parts,
the front of data whose quantity is excatly multiple of grid-thread
......@@ -185,8 +191,12 @@ __global__ void L2NormKernel(
p_tmp = static_cast<MT>(p_data[tid]);
g_tmp = static_cast<MT>(g_data[tid]);
tid += grid_stride;
s_buffer[0] += math::blockReduceSum<MT>(p_tmp * p_tmp, FINAL_MASK);
s_buffer[1] += math::blockReduceSum<MT>(g_tmp * g_tmp, FINAL_MASK);
MT tmp0 = math::blockReduceSum<MT>(p_tmp * p_tmp, FINAL_MASK);
MT tmp1 = math::blockReduceSum<MT>(g_tmp * g_tmp, FINAL_MASK);
if (threadIdx.x == 0) {
s_buffer[0] += tmp0;
s_buffer[1] += tmp1;
}
__syncthreads();
}
MT p_val = 0;
......@@ -195,8 +205,12 @@ __global__ void L2NormKernel(
p_val = static_cast<MT>(p_data[tid]);
g_val = static_cast<MT>(g_data[tid]);
}
s_buffer[0] += math::blockReduceSum<MT>(p_val * p_val, FINAL_MASK);
s_buffer[1] += math::blockReduceSum<MT>(g_val * g_val, FINAL_MASK);
MT tmp0 = math::blockReduceSum<MT>(p_val * p_val, FINAL_MASK);
MT tmp1 = math::blockReduceSum<MT>(g_val * g_val, FINAL_MASK);
if (threadIdx.x == 0) {
s_buffer[0] += tmp0;
s_buffer[1] += tmp1;
}
}
__syncthreads();
......@@ -208,8 +222,15 @@ __global__ void L2NormKernel(
cg->sync(); // Grid sync for writring partial result to gloabl memory
MT p_part_sum = threadIdx.x < gridDim.x ? p_buffer[threadIdx.x] : 0;
MT g_part_sum = threadIdx.x < gridDim.x ? g_buffer[threadIdx.x] : 0;
*p_n = Sqrt(math::blockReduceSum<MT>(p_part_sum, FINAL_MASK));
*g_n = Sqrt(rescale_pow * math::blockReduceSum<MT>(g_part_sum, FINAL_MASK));
MT tmp0 = math::blockReduceSum<MT>(p_part_sum, FINAL_MASK);
MT tmp1 = math::blockReduceSum<MT>(g_part_sum, FINAL_MASK);
if (threadIdx.x == 0) {
s_buffer[0] = tmp0;
s_buffer[1] = tmp1;
}
__syncthreads();
*p_n = Sqrt(s_buffer[0]);
*g_n = Sqrt(rescale_pow * s_buffer[1]);
#endif
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册