diff --git a/python/paddle/fluid/tests/unittests/test_attn_bias.py b/python/paddle/fluid/tests/unittests/test_attn_bias.py new file mode 100644 index 0000000000000000000000000000000000000000..29ed79900431d5fb4816649e0bc65a1caf4df81d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_attn_bias.py @@ -0,0 +1,412 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.incubate.nn.attn_bias import ( + BlockDiagonalCausalMask, + BlockDiagonalCausalWithOffsetPaddedKeysMask, + BlockDiagonalMask, + LowerTriangularMask, + LowerTriangularMaskWithTensorBias, + PaddedSeqLenInfo, + SeqLenInfo, +) + + +def all_dtypes(): + dtypes = [paddle.float32, paddle.float64] + if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm(): + dtypes.append(paddle.float16) + prop = paddle.device.cuda.get_device_properties() + if prop.major >= 8: + dtypes.append(paddle.bfloat16) + return dtypes + + +class TestLowerTriangularMask(unittest.TestCase): + @paddle.no_grad() + def check_materialize(self, shape, dtype, has_bias=False): + assert len(shape) >= 2 + if has_bias: + bias = paddle.rand(shape=shape, dtype=dtype) + mask = LowerTriangularMaskWithTensorBias(bias) + else: + mask = LowerTriangularMask() + + mask = mask.materialize(shape=shape, dtype=dtype) + self.assertEqual(mask.dtype, dtype) + self.assertEqual(mask.shape, shape) + dst_shape = [-1, mask.shape[-2], mask.shape[-1]] + + mask = mask.reshape(dst_shape).astype(paddle.float64).numpy() + if has_bias: + bias = bias.reshape(dst_shape).astype(paddle.float64).numpy() + + for i in range(mask.shape[0]): + for j in range(mask.shape[1]): + for k in range(mask.shape[2]): + value = mask[i][j][k] + if j >= k: + if has_bias: + self.assertEqual(value, bias[i][j][k]) + else: + self.assertEqual(value, 0) + else: + self.assertEqual(value, float('-inf')) + + def test_materialize(self): + shape = [5, 6, 7] + for dtype in all_dtypes(): + for has_bias in [False, True]: + self.check_materialize(shape, dtype, has_bias) + + +def check_split_tensor_without_batch_sizes(seqinfo, extra_shape): + seqlens = [] + for i in range(len(seqinfo.seqstart_py) - 1): + seqlens.append(seqinfo.seqstart_py[i + 1] - seqinfo.seqstart_py[i]) + shape = [1, seqinfo.seqstart_py[-1]] + list(extra_shape) + + x = paddle.rand(shape) + tensors = seqinfo.split(x) + for i, t in enumerate(tensors): + assert t.shape[0] == 1 + assert t.shape[1] == seqlens[i] + assert t.shape[2:] == x.shape[2:] + concated_x = paddle.concat(tensors, axis=1) + np.testing.assert_equal(x.numpy(), concated_x.numpy()) + return x, tensors + + +def check_split_tensor_with_batch_sizes(seqinfo, extra_shape, batch_sizes): + seqlens = [] + for i in range(len(seqinfo.seqstart_py) - 1): + seqlens.append(seqinfo.seqstart_py[i + 1] - seqinfo.seqstart_py[i]) + + cumsum_bs = 0 + uniq_seqlens = [] + for bs in batch_sizes: + start = cumsum_bs + end = cumsum_bs + bs + for s in seqlens[start:end]: + assert s == seqlens[start] + cumsum_bs += bs + uniq_seqlens.append(seqlens[start]) + + x = paddle.rand(shape=[1, sum(seqlens)] + extra_shape) + tensors = seqinfo.split(x, batch_sizes) + assert len(tensors) == len(batch_sizes) + for i, t in enumerate(tensors): + shape = t.shape + assert len(shape) == 2 + len(extra_shape) + assert shape[0] == batch_sizes[i] + assert shape[1] == uniq_seqlens[i] + + concated_tensor = paddle.concat( + [t.reshape([-1, *t.shape[2:]]) for t in tensors] + ).unsqueeze(0) + np.testing.assert_equal(x.numpy(), concated_tensor.numpy()) + return x, tensors + + +def check_split_tensor(seqinfo, extra_shape, batch_sizes): + if batch_sizes is None: + return check_split_tensor_without_batch_sizes(seqinfo, extra_shape) + else: + return check_split_tensor_with_batch_sizes( + seqinfo, extra_shape, batch_sizes + ) + + +def check_same_tensor_list(tensors1, tensors2): + assert len(tensors1) == len(tensors2) + for t1, t2 in zip(tensors1, tensors2): + assert t1.shape == t2.shape + assert t1.dtype == t2.dtype + np.testing.assert_equal(t1.numpy(), t2.numpy()) + + +class TestSeqLenInfo(unittest.TestCase): + def test_seq_len_info(self): + n = 100 + seqlens = np.random.randint(2, 100, size=[n]).tolist() + cumsum_seqlens = [0] + np.cumsum(seqlens).tolist() + info = SeqLenInfo.from_seqlens(seqlens) + self.assertEqual(max(seqlens), info.max_seqlen) + np.testing.assert_equal(cumsum_seqlens, info.seqstart.numpy()) + np.testing.assert_equal(cumsum_seqlens, info.seqstart_py) + intervals = list(info.intervals()) + self.assertEqual(n, len(intervals)) + for i in range(n): + self.assertEqual(cumsum_seqlens[i], intervals[i][0]) + self.assertEqual(cumsum_seqlens[i + 1], intervals[i][1]) + + check_split_tensor_without_batch_sizes(info, [8, 9]) + + def test_split_with_batch_sizes(self): + n_tensor = 10 + extra_shape = [3, 4] + batch_sizes = np.random.randint(10, 200, size=[n_tensor]).tolist() + seqlens = [] + uniq_seqlens = [] + for bs in batch_sizes: + tmp_seqlen = np.random.randint(10, 200, size=[1])[0] + uniq_seqlens.append(tmp_seqlen) + seqlens.extend([tmp_seqlen] * bs) + info = SeqLenInfo.from_seqlens(seqlens) + + check_split_tensor_with_batch_sizes(info, extra_shape, batch_sizes) + + +class TestPaddedSeqLenInfo(unittest.TestCase): + def test_padded_seq_len_info(self): + n = 100 + padding = 200 + seqlens = np.random.randint(2, padding, size=[n]).tolist() + info = PaddedSeqLenInfo.from_seqlens_padded(seqlens, padding) + self.assertEqual(max(seqlens), info.max_seqlen) + np.testing.assert_equal(info.seqstart.numpy(), info.seqstart_py) + self.assertEqual(len(info.seqstart_py), n + 1) + self.assertEqual(info.seqstart_py[0], 0) + self.assertTrue(np.all(np.diff(info.seqstart_py) == padding)) + intervals = list(info.intervals()) + self.assertEqual(len(intervals), n) + for i in range(n): + interval = intervals[i] + self.assertEqual(interval[0], padding * i) + self.assertEqual(interval[1] - interval[0], seqlens[i]) + + +class TestBlockDiagonalMask(unittest.TestCase): + def setUp(self): + self.mask_class = BlockDiagonalMask + self.q_n = 10 + self.qkv_same_length = True + self.config() + + def config(self): + pass + + def test_from_seq_lens(self): + q_seqlen = np.random.randint(2, 100, self.q_n).tolist() + if self.qkv_same_length: + kv_seqlen = q_seqlen + else: + kv_seqlen = np.random.randint(2, 100, int(self.q_n)).tolist() + + mask = self.mask_class.from_seqlens(q_seqlen, kv_seqlen) + self.check_main(mask, q_seqlen, kv_seqlen, [3, 4]) + + def test_from_tensor_list(self): + shapes = [[2, 3], [7, 9], [11, 5]] + extra_shape = [13, 19] + tensors = [] + seqlens = [] + for s in shapes: + tmp_s = s + extra_shape + tensors.append(paddle.rand(tmp_s)) + seqlens.extend([tmp_s[1]] * tmp_s[0]) + mask, concated_tensor = self.mask_class.from_tensor_list(tensors) + self.check_main(mask, seqlens, seqlens, extra_shape) + + def test_from_tensor_lists_qk(self): + self.check_from_tensor_lists_qkv() + + def test_from_tensor_lists_qkv(self): + self.check_from_tensor_lists_qkv(has_value=True) + + def check_from_tensor_lists_qkv(self, has_value=False): + batch_sizes = [2, 3, 4] + q_uniq_seqlens = [5, 6, 7] + k_uniq_seqlens = [8, 9, 10] + extra_shape = [13, 19] + + tensors_q = [] + tensors_k = [] + tensors_v = [] if has_value else None + q_seqlens = [] + kv_seqlens = [] + for i, bs in enumerate(batch_sizes): + q_shape = [bs, q_uniq_seqlens[i]] + extra_shape + kv_shape = [bs, k_uniq_seqlens[i]] + extra_shape + tensors_q.append(paddle.rand(q_shape)) + tensors_k.append(paddle.rand(kv_shape)) + q_seqlens.extend([q_shape[1]] * q_shape[0]) + kv_seqlens.extend([kv_shape[1]] * kv_shape[0]) + if has_value: + tensors_v.append(paddle.rand(kv_shape)) + + mask, q, k, v = self.mask_class.from_tensor_lists_qkv( + tensors_q, tensors_k, tensors_v + ) + self.check_main( + mask, + q_seqlens, + kv_seqlens, + extra_shape, + check_same_shape_split=False, + ) + + def check_main( + self, + mask, + q_seqlen, + kv_seqlen, + extra_shape, + check_same_shape_split=True, + ): + total_q_tokens = sum(q_seqlen) + total_kv_tokens = sum(kv_seqlen) + shape = extra_shape + [total_q_tokens, total_kv_tokens] + mask_value = mask.materialize(shape=shape) + self.assertEqual(mask_value.shape, shape) + + mask_value = mask_value.numpy() + mask_value = mask_value.reshape([-1, *mask_value.shape[-2:]]) + for i in range(1, mask_value.shape[0]): + np.testing.assert_equal(mask_value[i], mask_value[0]) + mask_value = mask_value[0] + self.check_mask( + mask_value, + list(mask.q_seqinfo.intervals()), + list(mask.k_seqinfo.intervals()), + ) + + x, tensors = check_split_tensor( + mask.q_seqinfo, extra_shape, mask._batch_sizes + ) + check_same_tensor_list(mask.split_queries(x), tensors) + + x, tensors = check_split_tensor( + mask.k_seqinfo, extra_shape, mask._batch_sizes + ) + check_same_tensor_list(mask.split_kv(x), tensors) + + if self.qkv_same_length and check_same_shape_split: + x, tensors = check_split_tensor( + mask.q_seqinfo, extra_shape, mask._batch_sizes + ) + check_same_tensor_list(mask.split(x), tensors) + + if self.mask_class == BlockDiagonalMask: + self.assertEqual(type(mask.make_causal()), BlockDiagonalCausalMask) + + def check_mask(self, mask, q_intervals, k_intervals): + self.assertEqual(len(mask.shape), 2) + m, n = mask.shape + self.assertEqual(len(q_intervals), len(k_intervals)) + for (q_start, q_end), (k_start, k_end) in zip(q_intervals, k_intervals): + if k_start > 0: + self.assertTrue( + np.all(mask[q_start:q_end, 0:k_start] == float('-inf')) + ) + if k_end < n: + self.assertTrue( + np.all(mask[q_start:q_end, k_end:] == float('-inf')) + ) + block_mask = mask[q_start:q_end, k_start:k_end] + self.check_block_mask(block_mask) + + def check_block_mask(self, block_mask): + self.assertTrue(np.all(block_mask == 0)) + + +class TestBlockDiagonalMaskQKVDiffLength(TestBlockDiagonalMask): + def config(self): + self.qkv_same_length = False + + +class TestBlockDiagonalCausalMask(TestBlockDiagonalMask): + def config(self): + self.mask_class = BlockDiagonalCausalMask + + def check_block_mask(self, block_mask): + self.assertEqual(len(block_mask.shape), 2) + m, n = block_mask.shape + for i in range(m): + for j in range(n): + if i >= j: + self.assertEqual(block_mask[i][j], 0) + else: + self.assertEqual(block_mask[i][j], float('-inf')) + + +class TestBlockDiagonalCausalMaskQKVDiffLength(TestBlockDiagonalCausalMask): + def config(self): + self.mask_class = BlockDiagonalCausalMask + self.qkv_same_length = False + + +class TestBlockDiagonalCausalWithOffsetPaddedKeysMask(unittest.TestCase): + def test_main(self): + kv_padding = 20 + n = 10 + extra_shape = [3, 4] + + q_seqlen = np.random.randint(0, kv_padding, size=[n]).tolist() + kv_seqlen = np.random.randint(0, kv_padding, size=[n]).tolist() + + q_ntokens = sum(q_seqlen) + kv_ntokens = n * kv_padding + max_causal_diagonal = min(q_ntokens, kv_ntokens) - 2 + causal_diagonal_np = np.random.randint( + 0, max_causal_diagonal, size=[n] + ).astype(np.int32) + causal_diagonal = paddle.to_tensor(causal_diagonal_np) + + mask = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( + q_seqlen, kv_padding, kv_seqlen, causal_diagonal + ) + + shape = extra_shape + [q_ntokens, kv_ntokens] + mask_np = mask.materialize(shape).numpy() + self.assertEqual(list(mask_np.shape[: len(extra_shape)]), extra_shape) + mask_np = mask_np.reshape([-1, *mask_np.shape[2:]]) + for i in range(1, mask_np.shape[0]): + np.testing.assert_equal(mask_np[i], mask_np[0]) + + mask_np = mask_np[0] + q_intervals = list(mask.q_seqinfo.intervals()) + k_intervals = list(mask.k_seqinfo.intervals()) + self.assertEqual(len(q_intervals), len(k_intervals)) + for i, ((q_start, q_end), (k_start, k_end)) in enumerate( + zip(q_intervals, k_intervals) + ): + if k_start != 0: + np.testing.assert_equal( + mask_np[q_start:q_end, 0:k_start], float('-inf') + ) + + np.testing.assert_equal( + mask_np[q_start:q_end, k_start:k_end], + self.create_numpy_block_mask( + (q_end - q_start, k_end - k_start), causal_diagonal_np[i] + ), + ) + if k_end != kv_ntokens: + np.testing.assert_equal( + mask_np[q_start:q_end, k_end:kv_ntokens], float('-inf') + ) + + def create_numpy_block_mask(self, shape, offset, dtype=np.float32): + t = np.full(shape, dtype=dtype, fill_value=float('-inf')) + return np.triu(t, 1 + offset) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/incubate/nn/attn_bias.py b/python/paddle/incubate/nn/attn_bias.py new file mode 100644 index 0000000000000000000000000000000000000000..fbbb016df4e2d127e1077b317e121f18f8cd4038 --- /dev/null +++ b/python/paddle/incubate/nn/attn_bias.py @@ -0,0 +1,265 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# The following codes are from https://github.com/facebookresearch/xformers + +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List, Optional, Sequence + +import paddle + + +class AttentionBias(ABC): + @abstractmethod + def materialize(self, shape, dtype=paddle.float32): + raise NotImplementedError() + + +class LowerTriangularMask(AttentionBias): + def materialize(self, shape, dtype=paddle.float32): + create_as = dtype if dtype is not paddle.bfloat16 else paddle.float32 + tensor = paddle.full( + shape=shape, fill_value=float("-inf"), dtype=create_as + ) + return paddle.triu(tensor, diagonal=1).astype(dtype) + + def add_bias(self, bias): + return LowerTriangularMaskWithTensorBias(bias) + + +class LowerTriangularMaskWithTensorBias(LowerTriangularMask): + def __init__(self, bias): + self._bias = bias + + def materialize(self, shape, dtype=paddle.float32): + return super().materialize(shape, dtype) + self._bias + + +@dataclass +class SeqLenInfo: + seqstart: paddle.Tensor + max_seqlen: int + seqstart_py: List[int] + + def intervals(self): + yield from zip(self.seqstart_py, self.seqstart_py[1:]) + + @classmethod + def from_seqlens(cls, seqlens): + seqstart_py = [0] + max_seqlen = -1 + for seqlen in seqlens: + max_seqlen = max(max_seqlen, seqlen) + seqstart_py.append(seqstart_py[-1] + seqlen) + seqstart = paddle.to_tensor(seqstart_py, dtype=paddle.int32) + return cls( + max_seqlen=max_seqlen, seqstart=seqstart, seqstart_py=seqstart_py + ) + + def split(self, x, batch_sizes=None): + assert self.seqstart_py[-1] == x.shape[1] and x.shape[0] == 1 + if batch_sizes is None: + batch_sizes = [1] * (len(self.seqstart_py) - 1) + split_chunks = [] + it = 0 + for batch_size in batch_sizes: + split_chunks.append( + self.seqstart_py[it + batch_size] - self.seqstart_py[it] + ) + it += batch_size + return [ + tensor.reshape([bs, -1, *tensor.shape[2:]]) + for bs, tensor in zip(batch_sizes, x.split(split_chunks, axis=1)) + ] + + +@dataclass +class PaddedSeqLenInfo(SeqLenInfo): + seqlen: paddle.Tensor + seqlen_py: Sequence[int] + + def intervals(self): + for (start, _), length in zip(super().intervals(), self.seqlen_py): + yield start, start + length + + @classmethod + def from_seqlens(cls, seqlens): + raise NotImplementedError( + "Please use SeqLenInfo.from_seq_lens() or PaddedSeqLenInfo.from_seq_lens_padded()." + ) + + @classmethod + def from_seqlens_padded(cls, seqlens, padding): + assert all([seqlen <= padding for seqlen in seqlens]) + seqstart_py = list(range(0, len(seqlens) * padding + 1, padding)) + return cls( + seqlen=paddle.to_tensor(seqlens, dtype=paddle.int32), + seqlen_py=seqlens, + max_seqlen=max(seqlens), + seqstart=paddle.to_tensor(seqstart_py, dtype=paddle.int32), + seqstart_py=seqstart_py, + ) + + def split(self, x, batch_sizes=None): + raise NotImplementedError() + + +@dataclass +class BlockDiagonalMask(AttentionBias): + q_seqinfo: SeqLenInfo + k_seqinfo: SeqLenInfo + _batch_sizes: Optional[Sequence[int]] = None + + def _create_block_mask(self, shape, dtype=paddle.float32): + return paddle.zeros(shape=shape, dtype=dtype) + + def materialize(self, shape, dtype=paddle.float32): + assert shape[-1] == self.k_seqinfo.seqstart_py[-1] + assert shape[-2] == self.q_seqinfo.seqstart_py[-1] + mask = paddle.full(shape[-2:], fill_value=float('-inf'), dtype=dtype) + for (q_start, q_end), (k_start, k_end) in zip( + self.q_seqinfo.intervals(), self.k_seqinfo.intervals() + ): + sub_shape = [q_end - q_start, k_end - k_start] + mask[q_start:q_end, k_start:k_end] = self._create_block_mask( + sub_shape, dtype + ) + for _ in range(len(shape) - 2): + mask = mask.unsqueeze(0) + return mask.expand(shape) + + @classmethod + def from_seqlens(cls, q_seqlen, kv_seqlen=None): + assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen) + q_seqinfo = SeqLenInfo.from_seqlens(q_seqlen) + if kv_seqlen is None or q_seqlen == kv_seqlen: + k_seqinfo = q_seqinfo + else: + k_seqinfo = SeqLenInfo.from_seqlens(kv_seqlen) + return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo) + + @classmethod + def from_tensor_list(cls, tensors): + batch_sizes = [tensor.shape[0] for tensor in tensors] + seqlens = [] + for x in tensors: + for _ in range(x.shape[0]): + seqlens.append(x.shape[1]) + block_diag = cls.from_seqlens(seqlens) + block_diag._batch_sizes = batch_sizes + concated_tensor = paddle.concat( + [x.reshape([1, -1, *x.shape[2:]]) for x in tensors], axis=1 + ) + return block_diag, concated_tensor + + @classmethod + def from_tensor_lists_qkv(cls, tensors_q, tensors_k, tensors_v=None): + assert len(tensors_q) == len(tensors_k) + assert tensors_v is None or len(tensors_v) == len(tensors_q) + batch_sizes = [tensor.shape[0] for tensor in tensors_q] + q_seqlens, kv_seqlens = [], [] + for i, (q, k) in enumerate(zip(tensors_q, tensors_k)): + assert q.shape[0] == k.shape[0] + q_seqlens.extend([q.shape[1]] * q.shape[0]) + kv_seqlens.extend([k.shape[1]] * k.shape[0]) + assert tensors_v is None or tensors_v[i].shape[:2] == k.shape[:2] + block_diag = cls.from_seqlens(q_seqlens, kv_seqlens) + block_diag._batch_sizes = [x.shape[0] for x in tensors_q] + return ( + block_diag, + paddle.concat( + [x.reshape([1, -1, *x.shape[2:]]) for x in tensors_q], axis=1 + ), + paddle.concat( + [x.reshape([1, -1, *x.shape[2:]]) for x in tensors_k], axis=1 + ), + paddle.concat( + [x.reshape([1, -1, *x.shape[2:]]) for x in tensors_v], axis=1 + ) + if tensors_v is not None + else None, + ) + + def split_queries(self, tensor): + return self.q_seqinfo.split(tensor, self._batch_sizes) + + def split_kv(self, tensor): + return self.k_seqinfo.split(tensor, self._batch_sizes) + + def split(self, tensor): + assert self.q_seqinfo is self.k_seqinfo + return self.q_seqinfo.split(tensor, self._batch_sizes) + + def make_causal(self): + return BlockDiagonalCausalMask( + q_seqinfo=self.q_seqinfo, + k_seqinfo=self.k_seqinfo, + _batch_sizes=self._batch_sizes, + ) + + +@dataclass +class BlockDiagonalCausalMask(BlockDiagonalMask): + def _create_block_mask(self, shape, dtype=paddle.float32): + return LowerTriangularMask().materialize(shape=shape, dtype=dtype) + + +@dataclass +class BlockDiagonalCausalWithOffsetPaddedKeysMask(AttentionBias): + q_seqinfo: SeqLenInfo + k_seqinfo: PaddedSeqLenInfo + causal_diagonal: Optional[paddle.Tensor] = None + + def _create_block_mask(self, shape, offset=0, dtype=paddle.float32): + create_as = dtype if dtype is not paddle.bfloat16 else paddle.float32 + tensor = paddle.full(shape, dtype=create_as, fill_value=float('-inf')) + return paddle.triu(tensor, diagonal=1 + offset).astype(dtype) + + def materialize(self, shape, dtype=paddle.float32): + assert shape[-1] == self.k_seqinfo.seqstart_py[-1] + assert shape[-2] == self.q_seqinfo.seqstart_py[-1] + mask = paddle.full(shape[-2:], dtype=dtype, fill_value=float('-inf')) + for i, ((q_start, q_end), (k_start, k_end)) in enumerate( + zip(self.q_seqinfo.intervals(), self.k_seqinfo.intervals()) + ): + mask[q_start:q_end, k_start:k_end] = self._create_block_mask( + (q_end - q_start, k_end - k_start), + offset=0 + if self.causal_diagonal is None + else int(self.causal_diagonal[i].item()), + dtype=dtype, + ) + for _ in range(len(shape) - 2): + mask = mask.unsqueeze(0) + return mask.expand(shape) + + @classmethod + def from_seqlens( + cls, q_seqlen, kv_padding, kv_seqlen, causal_diagonal=None + ): + assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen) + q_seqinfo = SeqLenInfo.from_seqlens(q_seqlen) + k_seqinfo = PaddedSeqLenInfo.from_seqlens_padded(kv_seqlen, kv_padding) + return cls( + q_seqinfo=q_seqinfo, + k_seqinfo=k_seqinfo, + causal_diagonal=causal_diagonal, + ) diff --git a/python/paddle/incubate/nn/memory_efficient_attention.py b/python/paddle/incubate/nn/memory_efficient_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..2591d70fb3e8e4a3360e82940b0a29b90dfd4af2 --- /dev/null +++ b/python/paddle/incubate/nn/memory_efficient_attention.py @@ -0,0 +1,93 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The following codes are from https://github.com/facebookresearch/xformers + +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import paddle + +from .attn_bias import ( + BlockDiagonalCausalMask, + BlockDiagonalCausalWithOffsetPaddedKeysMask, + BlockDiagonalMask, + LowerTriangularMask, + LowerTriangularMaskWithTensorBias, +) + +SUPPORTED_ATTN_BIAS_TYPES = { + type(None), + paddle.Tensor, + LowerTriangularMask, + LowerTriangularMaskWithTensorBias, + BlockDiagonalMask, + BlockDiagonalCausalMask, + BlockDiagonalCausalWithOffsetPaddedKeysMask, +} + + +def _get_seqlen_info(attn_bias): + if isinstance( + attn_bias, + (BlockDiagonalMask, BlockDiagonalCausalWithOffsetPaddedKeysMask), + ): + return ( + attn_bias.k_seqinfo.seqstart, + attn_bias.q_seqinfo.seqstart, + attn_bias.q_seqinfo.max_seqlen, + attn_bias.k_seqinfo.max_seqlen, + ) + else: + return None, None, -1, -1 + + +def _get_tensor_bias(attn_bias): + if isinstance(attn_bias, paddle.Tensor): + return attn_bias + elif isinstance(attn_bias, LowerTriangularMaskWithTensorBias): + return attn_bias._bias + else: + return None + + +def memory_efficient_attention( + query, key, value, attn_bias, p=0.0, scale=None, training=True +): + assert type(attn_bias) in SUPPORTED_ATTN_BIAS_TYPES + causal = isinstance( + attn_bias, + ( + LowerTriangularMask, + BlockDiagonalCausalMask, + BlockDiagonalCausalWithOffsetPaddedKeysMask, + ), + ) + seqstart_k, seqstart_q, max_seqlen_q, _ = _get_seqlen_info(attn_bias) + # NOTE: compute_logsumexp = training + is_test = not training + causal_diagonal = ( + attn_bias.causal_diagonal + if isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask) + else None + ) + seqlen_k = ( + attn_bias.k_seqinfo.seqlen + if isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask) + else None + ) + attn_bias = _get_tensor_bias(attn_bias) + # TODO(zhangdanyang): add C++ codes here