test_residual_add.py 4.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
"""
Copyright 2022 The Microsoft DeepSpeed Team
"""

import pytest
import torch
import deepspeed
from deepspeed.ops.op_builder import InferenceBuilder

if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
    pytest.skip("Inference ops are not available on this system",
                allow_module_level=True)


def allclose(x, y):
    assert x.dtype == y.dtype
    rtol, atol = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-2)}[x.dtype]
    return torch.allclose(x, y, rtol=rtol, atol=atol)


21
@pytest.fixture(scope="module")
22 23 24 25
def inference_module():
    return InferenceBuilder().load()


26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
def res_add_bias_ref(hidden_state,
                     residual,
                     attn_output,
                     attn_bias,
                     final_bias,
                     mp_size=1,
                     pre_attn_norm=True):
    if pre_attn_norm:
        hidden_state += (residual + final_bias + attn_output + attn_bias) / mp_size
    else:
        hidden_state += residual + final_bias
    return hidden_state


def res_add_bias_ref_gptj(hidden_state,
                          residual,
                          attn_output,
                          attn_bias,
                          final_bias,
                          add_attn_bias,
                          mp_size):
    hidden_state += attn_output + (residual + final_bias) / mp_size
    if add_attn_bias:
        hidden_state += attn_bias / mp_size
    return hidden_state


53 54
def run_residual_add_reference(hidden_state,
                               residual,
55 56
                               attn_output,
                               attn_bias,
57
                               final_bias,
58
                               mlp_after_attn,
59 60 61
                               add_attn_bias,
                               mp_size,
                               pre_attn_norm):
62
    if mlp_after_attn:
63 64 65 66 67 68 69
        return res_add_bias_ref(hidden_state,
                                residual,
                                attn_output,
                                attn_bias,
                                final_bias,
                                mp_size,
                                pre_attn_norm)
70
    else:
71 72 73 74 75 76 77
        return res_add_bias_ref_gptj(hidden_state,
                                     residual,
                                     attn_output,
                                     attn_bias,
                                     final_bias,
                                     add_attn_bias,
                                     mp_size)
78 79 80 81 82 83 84 85 86 87


@pytest.mark.inference
@pytest.mark.parametrize("batch", [1, 2])
@pytest.mark.parametrize("sequence", [1, 128, 255])
@pytest.mark.parametrize("hidden_dim", [512, 1232, 4096])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
@pytest.mark.parametrize("mlp_after_attn", [True, False])
@pytest.mark.parametrize("add_bias", [True, False])
@pytest.mark.parametrize("mp_size", [1, 2])
88
@pytest.mark.parametrize("pre_attn_norm", [True, False])
89 90 91 92 93 94 95
def test_residual_add(inference_module,
                      batch,
                      sequence,
                      hidden_dim,
                      dtype,
                      mlp_after_attn,
                      add_bias,
96 97
                      mp_size,
                      pre_attn_norm):
98 99
    ds_out = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device='cuda')
    residual = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device='cuda')
100
    attn_output = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device='cuda')
101
    final_bias = torch.randn((hidden_dim), dtype=dtype, device='cuda')
102
    attn_bias = torch.randn((hidden_dim), dtype=dtype, device='cuda')
103 104 105 106

    ref_out = ds_out.clone()
    ref_out = run_residual_add_reference(ref_out,
                                         residual,
107 108
                                         attn_output,
                                         attn_bias,
109
                                         final_bias,
110 111
                                         mlp_after_attn,
                                         add_bias,
112 113
                                         mp_size,
                                         pre_attn_norm)
114

115 116 117
    res_add_args = [
        ds_out,
        residual,
118 119
        attn_output,
        attn_bias,
120 121 122 123
        final_bias,
        mp_size,
        mlp_after_attn,
        add_bias,
124
        pre_attn_norm
125 126 127 128 129 130 131 132
    ]

    if dtype == torch.float16:
        ds_out = inference_module.residual_add_bias_fp16(*res_add_args)
    elif dtype == torch.float32:
        ds_out = inference_module.residual_add_bias_fp32(*res_add_args)
    else:
        raise ValueError(f"Unsupported dtype: {dtype}")
133 134

    assert (allclose(ds_out, ref_out))