test_sparse_attention_op.py 18.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
19
from paddle.static import Program, program_guard
20
import paddle
21 22 23
import paddle.fluid as fluid
import paddle.fluid.framework as framework
import paddle.nn.functional as F
24 25
import os
import re
26
import copy
27 28 29 30 31 32 33 34 35 36 37 38 39 40


def get_cuda_version():
    result = os.popen("nvcc --version").read()
    regex = r'release (\S+),'
    match = re.search(regex, result)
    if match:
        num = str(match.group(1))
        integer, decimal = num.split('.')
        return int(integer) * 1000 + int(float(decimal) * 10)
    else:
        return -1


41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
def masked_fill(x):
    row, col = x.shape[0], x.shape[1]
    for i in range(row):
        for j in range(col):
            if x[i][j] == 0:
                x[i][j] = float('-inf')
    return x


def init_mask(x):
    row, col = x.shape[0], x.shape[1]
    for i in range(row):
        for j in range(col):
            if x[i][j] == 0 and (j < 0.8 * col):
                x[i][j] = 1
    return x


def softmax(x, kp_mask=None, attn_mask=None, bsz=None):
    if kp_mask is None and attn_mask is None:
        max = np.max(x, axis=1, keepdims=True)
        e_x = np.exp(x - max)
        sum = np.sum(e_x, axis=1, keepdims=True)
        f_x = e_x / sum
        return f_x
    else:
        # kp_mask
        current_kp_mask = kp_mask[bsz]
        row = current_kp_mask.shape[0]
        current_kp_mask = np.expand_dims(current_kp_mask, 0).repeat(row, axis=0)
        # attn_mask
        current_attn_mask = copy.deepcopy(attn_mask)
        current_attn_mask = masked_fill(current_attn_mask)
        current_kp_mask = masked_fill(current_kp_mask)
        x = x + current_kp_mask
        x = x + current_attn_mask
        max = np.max(x, axis=1, keepdims=True)
        e_x = np.exp(x - max)
        sum = np.sum(e_x, axis=1, keepdims=True)
        f_x = e_x / sum
        return f_x
82 83 84 85 86 87 88 89 90 91 92 93 94 95


def get_csr_value(mat, layout, nnz):
    row, col = mat.shape[0], mat.shape[1]
    value = np.zeros(nnz)
    ptr = 0
    for i in range(row):
        for j in range(col):
            if layout[i][j] == 1:
                value[ptr] = mat[i][j]
                ptr += 1
    return value


96 97 98 99 100 101 102 103
def ref_sparse_attention(q,
                         k,
                         v,
                         offset,
                         columns,
                         kp_mask=None,
                         attn_mask=None,
                         bsz=None):
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
    row, col, nnz = q.shape[0], q.shape[1], columns.shape[0]
    mat = np.zeros((row, row))
    for cur_row in range(row):
        start_ptr = int(offset[cur_row])
        end_ptr = int(offset[cur_row + 1])
        for ptr in range(start_ptr, end_ptr):
            cur_col = int(columns[ptr])
            mat[cur_row][cur_col] = 1
    a = np.dot(q, k.T) * mat
    a_value = get_csr_value(a, mat, nnz)
    scaling = float(col)**-0.5
    a = scaling * a
    for i in range(row):
        for j in range(row):
            if mat[i][j] == 0:
                a[i][j] = float('-inf')
120 121 122 123 124
    # softmax
    if kp_mask is None and attn_mask is None:
        b = softmax(a)
    else:
        b = softmax(a, kp_mask=kp_mask, attn_mask=attn_mask, bsz=bsz)
125 126 127 128 129
    b_value = get_csr_value(b, mat, nnz)
    result = np.dot(b, v)
    return result, a_value, b_value


130 131 132 133 134 135 136
def ref_batch_sparse_attention(q,
                               k,
                               v,
                               offset,
                               columns,
                               kp_mask=None,
                               attn_mask=None):
137 138 139 140 141 142 143 144 145
    batch_size, num_heads, row, col = q.shape
    nnz = columns.shape[2]
    result = np.zeros((batch_size, num_heads, row, col))
    result_sdd = np.zeros((batch_size, num_heads, nnz))
    result_softmax = np.zeros((batch_size, num_heads, nnz))
    for i in range(batch_size):
        for j in range(num_heads):
            cur_q, cur_k, cur_v, = q[i][j], k[i][j], v[i][j]
            cur_offset, cur_columns = offset[i][j], columns[i][j]
146 147 148 149 150 151 152 153 154 155 156 157 158
            if kp_mask is None and attn_mask is None:
                cur_result, cur_sdd, cur_softmax = ref_sparse_attention(
                    cur_q, cur_k, cur_v, cur_offset, cur_columns)
            else:
                cur_result, cur_sdd, cur_softmax = ref_sparse_attention(
                    cur_q,
                    cur_k,
                    cur_v,
                    cur_offset,
                    cur_columns,
                    kp_mask=kp_mask,
                    attn_mask=attn_mask,
                    bsz=i)
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
            result[i][j] = cur_result
            result_sdd[i][j], result_softmax[i][j] = cur_sdd, cur_softmax
    return result, result_sdd, result_softmax


def init_csr_format(batch_size, num_heads, rows, blocksize):
    block_num, block_last = rows / blocksize, rows % blocksize
    nnz_num = block_num * blocksize * blocksize + block_last * block_last
    offset = np.zeros(rows + 1)
    columns = np.zeros(int(nnz_num))
    mat = np.zeros((rows, rows))
    for i in range(0, rows, blocksize):
        for x in range(blocksize):
            for y in range(blocksize):
                p_x, p_y = i + x, i + y
                if (p_x < rows) and (p_y < rows):
                    mat[p_x][p_y] = 1
    p_offset, p_column, count = 0, 0, 0
    for i in range(rows):
        for j in range(rows):
            if mat[i][j] != 0:
                count += 1
                columns[p_column] = j
                p_column += 1
        p_offset += 1
        offset[p_offset] = count
    offset = np.expand_dims(np.expand_dims(offset, 0), 0)
    offset = offset.repeat(num_heads, axis=1)
    offset = offset.repeat(batch_size, axis=0)
    columns = np.expand_dims(np.expand_dims(columns, 0), 0)
    columns = columns.repeat(num_heads, axis=1)
    columns = columns.repeat(batch_size, axis=0)
    return offset, columns


@unittest.skipIf(
195 196
    not core.is_compiled_with_cuda() or get_cuda_version() < 11030,
    "core is not compiled with CUDA and cuda version need larger than or equal to 11.3"
197
)
198
class TestSparseAttentionOp(OpTest):
199

200
    def config(self):
201 202
        self.shape = (1, 1, 16, 16)
        self.blocksize = 4
203
        self.dtype = "float64"
204
        self.use_mask = True
205 206 207 208 209 210 211 212 213

    def setUp(self):
        paddle.enable_static()
        self.config()
        self.op_type = "sparse_attention"
        self.place = paddle.CUDAPlace(0)
        self.q = np.random.random(self.shape).astype(self.dtype)
        self.k = np.random.random(self.shape).astype(self.dtype)
        self.v = np.random.random(self.shape).astype(self.dtype)
214
        # init CSR tensor
215 216 217 218
        offset, columns = init_csr_format(self.shape[0], self.shape[1],
                                          self.shape[2], self.blocksize)
        self.offset = offset.astype('int32')
        self.columns = columns.astype('int32')
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
        # init mask tensor
        key_padding_mask_shape = (self.shape[0], self.shape[2])
        attn_mask_shape = (self.shape[2], self.shape[2])
        key_padding_mask = np.random.randint(0, 2, size=key_padding_mask_shape)
        attn_mask = np.random.randint(0, 2, size=attn_mask_shape)
        key_padding_mask = init_mask(key_padding_mask)
        attn_mask = init_mask(attn_mask)

        self.key_padding_mask = key_padding_mask.astype(self.dtype)
        self.attn_mask = attn_mask.astype(self.dtype)
        if self.use_mask == True:
            result, result_sdd, result_softmax = ref_batch_sparse_attention(
                self.q,
                self.k,
                self.v,
                self.offset,
                self.columns,
                kp_mask=self.key_padding_mask,
                attn_mask=self.attn_mask)
        else:
            result, result_sdd, result_softmax = ref_batch_sparse_attention(
                self.q, self.k, self.v, self.offset, self.columns)

        if self.use_mask == True:
            self.inputs = {
                'Q': self.q,
                'K': self.k,
                'V': self.v,
                'Offset': self.offset,
                'Columns': self.columns,
                'KeyPaddingMask': self.key_padding_mask,
                'AttnMask': self.attn_mask,
            }
        else:
            self.inputs = {
                'Q': self.q,
                'K': self.k,
                'V': self.v,
                'Offset': self.offset,
                'Columns': self.columns,
            }
260 261
        self.outputs = {
            'Out': result.astype(self.dtype),
262 263
            'SparseDotSdd': result_sdd.astype(self.dtype),
            'Softmax': result_softmax.astype(self.dtype)
264 265 266 267 268 269 270 271 272 273 274 275
        }

    def test_check_output(self):
        self.check_output_with_place(self.place)

    def test_check_grad(self):
        self.check_grad_with_place(self.place, ['Q'], 'Out')
        self.check_grad_with_place(self.place, ['K'], 'Out')
        self.check_grad_with_place(self.place, ['V'], 'Out')


class TestSparseAttentionOpFp32Test(TestSparseAttentionOp):
276

277 278 279 280
    def config(self):
        self.shape = (1, 1, 8, 16)
        self.blocksize = 2
        self.dtype = "float32"
281
        self.use_mask = False
282 283 284


class TestSparseAttentionOpShapeTest(TestSparseAttentionOp):
285

286 287 288 289
    def config(self):
        self.shape = (2, 2, 32, 8)
        self.blocksize = 8
        self.dtype = "float64"
290
        self.use_mask = False
291 292


293
@unittest.skipIf(
294 295
    not core.is_compiled_with_cuda() or get_cuda_version() < 11030,
    "core is not compiled with CUDA and cuda version need larger than or equal to 11.3"
296 297
)
class TestSparseAttentionAPI(unittest.TestCase):
298

299 300 301 302 303
    def setUp(self):
        self.place = paddle.CUDAPlace(0)
        self.shape = (1, 1, 8, 4)
        self.blocksize = 2
        self.dtype = 'float64'
304
        self.use_mask = True
305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320

    def test_static_graph(self):
        paddle.enable_static()
        with paddle.static.program_guard(paddle.static.Program()):
            Q = paddle.static.data(name="Q", shape=self.shape, dtype=self.dtype)
            K = paddle.static.data(name="K", shape=self.shape, dtype=self.dtype)
            V = paddle.static.data(name="V", shape=self.shape, dtype=self.dtype)

            batch_size, num_heads, rows = self.shape[0], self.shape[
                1], self.shape[2]
            block_num = rows / self.blocksize
            block_last = rows % self.blocksize
            sparse_nnz_num = block_num * self.blocksize * self.blocksize + block_last * block_last
            offset_shape = (batch_size, num_heads, rows + 1)
            columns_shape = (batch_size, num_heads, int(sparse_nnz_num))

321 322 323 324 325 326
            offset = paddle.static.data(name="Offset",
                                        shape=offset_shape,
                                        dtype="int32")
            columns = paddle.static.data(name="Columns",
                                         shape=columns_shape,
                                         dtype="int32")
327 328 329 330 331 332 333
            key_padding_mask_shape = (self.shape[0], self.shape[2])
            attn_mask_shape = (self.shape[2], self.shape[2])
            if self.use_mask == True:
                key_padding_mask = paddle.static.data(
                    name="KeyPaddingMask",
                    shape=key_padding_mask_shape,
                    dtype=self.dtype)
334 335 336 337 338 339 340 341 342 343
                attn_mask = paddle.static.data(name="AttnMask",
                                               shape=attn_mask_shape,
                                               dtype=self.dtype)
                Out = F.sparse_attention(Q,
                                         K,
                                         V,
                                         offset,
                                         columns,
                                         key_padding_mask=key_padding_mask,
                                         attn_mask=attn_mask)
344 345
            else:
                Out = F.sparse_attention(Q, K, V, offset, columns)
346 347 348 349

            Q_np = np.random.random(self.shape).astype(self.dtype)
            K_np = np.random.random(self.shape).astype(self.dtype)
            V_np = np.random.random(self.shape).astype(self.dtype)
350 351 352 353
            offset_np, columns_np = init_csr_format(self.shape[0],
                                                    self.shape[1],
                                                    self.shape[2],
                                                    self.blocksize)
354 355 356
            offset_np = offset_np.astype('int32')
            columns_np = columns_np.astype('int32')

357
            # init mask tensor
358 359 360
            key_padding_mask_np = np.random.randint(0,
                                                    2,
                                                    size=key_padding_mask_shape)
361 362 363 364 365 366
            attn_mask_np = np.random.randint(0, 2, size=attn_mask_shape)
            key_padding_mask_np = init_mask(key_padding_mask_np)
            attn_mask_np = init_mask(attn_mask_np)
            key_padding_mask_np = key_padding_mask_np.astype(self.dtype)
            attn_mask_np = attn_mask_np.astype(self.dtype)

367
            exe = fluid.Executor(self.place)
368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397
            if self.use_mask == True:
                fetches_result = exe.run(feed={
                    "Q": Q_np,
                    "K": K_np,
                    "V": V_np,
                    "Offset": offset_np,
                    "Columns": columns_np,
                    'KeyPaddingMask': key_padding_mask_np,
                    'AttnMask': attn_mask_np
                },
                                         fetch_list=[Out])
                expected_result, __, __ = ref_batch_sparse_attention(
                    Q_np,
                    K_np,
                    V_np,
                    offset_np,
                    columns_np,
                    kp_mask=key_padding_mask_np,
                    attn_mask=attn_mask_np)
            else:
                fetches_result = exe.run(feed={
                    "Q": Q_np,
                    "K": K_np,
                    "V": V_np,
                    "Offset": offset_np,
                    "Columns": columns_np
                },
                                         fetch_list=[Out])
                expected_result, __, __ = ref_batch_sparse_attention(
                    Q_np, K_np, V_np, offset_np, columns_np)
398

399 400 401 402
            np.testing.assert_allclose(fetches_result,
                                       expected_result,
                                       rtol=1e-05,
                                       atol=1e-05)
403 404 405 406 407 408 409 410 411 412

    def test_dygraph(self):
        paddle.disable_static()
        offset, columns = init_csr_format(self.shape[0], self.shape[1],
                                          self.shape[2], self.blocksize)
        offset = offset.astype('int32')
        columns = columns.astype('int32')
        query = np.random.random(self.shape).astype(self.dtype)
        key = np.random.random(self.shape).astype(self.dtype)
        value = np.random.random(self.shape).astype(self.dtype)
413 414 415 416 417 418 419 420 421
        # init mask tensor
        key_padding_mask_shape = (self.shape[0], self.shape[2])
        attn_mask_shape = (self.shape[2], self.shape[2])
        key_padding_mask = np.random.randint(0, 2, size=key_padding_mask_shape)
        attn_mask = np.random.randint(0, 2, size=attn_mask_shape)
        key_padding_mask = init_mask(key_padding_mask)
        attn_mask = init_mask(attn_mask)
        key_padding_mask = key_padding_mask.astype(self.dtype)
        attn_mask = attn_mask.astype(self.dtype)
422 423 424 425 426 427

        paddle_query = paddle.to_tensor(query, place=self.place)
        paddle_key = paddle.to_tensor(key, place=self.place)
        paddle_value = paddle.to_tensor(value, place=self.place)
        paddle_offset = paddle.to_tensor(offset, place=self.place)
        paddle_colunmns = paddle.to_tensor(columns, place=self.place)
428 429 430 431
        paddle_kp_mask = paddle.to_tensor(key_padding_mask, place=self.place)
        paddle_attn_mask = paddle.to_tensor(attn_mask, place=self.place)

        if self.use_mask == True:
432 433 434 435 436 437 438
            paddle_result = F.sparse_attention(paddle_query,
                                               paddle_key,
                                               paddle_value,
                                               paddle_offset,
                                               paddle_colunmns,
                                               key_padding_mask=paddle_kp_mask,
                                               attn_mask=paddle_attn_mask)
439 440 441 442 443 444 445 446 447 448 449 450 451 452 453

            numpy_result, __, __ = ref_batch_sparse_attention(
                query,
                key,
                value,
                offset,
                columns,
                kp_mask=key_padding_mask,
                attn_mask=attn_mask)
            numpy_result = numpy_result.astype(self.dtype)
        else:
            paddle_result = F.sparse_attention(paddle_query, paddle_key,
                                               paddle_value, paddle_offset,
                                               paddle_colunmns)

454 455
            numpy_result, __, __ = ref_batch_sparse_attention(
                query, key, value, offset, columns)
456
            numpy_result = numpy_result.astype(self.dtype)
457

458 459 460 461
        np.testing.assert_allclose(paddle_result.numpy(),
                                   numpy_result,
                                   rtol=1e-05,
                                   atol=1e-05)
462 463 464


class TestSparseAttentionAPITestFloat(TestSparseAttentionAPI):
465

466 467 468 469 470
    def setUp(self):
        self.place = paddle.CUDAPlace(0)
        self.shape = (2, 2, 8, 4)
        self.blocksize = 2
        self.dtype = 'float32'
471
        self.use_mask = False
472 473 474


class TestSparseAttentionAPITestShape1(TestSparseAttentionAPI):
475

476 477 478 479 480
    def setUp(self):
        self.place = paddle.CUDAPlace(0)
        self.shape = (2, 2, 64, 32)
        self.blocksize = 2
        self.dtype = 'float64'
481
        self.use_mask = False
482 483 484


class TestSparseAttentionAPITestShape2(TestSparseAttentionAPI):
485

486 487 488 489 490
    def setUp(self):
        self.place = paddle.CUDAPlace(0)
        self.shape = (2, 1, 64, 32)
        self.blocksize = 2
        self.dtype = 'float64'
491
        self.use_mask = False
492 493 494


class TestSparseAttentionAPITestShape3(TestSparseAttentionAPI):
495

496 497 498 499 500
    def setUp(self):
        self.place = paddle.CUDAPlace(0)
        self.shape = (4, 4, 128, 32)
        self.blocksize = 8
        self.dtype = 'float64'
501
        self.use_mask = False
502 503 504


class TestSparseAttentionAPITestShape4(TestSparseAttentionAPI):
505

506 507 508 509 510
    def setUp(self):
        self.place = paddle.CUDAPlace(0)
        self.shape = (3, 3, 35, 15)
        self.blocksize = 3
        self.dtype = 'float64'
511
        self.use_mask = False
512 513


514 515
if __name__ == '__main__':
    unittest.main()