From d3657c32d9a352940eab1d4436794b57a1432751 Mon Sep 17 00:00:00 2001 From: Javier Date: Fri, 4 Aug 2023 13:12:37 +0100 Subject: [PATCH] Added a example of flash and linear attention. Fix some small bugs in one example. Adjusted all new functionality to GPU usage --- ...adult_census_linear_and_flash_attention.py | 84 +++++++++ .../ml100k_data_preparation.py | 7 +- .../pytorch_wide_deep_pt2.py | 8 +- .../tabular/transformers/_attention_layers.py | 171 +++++++----------- .../test_mc_attn_layers.py | 39 ---- 5 files changed, 157 insertions(+), 152 deletions(-) create mode 100644 examples/scripts/adult_census_linear_and_flash_attention.py diff --git a/examples/scripts/adult_census_linear_and_flash_attention.py b/examples/scripts/adult_census_linear_and_flash_attention.py new file mode 100644 index 0000000..585e0c0 --- /dev/null +++ b/examples/scripts/adult_census_linear_and_flash_attention.py @@ -0,0 +1,84 @@ +from time import time + +from sklearn.model_selection import train_test_split + +from pytorch_widedeep import Trainer +from pytorch_widedeep.models import WideDeep, TabTransformer +from pytorch_widedeep.metrics import Accuracy +from pytorch_widedeep.datasets import load_adult +from pytorch_widedeep.preprocessing import TabPreprocessor + +# use_cuda = torch.cuda.is_available() + +df = load_adult(as_frame=True) +df.columns = [c.replace("-", "_") for c in df.columns] +df["income_label"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int) +df.drop("income", axis=1, inplace=True) +target_colname = "income_label" + +cat_embed_cols = [] +for col in df.columns: + if df[col].dtype == "O" or df[col].nunique() < 200 and col != target_colname: + cat_embed_cols.append(col) + +train, test = train_test_split( + df, test_size=0.1, random_state=1, stratify=df[[target_colname]] +) + +with_cls_token = True +tab_preprocessor = TabPreprocessor( + cat_embed_cols=cat_embed_cols, with_attention=True, with_cls_token=with_cls_token +) + +X_tab_train = tab_preprocessor.fit_transform(train) +X_tab_test = tab_preprocessor.transform(test) +target = train[target_colname].values + + +tab_transformer = TabTransformer( + column_idx=tab_preprocessor.column_idx, + cat_embed_input=tab_preprocessor.cat_embed_input, + input_dim=16, + n_heads=2, + n_blocks=2, +) + +linear_tab_transformer = TabTransformer( + column_idx=tab_preprocessor.column_idx, + cat_embed_input=tab_preprocessor.cat_embed_input, + input_dim=16, + n_heads=2, + n_blocks=2, + use_linear_attention=True, +) + +flash_tab_transformer = TabTransformer( + column_idx=tab_preprocessor.column_idx, + cat_embed_input=tab_preprocessor.cat_embed_input, + input_dim=16, + n_heads=2, + n_blocks=2, + use_flash_attention=True, +) + +s_model = WideDeep(deeptabular=tab_transformer) +l_model = WideDeep(deeptabular=linear_tab_transformer) +f_model = WideDeep(deeptabular=flash_tab_transformer) + +for name, model in [("standard", s_model), ("linear", l_model), ("flash", f_model)]: + trainer = Trainer( + model, + objective="binary", + metrics=[Accuracy], + ) + + s = time() + trainer.fit( + X_tab=X_tab_train, + target=target, + n_epochs=1, + batch_size=64, + val_split=0.2, + ) + e = time() - s + print(f"{name} attention time: {round(e, 3)} secs") diff --git a/examples/scripts/wide_deep_for_recsys/ml100k_data_preparation.py b/examples/scripts/wide_deep_for_recsys/ml100k_data_preparation.py index f701ce1..fa1dd6f 100644 --- a/examples/scripts/wide_deep_for_recsys/ml100k_data_preparation.py +++ b/examples/scripts/wide_deep_for_recsys/ml100k_data_preparation.py @@ -10,7 +10,7 @@ from sklearn.model_selection import train_test_split from pytorch_widedeep.datasets import load_movielens100k -data, user, items = load_movielens100k(as_frame=True) +data, users, items = load_movielens100k(as_frame=True) # Alternatively, as specified in the docs: 'The last 19 fields are the genres' so: # list_of_genres = items.columns.tolist()[-19:] @@ -37,7 +37,7 @@ list_of_genres = [ ] -# adding a column with the number of movies watched per user +# adding a column with the number of movies watched per users dataset = data.sort_values(["user_id", "timestamp"]).reset_index(drop=True) dataset["one"] = 1 dataset["num_watched"] = dataset.groupby("user_id")["one"].cumsum() @@ -61,6 +61,9 @@ dataset["prev_movies"] = ( ) dataset["prev_movies"] = dataset["prev_movies"].apply(lambda x: x.split()) +# Adding user feats +dataset = dataset.merge(users, on="user_id", how="left") + # Adding a genre_rate as the mean of all movies rated for a given genre per # user dataset = dataset.merge(items[["movie_id"] + list_of_genres], on="movie_id", how="left") diff --git a/examples/scripts/wide_deep_for_recsys/pytorch_wide_deep_pt2.py b/examples/scripts/wide_deep_for_recsys/pytorch_wide_deep_pt2.py index 053a7f0..98bf7c3 100644 --- a/examples/scripts/wide_deep_for_recsys/pytorch_wide_deep_pt2.py +++ b/examples/scripts/wide_deep_for_recsys/pytorch_wide_deep_pt2.py @@ -47,7 +47,7 @@ train_movies_sequences = df_train.prev_movies.apply( ).to_list() y_train = df_train.target.values.astype(int) -df_test_user_item = df_train[["user_id", "movie_id", "rating"]] +df_test_user_item = df_test[["user_id", "movie_id", "rating"]] test_movies_sequences = df_test.prev_movies.apply( lambda x: [int(el) for el in x] ).to_list() @@ -89,7 +89,7 @@ X_test_text = np.array( tab_mlp = TabMlp( column_idx=tab_preprocessor.column_idx, cat_embed_input=tab_preprocessor.cat_embed_input, - mlp_hidden_dims=[1024, 512, 256], + mlp_hidden_dims=[512, 256], mlp_activation="relu", ) @@ -124,7 +124,7 @@ trainer.fit( "X_text": X_test_text, "target": y_test, }, - n_epochs=10, - batch_size=521, + n_epochs=2, + batch_size=32, shuffle=False, ) diff --git a/pytorch_widedeep/models/tabular/transformers/_attention_layers.py b/pytorch_widedeep/models/tabular/transformers/_attention_layers.py index f5ee884..e1b5110 100644 --- a/pytorch_widedeep/models/tabular/transformers/_attention_layers.py +++ b/pytorch_widedeep/models/tabular/transformers/_attention_layers.py @@ -4,16 +4,13 @@ https://github.com/lucidrains """ import math -import warnings -from enum import Enum -from typing import ContextManager import torch import einops import torch.nn.functional as F from torch import nn, einsum -from pytorch_widedeep.wdtypes import List, Tuple, Tensor, Optional +from pytorch_widedeep.wdtypes import Tuple, Tensor, Optional from pytorch_widedeep.models._get_activation_fn import get_activation_fn @@ -67,88 +64,6 @@ class AddNorm(nn.Module): return self.ln(X + self.dropout(sublayer(X))) -def _standard_attention( - q: Tensor, k: Tensor, v: Tensor, head_dim: int, dropout: float -) -> Tuple[Tensor, Tensor]: - """'Standard' multihead attention implemenation from [Attention Is All You - Need](https://arxiv.org/abs/1706.03762) - """ - # b: batch size - # s: seq length - # l: target sequence length - # m: used to refer indistinctively to s or l - # h: number of attention heads, - # d and e: head_dim - - # Normalised Query, Key dot product + softmax. Fraction in brackets in - # their Eq 1 - scores = einsum("b h s d, b h l d -> b h s l", q, k) / math.sqrt(head_dim) - attn_weights = scores.softmax(dim=-1) - - # Attention(Q, K, V ) (with dropout) in their Eq 1 - attn_output = einsum( - "b h s l, b h l d -> b h s d", nn.Dropout(dropout)(attn_weights), v - ) - - return attn_weights, attn_output - - -def _linear_attention(q: Tensor, k: Tensor, v: Tensor) -> Tensor: - """Liner attention implemenation from [Transformers are RNNs: Fast - Autoregressive Transformers with Linear Attention] - (https://arxiv.org/abs/2006.16236) - """ - # b: batch size - # s: seq length - # l: target sequence length - # m: used to refer indistinctively to s or l - # h: number of attention heads, - # d and e: head_dim - q, k = ( - nn.functional.elu(q) + 1, - nn.functional.elu(k) + 1, - ) - - # Term within the summation in the numerator of their Eq 5 - scores = einsum("b h s e, b h l d -> b h e d", k, v) - - # The denominator in their Eq 5 - z = 1 / (torch.einsum("b h m d, b h d -> b h m", q, k.sum(dim=2)) + 1e-6) - - # Their Eq 5 - attn_output = torch.einsum("b h m d, b h e d, b h m -> b h m d", q, scores, z) - - return attn_output - - -class SDPBackend(Enum): - FLASH: int = 0 - MEM_EFFICIENT: int = 1 - - -def _flash_kernel_setup(enabled_flash_backends: List[SDPBackend]) -> ContextManager: - assert ( - torch.cuda.is_available() - ), "optimized kernels can only be used if CUDA is available." - - warnings.warn( - "Note that FlashAttention is beta and subject to change", - RuntimeWarning, - ) - - # Setting backends as suggested at https://discuss.pytorch.org/t/flash-attention/174955 - enable_flash = True if SDPBackend.FLASH in enabled_flash_backends else False - enable_mem_eff = ( - True if SDPBackend.MEM_EFFICIENT in enabled_flash_backends else False - ) - - return torch.backends.cuda.sdp_kernel( - enable_flash=enable_flash, - enable_mem_efficient=enable_mem_eff, - enable_math=False, - ) - - class MultiHeadedAttention(nn.Module): def __init__( self, @@ -159,10 +74,6 @@ class MultiHeadedAttention(nn.Module): query_dim: Optional[int] = None, use_linear_attention: bool = False, use_flash_attention: bool = False, - enabled_flash_backends: List[SDPBackend] = [ - SDPBackend.FLASH, - SDPBackend.MEM_EFFICIENT, - ], # in the next release we will offer the posibility of setting up the backend ): super(MultiHeadedAttention, self).__init__() @@ -170,7 +81,6 @@ class MultiHeadedAttention(nn.Module): self.use_linear_attention = use_linear_attention self.use_flash_attention = use_flash_attention - self.enabled_flash_backends = enabled_flash_backends self.head_dim = input_dim // n_heads self.n_heads = n_heads @@ -200,26 +110,20 @@ class MultiHeadedAttention(nn.Module): ) if self.use_flash_attention: - # in the future we will offer the possibility of setting up the - # backend. For the time being this context manager - # is 'redundant' - with _flash_kernel_setup(self.enabled_flash_backends): - attn_output = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - dropout_p=self.dropout if self.training else 0, - is_causal=False, - ) + attn_output = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + dropout_p=self.dropout if self.training else 0, + is_causal=False, + ) self.attn_weights: Optional[Tensor] = None elif self.use_linear_attention: - attn_output = _linear_attention(q, k, v) + attn_output = self._linear_attention(q, k, v) self.attn_weights = None else: - self.attn_weights, attn_output = _standard_attention( - q, k, v, self.head_dim, self.dropout - ) + self.attn_weights, attn_output = self._standard_attention(q, k, v) output = einops.rearrange(attn_output, "b h s d -> b s (h d)", h=self.n_heads) @@ -228,6 +132,59 @@ class MultiHeadedAttention(nn.Module): return output + def _standard_attention( + self, q: Tensor, k: Tensor, v: Tensor + ) -> Tuple[Tensor, Tensor]: + """'Standard' multihead attention implemenation from [Attention Is All You + Need](https://arxiv.org/abs/1706.03762) + """ + # b: batch size + # s: seq length + # l: target sequence length + # m: used to refer indistinctively to s or l + # h: number of attention heads, + # d and e: head_dim + + # Normalised Query, Key dot product + softmax. Fraction in brackets in + # their Eq 1 + scores = einsum("b h s d, b h l d -> b h s l", q, k) / math.sqrt(self.head_dim) + attn_weights = scores.softmax(dim=-1) + + # Attention(Q, K, V ) (with dropout) in their Eq 1 + attn_output = einsum( + "b h s l, b h l d -> b h s d", nn.Dropout(self.dropout)(attn_weights), v + ) + + return attn_weights, attn_output + + @staticmethod + def _linear_attention(q: Tensor, k: Tensor, v: Tensor) -> Tensor: + """Liner attention implemenation from [Transformers are RNNs: Fast + Autoregressive Transformers with Linear Attention] + (https://arxiv.org/abs/2006.16236) + """ + # b: batch size + # s: seq length + # l: target sequence length + # m: used to refer indistinctively to s or l + # h: number of attention heads, + # d and e: head_dim + q, k = ( + nn.functional.elu(q) + 1, + nn.functional.elu(k) + 1, + ) + + # Term within the summation in the numerator of their Eq 5 + scores = einsum("b h s e, b h l d -> b h e d", k, v) + + # The denominator in their Eq 5 + z = 1 / (torch.einsum("b h m d, b h d -> b h m", q, k.sum(dim=2)) + 1e-6) + + # Their Eq 5 + attn_output = torch.einsum("b h m d, b h e d, b h m -> b h m d", q, scores, z) + + return attn_output + class LinearAttentionLinformer(nn.Module): """Linear Attention implementation from [Linformer: Self-Attention with diff --git a/tests/test_model_components/test_mc_attn_layers.py b/tests/test_model_components/test_mc_attn_layers.py index 5d1c0b6..e510699 100644 --- a/tests/test_model_components/test_mc_attn_layers.py +++ b/tests/test_model_components/test_mc_attn_layers.py @@ -3,12 +3,9 @@ import timeit import torch import pytest -import torch.backends.cuda as tcud from pytorch_widedeep.models.tabular.transformers._attention_layers import ( - SDPBackend, MultiHeadedAttention, - _flash_kernel_setup, ) torch.backends.cudnn.deterministic = True @@ -59,42 +56,6 @@ for module in ["q_proj", "kv_proj", "out_proj"]: X = torch.randn(128, 100, input_dim).to(device) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA to run") -@pytest.mark.parametrize( - "backends, active", - [ - ( - [SDPBackend.FLASH], - { - tcud.flash_sdp_enabled: True, - tcud.mem_efficient_sdp_enabled: False, - tcud.math_sdp_enabled: False, - }, - ), - ( - [SDPBackend.MEM_EFFICIENT], - { - tcud.flash_sdp_enabled: False, - tcud.mem_efficient_sdp_enabled: True, - tcud.math_sdp_enabled: False, - }, - ), - ( - [SDPBackend.FLASH, SDPBackend.MEM_EFFICIENT], - { - tcud.flash_sdp_enabled: True, - tcud.mem_efficient_sdp_enabled: True, - tcud.math_sdp_enabled: False, - }, - ), - ], -) -def test_cdp_context_managment(backends, active): - ctx = _flash_kernel_setup(backends) - with ctx: - assert all([f() == v for f, v in active.items()]) - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA to run") def test_flash_standard_shapes(): # Check that shapes of output are the same -- GitLab