test_sparse_attention.py 9.5 KB
Newer Older
1 2
'''Copyright The Microsoft DeepSpeed Team'''

3 4 5 6 7 8 9
# DeepSpeed note, some parts of code taken & adapted from commit c368a9fd1b2c9dee4cc94de9a6bb0be3d447be41
# https://github.com/ptillet/torch-blocksparse/blob/master/tests/test_softmax.py
# https://github.com/ptillet/torch-blocksparse/blob/master/tests/test_matmul.py
# https://github.com/ptillet/torch-blocksparse/blob/master/tests/utils

import pytest
import torch
10
import deepspeed
11
from deepspeed.accelerator import get_accelerator
12
from deepspeed.ops.op_builder import SparseAttnBuilder
L
Logan Adams 已提交
13
from unit.util import skip_on_arch, skip_on_cuda
14

15 16 17
if not deepspeed.ops.__compatible_ops__[SparseAttnBuilder.NAME]:
    pytest.skip("sparse attention op is not compatible on this system",
                allow_module_level=True)
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48


def dense_to_sparse(w, mask, block):
    """Converts dense matrix with explicit zeros to sparse matrix
    """
    Z = w.size(0)
    ret = torch.empty((Z, mask.sum(), block, block), dtype=w.dtype, device=w.device)
    nnz = mask.nonzero()
    h, i, j = nnz[:, 0], nnz[:, 1], nnz[:, 2]
    for zz in range(Z):
        for idx, (hh, ii, jj) in enumerate(zip(h, i, j)):
            ret[zz, idx, :, :] = w[zz, hh, ii*block: (ii+1)*block, jj*block: (jj+1)*block]
    return ret


def sparse_to_dense(w, mask, block, zero=0):
    """Converts sparse matrix to dense matrix with explicit zeros
    """
    maskedw = w.clone()
    for bz, wz in enumerate(range(0, w.size(0))):
        for bh, wh in enumerate(range(0, w.size(1))):
            for bi, wi in enumerate(range(0, w.size(2), block)):
                for bj, wj in enumerate(range(0, w.size(3), block)):
                    if mask[bh, bi, bj] == 0:
                        maskedw[wz, wh, wi:wi + block, wj:wj + block] = zero
                    #maskedw[wz, wh, wi : wi+block, wj : wj+block] *= mask[bh, bi, bj]
    return maskedw


def allclose(x, y):
    assert x.dtype == y.dtype
49
    rtol, atol = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-3)}[x.dtype]
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
    return torch.allclose(x, y, rtol=rtol, atol=atol)


def make_layout(rho, shape):
    probs = torch.Tensor([rho, 1 - rho])
    generator = torch.distributions.categorical.Categorical(probs)
    layout = generator.sample(shape)
    return layout


def run_softmax_reference(x, scale, dx, kp_mask, attn_mask, layout, block):
    x = sparse_to_dense(x, layout, block, zero=float('-inf'))
    x.retain_grad()
    if kp_mask is not None:
        bcattn_mask = attn_mask[None, None, :, :] + torch.zeros_like(x)
        x[bcattn_mask == 0] = float('-inf')
        y = torch.softmax(x * scale + kp_mask[:, None, None, :], -1)
    else:
        y = torch.softmax(x * scale, -1)
    y.backward(dx)
    dx = x.grad.clone()
    dx = dense_to_sparse(dx, layout, block)
    y = dense_to_sparse(y, layout, block)
    return y, dx


def run_softmax_sparse(x, scale, dx, kp_mask, attn_mask, layout, block):
77
    from deepspeed.ops.sparse_attention.softmax import Softmax
78
    sparse_softmax = Softmax(layout, block, bench=False)
79

80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
    dx = dense_to_sparse(dx, layout, block)
    x = dense_to_sparse(x, layout, block)
    x.retain_grad()
    y = sparse_softmax(x,
                       scale=scale,
                       key_padding_mask=kp_mask,
                       key_padding_mask_mode='add',
                       attn_mask=attn_mask,
                       attn_mask_mode='mul')
    y.backward(dx)
    dx = x.grad.clone()
    x.grad.zero_()
    return x, dx


def init_softmax_inputs(Z, H, M, N, scale, rho, block, dtype, dense_x=True, layout=None):
    if layout is None:
        layout = make_layout(rho, (H, M // block, N // block))
    if dense_x:
99 100 101 102 103 104 105
        x = torch.rand((Z,
                        H,
                        M,
                        N),
                       dtype=dtype,
                       requires_grad=True,
                       device=get_accelerator().device_name())
106 107 108 109 110 111 112
    else:
        x = torch.rand((Z,
                        layout.sum(),
                        block,
                        block),
                       dtype=dtype,
                       requires_grad=True,
113
                       device=get_accelerator().device_name())
114 115 116 117 118 119 120
    dx = torch.rand_like(x)
    bool_attn_mask = torch.randint(low=0,
                                   high=2,
                                   size=(N,
                                         N),
                                   dtype=torch.bool,
                                   requires_grad=False,
121
                                   device=get_accelerator().device_name())
122 123 124 125 126 127 128
    fp_attn_mask = bool_attn_mask.type(dtype)
    kp_mask = torch.randint(low=0,
                            high=2,
                            size=(Z,
                                  N),
                            dtype=dtype,
                            requires_grad=False,
129
                            device=get_accelerator().device_name())
130 131 132 133 134 135 136 137
    kp_mask[kp_mask == 1.] = float('-inf')
    return layout, x, dx, bool_attn_mask, fp_attn_mask, kp_mask


@pytest.mark.parametrize("block", [16, 32])
@pytest.mark.parametrize("width", [256, 576])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
def test_softmax(block, width, dtype):
L
Logan Adams 已提交
138 139 140 141
    valid_cuda_versions = [101, 102, 110, 111]
    skip_on_arch(min_arch=7)
    skip_on_cuda(valid_cuda=valid_cuda_versions)

142 143 144 145 146 147 148 149
    Z = 2
    H = 4
    scale = 0.4
    rho = 0.4
    M = N = width
    layout, x, dx, bool_attn_mask, fp_attn_mask, kp_mask = init_softmax_inputs(Z, H, M, N, scale, rho, block, dtype, layout=None)
    ref_y, ref_dx = run_softmax_reference(x, scale, dx, kp_mask, bool_attn_mask, layout, block)
    st_y, st_dx = run_softmax_sparse(x, scale, dx, kp_mask, fp_attn_mask, layout, block)
150

151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
    assert allclose(ref_y, st_y)
    assert allclose(ref_dx, st_dx)


def run_matmul_reference(x, w, mode, trans_a, trans_b, layout, block, dy):
    x = sparse_to_dense(x, layout, block) if mode == 'dsd' else x
    w = sparse_to_dense(w, layout, block) if mode == 'dds' else w
    x.retain_grad()
    w.retain_grad()
    xx = x.transpose(2, 3) if trans_a else x
    ww = w.transpose(2, 3) if trans_b else w
    y = torch.matmul(xx, ww)
    y = sparse_to_dense(y, layout, block) if mode == 'sdd' else y
    y.backward(dy)
    dx = x.grad.clone()
    dw = w.grad.clone()
    x.grad.zero_()
    w.grad.zero_()
    y = dense_to_sparse(y, layout, block) if mode == 'sdd' else y
    dx = dense_to_sparse(dx, layout, block) if mode == 'dsd' else dx
    dw = dense_to_sparse(dw, layout, block) if mode == 'dds' else dw
    return y, dx, dw


def run_matmul_sparse(x, w, mode, trans_a, trans_b, layout, block, dy):
176
    from deepspeed.ops.sparse_attention.matmul import MatMul
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
    x = dense_to_sparse(x, layout, block) if mode == 'dsd' else x
    w = dense_to_sparse(w, layout, block) if mode == 'dds' else w
    dy = dense_to_sparse(dy, layout, block) if mode == 'sdd' else dy
    op = MatMul(layout, block, mode, trans_a=trans_a, trans_b=trans_b)
    x.retain_grad()
    w.retain_grad()
    y = op(x, w)
    y.backward(dy)
    dx = x.grad.clone()
    dw = w.grad.clone()
    x.grad.zero_()
    return y, dx, dw


def init_matmul_inputs(Z, H, M, N, K, rho, mode, trans_a, trans_b, block, dtype, layout):
    torch.manual_seed(1)
    AS0 = K if trans_a else M
    AS1 = M if trans_a else K
    BS0 = N if trans_b else K
    BS1 = K if trans_b else N
    shape = {'sdd': (M, N), 'dsd': (AS0, AS1), 'dds': (BS0, BS1)}[mode]
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
    x = torch.rand((Z,
                    H,
                    AS0,
                    AS1),
                   dtype=dtype,
                   requires_grad=True,
                   device=get_accelerator().device_name())
    w = torch.rand((Z,
                    H,
                    BS0,
                    BS1),
                   dtype=dtype,
                   requires_grad=True,
                   device=get_accelerator().device_name())
    dy = torch.rand((Z, H, M, N), dtype=dtype, device=get_accelerator().device_name())
213 214 215 216 217 218 219 220 221 222
    if layout is None:
        layout = make_layout(rho, (H, shape[0] // block, shape[1] // block))
    else:
        assert list(layout.shape) == [H, shape[0] // block, shape[1] // block]
    x.retain_grad()
    w.retain_grad()
    return x, w, dy, shape, layout

testdata = [
      (16, dtype, mode, trans_a, trans_b)\
223 224 225
         for dtype in [torch.float16]\
         for mode in ['sdd', 'dds']\
         for trans_a   in [False]\
226
         for trans_b   in [False, True]\
227 228 229 230 231 232 233 234 235 236 237 238
   ] + [
      (16, dtype, mode, trans_a, trans_b)\
         for dtype in [torch.float16]\
         for mode in ['dsd']\
         for trans_a   in [False, True]\
         for trans_b   in [False]\
   ] + [
      (16, dtype, mode, trans_a, trans_b)\
         for dtype in [torch.float32]\
         for mode in ['sdd', 'dsd', 'dds']\
         for trans_a   in [False]\
         for trans_b   in [False]\
239 240 241 242 243 244 245 246 247
   ] + [
      (block, torch.float16, mode, False, False)\
         for block in [16, 32, 64]\
         for mode in ['sdd', 'dsd', 'dds']\
   ]


@pytest.mark.parametrize("block, dtype, mode, trans_a, trans_b", testdata)
def test_matmul(block, dtype, mode, trans_a, trans_b):
L
Logan Adams 已提交
248 249 250 251
    valid_cuda_versions = [101, 102, 110, 111]
    skip_on_arch(min_arch=7)
    skip_on_cuda(valid_cuda=valid_cuda_versions)

252 253 254 255 256 257 258 259 260
    Z = 3
    H = 2
    M = 128
    N = 256
    K = 192
    rho = 0.5
    x, w, dy, shape, layout = init_matmul_inputs(Z, H, M, N, K, rho, mode, trans_a, trans_b, block, dtype, layout=None)
    ref_y, ref_dx, ref_dw = run_matmul_reference(x.clone(), w.clone(), mode, trans_a, trans_b, layout, block, dy)
    st_y, st_dx, st_dw = run_matmul_sparse(x.clone(), w.clone(), mode, trans_a, trans_b, layout, block, dy)
L
Logan Adams 已提交
261

262 263 264
    assert allclose(ref_y, st_y)
    assert allclose(ref_dx, st_dx)
    assert allclose(ref_dw, st_dw)