未验证 提交 017a6164 编写于 作者: J Jeng Bai-Cheng 提交者: GitHub

Bugfix, fast layer norm, OOB (#55639)

* Fix LayerNormForward perf issue

* Bugfix, fast_layer_norm OOB

* apply pre-commit

---------
Co-authored-by: NShijie Wang <jaywan@nvidia.com>
上级 f9e1b2d2
......@@ -217,8 +217,13 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel(
Vec_scale beta[LDGS];
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
phi::Load<ScaleT, VecSize>(gamma_ptr + col * VecSize, &gamma[it]);
phi::Load<ScaleT, VecSize>(beta_ptr + col * VecSize, &beta[it]);
if (col < cols) {
phi::Load<ScaleT, VecSize>(gamma_ptr + col * VecSize, &gamma[it]);
phi::Load<ScaleT, VecSize>(beta_ptr + col * VecSize, &beta[it]);
} else {
gamma[it] = Vec_scale{};
beta[it] = Vec_scale{};
}
col += THREADS_PER_ROW;
}
......@@ -227,7 +232,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel(
Vec x[LDGS];
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
phi::Load<T, VecSize>(x_ptr + row * ELTS_PER_ROW + col * VecSize, &x[it]);
if (col < cols) {
phi::Load<T, VecSize>(x_ptr + row * ELTS_PER_ROW + col * VecSize,
&x[it]);
} else {
x[it] = Vec{};
}
col += THREADS_PER_ROW;
}
U xf[LDGS * VecSize];
......@@ -324,7 +334,10 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel(
#pragma unroll
for (int it = 0, col = c; it < LDGS; it++) {
phi::Store<T, VecSize>(x[it], y_ptr + row * ELTS_PER_ROW + col * VecSize);
if (col < cols) {
phi::Store<T, VecSize>(x[it],
y_ptr + row * ELTS_PER_ROW + col * VecSize);
}
col += THREADS_PER_ROW;
}
}
......
......@@ -578,7 +578,8 @@ void LayerNormKernel(const Context &dev_ctx,
VecSize, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG> \
BYTES_PER_LDG, \
feature_size> \
<<<grid, THREADS_PER_CTA, 0, stream>>>( \
batch_size, \
feature_size, \
......@@ -605,8 +606,7 @@ void LayerNormKernel(const Context &dev_ctx,
if ((feature_size >= 768 && feature_size <= 2048 && feature_size % 256 == 0 ||
feature_size == 4096) &&
scale != nullptr && bias != nullptr) {
// can_call_fast_kernel = true;
can_call_fast_kernel = false;
can_call_fast_kernel = true;
}
if (can_call_fast_kernel) {
......
......@@ -515,7 +515,13 @@ class TestLayerNormOp(unittest.TestCase):
self.use_cudnn = True
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
np.testing.assert_allclose(
np.array(tensor).flatten(),
np_array.flatten(),
rtol=1e-3,
atol=atol,
err_msg=msg,
)
def check_forward_backward(
self,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册