未验证 提交 5acd764d 编写于 作者: Z Zhang Zheng 提交者: GitHub

Fix the implementation of fused_fast_ln_fwd_kernel in test mode (#42527)

上级 2c5cecb1
...@@ -298,10 +298,16 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel( ...@@ -298,10 +298,16 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_fast_ln_fwd_kernel(
for (int it = 0, col = c; it < LDGS; it++) { for (int it = 0, col = c; it < LDGS; it++) {
phi::Store<T, VecSize>( phi::Store<T, VecSize>(
x[it], residual_out_ptr + row * ELTS_PER_ROW + col * VecSize); x[it], residual_out_ptr + row * ELTS_PER_ROW + col * VecSize);
phi::Store<MaskType, VecSize>(
mask_vec[it], mask_out_ptr + row * ELTS_PER_ROW + col * VecSize);
col += THREADS_PER_ROW; col += THREADS_PER_ROW;
} }
if (!is_test) {
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
phi::Store<MaskType, VecSize>(
mask_vec[it], mask_out_ptr + row * ELTS_PER_ROW + col * VecSize);
col += THREADS_PER_ROW;
}
}
U mu_local = 0.f; U mu_local = 0.f;
#pragma unroll #pragma unroll
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册