未验证 提交 d9f8636c 编写于 作者: Z Zhang Zheng 提交者: GitHub

Supoort more dimensions in forward fast layer_norm kernel (#43226)

上级 264de612
......@@ -481,10 +481,12 @@ void LaunchLayernormResidualDropoutBias(
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1536); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(1792); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(2048); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(3072); \
LAUNCH_FUSED_FAST_LN_KERNEL_BASE(4096)
bool can_call_fast_ln_kernel = false;
if (((cols >= 768 && cols <= 2048 && cols % 256 == 0) || cols == 4096) &&
if (((cols >= 768 && cols <= 2048 && cols % 256 == 0) || cols == 3072 ||
cols == 4096) &&
scale != nullptr && layernorm_bias != nullptr) {
can_call_fast_ln_kernel = true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册