From 1e44d48d5377a5f839c53ccefcd6951625479f51 Mon Sep 17 00:00:00 2001 From: Zhun Date: Fri, 4 Dec 2020 14:58:10 -0800 Subject: [PATCH] Fix potential random layout inconsistency issues in sparse attention modules (#534) * 1) Register layout as buffer of module so that we can save/load checkpoint; 2) Add a broadcast of layout at the beginning to ensure different processes will have consistent layout during distributed training. * Add docstring for max_seq_length argument in SparseSelfAttention Co-authored-by: Zhun Liu Co-authored-by: Jeff Rasley --- .../sparse_attention/sparse_self_attention.py | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/deepspeed/ops/sparse_attention/sparse_self_attention.py b/deepspeed/ops/sparse_attention/sparse_self_attention.py index 2e315604..6e7d8905 100644 --- a/deepspeed/ops/sparse_attention/sparse_self_attention.py +++ b/deepspeed/ops/sparse_attention/sparse_self_attention.py @@ -5,6 +5,7 @@ Copyright 2020 The Microsoft DeepSpeed Team import torch.nn as nn from torch.nn.functional import * import torch +from torch import distributed as dist from collections import namedtuple from deepspeed.ops.sparse_attention import MatMul, Softmax, SparsityConfig import sys @@ -22,29 +23,50 @@ class SparseSelfAttention(nn.Module): # SparsityConfig parameters needs to be set accordingly sparsity_config=SparsityConfig(num_heads=4), key_padding_mask_mode='add', - attn_mask_mode='mul'): + attn_mask_mode='mul', + max_seq_length=2048): """Initialize the sparse self attention layer. Arguments: sparsity_config: optional: this parameter determins sparsity pattern configuration; it is based on SparsityConfig class. key_padding_mask_mode: optional: a string determining if key padding mask needs to be added, `add`, or be multiplied, `mul`. attn_mask_mode: optional: a string determining if attention mask needs to be added, `add`, or be multiplied, `mul`. + max_seq_length: optional: the maximum sequence length this sparse attention module will be applied to; it controls the size of the master_layout. """ super().__init__() # sparsity information self.sparsity_config = sparsity_config + # initialize sparse layout and register as buffer + master_layout = self.sparsity_config.make_layout(max_seq_length) + self.register_buffer("master_layout", master_layout) + self._need_layout_synchronization = True + # mask modes self.key_padding_mask_mode = key_padding_mask_mode self.attn_mask_mode = attn_mask_mode ops = dict() + def get_layout(self, L): + # if layout is never synchronized across GPUs, broadcast the layout from global rank 0 + if self._need_layout_synchronization and dist.is_initialized(): + dist.broadcast(self.master_layout, src=0) + self._need_layout_synchronization = False + + if (L % self.sparsity_config.block != 0): + raise ValueError( + f'Sequence Length, {L}, needs to be dividable by Block size {self.sparsity_config.block}!' + ) + + num_blocks = L // self.sparsity_config.block + return self.master_layout[..., :num_blocks, :num_blocks].cpu() # layout needs to be a CPU tensor + # add to cache def get_ops(self, H, L): import sys if L not in SparseSelfAttention.ops: - sparsity_layout = self.sparsity_config.make_layout(L) + sparsity_layout = self.get_layout(L) sparse_dot_sdd_nt = MatMul(sparsity_layout, self.sparsity_config.block, 'sdd', -- GitLab