未验证 提交 79692af1 编写于 作者: A Arash Bakhtiari 提交者: GitHub

Extend residual_add kernel tests to conver pre_attn_norm (#2354)

Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 b450da4f
......@@ -23,35 +23,58 @@ def inference_module():
return InferenceBuilder().load()
def res_add_bias_ref(hidden_state,
residual,
attn_output,
attn_bias,
final_bias,
mp_size=1,
pre_attn_norm=True):
if pre_attn_norm:
hidden_state += (residual + final_bias + attn_output + attn_bias) / mp_size
else:
hidden_state += residual + final_bias
return hidden_state
def res_add_bias_ref_gptj(hidden_state,
residual,
attn_output,
attn_bias,
final_bias,
add_attn_bias,
mp_size):
hidden_state += attn_output + (residual + final_bias) / mp_size
if add_attn_bias:
hidden_state += attn_bias / mp_size
return hidden_state
def run_residual_add_reference(hidden_state,
residual,
attention_output,
attention_output_bias,
attn_output,
attn_bias,
final_bias,
mlp_after_attn,
add_bias,
mp_size=1):
residual_scaled = residual / mp_size
final_bias_scaled = final_bias / mp_size
attention_output_scaled = attention_output / mp_size
attention_output_bias_scaled = attention_output_bias / mp_size
hidden_state = hidden_state + residual_scaled + final_bias_scaled
# in case that mlp_after_attn = True, we additionally need to scale attention_output as well
add_attn_bias,
mp_size,
pre_attn_norm):
if mlp_after_attn:
hidden_state += attention_output_scaled
return res_add_bias_ref(hidden_state,
residual,
attn_output,
attn_bias,
final_bias,
mp_size,
pre_attn_norm)
else:
hidden_state += attention_output
# TODO: The `add_bias` flag is used only for `launch_gptj_residual_add` kernel (aka, mlp_after_attn is False).
# This is a hack to get the parametarized add_bias to work. We need to fix this after refactoring the kernels.
add_bias = True if mlp_after_attn else add_bias
if add_bias:
hidden_state = hidden_state + attention_output_bias_scaled
return hidden_state
return res_add_bias_ref_gptj(hidden_state,
residual,
attn_output,
attn_bias,
final_bias,
add_attn_bias,
mp_size)
@pytest.mark.inference
......@@ -62,7 +85,7 @@ def run_residual_add_reference(hidden_state,
@pytest.mark.parametrize("mlp_after_attn", [True, False])
@pytest.mark.parametrize("add_bias", [True, False])
@pytest.mark.parametrize("mp_size", [1, 2])
# @pytest.mark.parametrize("preln", [True]) # TODO: add support for preln
@pytest.mark.parametrize("pre_attn_norm", [True, False])
def test_residual_add(inference_module,
batch,
sequence,
......@@ -70,38 +93,35 @@ def test_residual_add(inference_module,
dtype,
mlp_after_attn,
add_bias,
mp_size):
preln = True
mp_size,
pre_attn_norm):
ds_out = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device='cuda')
residual = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device='cuda')
attention_output = torch.randn((batch,
sequence,
hidden_dim),
dtype=dtype,
device='cuda')
attn_output = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device='cuda')
final_bias = torch.randn((hidden_dim), dtype=dtype, device='cuda')
attention_output_bias = torch.randn((hidden_dim), dtype=dtype, device='cuda')
attn_bias = torch.randn((hidden_dim), dtype=dtype, device='cuda')
ref_out = ds_out.clone()
ref_out = run_residual_add_reference(ref_out,
residual,
attention_output,
attention_output_bias,
attn_output,
attn_bias,
final_bias,
mlp_after_attn,
add_bias,
mp_size)
mp_size,
pre_attn_norm)
res_add_args = [
ds_out,
residual,
attention_output,
attention_output_bias,
attn_output,
attn_bias,
final_bias,
mp_size,
mlp_after_attn,
add_bias,
preln
pre_attn_norm
]
if dtype == torch.float16:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册