未验证 提交 017a6164 编写于 作者: J Jeng Bai-Cheng 提交者: GitHub

Bugfix, fast layer norm, OOB (#55639)

* Fix LayerNormForward perf issue

* Bugfix, fast_layer_norm OOB

* apply pre-commit

---------
Co-authored-by: NShijie Wang <jaywan@nvidia.com>
上级 f9e1b2d2
...@@ -217,8 +217,13 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel( ...@@ -217,8 +217,13 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel(
Vec_scale beta[LDGS]; Vec_scale beta[LDGS];
#pragma unroll #pragma unroll
for (int it = 0, col = c; it < LDGS; it++) { for (int it = 0, col = c; it < LDGS; it++) {
if (col < cols) {
phi::Load<ScaleT, VecSize>(gamma_ptr + col * VecSize, &gamma[it]); phi::Load<ScaleT, VecSize>(gamma_ptr + col * VecSize, &gamma[it]);
phi::Load<ScaleT, VecSize>(beta_ptr + col * VecSize, &beta[it]); phi::Load<ScaleT, VecSize>(beta_ptr + col * VecSize, &beta[it]);
} else {
gamma[it] = Vec_scale{};
beta[it] = Vec_scale{};
}
col += THREADS_PER_ROW; col += THREADS_PER_ROW;
} }
...@@ -227,7 +232,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel( ...@@ -227,7 +232,12 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel(
Vec x[LDGS]; Vec x[LDGS];
#pragma unroll #pragma unroll
for (int it = 0, col = c; it < LDGS; it++) { for (int it = 0, col = c; it < LDGS; it++) {
phi::Load<T, VecSize>(x_ptr + row * ELTS_PER_ROW + col * VecSize, &x[it]); if (col < cols) {
phi::Load<T, VecSize>(x_ptr + row * ELTS_PER_ROW + col * VecSize,
&x[it]);
} else {
x[it] = Vec{};
}
col += THREADS_PER_ROW; col += THREADS_PER_ROW;
} }
U xf[LDGS * VecSize]; U xf[LDGS * VecSize];
...@@ -324,7 +334,10 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel( ...@@ -324,7 +334,10 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel(
#pragma unroll #pragma unroll
for (int it = 0, col = c; it < LDGS; it++) { for (int it = 0, col = c; it < LDGS; it++) {
phi::Store<T, VecSize>(x[it], y_ptr + row * ELTS_PER_ROW + col * VecSize); if (col < cols) {
phi::Store<T, VecSize>(x[it],
y_ptr + row * ELTS_PER_ROW + col * VecSize);
}
col += THREADS_PER_ROW; col += THREADS_PER_ROW;
} }
} }
......
...@@ -578,7 +578,8 @@ void LayerNormKernel(const Context &dev_ctx, ...@@ -578,7 +578,8 @@ void LayerNormKernel(const Context &dev_ctx,
VecSize, \ VecSize, \
WARPS_M, \ WARPS_M, \
WARPS_N, \ WARPS_N, \
BYTES_PER_LDG> \ BYTES_PER_LDG, \
feature_size> \
<<<grid, THREADS_PER_CTA, 0, stream>>>( \ <<<grid, THREADS_PER_CTA, 0, stream>>>( \
batch_size, \ batch_size, \
feature_size, \ feature_size, \
...@@ -605,8 +606,7 @@ void LayerNormKernel(const Context &dev_ctx, ...@@ -605,8 +606,7 @@ void LayerNormKernel(const Context &dev_ctx,
if ((feature_size >= 768 && feature_size <= 2048 && feature_size % 256 == 0 || if ((feature_size >= 768 && feature_size <= 2048 && feature_size % 256 == 0 ||
feature_size == 4096) && feature_size == 4096) &&
scale != nullptr && bias != nullptr) { scale != nullptr && bias != nullptr) {
// can_call_fast_kernel = true; can_call_fast_kernel = true;
can_call_fast_kernel = false;
} }
if (can_call_fast_kernel) { if (can_call_fast_kernel) {
......
...@@ -515,7 +515,13 @@ class TestLayerNormOp(unittest.TestCase): ...@@ -515,7 +515,13 @@ class TestLayerNormOp(unittest.TestCase):
self.use_cudnn = True self.use_cudnn = True
def __assert_close(self, tensor, np_array, msg, atol=1e-4): def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) np.testing.assert_allclose(
np.array(tensor).flatten(),
np_array.flatten(),
rtol=1e-3,
atol=atol,
err_msg=msg,
)
def check_forward_backward( def check_forward_backward(
self, self,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册