提交 d3657c32 编写于 作者: J Javier

Added a example of flash and linear attention. Fix some small bugs in one...

Added a example of flash and linear attention. Fix some small bugs in one example. Adjusted all new functionality to GPU usage
上级 b6362d1d
from time import time
from sklearn.model_selection import train_test_split
from pytorch_widedeep import Trainer
from pytorch_widedeep.models import WideDeep, TabTransformer
from pytorch_widedeep.metrics import Accuracy
from pytorch_widedeep.datasets import load_adult
from pytorch_widedeep.preprocessing import TabPreprocessor
# use_cuda = torch.cuda.is_available()
df = load_adult(as_frame=True)
df.columns = [c.replace("-", "_") for c in df.columns]
df["income_label"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
df.drop("income", axis=1, inplace=True)
target_colname = "income_label"
cat_embed_cols = []
for col in df.columns:
if df[col].dtype == "O" or df[col].nunique() < 200 and col != target_colname:
cat_embed_cols.append(col)
train, test = train_test_split(
df, test_size=0.1, random_state=1, stratify=df[[target_colname]]
)
with_cls_token = True
tab_preprocessor = TabPreprocessor(
cat_embed_cols=cat_embed_cols, with_attention=True, with_cls_token=with_cls_token
)
X_tab_train = tab_preprocessor.fit_transform(train)
X_tab_test = tab_preprocessor.transform(test)
target = train[target_colname].values
tab_transformer = TabTransformer(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
input_dim=16,
n_heads=2,
n_blocks=2,
)
linear_tab_transformer = TabTransformer(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
input_dim=16,
n_heads=2,
n_blocks=2,
use_linear_attention=True,
)
flash_tab_transformer = TabTransformer(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
input_dim=16,
n_heads=2,
n_blocks=2,
use_flash_attention=True,
)
s_model = WideDeep(deeptabular=tab_transformer)
l_model = WideDeep(deeptabular=linear_tab_transformer)
f_model = WideDeep(deeptabular=flash_tab_transformer)
for name, model in [("standard", s_model), ("linear", l_model), ("flash", f_model)]:
trainer = Trainer(
model,
objective="binary",
metrics=[Accuracy],
)
s = time()
trainer.fit(
X_tab=X_tab_train,
target=target,
n_epochs=1,
batch_size=64,
val_split=0.2,
)
e = time() - s
print(f"{name} attention time: {round(e, 3)} secs")
......@@ -10,7 +10,7 @@ from sklearn.model_selection import train_test_split
from pytorch_widedeep.datasets import load_movielens100k
data, user, items = load_movielens100k(as_frame=True)
data, users, items = load_movielens100k(as_frame=True)
# Alternatively, as specified in the docs: 'The last 19 fields are the genres' so:
# list_of_genres = items.columns.tolist()[-19:]
......@@ -37,7 +37,7 @@ list_of_genres = [
]
# adding a column with the number of movies watched per user
# adding a column with the number of movies watched per users
dataset = data.sort_values(["user_id", "timestamp"]).reset_index(drop=True)
dataset["one"] = 1
dataset["num_watched"] = dataset.groupby("user_id")["one"].cumsum()
......@@ -61,6 +61,9 @@ dataset["prev_movies"] = (
)
dataset["prev_movies"] = dataset["prev_movies"].apply(lambda x: x.split())
# Adding user feats
dataset = dataset.merge(users, on="user_id", how="left")
# Adding a genre_rate as the mean of all movies rated for a given genre per
# user
dataset = dataset.merge(items[["movie_id"] + list_of_genres], on="movie_id", how="left")
......
......@@ -47,7 +47,7 @@ train_movies_sequences = df_train.prev_movies.apply(
).to_list()
y_train = df_train.target.values.astype(int)
df_test_user_item = df_train[["user_id", "movie_id", "rating"]]
df_test_user_item = df_test[["user_id", "movie_id", "rating"]]
test_movies_sequences = df_test.prev_movies.apply(
lambda x: [int(el) for el in x]
).to_list()
......@@ -89,7 +89,7 @@ X_test_text = np.array(
tab_mlp = TabMlp(
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
mlp_hidden_dims=[1024, 512, 256],
mlp_hidden_dims=[512, 256],
mlp_activation="relu",
)
......@@ -124,7 +124,7 @@ trainer.fit(
"X_text": X_test_text,
"target": y_test,
},
n_epochs=10,
batch_size=521,
n_epochs=2,
batch_size=32,
shuffle=False,
)
......@@ -4,16 +4,13 @@ https://github.com/lucidrains
"""
import math
import warnings
from enum import Enum
from typing import ContextManager
import torch
import einops
import torch.nn.functional as F
from torch import nn, einsum
from pytorch_widedeep.wdtypes import List, Tuple, Tensor, Optional
from pytorch_widedeep.wdtypes import Tuple, Tensor, Optional
from pytorch_widedeep.models._get_activation_fn import get_activation_fn
......@@ -67,88 +64,6 @@ class AddNorm(nn.Module):
return self.ln(X + self.dropout(sublayer(X)))
def _standard_attention(
q: Tensor, k: Tensor, v: Tensor, head_dim: int, dropout: float
) -> Tuple[Tensor, Tensor]:
"""'Standard' multihead attention implemenation from [Attention Is All You
Need](https://arxiv.org/abs/1706.03762)
"""
# b: batch size
# s: seq length
# l: target sequence length
# m: used to refer indistinctively to s or l
# h: number of attention heads,
# d and e: head_dim
# Normalised Query, Key dot product + softmax. Fraction in brackets in
# their Eq 1
scores = einsum("b h s d, b h l d -> b h s l", q, k) / math.sqrt(head_dim)
attn_weights = scores.softmax(dim=-1)
# Attention(Q, K, V ) (with dropout) in their Eq 1
attn_output = einsum(
"b h s l, b h l d -> b h s d", nn.Dropout(dropout)(attn_weights), v
)
return attn_weights, attn_output
def _linear_attention(q: Tensor, k: Tensor, v: Tensor) -> Tensor:
"""Liner attention implemenation from [Transformers are RNNs: Fast
Autoregressive Transformers with Linear Attention]
(https://arxiv.org/abs/2006.16236)
"""
# b: batch size
# s: seq length
# l: target sequence length
# m: used to refer indistinctively to s or l
# h: number of attention heads,
# d and e: head_dim
q, k = (
nn.functional.elu(q) + 1,
nn.functional.elu(k) + 1,
)
# Term within the summation in the numerator of their Eq 5
scores = einsum("b h s e, b h l d -> b h e d", k, v)
# The denominator in their Eq 5
z = 1 / (torch.einsum("b h m d, b h d -> b h m", q, k.sum(dim=2)) + 1e-6)
# Their Eq 5
attn_output = torch.einsum("b h m d, b h e d, b h m -> b h m d", q, scores, z)
return attn_output
class SDPBackend(Enum):
FLASH: int = 0
MEM_EFFICIENT: int = 1
def _flash_kernel_setup(enabled_flash_backends: List[SDPBackend]) -> ContextManager:
assert (
torch.cuda.is_available()
), "optimized kernels can only be used if CUDA is available."
warnings.warn(
"Note that FlashAttention is beta and subject to change",
RuntimeWarning,
)
# Setting backends as suggested at https://discuss.pytorch.org/t/flash-attention/174955
enable_flash = True if SDPBackend.FLASH in enabled_flash_backends else False
enable_mem_eff = (
True if SDPBackend.MEM_EFFICIENT in enabled_flash_backends else False
)
return torch.backends.cuda.sdp_kernel(
enable_flash=enable_flash,
enable_mem_efficient=enable_mem_eff,
enable_math=False,
)
class MultiHeadedAttention(nn.Module):
def __init__(
self,
......@@ -159,10 +74,6 @@ class MultiHeadedAttention(nn.Module):
query_dim: Optional[int] = None,
use_linear_attention: bool = False,
use_flash_attention: bool = False,
enabled_flash_backends: List[SDPBackend] = [
SDPBackend.FLASH,
SDPBackend.MEM_EFFICIENT,
], # in the next release we will offer the posibility of setting up the backend
):
super(MultiHeadedAttention, self).__init__()
......@@ -170,7 +81,6 @@ class MultiHeadedAttention(nn.Module):
self.use_linear_attention = use_linear_attention
self.use_flash_attention = use_flash_attention
self.enabled_flash_backends = enabled_flash_backends
self.head_dim = input_dim // n_heads
self.n_heads = n_heads
......@@ -200,26 +110,20 @@ class MultiHeadedAttention(nn.Module):
)
if self.use_flash_attention:
# in the future we will offer the possibility of setting up the
# backend. For the time being this context manager
# is 'redundant'
with _flash_kernel_setup(self.enabled_flash_backends):
attn_output = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None,
dropout_p=self.dropout if self.training else 0,
is_causal=False,
)
attn_output = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None,
dropout_p=self.dropout if self.training else 0,
is_causal=False,
)
self.attn_weights: Optional[Tensor] = None
elif self.use_linear_attention:
attn_output = _linear_attention(q, k, v)
attn_output = self._linear_attention(q, k, v)
self.attn_weights = None
else:
self.attn_weights, attn_output = _standard_attention(
q, k, v, self.head_dim, self.dropout
)
self.attn_weights, attn_output = self._standard_attention(q, k, v)
output = einops.rearrange(attn_output, "b h s d -> b s (h d)", h=self.n_heads)
......@@ -228,6 +132,59 @@ class MultiHeadedAttention(nn.Module):
return output
def _standard_attention(
self, q: Tensor, k: Tensor, v: Tensor
) -> Tuple[Tensor, Tensor]:
"""'Standard' multihead attention implemenation from [Attention Is All You
Need](https://arxiv.org/abs/1706.03762)
"""
# b: batch size
# s: seq length
# l: target sequence length
# m: used to refer indistinctively to s or l
# h: number of attention heads,
# d and e: head_dim
# Normalised Query, Key dot product + softmax. Fraction in brackets in
# their Eq 1
scores = einsum("b h s d, b h l d -> b h s l", q, k) / math.sqrt(self.head_dim)
attn_weights = scores.softmax(dim=-1)
# Attention(Q, K, V ) (with dropout) in their Eq 1
attn_output = einsum(
"b h s l, b h l d -> b h s d", nn.Dropout(self.dropout)(attn_weights), v
)
return attn_weights, attn_output
@staticmethod
def _linear_attention(q: Tensor, k: Tensor, v: Tensor) -> Tensor:
"""Liner attention implemenation from [Transformers are RNNs: Fast
Autoregressive Transformers with Linear Attention]
(https://arxiv.org/abs/2006.16236)
"""
# b: batch size
# s: seq length
# l: target sequence length
# m: used to refer indistinctively to s or l
# h: number of attention heads,
# d and e: head_dim
q, k = (
nn.functional.elu(q) + 1,
nn.functional.elu(k) + 1,
)
# Term within the summation in the numerator of their Eq 5
scores = einsum("b h s e, b h l d -> b h e d", k, v)
# The denominator in their Eq 5
z = 1 / (torch.einsum("b h m d, b h d -> b h m", q, k.sum(dim=2)) + 1e-6)
# Their Eq 5
attn_output = torch.einsum("b h m d, b h e d, b h m -> b h m d", q, scores, z)
return attn_output
class LinearAttentionLinformer(nn.Module):
"""Linear Attention implementation from [Linformer: Self-Attention with
......
......@@ -3,12 +3,9 @@ import timeit
import torch
import pytest
import torch.backends.cuda as tcud
from pytorch_widedeep.models.tabular.transformers._attention_layers import (
SDPBackend,
MultiHeadedAttention,
_flash_kernel_setup,
)
torch.backends.cudnn.deterministic = True
......@@ -59,42 +56,6 @@ for module in ["q_proj", "kv_proj", "out_proj"]:
X = torch.randn(128, 100, input_dim).to(device)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA to run")
@pytest.mark.parametrize(
"backends, active",
[
(
[SDPBackend.FLASH],
{
tcud.flash_sdp_enabled: True,
tcud.mem_efficient_sdp_enabled: False,
tcud.math_sdp_enabled: False,
},
),
(
[SDPBackend.MEM_EFFICIENT],
{
tcud.flash_sdp_enabled: False,
tcud.mem_efficient_sdp_enabled: True,
tcud.math_sdp_enabled: False,
},
),
(
[SDPBackend.FLASH, SDPBackend.MEM_EFFICIENT],
{
tcud.flash_sdp_enabled: True,
tcud.mem_efficient_sdp_enabled: True,
tcud.math_sdp_enabled: False,
},
),
],
)
def test_cdp_context_managment(backends, active):
ctx = _flash_kernel_setup(backends)
with ctx:
assert all([f() == v for f, v in active.items()])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA to run")
def test_flash_standard_shapes():
# Check that shapes of output are the same
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册