未验证 提交 1e44d48d 编写于 作者: Z Zhun 提交者: GitHub

Fix potential random layout inconsistency issues in sparse attention modules (#534)

* 1) Register layout as buffer of module so that we can save/load checkpoint; 2) Add a broadcast of layout at the beginning to ensure different processes will have consistent layout during distributed training.

* Add docstring for max_seq_length argument in SparseSelfAttention
Co-authored-by: NZhun Liu <zhunliu@microsoft.com>
Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 ff58fa7e
......@@ -5,6 +5,7 @@ Copyright 2020 The Microsoft DeepSpeed Team
import torch.nn as nn
from torch.nn.functional import *
import torch
from torch import distributed as dist
from collections import namedtuple
from deepspeed.ops.sparse_attention import MatMul, Softmax, SparsityConfig
import sys
......@@ -22,29 +23,50 @@ class SparseSelfAttention(nn.Module):
# SparsityConfig parameters needs to be set accordingly
sparsity_config=SparsityConfig(num_heads=4),
key_padding_mask_mode='add',
attn_mask_mode='mul'):
attn_mask_mode='mul',
max_seq_length=2048):
"""Initialize the sparse self attention layer.
Arguments:
sparsity_config: optional: this parameter determins sparsity pattern configuration; it is based on SparsityConfig class.
key_padding_mask_mode: optional: a string determining if key padding mask needs to be added, `add`, or be multiplied, `mul`.
attn_mask_mode: optional: a string determining if attention mask needs to be added, `add`, or be multiplied, `mul`.
max_seq_length: optional: the maximum sequence length this sparse attention module will be applied to; it controls the size of the master_layout.
"""
super().__init__()
# sparsity information
self.sparsity_config = sparsity_config
# initialize sparse layout and register as buffer
master_layout = self.sparsity_config.make_layout(max_seq_length)
self.register_buffer("master_layout", master_layout)
self._need_layout_synchronization = True
# mask modes
self.key_padding_mask_mode = key_padding_mask_mode
self.attn_mask_mode = attn_mask_mode
ops = dict()
def get_layout(self, L):
# if layout is never synchronized across GPUs, broadcast the layout from global rank 0
if self._need_layout_synchronization and dist.is_initialized():
dist.broadcast(self.master_layout, src=0)
self._need_layout_synchronization = False
if (L % self.sparsity_config.block != 0):
raise ValueError(
f'Sequence Length, {L}, needs to be dividable by Block size {self.sparsity_config.block}!'
)
num_blocks = L // self.sparsity_config.block
return self.master_layout[..., :num_blocks, :num_blocks].cpu() # layout needs to be a CPU tensor
# add to cache
def get_ops(self, H, L):
import sys
if L not in SparseSelfAttention.ops:
sparsity_layout = self.sparsity_config.make_layout(L)
sparsity_layout = self.get_layout(L)
sparse_dot_sdd_nt = MatMul(sparsity_layout,
self.sparsity_config.block,
'sdd',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册