未验证 提交 54331f1a 编写于 作者: S sneaxiy 提交者: GitHub

Add attn_bias.py of xformers (#51387)

* add attn_bias.py

* add Python interface

* add license

* add test_attn_bias.py

* fix CPU test error

* fix ci error
上级 24258c27
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.incubate.nn.attn_bias import (
BlockDiagonalCausalMask,
BlockDiagonalCausalWithOffsetPaddedKeysMask,
BlockDiagonalMask,
LowerTriangularMask,
LowerTriangularMaskWithTensorBias,
PaddedSeqLenInfo,
SeqLenInfo,
)
def all_dtypes():
dtypes = [paddle.float32, paddle.float64]
if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm():
dtypes.append(paddle.float16)
prop = paddle.device.cuda.get_device_properties()
if prop.major >= 8:
dtypes.append(paddle.bfloat16)
return dtypes
class TestLowerTriangularMask(unittest.TestCase):
@paddle.no_grad()
def check_materialize(self, shape, dtype, has_bias=False):
assert len(shape) >= 2
if has_bias:
bias = paddle.rand(shape=shape, dtype=dtype)
mask = LowerTriangularMaskWithTensorBias(bias)
else:
mask = LowerTriangularMask()
mask = mask.materialize(shape=shape, dtype=dtype)
self.assertEqual(mask.dtype, dtype)
self.assertEqual(mask.shape, shape)
dst_shape = [-1, mask.shape[-2], mask.shape[-1]]
mask = mask.reshape(dst_shape).astype(paddle.float64).numpy()
if has_bias:
bias = bias.reshape(dst_shape).astype(paddle.float64).numpy()
for i in range(mask.shape[0]):
for j in range(mask.shape[1]):
for k in range(mask.shape[2]):
value = mask[i][j][k]
if j >= k:
if has_bias:
self.assertEqual(value, bias[i][j][k])
else:
self.assertEqual(value, 0)
else:
self.assertEqual(value, float('-inf'))
def test_materialize(self):
shape = [5, 6, 7]
for dtype in all_dtypes():
for has_bias in [False, True]:
self.check_materialize(shape, dtype, has_bias)
def check_split_tensor_without_batch_sizes(seqinfo, extra_shape):
seqlens = []
for i in range(len(seqinfo.seqstart_py) - 1):
seqlens.append(seqinfo.seqstart_py[i + 1] - seqinfo.seqstart_py[i])
shape = [1, seqinfo.seqstart_py[-1]] + list(extra_shape)
x = paddle.rand(shape)
tensors = seqinfo.split(x)
for i, t in enumerate(tensors):
assert t.shape[0] == 1
assert t.shape[1] == seqlens[i]
assert t.shape[2:] == x.shape[2:]
concated_x = paddle.concat(tensors, axis=1)
np.testing.assert_equal(x.numpy(), concated_x.numpy())
return x, tensors
def check_split_tensor_with_batch_sizes(seqinfo, extra_shape, batch_sizes):
seqlens = []
for i in range(len(seqinfo.seqstart_py) - 1):
seqlens.append(seqinfo.seqstart_py[i + 1] - seqinfo.seqstart_py[i])
cumsum_bs = 0
uniq_seqlens = []
for bs in batch_sizes:
start = cumsum_bs
end = cumsum_bs + bs
for s in seqlens[start:end]:
assert s == seqlens[start]
cumsum_bs += bs
uniq_seqlens.append(seqlens[start])
x = paddle.rand(shape=[1, sum(seqlens)] + extra_shape)
tensors = seqinfo.split(x, batch_sizes)
assert len(tensors) == len(batch_sizes)
for i, t in enumerate(tensors):
shape = t.shape
assert len(shape) == 2 + len(extra_shape)
assert shape[0] == batch_sizes[i]
assert shape[1] == uniq_seqlens[i]
concated_tensor = paddle.concat(
[t.reshape([-1, *t.shape[2:]]) for t in tensors]
).unsqueeze(0)
np.testing.assert_equal(x.numpy(), concated_tensor.numpy())
return x, tensors
def check_split_tensor(seqinfo, extra_shape, batch_sizes):
if batch_sizes is None:
return check_split_tensor_without_batch_sizes(seqinfo, extra_shape)
else:
return check_split_tensor_with_batch_sizes(
seqinfo, extra_shape, batch_sizes
)
def check_same_tensor_list(tensors1, tensors2):
assert len(tensors1) == len(tensors2)
for t1, t2 in zip(tensors1, tensors2):
assert t1.shape == t2.shape
assert t1.dtype == t2.dtype
np.testing.assert_equal(t1.numpy(), t2.numpy())
class TestSeqLenInfo(unittest.TestCase):
def test_seq_len_info(self):
n = 100
seqlens = np.random.randint(2, 100, size=[n]).tolist()
cumsum_seqlens = [0] + np.cumsum(seqlens).tolist()
info = SeqLenInfo.from_seqlens(seqlens)
self.assertEqual(max(seqlens), info.max_seqlen)
np.testing.assert_equal(cumsum_seqlens, info.seqstart.numpy())
np.testing.assert_equal(cumsum_seqlens, info.seqstart_py)
intervals = list(info.intervals())
self.assertEqual(n, len(intervals))
for i in range(n):
self.assertEqual(cumsum_seqlens[i], intervals[i][0])
self.assertEqual(cumsum_seqlens[i + 1], intervals[i][1])
check_split_tensor_without_batch_sizes(info, [8, 9])
def test_split_with_batch_sizes(self):
n_tensor = 10
extra_shape = [3, 4]
batch_sizes = np.random.randint(10, 200, size=[n_tensor]).tolist()
seqlens = []
uniq_seqlens = []
for bs in batch_sizes:
tmp_seqlen = np.random.randint(10, 200, size=[1])[0]
uniq_seqlens.append(tmp_seqlen)
seqlens.extend([tmp_seqlen] * bs)
info = SeqLenInfo.from_seqlens(seqlens)
check_split_tensor_with_batch_sizes(info, extra_shape, batch_sizes)
class TestPaddedSeqLenInfo(unittest.TestCase):
def test_padded_seq_len_info(self):
n = 100
padding = 200
seqlens = np.random.randint(2, padding, size=[n]).tolist()
info = PaddedSeqLenInfo.from_seqlens_padded(seqlens, padding)
self.assertEqual(max(seqlens), info.max_seqlen)
np.testing.assert_equal(info.seqstart.numpy(), info.seqstart_py)
self.assertEqual(len(info.seqstart_py), n + 1)
self.assertEqual(info.seqstart_py[0], 0)
self.assertTrue(np.all(np.diff(info.seqstart_py) == padding))
intervals = list(info.intervals())
self.assertEqual(len(intervals), n)
for i in range(n):
interval = intervals[i]
self.assertEqual(interval[0], padding * i)
self.assertEqual(interval[1] - interval[0], seqlens[i])
class TestBlockDiagonalMask(unittest.TestCase):
def setUp(self):
self.mask_class = BlockDiagonalMask
self.q_n = 10
self.qkv_same_length = True
self.config()
def config(self):
pass
def test_from_seq_lens(self):
q_seqlen = np.random.randint(2, 100, self.q_n).tolist()
if self.qkv_same_length:
kv_seqlen = q_seqlen
else:
kv_seqlen = np.random.randint(2, 100, int(self.q_n)).tolist()
mask = self.mask_class.from_seqlens(q_seqlen, kv_seqlen)
self.check_main(mask, q_seqlen, kv_seqlen, [3, 4])
def test_from_tensor_list(self):
shapes = [[2, 3], [7, 9], [11, 5]]
extra_shape = [13, 19]
tensors = []
seqlens = []
for s in shapes:
tmp_s = s + extra_shape
tensors.append(paddle.rand(tmp_s))
seqlens.extend([tmp_s[1]] * tmp_s[0])
mask, concated_tensor = self.mask_class.from_tensor_list(tensors)
self.check_main(mask, seqlens, seqlens, extra_shape)
def test_from_tensor_lists_qk(self):
self.check_from_tensor_lists_qkv()
def test_from_tensor_lists_qkv(self):
self.check_from_tensor_lists_qkv(has_value=True)
def check_from_tensor_lists_qkv(self, has_value=False):
batch_sizes = [2, 3, 4]
q_uniq_seqlens = [5, 6, 7]
k_uniq_seqlens = [8, 9, 10]
extra_shape = [13, 19]
tensors_q = []
tensors_k = []
tensors_v = [] if has_value else None
q_seqlens = []
kv_seqlens = []
for i, bs in enumerate(batch_sizes):
q_shape = [bs, q_uniq_seqlens[i]] + extra_shape
kv_shape = [bs, k_uniq_seqlens[i]] + extra_shape
tensors_q.append(paddle.rand(q_shape))
tensors_k.append(paddle.rand(kv_shape))
q_seqlens.extend([q_shape[1]] * q_shape[0])
kv_seqlens.extend([kv_shape[1]] * kv_shape[0])
if has_value:
tensors_v.append(paddle.rand(kv_shape))
mask, q, k, v = self.mask_class.from_tensor_lists_qkv(
tensors_q, tensors_k, tensors_v
)
self.check_main(
mask,
q_seqlens,
kv_seqlens,
extra_shape,
check_same_shape_split=False,
)
def check_main(
self,
mask,
q_seqlen,
kv_seqlen,
extra_shape,
check_same_shape_split=True,
):
total_q_tokens = sum(q_seqlen)
total_kv_tokens = sum(kv_seqlen)
shape = extra_shape + [total_q_tokens, total_kv_tokens]
mask_value = mask.materialize(shape=shape)
self.assertEqual(mask_value.shape, shape)
mask_value = mask_value.numpy()
mask_value = mask_value.reshape([-1, *mask_value.shape[-2:]])
for i in range(1, mask_value.shape[0]):
np.testing.assert_equal(mask_value[i], mask_value[0])
mask_value = mask_value[0]
self.check_mask(
mask_value,
list(mask.q_seqinfo.intervals()),
list(mask.k_seqinfo.intervals()),
)
x, tensors = check_split_tensor(
mask.q_seqinfo, extra_shape, mask._batch_sizes
)
check_same_tensor_list(mask.split_queries(x), tensors)
x, tensors = check_split_tensor(
mask.k_seqinfo, extra_shape, mask._batch_sizes
)
check_same_tensor_list(mask.split_kv(x), tensors)
if self.qkv_same_length and check_same_shape_split:
x, tensors = check_split_tensor(
mask.q_seqinfo, extra_shape, mask._batch_sizes
)
check_same_tensor_list(mask.split(x), tensors)
if self.mask_class == BlockDiagonalMask:
self.assertEqual(type(mask.make_causal()), BlockDiagonalCausalMask)
def check_mask(self, mask, q_intervals, k_intervals):
self.assertEqual(len(mask.shape), 2)
m, n = mask.shape
self.assertEqual(len(q_intervals), len(k_intervals))
for (q_start, q_end), (k_start, k_end) in zip(q_intervals, k_intervals):
if k_start > 0:
self.assertTrue(
np.all(mask[q_start:q_end, 0:k_start] == float('-inf'))
)
if k_end < n:
self.assertTrue(
np.all(mask[q_start:q_end, k_end:] == float('-inf'))
)
block_mask = mask[q_start:q_end, k_start:k_end]
self.check_block_mask(block_mask)
def check_block_mask(self, block_mask):
self.assertTrue(np.all(block_mask == 0))
class TestBlockDiagonalMaskQKVDiffLength(TestBlockDiagonalMask):
def config(self):
self.qkv_same_length = False
class TestBlockDiagonalCausalMask(TestBlockDiagonalMask):
def config(self):
self.mask_class = BlockDiagonalCausalMask
def check_block_mask(self, block_mask):
self.assertEqual(len(block_mask.shape), 2)
m, n = block_mask.shape
for i in range(m):
for j in range(n):
if i >= j:
self.assertEqual(block_mask[i][j], 0)
else:
self.assertEqual(block_mask[i][j], float('-inf'))
class TestBlockDiagonalCausalMaskQKVDiffLength(TestBlockDiagonalCausalMask):
def config(self):
self.mask_class = BlockDiagonalCausalMask
self.qkv_same_length = False
class TestBlockDiagonalCausalWithOffsetPaddedKeysMask(unittest.TestCase):
def test_main(self):
kv_padding = 20
n = 10
extra_shape = [3, 4]
q_seqlen = np.random.randint(0, kv_padding, size=[n]).tolist()
kv_seqlen = np.random.randint(0, kv_padding, size=[n]).tolist()
q_ntokens = sum(q_seqlen)
kv_ntokens = n * kv_padding
max_causal_diagonal = min(q_ntokens, kv_ntokens) - 2
causal_diagonal_np = np.random.randint(
0, max_causal_diagonal, size=[n]
).astype(np.int32)
causal_diagonal = paddle.to_tensor(causal_diagonal_np)
mask = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
q_seqlen, kv_padding, kv_seqlen, causal_diagonal
)
shape = extra_shape + [q_ntokens, kv_ntokens]
mask_np = mask.materialize(shape).numpy()
self.assertEqual(list(mask_np.shape[: len(extra_shape)]), extra_shape)
mask_np = mask_np.reshape([-1, *mask_np.shape[2:]])
for i in range(1, mask_np.shape[0]):
np.testing.assert_equal(mask_np[i], mask_np[0])
mask_np = mask_np[0]
q_intervals = list(mask.q_seqinfo.intervals())
k_intervals = list(mask.k_seqinfo.intervals())
self.assertEqual(len(q_intervals), len(k_intervals))
for i, ((q_start, q_end), (k_start, k_end)) in enumerate(
zip(q_intervals, k_intervals)
):
if k_start != 0:
np.testing.assert_equal(
mask_np[q_start:q_end, 0:k_start], float('-inf')
)
np.testing.assert_equal(
mask_np[q_start:q_end, k_start:k_end],
self.create_numpy_block_mask(
(q_end - q_start, k_end - k_start), causal_diagonal_np[i]
),
)
if k_end != kv_ntokens:
np.testing.assert_equal(
mask_np[q_start:q_end, k_end:kv_ntokens], float('-inf')
)
def create_numpy_block_mask(self, shape, offset, dtype=np.float32):
t = np.full(shape, dtype=dtype, fill_value=float('-inf'))
return np.triu(t, 1 + offset)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The following codes are from https://github.com/facebookresearch/xformers
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional, Sequence
import paddle
class AttentionBias(ABC):
@abstractmethod
def materialize(self, shape, dtype=paddle.float32):
raise NotImplementedError()
class LowerTriangularMask(AttentionBias):
def materialize(self, shape, dtype=paddle.float32):
create_as = dtype if dtype is not paddle.bfloat16 else paddle.float32
tensor = paddle.full(
shape=shape, fill_value=float("-inf"), dtype=create_as
)
return paddle.triu(tensor, diagonal=1).astype(dtype)
def add_bias(self, bias):
return LowerTriangularMaskWithTensorBias(bias)
class LowerTriangularMaskWithTensorBias(LowerTriangularMask):
def __init__(self, bias):
self._bias = bias
def materialize(self, shape, dtype=paddle.float32):
return super().materialize(shape, dtype) + self._bias
@dataclass
class SeqLenInfo:
seqstart: paddle.Tensor
max_seqlen: int
seqstart_py: List[int]
def intervals(self):
yield from zip(self.seqstart_py, self.seqstart_py[1:])
@classmethod
def from_seqlens(cls, seqlens):
seqstart_py = [0]
max_seqlen = -1
for seqlen in seqlens:
max_seqlen = max(max_seqlen, seqlen)
seqstart_py.append(seqstart_py[-1] + seqlen)
seqstart = paddle.to_tensor(seqstart_py, dtype=paddle.int32)
return cls(
max_seqlen=max_seqlen, seqstart=seqstart, seqstart_py=seqstart_py
)
def split(self, x, batch_sizes=None):
assert self.seqstart_py[-1] == x.shape[1] and x.shape[0] == 1
if batch_sizes is None:
batch_sizes = [1] * (len(self.seqstart_py) - 1)
split_chunks = []
it = 0
for batch_size in batch_sizes:
split_chunks.append(
self.seqstart_py[it + batch_size] - self.seqstart_py[it]
)
it += batch_size
return [
tensor.reshape([bs, -1, *tensor.shape[2:]])
for bs, tensor in zip(batch_sizes, x.split(split_chunks, axis=1))
]
@dataclass
class PaddedSeqLenInfo(SeqLenInfo):
seqlen: paddle.Tensor
seqlen_py: Sequence[int]
def intervals(self):
for (start, _), length in zip(super().intervals(), self.seqlen_py):
yield start, start + length
@classmethod
def from_seqlens(cls, seqlens):
raise NotImplementedError(
"Please use SeqLenInfo.from_seq_lens() or PaddedSeqLenInfo.from_seq_lens_padded()."
)
@classmethod
def from_seqlens_padded(cls, seqlens, padding):
assert all([seqlen <= padding for seqlen in seqlens])
seqstart_py = list(range(0, len(seqlens) * padding + 1, padding))
return cls(
seqlen=paddle.to_tensor(seqlens, dtype=paddle.int32),
seqlen_py=seqlens,
max_seqlen=max(seqlens),
seqstart=paddle.to_tensor(seqstart_py, dtype=paddle.int32),
seqstart_py=seqstart_py,
)
def split(self, x, batch_sizes=None):
raise NotImplementedError()
@dataclass
class BlockDiagonalMask(AttentionBias):
q_seqinfo: SeqLenInfo
k_seqinfo: SeqLenInfo
_batch_sizes: Optional[Sequence[int]] = None
def _create_block_mask(self, shape, dtype=paddle.float32):
return paddle.zeros(shape=shape, dtype=dtype)
def materialize(self, shape, dtype=paddle.float32):
assert shape[-1] == self.k_seqinfo.seqstart_py[-1]
assert shape[-2] == self.q_seqinfo.seqstart_py[-1]
mask = paddle.full(shape[-2:], fill_value=float('-inf'), dtype=dtype)
for (q_start, q_end), (k_start, k_end) in zip(
self.q_seqinfo.intervals(), self.k_seqinfo.intervals()
):
sub_shape = [q_end - q_start, k_end - k_start]
mask[q_start:q_end, k_start:k_end] = self._create_block_mask(
sub_shape, dtype
)
for _ in range(len(shape) - 2):
mask = mask.unsqueeze(0)
return mask.expand(shape)
@classmethod
def from_seqlens(cls, q_seqlen, kv_seqlen=None):
assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen)
q_seqinfo = SeqLenInfo.from_seqlens(q_seqlen)
if kv_seqlen is None or q_seqlen == kv_seqlen:
k_seqinfo = q_seqinfo
else:
k_seqinfo = SeqLenInfo.from_seqlens(kv_seqlen)
return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo)
@classmethod
def from_tensor_list(cls, tensors):
batch_sizes = [tensor.shape[0] for tensor in tensors]
seqlens = []
for x in tensors:
for _ in range(x.shape[0]):
seqlens.append(x.shape[1])
block_diag = cls.from_seqlens(seqlens)
block_diag._batch_sizes = batch_sizes
concated_tensor = paddle.concat(
[x.reshape([1, -1, *x.shape[2:]]) for x in tensors], axis=1
)
return block_diag, concated_tensor
@classmethod
def from_tensor_lists_qkv(cls, tensors_q, tensors_k, tensors_v=None):
assert len(tensors_q) == len(tensors_k)
assert tensors_v is None or len(tensors_v) == len(tensors_q)
batch_sizes = [tensor.shape[0] for tensor in tensors_q]
q_seqlens, kv_seqlens = [], []
for i, (q, k) in enumerate(zip(tensors_q, tensors_k)):
assert q.shape[0] == k.shape[0]
q_seqlens.extend([q.shape[1]] * q.shape[0])
kv_seqlens.extend([k.shape[1]] * k.shape[0])
assert tensors_v is None or tensors_v[i].shape[:2] == k.shape[:2]
block_diag = cls.from_seqlens(q_seqlens, kv_seqlens)
block_diag._batch_sizes = [x.shape[0] for x in tensors_q]
return (
block_diag,
paddle.concat(
[x.reshape([1, -1, *x.shape[2:]]) for x in tensors_q], axis=1
),
paddle.concat(
[x.reshape([1, -1, *x.shape[2:]]) for x in tensors_k], axis=1
),
paddle.concat(
[x.reshape([1, -1, *x.shape[2:]]) for x in tensors_v], axis=1
)
if tensors_v is not None
else None,
)
def split_queries(self, tensor):
return self.q_seqinfo.split(tensor, self._batch_sizes)
def split_kv(self, tensor):
return self.k_seqinfo.split(tensor, self._batch_sizes)
def split(self, tensor):
assert self.q_seqinfo is self.k_seqinfo
return self.q_seqinfo.split(tensor, self._batch_sizes)
def make_causal(self):
return BlockDiagonalCausalMask(
q_seqinfo=self.q_seqinfo,
k_seqinfo=self.k_seqinfo,
_batch_sizes=self._batch_sizes,
)
@dataclass
class BlockDiagonalCausalMask(BlockDiagonalMask):
def _create_block_mask(self, shape, dtype=paddle.float32):
return LowerTriangularMask().materialize(shape=shape, dtype=dtype)
@dataclass
class BlockDiagonalCausalWithOffsetPaddedKeysMask(AttentionBias):
q_seqinfo: SeqLenInfo
k_seqinfo: PaddedSeqLenInfo
causal_diagonal: Optional[paddle.Tensor] = None
def _create_block_mask(self, shape, offset=0, dtype=paddle.float32):
create_as = dtype if dtype is not paddle.bfloat16 else paddle.float32
tensor = paddle.full(shape, dtype=create_as, fill_value=float('-inf'))
return paddle.triu(tensor, diagonal=1 + offset).astype(dtype)
def materialize(self, shape, dtype=paddle.float32):
assert shape[-1] == self.k_seqinfo.seqstart_py[-1]
assert shape[-2] == self.q_seqinfo.seqstart_py[-1]
mask = paddle.full(shape[-2:], dtype=dtype, fill_value=float('-inf'))
for i, ((q_start, q_end), (k_start, k_end)) in enumerate(
zip(self.q_seqinfo.intervals(), self.k_seqinfo.intervals())
):
mask[q_start:q_end, k_start:k_end] = self._create_block_mask(
(q_end - q_start, k_end - k_start),
offset=0
if self.causal_diagonal is None
else int(self.causal_diagonal[i].item()),
dtype=dtype,
)
for _ in range(len(shape) - 2):
mask = mask.unsqueeze(0)
return mask.expand(shape)
@classmethod
def from_seqlens(
cls, q_seqlen, kv_padding, kv_seqlen, causal_diagonal=None
):
assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen)
q_seqinfo = SeqLenInfo.from_seqlens(q_seqlen)
k_seqinfo = PaddedSeqLenInfo.from_seqlens_padded(kv_seqlen, kv_padding)
return cls(
q_seqinfo=q_seqinfo,
k_seqinfo=k_seqinfo,
causal_diagonal=causal_diagonal,
)
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The following codes are from https://github.com/facebookresearch/xformers
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import paddle
from .attn_bias import (
BlockDiagonalCausalMask,
BlockDiagonalCausalWithOffsetPaddedKeysMask,
BlockDiagonalMask,
LowerTriangularMask,
LowerTriangularMaskWithTensorBias,
)
SUPPORTED_ATTN_BIAS_TYPES = {
type(None),
paddle.Tensor,
LowerTriangularMask,
LowerTriangularMaskWithTensorBias,
BlockDiagonalMask,
BlockDiagonalCausalMask,
BlockDiagonalCausalWithOffsetPaddedKeysMask,
}
def _get_seqlen_info(attn_bias):
if isinstance(
attn_bias,
(BlockDiagonalMask, BlockDiagonalCausalWithOffsetPaddedKeysMask),
):
return (
attn_bias.k_seqinfo.seqstart,
attn_bias.q_seqinfo.seqstart,
attn_bias.q_seqinfo.max_seqlen,
attn_bias.k_seqinfo.max_seqlen,
)
else:
return None, None, -1, -1
def _get_tensor_bias(attn_bias):
if isinstance(attn_bias, paddle.Tensor):
return attn_bias
elif isinstance(attn_bias, LowerTriangularMaskWithTensorBias):
return attn_bias._bias
else:
return None
def memory_efficient_attention(
query, key, value, attn_bias, p=0.0, scale=None, training=True
):
assert type(attn_bias) in SUPPORTED_ATTN_BIAS_TYPES
causal = isinstance(
attn_bias,
(
LowerTriangularMask,
BlockDiagonalCausalMask,
BlockDiagonalCausalWithOffsetPaddedKeysMask,
),
)
seqstart_k, seqstart_q, max_seqlen_q, _ = _get_seqlen_info(attn_bias)
# NOTE: compute_logsumexp = training
is_test = not training
causal_diagonal = (
attn_bias.causal_diagonal
if isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
else None
)
seqlen_k = (
attn_bias.k_seqinfo.seqlen
if isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
else None
)
attn_bias = _get_tensor_bias(attn_bias)
# TODO(zhangdanyang): add C++ codes here
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册