From 79692af1ea287494849f754e79024a4417325aa6 Mon Sep 17 00:00:00 2001 From: Arash Bakhtiari Date: Tue, 27 Sep 2022 09:34:51 -0700 Subject: [PATCH] Extend residual_add kernel tests to conver pre_attn_norm (#2354) Co-authored-by: Jeff Rasley --- .../inference/test_residual_add.py | 96 +++++++++++-------- 1 file changed, 58 insertions(+), 38 deletions(-) diff --git a/tests/unit/ops/transformer/inference/test_residual_add.py b/tests/unit/ops/transformer/inference/test_residual_add.py index 2cc5a8eb..336008f5 100644 --- a/tests/unit/ops/transformer/inference/test_residual_add.py +++ b/tests/unit/ops/transformer/inference/test_residual_add.py @@ -23,35 +23,58 @@ def inference_module(): return InferenceBuilder().load() +def res_add_bias_ref(hidden_state, + residual, + attn_output, + attn_bias, + final_bias, + mp_size=1, + pre_attn_norm=True): + if pre_attn_norm: + hidden_state += (residual + final_bias + attn_output + attn_bias) / mp_size + else: + hidden_state += residual + final_bias + return hidden_state + + +def res_add_bias_ref_gptj(hidden_state, + residual, + attn_output, + attn_bias, + final_bias, + add_attn_bias, + mp_size): + hidden_state += attn_output + (residual + final_bias) / mp_size + if add_attn_bias: + hidden_state += attn_bias / mp_size + return hidden_state + + def run_residual_add_reference(hidden_state, residual, - attention_output, - attention_output_bias, + attn_output, + attn_bias, final_bias, mlp_after_attn, - add_bias, - mp_size=1): - residual_scaled = residual / mp_size - final_bias_scaled = final_bias / mp_size - attention_output_scaled = attention_output / mp_size - attention_output_bias_scaled = attention_output_bias / mp_size - - hidden_state = hidden_state + residual_scaled + final_bias_scaled - - # in case that mlp_after_attn = True, we additionally need to scale attention_output as well + add_attn_bias, + mp_size, + pre_attn_norm): if mlp_after_attn: - hidden_state += attention_output_scaled + return res_add_bias_ref(hidden_state, + residual, + attn_output, + attn_bias, + final_bias, + mp_size, + pre_attn_norm) else: - hidden_state += attention_output - - # TODO: The `add_bias` flag is used only for `launch_gptj_residual_add` kernel (aka, mlp_after_attn is False). - # This is a hack to get the parametarized add_bias to work. We need to fix this after refactoring the kernels. - add_bias = True if mlp_after_attn else add_bias - - if add_bias: - hidden_state = hidden_state + attention_output_bias_scaled - - return hidden_state + return res_add_bias_ref_gptj(hidden_state, + residual, + attn_output, + attn_bias, + final_bias, + add_attn_bias, + mp_size) @pytest.mark.inference @@ -62,7 +85,7 @@ def run_residual_add_reference(hidden_state, @pytest.mark.parametrize("mlp_after_attn", [True, False]) @pytest.mark.parametrize("add_bias", [True, False]) @pytest.mark.parametrize("mp_size", [1, 2]) -# @pytest.mark.parametrize("preln", [True]) # TODO: add support for preln +@pytest.mark.parametrize("pre_attn_norm", [True, False]) def test_residual_add(inference_module, batch, sequence, @@ -70,38 +93,35 @@ def test_residual_add(inference_module, dtype, mlp_after_attn, add_bias, - mp_size): - preln = True + mp_size, + pre_attn_norm): ds_out = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device='cuda') residual = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device='cuda') - attention_output = torch.randn((batch, - sequence, - hidden_dim), - dtype=dtype, - device='cuda') + attn_output = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device='cuda') final_bias = torch.randn((hidden_dim), dtype=dtype, device='cuda') - attention_output_bias = torch.randn((hidden_dim), dtype=dtype, device='cuda') + attn_bias = torch.randn((hidden_dim), dtype=dtype, device='cuda') ref_out = ds_out.clone() ref_out = run_residual_add_reference(ref_out, residual, - attention_output, - attention_output_bias, + attn_output, + attn_bias, final_bias, mlp_after_attn, add_bias, - mp_size) + mp_size, + pre_attn_norm) res_add_args = [ ds_out, residual, - attention_output, - attention_output_bias, + attn_output, + attn_bias, final_bias, mp_size, mlp_after_attn, add_bias, - preln + pre_attn_norm ] if dtype == torch.float16: -- GitLab