未验证 提交 79692af1 编写于 作者: A Arash Bakhtiari 提交者: GitHub

Extend residual_add kernel tests to conver pre_attn_norm (#2354)

Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 b450da4f
...@@ -23,37 +23,60 @@ def inference_module(): ...@@ -23,37 +23,60 @@ def inference_module():
return InferenceBuilder().load() return InferenceBuilder().load()
def run_residual_add_reference(hidden_state, def res_add_bias_ref(hidden_state,
residual, residual,
attention_output, attn_output,
attention_output_bias, attn_bias,
final_bias, final_bias,
mlp_after_attn, mp_size=1,
add_bias, pre_attn_norm=True):
mp_size=1): if pre_attn_norm:
residual_scaled = residual / mp_size hidden_state += (residual + final_bias + attn_output + attn_bias) / mp_size
final_bias_scaled = final_bias / mp_size
attention_output_scaled = attention_output / mp_size
attention_output_bias_scaled = attention_output_bias / mp_size
hidden_state = hidden_state + residual_scaled + final_bias_scaled
# in case that mlp_after_attn = True, we additionally need to scale attention_output as well
if mlp_after_attn:
hidden_state += attention_output_scaled
else: else:
hidden_state += attention_output hidden_state += residual + final_bias
return hidden_state
# TODO: The `add_bias` flag is used only for `launch_gptj_residual_add` kernel (aka, mlp_after_attn is False).
# This is a hack to get the parametarized add_bias to work. We need to fix this after refactoring the kernels.
add_bias = True if mlp_after_attn else add_bias
if add_bias:
hidden_state = hidden_state + attention_output_bias_scaled
def res_add_bias_ref_gptj(hidden_state,
residual,
attn_output,
attn_bias,
final_bias,
add_attn_bias,
mp_size):
hidden_state += attn_output + (residual + final_bias) / mp_size
if add_attn_bias:
hidden_state += attn_bias / mp_size
return hidden_state return hidden_state
def run_residual_add_reference(hidden_state,
residual,
attn_output,
attn_bias,
final_bias,
mlp_after_attn,
add_attn_bias,
mp_size,
pre_attn_norm):
if mlp_after_attn:
return res_add_bias_ref(hidden_state,
residual,
attn_output,
attn_bias,
final_bias,
mp_size,
pre_attn_norm)
else:
return res_add_bias_ref_gptj(hidden_state,
residual,
attn_output,
attn_bias,
final_bias,
add_attn_bias,
mp_size)
@pytest.mark.inference @pytest.mark.inference
@pytest.mark.parametrize("batch", [1, 2]) @pytest.mark.parametrize("batch", [1, 2])
@pytest.mark.parametrize("sequence", [1, 128, 255]) @pytest.mark.parametrize("sequence", [1, 128, 255])
...@@ -62,7 +85,7 @@ def run_residual_add_reference(hidden_state, ...@@ -62,7 +85,7 @@ def run_residual_add_reference(hidden_state,
@pytest.mark.parametrize("mlp_after_attn", [True, False]) @pytest.mark.parametrize("mlp_after_attn", [True, False])
@pytest.mark.parametrize("add_bias", [True, False]) @pytest.mark.parametrize("add_bias", [True, False])
@pytest.mark.parametrize("mp_size", [1, 2]) @pytest.mark.parametrize("mp_size", [1, 2])
# @pytest.mark.parametrize("preln", [True]) # TODO: add support for preln @pytest.mark.parametrize("pre_attn_norm", [True, False])
def test_residual_add(inference_module, def test_residual_add(inference_module,
batch, batch,
sequence, sequence,
...@@ -70,38 +93,35 @@ def test_residual_add(inference_module, ...@@ -70,38 +93,35 @@ def test_residual_add(inference_module,
dtype, dtype,
mlp_after_attn, mlp_after_attn,
add_bias, add_bias,
mp_size): mp_size,
preln = True pre_attn_norm):
ds_out = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device='cuda') ds_out = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device='cuda')
residual = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device='cuda') residual = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device='cuda')
attention_output = torch.randn((batch, attn_output = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device='cuda')
sequence,
hidden_dim),
dtype=dtype,
device='cuda')
final_bias = torch.randn((hidden_dim), dtype=dtype, device='cuda') final_bias = torch.randn((hidden_dim), dtype=dtype, device='cuda')
attention_output_bias = torch.randn((hidden_dim), dtype=dtype, device='cuda') attn_bias = torch.randn((hidden_dim), dtype=dtype, device='cuda')
ref_out = ds_out.clone() ref_out = ds_out.clone()
ref_out = run_residual_add_reference(ref_out, ref_out = run_residual_add_reference(ref_out,
residual, residual,
attention_output, attn_output,
attention_output_bias, attn_bias,
final_bias, final_bias,
mlp_after_attn, mlp_after_attn,
add_bias, add_bias,
mp_size) mp_size,
pre_attn_norm)
res_add_args = [ res_add_args = [
ds_out, ds_out,
residual, residual,
attention_output, attn_output,
attention_output_bias, attn_bias,
final_bias, final_bias,
mp_size, mp_size,
mlp_after_attn, mlp_after_attn,
add_bias, add_bias,
preln pre_attn_norm
] ]
if dtype == torch.float16: if dtype == torch.float16:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册