\n",
" \n",
" 0 | \n",
- " 0.225108 | \n",
- " [0.9270001649856567, 0.8861073851585388] | \n",
- " 0.906365 | \n",
- " [0.9270001649856567, 0.8861073851585388] | \n",
- " [0.9074797034263611, 0.9052220582962036] | \n",
- " 0.103954 | \n",
- " [0.962550163269043, 0.875] | \n",
- " 0.961784 | \n",
- " [0.962550163269043, 0.875] | \n",
- " [0.9803645014762878, 0.28863346576690674] | \n",
+ " 0.232111 | \n",
+ " [0.9225957989692688, 0.8899886012077332] | \n",
+ " 0.906075 | \n",
+ " [0.9225957989692688, 0.8899886012077332] | \n",
+ " [0.9064720273017883, 0.9056749939918518] | \n",
+ " 0.059772 | \n",
+ " [0.9809635877609253, 0.8671875] | \n",
+ " 0.979966 | \n",
+ " [0.9809635877609253, 0.8671875] | \n",
+ " [0.989802360534668, 0.4341084957122803] | \n",
"
\n",
" \n",
" 1 | \n",
- " 0.152386 | \n",
- " [0.9471125602722168, 0.9297876358032227] | \n",
- " 0.938380 | \n",
- " [0.9471125602722168, 0.9297876358032227] | \n",
- " [0.9384452104568481, 0.9383144974708557] | \n",
- " 0.091541 | \n",
- " [0.9680188298225403, 0.890625] | \n",
- " 0.967341 | \n",
- " [0.9680188298225403, 0.890625] | \n",
- " [0.9832653999328613, 0.3257790505886078] | \n",
+ " 0.160898 | \n",
+ " [0.9417101144790649, 0.9261900186538696] | \n",
+ " 0.933944 | \n",
+ " [0.9417101144790649, 0.9261900186538696] | \n",
+ " [0.9344058036804199, 0.9334757328033447] | \n",
+ " 0.122636 | \n",
+ " [0.954935610294342, 0.921875] | \n",
+ " 0.954648 | \n",
+ " [0.954935610294342, 0.921875] | \n",
+ " [0.9766026139259338, 0.2647385895252228] | \n",
"
\n",
" \n",
" 2 | \n",
- " 0.134425 | \n",
- " [0.9490371346473694, 0.9407152533531189] | \n",
- " 0.944841 | \n",
- " [0.9490371346473694, 0.9407152533531189] | \n",
- " [0.9446273446083069, 0.9450528621673584] | \n",
- " 0.084717 | \n",
- " [0.967949628829956, 0.8828125] | \n",
- " 0.967204 | \n",
- " [0.967949628829956, 0.8828125] | \n",
- " [0.9831950664520264, 0.3229461908340454] | \n",
+ " 0.138879 | \n",
+ " [0.9479646682739258, 0.9413018226623535] | \n",
+ " 0.944648 | \n",
+ " [0.9479646682739258, 0.9413018226623535] | \n",
+ " [0.945061206817627, 0.9442285299301147] | \n",
+ " 0.072151 | \n",
+ " [0.9743181467056274, 0.8984375] | \n",
+ " 0.973653 | \n",
+ " [0.9743181467056274, 0.8984375] | \n",
+ " [0.9865424036979675, 0.3766234219074249] | \n",
"
\n",
" \n",
"\n",
@@ -973,29 +975,29 @@
],
"text/plain": [
" train_loss train_Accuracy train_Precision \\\n",
- "0 0.225108 [0.9270001649856567, 0.8861073851585388] 0.906365 \n",
- "1 0.152386 [0.9471125602722168, 0.9297876358032227] 0.938380 \n",
- "2 0.134425 [0.9490371346473694, 0.9407152533531189] 0.944841 \n",
+ "0 0.232111 [0.9225957989692688, 0.8899886012077332] 0.906075 \n",
+ "1 0.160898 [0.9417101144790649, 0.9261900186538696] 0.933944 \n",
+ "2 0.138879 [0.9479646682739258, 0.9413018226623535] 0.944648 \n",
"\n",
" train_Recall \\\n",
- "0 [0.9270001649856567, 0.8861073851585388] \n",
- "1 [0.9471125602722168, 0.9297876358032227] \n",
- "2 [0.9490371346473694, 0.9407152533531189] \n",
+ "0 [0.9225957989692688, 0.8899886012077332] \n",
+ "1 [0.9417101144790649, 0.9261900186538696] \n",
+ "2 [0.9479646682739258, 0.9413018226623535] \n",
"\n",
" train_F1 val_loss \\\n",
- "0 [0.9074797034263611, 0.9052220582962036] 0.103954 \n",
- "1 [0.9384452104568481, 0.9383144974708557] 0.091541 \n",
- "2 [0.9446273446083069, 0.9450528621673584] 0.084717 \n",
+ "0 [0.9064720273017883, 0.9056749939918518] 0.059772 \n",
+ "1 [0.9344058036804199, 0.9334757328033447] 0.122636 \n",
+ "2 [0.945061206817627, 0.9442285299301147] 0.072151 \n",
"\n",
- " val_Accuracy val_Precision \\\n",
- "0 [0.962550163269043, 0.875] 0.961784 \n",
- "1 [0.9680188298225403, 0.890625] 0.967341 \n",
- "2 [0.967949628829956, 0.8828125] 0.967204 \n",
+ " val_Accuracy val_Precision \\\n",
+ "0 [0.9809635877609253, 0.8671875] 0.979966 \n",
+ "1 [0.954935610294342, 0.921875] 0.954648 \n",
+ "2 [0.9743181467056274, 0.8984375] 0.973653 \n",
"\n",
- " val_Recall val_F1 \n",
- "0 [0.962550163269043, 0.875] [0.9803645014762878, 0.28863346576690674] \n",
- "1 [0.9680188298225403, 0.890625] [0.9832653999328613, 0.3257790505886078] \n",
- "2 [0.967949628829956, 0.8828125] [0.9831950664520264, 0.3229461908340454] "
+ " val_Recall val_F1 \n",
+ "0 [0.9809635877609253, 0.8671875] [0.989802360534668, 0.4341084957122803] \n",
+ "1 [0.954935610294342, 0.921875] [0.9766026139259338, 0.2647385895252228] \n",
+ "2 [0.9743181467056274, 0.8984375] [0.9865424036979675, 0.3766234219074249] "
]
},
"execution_count": 14,
@@ -1016,7 +1018,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "predict: 100%|██████████| 292/292 [00:00<00:00, 294.65it/s]\n"
+ "predict: 100%|██████████| 292/292 [00:01<00:00, 286.78it/s]\n"
]
},
{
@@ -1025,15 +1027,15 @@
"text": [
" precision recall f1-score support\n",
"\n",
- " 0 1.00 0.97 0.99 14446\n",
- " 1 0.23 0.93 0.36 130\n",
+ " 0 1.00 0.98 0.99 14446\n",
+ " 1 0.27 0.93 0.41 130\n",
"\n",
- " accuracy 0.97 14576\n",
- " macro avg 0.61 0.95 0.67 14576\n",
- "weighted avg 0.99 0.97 0.98 14576\n",
+ " accuracy 0.98 14576\n",
+ " macro avg 0.63 0.95 0.70 14576\n",
+ "weighted avg 0.99 0.98 0.98 14576\n",
"\n",
"Actual predicted values:\n",
- "(array([0, 1]), array([14039, 537]))\n"
+ "(array([0, 1]), array([14122, 454]))\n"
]
}
],
diff --git a/examples/10_The_Transformer_Family.ipynb b/examples/10_The_Transformer_Family.ipynb
index 245b47013acf9442aba28fba2a6974a064dc0045..c6f922b9ce14a21a405c397ff1405cbd0e4c0518 100644
--- a/examples/10_The_Transformer_Family.ipynb
+++ b/examples/10_The_Transformer_Family.ipynb
@@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "At the moment there are 3 transformer-based algorithms available. \n",
+ "At the moment there are 5 transformer-based algorithms available. \n",
"\n",
"Here are examples of how to use them\n",
"\n",
@@ -34,7 +34,7 @@
"\n",
"from pytorch_widedeep.preprocessing import TabPreprocessor\n",
"from pytorch_widedeep.training import Trainer\n",
- "from pytorch_widedeep.models import TabTransformer, SAINT, WideDeep\n",
+ "from pytorch_widedeep.models import TabTransformer, SAINT, FTTransformer, TabFastFormer, TabPerceiver, WideDeep\n",
"from pytorch_widedeep.metrics import Accuracy"
]
},
@@ -417,13 +417,6 @@
"X_tab = tab_preprocessor.fit_transform(df)"
]
},
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- },
{
"cell_type": "code",
"execution_count": 6,
@@ -435,7 +428,8 @@
"tab_transformer = TabTransformer(column_idx=tab_preprocessor.column_idx,\n",
" embed_input=tab_preprocessor.embeddings_input,\n",
" continuous_cols=tab_preprocessor.continuous_cols, \n",
- " cont_norm_layer=\"layernorm\"\n",
+ " cont_norm_layer=\"batchnorm\", \n",
+ " n_blocks=4, n_heads=4 \n",
" )"
]
},
@@ -443,64 +437,27 @@
"cell_type": "code",
"execution_count": 7,
"metadata": {
- "scrolled": false
+ "scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"TabTransformer(\n",
- " (cat_embed): Embedding(103, 32, padding_idx=0)\n",
- " (embedding_dropout): Dropout(p=0.1, inplace=False)\n",
- " (cont_norm): LayerNorm((5,), eps=1e-05, elementwise_affine=True)\n",
- " (transformer_blks): Sequential(\n",
- " (block0): TransformerEncoder(\n",
- " (self_attn): MultiHeadedAttention(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n",
- " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n",
- " )\n",
- " (ff): PositionwiseFF(\n",
- " (w_1): Linear(in_features=32, out_features=128, bias=True)\n",
- " (w_2): Linear(in_features=128, out_features=32, bias=True)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (activation): GELU()\n",
- " )\n",
- " (attn_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
- " (ff_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
- " )\n",
- " (block1): TransformerEncoder(\n",
- " (self_attn): MultiHeadedAttention(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n",
- " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n",
- " )\n",
- " (ff): PositionwiseFF(\n",
- " (w_1): Linear(in_features=32, out_features=128, bias=True)\n",
- " (w_2): Linear(in_features=128, out_features=32, bias=True)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (activation): GELU()\n",
- " )\n",
- " (attn_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
- " (ff_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
+ " (cat_and_cont_embed): CatAndContEmbeddings(\n",
+ " (cat_embed): CategoricalEmbeddings(\n",
+ " (embed): Embedding(103, 32, padding_idx=0)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
- " (block2): TransformerEncoder(\n",
- " (self_attn): MultiHeadedAttention(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n",
- " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n",
+ " (cont_norm): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " (transformer_blks): Sequential(\n",
+ " (transformer_block0): TransformerEncoder(\n",
+ " (attn): MultiHeadedAttention(\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
+ " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n",
+ " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n",
+ " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n",
" )\n",
" (ff): PositionwiseFF(\n",
" (w_1): Linear(in_features=32, out_features=128, bias=True)\n",
@@ -509,7 +466,7 @@
" (activation): GELU()\n",
" )\n",
" (attn_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
" (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (ff_addnorm): AddNorm(\n",
@@ -517,11 +474,12 @@
" (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
- " (block3): TransformerEncoder(\n",
- " (self_attn): MultiHeadedAttention(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n",
- " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n",
+ " (transformer_block1): TransformerEncoder(\n",
+ " (attn): MultiHeadedAttention(\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
+ " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n",
+ " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n",
+ " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n",
" )\n",
" (ff): PositionwiseFF(\n",
" (w_1): Linear(in_features=32, out_features=128, bias=True)\n",
@@ -530,7 +488,7 @@
" (activation): GELU()\n",
" )\n",
" (attn_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
" (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (ff_addnorm): AddNorm(\n",
@@ -538,11 +496,12 @@
" (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
- " (block4): TransformerEncoder(\n",
- " (self_attn): MultiHeadedAttention(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n",
- " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n",
+ " (transformer_block2): TransformerEncoder(\n",
+ " (attn): MultiHeadedAttention(\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
+ " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n",
+ " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n",
+ " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n",
" )\n",
" (ff): PositionwiseFF(\n",
" (w_1): Linear(in_features=32, out_features=128, bias=True)\n",
@@ -551,7 +510,7 @@
" (activation): GELU()\n",
" )\n",
" (attn_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
" (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (ff_addnorm): AddNorm(\n",
@@ -559,11 +518,12 @@
" (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
- " (block5): TransformerEncoder(\n",
- " (self_attn): MultiHeadedAttention(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n",
- " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n",
+ " (transformer_block3): TransformerEncoder(\n",
+ " (attn): MultiHeadedAttention(\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
+ " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n",
+ " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n",
+ " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n",
" )\n",
" (ff): PositionwiseFF(\n",
" (w_1): Linear(in_features=32, out_features=128, bias=True)\n",
@@ -572,7 +532,7 @@
" (activation): GELU()\n",
" )\n",
" (attn_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
" (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (ff_addnorm): AddNorm(\n",
@@ -634,8 +594,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "epoch 1: 100%|██████████| 153/153 [00:23<00:00, 6.54it/s, loss=0.375, metrics={'acc': 0.8247}]\n",
- "valid: 100%|██████████| 39/39 [00:01<00:00, 19.74it/s, loss=0.356, metrics={'acc': 0.8375}]\n"
+ "epoch 1: 100%|██████████| 153/153 [00:14<00:00, 10.56it/s, loss=0.356, metrics={'acc': 0.8321}]\n",
+ "valid: 100%|██████████| 39/39 [00:01<00:00, 36.52it/s, loss=0.336, metrics={'acc': 0.8465}]\n"
]
}
],
@@ -647,7 +607,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "We can also choose to use the `FT-Transformer` (just set `embed_continuous=True`), where continuous cols are also represented by \"Embeddings\", via a 1 layer MLP (with or without activation function). When using the `FT-Transformer` we can choose to use the `[CLS]` token as a pooling method or concatenate the output from the transformer blocks, as we did before. Let's use here the `[CLS]` token."
+ "We can also choose to use the `FT-Transformer`, where continuous cols are also represented by \"Embeddings\", via a 1 layer MLP (with or without activation function). When using the `FT-Transformer` we can choose to use the `[CLS]` token as a pooling method or concatenate the output from the transformer blocks, as we did before. Let's use here the `[CLS]` token. Also note that under the hood, the `FT-Transformer` uses Linear Attention. See [Linformer: Self-Attention with Linear Complexity](https://arxiv.org/pdf/2006.04768.pdf)"
]
},
{
@@ -670,170 +630,94 @@
"metadata": {},
"outputs": [],
"source": [
- "# here all categorical columns will be encoded as 32 dim embeddings, then passed through the transformer \n",
- "# blocks, concatenated with the continuous and finally through an MLP\n",
- "ft_transformer = TabTransformer(column_idx=tab_preprocessor.column_idx,\n",
- " embed_input=tab_preprocessor.embeddings_input,\n",
- " continuous_cols=tab_preprocessor.continuous_cols, \n",
- " embed_continuous=True,\n",
- " embed_continuous_activation=None, \n",
- " )"
+ "ft_transformer = FTTransformer(column_idx=tab_preprocessor.column_idx,\n",
+ " embed_input=tab_preprocessor.embeddings_input,\n",
+ " continuous_cols=tab_preprocessor.continuous_cols, \n",
+ " n_blocks=3, n_heads=6, input_dim=36\n",
+ " )"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
- "scrolled": false
+ "scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
- "TabTransformer(\n",
- " (cat_embed): Embedding(104, 32, padding_idx=0)\n",
- " (embedding_dropout): Dropout(p=0.1, inplace=False)\n",
- " (cont_norm): Identity()\n",
- " (cont_embed): ContinuousEmbeddings()\n",
- " (transformer_blks): Sequential(\n",
- " (block0): TransformerEncoder(\n",
- " (self_attn): MultiHeadedAttention(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n",
- " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n",
- " )\n",
- " (ff): PositionwiseFF(\n",
- " (w_1): Linear(in_features=32, out_features=128, bias=True)\n",
- " (w_2): Linear(in_features=128, out_features=32, bias=True)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (activation): GELU()\n",
- " )\n",
- " (attn_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
- " (ff_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
+ "FTTransformer(\n",
+ " (cat_and_cont_embed): CatAndContEmbeddings(\n",
+ " (cat_embed): CategoricalEmbeddings(\n",
+ " (embed): Embedding(104, 36, padding_idx=0)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
- " (block1): TransformerEncoder(\n",
- " (self_attn): MultiHeadedAttention(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n",
- " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n",
+ " (cont_norm): Identity()\n",
+ " (cont_embed): ContinuousEmbeddings()\n",
+ " )\n",
+ " (transformer_blks): Sequential(\n",
+ " (fttransformer_block0): FTTransformerEncoder(\n",
+ " (attn): LinearAttention(\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
+ " (qkv_proj): Linear(in_features=36, out_features=108, bias=False)\n",
+ " (out_proj): Linear(in_features=36, out_features=36, bias=False)\n",
" )\n",
" (ff): PositionwiseFF(\n",
- " (w_1): Linear(in_features=32, out_features=128, bias=True)\n",
- " (w_2): Linear(in_features=128, out_features=32, bias=True)\n",
+ " (w_1): Linear(in_features=36, out_features=94, bias=True)\n",
+ " (w_2): Linear(in_features=47, out_features=36, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
- " (activation): GELU()\n",
+ " (activation): REGLU()\n",
" )\n",
- " (attn_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
+ " (attn_normadd): NormAdd(\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
+ " (ln): LayerNorm((36,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
- " (ff_addnorm): AddNorm(\n",
+ " (ff_normadd): NormAdd(\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
+ " (ln): LayerNorm((36,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
- " (block2): TransformerEncoder(\n",
- " (self_attn): MultiHeadedAttention(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n",
- " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n",
+ " (fttransformer_block1): FTTransformerEncoder(\n",
+ " (attn): LinearAttention(\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
+ " (qkv_proj): Linear(in_features=36, out_features=108, bias=False)\n",
+ " (out_proj): Linear(in_features=36, out_features=36, bias=False)\n",
" )\n",
" (ff): PositionwiseFF(\n",
- " (w_1): Linear(in_features=32, out_features=128, bias=True)\n",
- " (w_2): Linear(in_features=128, out_features=32, bias=True)\n",
+ " (w_1): Linear(in_features=36, out_features=94, bias=True)\n",
+ " (w_2): Linear(in_features=47, out_features=36, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
- " (activation): GELU()\n",
+ " (activation): REGLU()\n",
" )\n",
- " (attn_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
+ " (attn_normadd): NormAdd(\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
+ " (ln): LayerNorm((36,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
- " (ff_addnorm): AddNorm(\n",
+ " (ff_normadd): NormAdd(\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
+ " (ln): LayerNorm((36,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
- " (block3): TransformerEncoder(\n",
- " (self_attn): MultiHeadedAttention(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n",
- " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n",
+ " (fttransformer_block2): FTTransformerEncoder(\n",
+ " (attn): LinearAttention(\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
+ " (qkv_proj): Linear(in_features=36, out_features=108, bias=False)\n",
+ " (out_proj): Linear(in_features=36, out_features=36, bias=False)\n",
" )\n",
" (ff): PositionwiseFF(\n",
- " (w_1): Linear(in_features=32, out_features=128, bias=True)\n",
- " (w_2): Linear(in_features=128, out_features=32, bias=True)\n",
+ " (w_1): Linear(in_features=36, out_features=94, bias=True)\n",
+ " (w_2): Linear(in_features=47, out_features=36, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
- " (activation): GELU()\n",
+ " (activation): REGLU()\n",
" )\n",
- " (attn_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
+ " (attn_normadd): NormAdd(\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
+ " (ln): LayerNorm((36,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
- " (ff_addnorm): AddNorm(\n",
+ " (ff_normadd): NormAdd(\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
- " )\n",
- " (block4): TransformerEncoder(\n",
- " (self_attn): MultiHeadedAttention(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n",
- " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n",
- " )\n",
- " (ff): PositionwiseFF(\n",
- " (w_1): Linear(in_features=32, out_features=128, bias=True)\n",
- " (w_2): Linear(in_features=128, out_features=32, bias=True)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (activation): GELU()\n",
- " )\n",
- " (attn_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
- " (ff_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
- " )\n",
- " (block5): TransformerEncoder(\n",
- " (self_attn): MultiHeadedAttention(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n",
- " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n",
- " )\n",
- " (ff): PositionwiseFF(\n",
- " (w_1): Linear(in_features=32, out_features=128, bias=True)\n",
- " (w_2): Linear(in_features=128, out_features=32, bias=True)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (activation): GELU()\n",
- " )\n",
- " (attn_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
- " (ff_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
- " )\n",
- " )\n",
- " (transformer_mlp): MLP(\n",
- " (mlp): Sequential(\n",
- " (dense_layer_0): Sequential(\n",
- " (0): Linear(in_features=32, out_features=128, bias=True)\n",
- " (1): ReLU(inplace=True)\n",
- " (2): Dropout(p=0.1, inplace=False)\n",
- " )\n",
- " (dense_layer_1): Sequential(\n",
- " (0): Linear(in_features=128, out_features=64, bias=True)\n",
- " (1): ReLU(inplace=True)\n",
- " (2): Dropout(p=0.1, inplace=False)\n",
+ " (ln): LayerNorm((36,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
" )\n",
@@ -876,8 +760,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "epoch 1: 100%|██████████| 153/153 [00:38<00:00, 4.02it/s, loss=0.412, metrics={'acc': 0.8002}]\n",
- "valid: 100%|██████████| 39/39 [00:02<00:00, 14.69it/s, loss=0.326, metrics={'acc': 0.8524}]\n"
+ "epoch 1: 100%|██████████| 153/153 [00:15<00:00, 9.62it/s, loss=0.382, metrics={'acc': 0.8167}]\n",
+ "valid: 100%|██████████| 39/39 [00:01<00:00, 28.84it/s, loss=0.317, metrics={'acc': 0.8566}]\n"
]
}
],
@@ -894,7 +778,7 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
@@ -902,56 +786,61 @@
" embed_input=tab_preprocessor.embeddings_input,\n",
" continuous_cols=tab_preprocessor.continuous_cols, \n",
" transformer_activation=\"geglu\",\n",
- " embed_continuous=True,\n",
- " embed_continuous_activation=None, \n",
+ " n_blocks=2, n_heads=4, \n",
" )"
]
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 18,
"metadata": {
- "scrolled": false
+ "scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"SAINT(\n",
- " (cat_embed): Embedding(104, 32, padding_idx=0)\n",
- " (embedding_dropout): Dropout(p=0.1, inplace=False)\n",
- " (cont_norm): LayerNorm((5,), eps=1e-05, elementwise_affine=True)\n",
- " (cont_embed): ContinuousEmbeddings()\n",
+ " (cat_and_cont_embed): CatAndContEmbeddings(\n",
+ " (cat_embed): CategoricalEmbeddings(\n",
+ " (embed): Embedding(104, 32, padding_idx=0)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (cont_norm): Identity()\n",
+ " (cont_embed): ContinuousEmbeddings()\n",
+ " )\n",
" (transformer_blks): Sequential(\n",
- " (block0): SaintEncoder(\n",
- " (self_attn): MultiHeadedAttention(\n",
+ " (saint_block0): SaintEncoder(\n",
+ " (col_attn): MultiHeadedAttention(\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n",
- " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n",
+ " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n",
+ " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n",
+ " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n",
" )\n",
- " (self_attn_ff): PositionwiseFF(\n",
+ " (col_attn_ff): PositionwiseFF(\n",
" (w_1): Linear(in_features=32, out_features=256, bias=True)\n",
" (w_2): Linear(in_features=128, out_features=32, bias=True)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
" (activation): GEGLU()\n",
" )\n",
- " (self_attn_addnorm): AddNorm(\n",
+ " (col_attn_addnorm): AddNorm(\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
- " (self_attn_ff_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " (col_attn_ff_addnorm): AddNorm(\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
" (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (row_attn): MultiHeadedAttention(\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=448, out_features=1344, bias=True)\n",
- " (out_proj): Linear(in_features=448, out_features=448, bias=True)\n",
+ " (q_proj): Linear(in_features=448, out_features=448, bias=False)\n",
+ " (kv_proj): Linear(in_features=448, out_features=896, bias=False)\n",
+ " (out_proj): Linear(in_features=448, out_features=448, bias=False)\n",
" )\n",
" (row_attn_ff): PositionwiseFF(\n",
" (w_1): Linear(in_features=448, out_features=3584, bias=True)\n",
" (w_2): Linear(in_features=1792, out_features=448, bias=True)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
" (activation): GEGLU()\n",
" )\n",
" (row_attn_addnorm): AddNorm(\n",
@@ -959,39 +848,41 @@
" (ln): LayerNorm((448,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (row_attn_ff_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
" (ln): LayerNorm((448,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
- " (block1): SaintEncoder(\n",
- " (self_attn): MultiHeadedAttention(\n",
+ " (saint_block1): SaintEncoder(\n",
+ " (col_attn): MultiHeadedAttention(\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n",
- " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n",
+ " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n",
+ " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n",
+ " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n",
" )\n",
- " (self_attn_ff): PositionwiseFF(\n",
+ " (col_attn_ff): PositionwiseFF(\n",
" (w_1): Linear(in_features=32, out_features=256, bias=True)\n",
" (w_2): Linear(in_features=128, out_features=32, bias=True)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
" (activation): GEGLU()\n",
" )\n",
- " (self_attn_addnorm): AddNorm(\n",
+ " (col_attn_addnorm): AddNorm(\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
- " (self_attn_ff_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " (col_attn_ff_addnorm): AddNorm(\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
" (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (row_attn): MultiHeadedAttention(\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=448, out_features=1344, bias=True)\n",
- " (out_proj): Linear(in_features=448, out_features=448, bias=True)\n",
+ " (q_proj): Linear(in_features=448, out_features=448, bias=False)\n",
+ " (kv_proj): Linear(in_features=448, out_features=896, bias=False)\n",
+ " (out_proj): Linear(in_features=448, out_features=448, bias=False)\n",
" )\n",
" (row_attn_ff): PositionwiseFF(\n",
" (w_1): Linear(in_features=448, out_features=3584, bias=True)\n",
" (w_2): Linear(in_features=1792, out_features=448, bias=True)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
" (activation): GEGLU()\n",
" )\n",
" (row_attn_addnorm): AddNorm(\n",
@@ -999,169 +890,150 @@
" (ln): LayerNorm((448,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (row_attn_ff_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
" (ln): LayerNorm((448,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
- " (block2): SaintEncoder(\n",
- " (self_attn): MultiHeadedAttention(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n",
- " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n",
- " )\n",
- " (self_attn_ff): PositionwiseFF(\n",
- " (w_1): Linear(in_features=32, out_features=256, bias=True)\n",
- " (w_2): Linear(in_features=128, out_features=32, bias=True)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (activation): GEGLU()\n",
- " )\n",
- " (self_attn_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
- " (self_attn_ff_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
- " (row_attn): MultiHeadedAttention(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=448, out_features=1344, bias=True)\n",
- " (out_proj): Linear(in_features=448, out_features=448, bias=True)\n",
- " )\n",
- " (row_attn_ff): PositionwiseFF(\n",
- " (w_1): Linear(in_features=448, out_features=3584, bias=True)\n",
- " (w_2): Linear(in_features=1792, out_features=448, bias=True)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (activation): GEGLU()\n",
- " )\n",
- " (row_attn_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((448,), eps=1e-05, elementwise_affine=True)\n",
+ " )\n",
+ " (transformer_mlp): MLP(\n",
+ " (mlp): Sequential(\n",
+ " (dense_layer_0): Sequential(\n",
+ " (0): Linear(in_features=32, out_features=128, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
" )\n",
- " (row_attn_ff_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((448,), eps=1e-05, elementwise_affine=True)\n",
+ " (dense_layer_1): Sequential(\n",
+ " (0): Linear(in_features=128, out_features=64, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
- " (block3): SaintEncoder(\n",
- " (self_attn): MultiHeadedAttention(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n",
- " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n",
- " )\n",
- " (self_attn_ff): PositionwiseFF(\n",
- " (w_1): Linear(in_features=32, out_features=256, bias=True)\n",
- " (w_2): Linear(in_features=128, out_features=32, bias=True)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (activation): GEGLU()\n",
- " )\n",
- " (self_attn_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
- " (self_attn_ff_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
- " (row_attn): MultiHeadedAttention(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=448, out_features=1344, bias=True)\n",
- " (out_proj): Linear(in_features=448, out_features=448, bias=True)\n",
- " )\n",
- " (row_attn_ff): PositionwiseFF(\n",
- " (w_1): Linear(in_features=448, out_features=3584, bias=True)\n",
- " (w_2): Linear(in_features=1792, out_features=448, bias=True)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (activation): GEGLU()\n",
- " )\n",
- " (row_attn_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((448,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
- " (row_attn_ff_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((448,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "saint"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "epoch 1: 100%|██████████| 306/306 [00:47<00:00, 6.42it/s, loss=0.377, metrics={'acc': 0.8224}]\n",
+ "valid: 100%|██████████| 77/77 [00:02<00:00, 32.20it/s, loss=0.338, metrics={'acc': 0.8529}]\n"
+ ]
+ }
+ ],
+ "source": [
+ "model = WideDeep(deeptabular=saint)\n",
+ "trainer = Trainer(model, objective='binary', metrics=[Accuracy])\n",
+ "trainer.fit(X_tab=X_tab, target=target, n_epochs=1, batch_size=128, val_split=0.2)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The previous models have all been published. The following two are adaptations of existing Transformer models for tabular data and by the time I am writing this they are only available in this library. If I have the time I will write a post about their implementation. Nonetheless, all the details can be found in the [docs](https://pytorch-widedeep.readthedocs.io/en/latest/index.html).\n",
+ "\n",
+ "The first one is an adaptation of [Fastformer: Additive Attention Can Be All You Need](https://arxiv.org/pdf/2108.09084.pdf). I have mixed feelings towards that paper, that I will not be covering here, but you can go and watch [Yannic's video](https://www.youtube.com/watch?v=qgUegkefocg&t=1s) since most of my opinions are also explained there. Nonetheless, the reason to bring this model to the library is because in essence, the `FastFormer` is an \"elaborated MLP\" with an \"interesting\" attention aggregated attention mechanism. Since MLPs work really well for tabular data compared to other, more complex models, why not add it to the library. \n",
+ "\n",
+ "To use it, just follow the same routine as with any other transformer-based model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tabfastformer = TabFastFormer(column_idx=tab_preprocessor.column_idx,\n",
+ " embed_input=tab_preprocessor.embeddings_input,\n",
+ " continuous_cols=tab_preprocessor.continuous_cols, \n",
+ " n_blocks=2, n_heads=4, \n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "TabFastFormer(\n",
+ " (cat_and_cont_embed): CatAndContEmbeddings(\n",
+ " (cat_embed): CategoricalEmbeddings(\n",
+ " (embed): Embedding(104, 32, padding_idx=0)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
- " (block4): SaintEncoder(\n",
- " (self_attn): MultiHeadedAttention(\n",
+ " (cont_norm): Identity()\n",
+ " (cont_embed): ContinuousEmbeddings()\n",
+ " )\n",
+ " (transformer_blks): Sequential(\n",
+ " (fastformer_block0): FastFormerEncoder(\n",
+ " (attn): AdditiveAttention(\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n",
- " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n",
+ " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n",
+ " (v_proj): Linear(in_features=32, out_features=32, bias=False)\n",
+ " (k_proj): Linear(in_features=32, out_features=32, bias=False)\n",
+ " (W_q): Linear(in_features=8, out_features=1, bias=False)\n",
+ " (W_k): Linear(in_features=8, out_features=1, bias=False)\n",
+ " (r_out): Linear(in_features=8, out_features=8, bias=True)\n",
" )\n",
- " (self_attn_ff): PositionwiseFF(\n",
- " (w_1): Linear(in_features=32, out_features=256, bias=True)\n",
+ " (ff): PositionwiseFF(\n",
+ " (w_1): Linear(in_features=32, out_features=128, bias=True)\n",
" (w_2): Linear(in_features=128, out_features=32, bias=True)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (activation): GEGLU()\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
+ " (activation): ReLU(inplace=True)\n",
" )\n",
- " (self_attn_addnorm): AddNorm(\n",
+ " (attn_addnorm): AddNorm(\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
- " (self_attn_ff_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " (ff_addnorm): AddNorm(\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
" (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
- " (row_attn): MultiHeadedAttention(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=448, out_features=1344, bias=True)\n",
- " (out_proj): Linear(in_features=448, out_features=448, bias=True)\n",
- " )\n",
- " (row_attn_ff): PositionwiseFF(\n",
- " (w_1): Linear(in_features=448, out_features=3584, bias=True)\n",
- " (w_2): Linear(in_features=1792, out_features=448, bias=True)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (activation): GEGLU()\n",
- " )\n",
- " (row_attn_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((448,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
- " (row_attn_ff_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((448,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
" )\n",
- " (block5): SaintEncoder(\n",
- " (self_attn): MultiHeadedAttention(\n",
+ " (fastformer_block1): FastFormerEncoder(\n",
+ " (attn): AdditiveAttention(\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n",
- " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n",
+ " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n",
+ " (v_proj): Linear(in_features=32, out_features=32, bias=False)\n",
+ " (k_proj): Linear(in_features=32, out_features=32, bias=False)\n",
+ " (W_q): Linear(in_features=8, out_features=1, bias=False)\n",
+ " (W_k): Linear(in_features=8, out_features=1, bias=False)\n",
+ " (r_out): Linear(in_features=8, out_features=8, bias=True)\n",
" )\n",
- " (self_attn_ff): PositionwiseFF(\n",
- " (w_1): Linear(in_features=32, out_features=256, bias=True)\n",
+ " (ff): PositionwiseFF(\n",
+ " (w_1): Linear(in_features=32, out_features=128, bias=True)\n",
" (w_2): Linear(in_features=128, out_features=32, bias=True)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (activation): GEGLU()\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
+ " (activation): ReLU(inplace=True)\n",
" )\n",
- " (self_attn_addnorm): AddNorm(\n",
+ " (attn_addnorm): AddNorm(\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
- " (self_attn_ff_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " (ff_addnorm): AddNorm(\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
" (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
- " (row_attn): MultiHeadedAttention(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=448, out_features=1344, bias=True)\n",
- " (out_proj): Linear(in_features=448, out_features=448, bias=True)\n",
- " )\n",
- " (row_attn_ff): PositionwiseFF(\n",
- " (w_1): Linear(in_features=448, out_features=3584, bias=True)\n",
- " (w_2): Linear(in_features=1792, out_features=448, bias=True)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (activation): GEGLU()\n",
- " )\n",
- " (row_attn_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((448,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
- " (row_attn_ff_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((448,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
" )\n",
" )\n",
" (transformer_mlp): MLP(\n",
@@ -1181,47 +1053,107 @@
")"
]
},
- "execution_count": 19,
+ "execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "saint"
+ "tabfastformer"
]
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "epoch 1: 100%|██████████| 306/306 [02:34<00:00, 1.98it/s, loss=0.555, metrics={'acc': 0.7607}]\n",
- "valid: 100%|██████████| 77/77 [00:08<00:00, 8.96it/s, loss=0.551, metrics={'acc': 0.7607}]\n"
+ "epoch 1: 100%|██████████| 153/153 [00:10<00:00, 14.58it/s, loss=0.46, metrics={'acc': 0.7867}] \n",
+ "valid: 100%|██████████| 39/39 [00:00<00:00, 48.19it/s, loss=0.342, metrics={'acc': 0.8443}]\n"
]
}
],
"source": [
- "model = WideDeep(deeptabular=saint)\n",
+ "model = WideDeep(deeptabular=tabfastformer)\n",
"trainer = Trainer(model, objective='binary', metrics=[Accuracy])\n",
- "trainer.fit(X_tab=X_tab, target=target, n_epochs=1, batch_size=128, val_split=0.2)"
+ "trainer.fit(X_tab=X_tab, target=target, n_epochs=1, batch_size=256, val_split=0.2)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "And finally, the last of the transformer-based models that are currently available in the library is DeepMind's [Perceiver](https://arxiv.org/pdf/2103.03206.pdf). The reason to add this model to the library is the following. The Perceiver is meant to be an architecture agnostic of the nature of the input data, i.e. it is meant to work with audio, images, text...So why not tabular, right? \n",
+ "\n",
+ "To use it...you guessed right! "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tab_preprocessor = TabPreprocessor(embed_cols=cat_cols, \n",
+ " continuous_cols=cont_cols, \n",
+ " for_transformer=True,\n",
+ " )\n",
+ "X_tab = tab_preprocessor.fit_transform(df)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tabperceiver = TabPerceiver(\n",
+ " column_idx=tab_preprocessor.column_idx,\n",
+ " embed_input=tab_preprocessor.embeddings_input,\n",
+ " continuous_cols=tab_preprocessor.continuous_cols, \n",
+ " n_perceiver_blocks=1, \n",
+ " n_latent_blocks=3, \n",
+ " n_latent_heads=2, \n",
+ " n_latents=6,\n",
+ " latent_dim=32,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "epoch 1: 100%|██████████| 153/153 [00:16<00:00, 9.45it/s, loss=0.4, metrics={'acc': 0.81}] \n",
+ "valid: 100%|██████████| 39/39 [00:01<00:00, 37.95it/s, loss=0.323, metrics={'acc': 0.8542}]\n"
+ ]
+ }
+ ],
+ "source": [
+ "model = WideDeep(deeptabular=tabperceiver)\n",
+ "trainer = Trainer(model, objective='binary', metrics=[Accuracy])\n",
+ "trainer.fit(X_tab=X_tab, target=target, n_epochs=1, batch_size=256, val_split=0.2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "One final comment is that all 3 transformer-based model have the option of using the so called \"Shared Embeddings\". The idea behind the shared embeddings is explained in the original TabTransformer paper and also here in this [post](https://jrzaurin.github.io/infinitoml/2021/02/18/pytorch-widedeep_iii.html).\n",
+ "One final comment is that all transformer-based models have the option of using the so called \"Shared Embeddings\". The idea behind the shared embeddings is explained in the original TabTransformer paper and also here in this [post](https://jrzaurin.github.io/infinitoml/2021/02/18/pytorch-widedeep_iii.html).\n",
"\n",
"For transformer-based models this implies a bit of a different data preparation process since each column will be encoded individually (programmatically is way easier to implement) and the use of shared embeddings needs to be specified at preprocessing stage"
]
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
@@ -1236,7 +1168,7 @@
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
@@ -1246,12 +1178,14 @@
" embed_continuous=True,\n",
" embed_continuous_activation=None, \n",
" shared_embed=True, \n",
+ " cont_norm_layer=\"batchnorm\", \n",
+ " n_blocks=4, n_heads=4 \n",
" )"
]
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 28,
"metadata": {
"scrolled": false
},
@@ -1260,73 +1194,57 @@
"data": {
"text/plain": [
"TabTransformer(\n",
- " (cat_embed): ModuleDict(\n",
- " (emb_layer_cls_token): SharedEmbeddings(\n",
- " (embed): Embedding(1, 32, padding_idx=0)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " )\n",
- " (emb_layer_education): SharedEmbeddings(\n",
- " (embed): Embedding(17, 32, padding_idx=0)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " )\n",
- " (emb_layer_gender): SharedEmbeddings(\n",
- " (embed): Embedding(3, 32, padding_idx=0)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " )\n",
- " (emb_layer_marital_status): SharedEmbeddings(\n",
- " (embed): Embedding(8, 32, padding_idx=0)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " )\n",
- " (emb_layer_native_country): SharedEmbeddings(\n",
- " (embed): Embedding(43, 32, padding_idx=0)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " )\n",
- " (emb_layer_occupation): SharedEmbeddings(\n",
- " (embed): Embedding(16, 32, padding_idx=0)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " )\n",
- " (emb_layer_race): SharedEmbeddings(\n",
- " (embed): Embedding(6, 32, padding_idx=0)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " )\n",
- " (emb_layer_relationship): SharedEmbeddings(\n",
- " (embed): Embedding(7, 32, padding_idx=0)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " )\n",
- " (emb_layer_workclass): SharedEmbeddings(\n",
- " (embed): Embedding(10, 32, padding_idx=0)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " (cat_and_cont_embed): CatAndContEmbeddings(\n",
+ " (cat_embed): CategoricalEmbeddings(\n",
+ " (embed): ModuleDict(\n",
+ " (emb_layer_cls_token): SharedEmbeddings(\n",
+ " (embed): Embedding(1, 32, padding_idx=0)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (emb_layer_education): SharedEmbeddings(\n",
+ " (embed): Embedding(17, 32, padding_idx=0)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (emb_layer_gender): SharedEmbeddings(\n",
+ " (embed): Embedding(3, 32, padding_idx=0)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (emb_layer_marital_status): SharedEmbeddings(\n",
+ " (embed): Embedding(8, 32, padding_idx=0)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (emb_layer_native_country): SharedEmbeddings(\n",
+ " (embed): Embedding(43, 32, padding_idx=0)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (emb_layer_occupation): SharedEmbeddings(\n",
+ " (embed): Embedding(16, 32, padding_idx=0)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (emb_layer_race): SharedEmbeddings(\n",
+ " (embed): Embedding(6, 32, padding_idx=0)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (emb_layer_relationship): SharedEmbeddings(\n",
+ " (embed): Embedding(7, 32, padding_idx=0)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (emb_layer_workclass): SharedEmbeddings(\n",
+ " (embed): Embedding(10, 32, padding_idx=0)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
" )\n",
+ " (cont_norm): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
+ " (cont_embed): ContinuousEmbeddings()\n",
" )\n",
- " (cont_norm): Identity()\n",
- " (cont_embed): ContinuousEmbeddings()\n",
" (transformer_blks): Sequential(\n",
- " (block0): TransformerEncoder(\n",
- " (self_attn): MultiHeadedAttention(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n",
- " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n",
- " )\n",
- " (ff): PositionwiseFF(\n",
- " (w_1): Linear(in_features=32, out_features=128, bias=True)\n",
- " (w_2): Linear(in_features=128, out_features=32, bias=True)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (activation): GELU()\n",
- " )\n",
- " (attn_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
- " (ff_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
- " )\n",
- " (block1): TransformerEncoder(\n",
- " (self_attn): MultiHeadedAttention(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n",
- " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n",
+ " (transformer_block0): TransformerEncoder(\n",
+ " (attn): MultiHeadedAttention(\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
+ " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n",
+ " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n",
+ " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n",
" )\n",
" (ff): PositionwiseFF(\n",
" (w_1): Linear(in_features=32, out_features=128, bias=True)\n",
@@ -1335,7 +1253,7 @@
" (activation): GELU()\n",
" )\n",
" (attn_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
" (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (ff_addnorm): AddNorm(\n",
@@ -1343,11 +1261,12 @@
" (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
- " (block2): TransformerEncoder(\n",
- " (self_attn): MultiHeadedAttention(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n",
- " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n",
+ " (transformer_block1): TransformerEncoder(\n",
+ " (attn): MultiHeadedAttention(\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
+ " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n",
+ " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n",
+ " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n",
" )\n",
" (ff): PositionwiseFF(\n",
" (w_1): Linear(in_features=32, out_features=128, bias=True)\n",
@@ -1356,7 +1275,7 @@
" (activation): GELU()\n",
" )\n",
" (attn_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
" (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (ff_addnorm): AddNorm(\n",
@@ -1364,11 +1283,12 @@
" (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
- " (block3): TransformerEncoder(\n",
- " (self_attn): MultiHeadedAttention(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n",
- " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n",
+ " (transformer_block2): TransformerEncoder(\n",
+ " (attn): MultiHeadedAttention(\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
+ " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n",
+ " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n",
+ " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n",
" )\n",
" (ff): PositionwiseFF(\n",
" (w_1): Linear(in_features=32, out_features=128, bias=True)\n",
@@ -1377,7 +1297,7 @@
" (activation): GELU()\n",
" )\n",
" (attn_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
" (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (ff_addnorm): AddNorm(\n",
@@ -1385,11 +1305,12 @@
" (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
- " (block4): TransformerEncoder(\n",
- " (self_attn): MultiHeadedAttention(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n",
- " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n",
+ " (transformer_block3): TransformerEncoder(\n",
+ " (attn): MultiHeadedAttention(\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
+ " (q_proj): Linear(in_features=32, out_features=32, bias=False)\n",
+ " (kv_proj): Linear(in_features=32, out_features=64, bias=False)\n",
+ " (out_proj): Linear(in_features=32, out_features=32, bias=False)\n",
" )\n",
" (ff): PositionwiseFF(\n",
" (w_1): Linear(in_features=32, out_features=128, bias=True)\n",
@@ -1398,28 +1319,7 @@
" (activation): GELU()\n",
" )\n",
" (attn_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
- " (ff_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
- " )\n",
- " )\n",
- " (block5): TransformerEncoder(\n",
- " (self_attn): MultiHeadedAttention(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (inp_proj): Linear(in_features=32, out_features=96, bias=True)\n",
- " (out_proj): Linear(in_features=32, out_features=32, bias=True)\n",
- " )\n",
- " (ff): PositionwiseFF(\n",
- " (w_1): Linear(in_features=32, out_features=128, bias=True)\n",
- " (w_2): Linear(in_features=128, out_features=32, bias=True)\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
- " (activation): GELU()\n",
- " )\n",
- " (attn_addnorm): AddNorm(\n",
- " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " (dropout): Dropout(p=0.2, inplace=False)\n",
" (ln): LayerNorm((32,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (ff_addnorm): AddNorm(\n",
@@ -1445,7 +1345,7 @@
")"
]
},
- "execution_count": 23,
+ "execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
@@ -1456,15 +1356,15 @@
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "epoch 1: 100%|██████████| 153/153 [00:35<00:00, 4.32it/s, loss=0.411, metrics={'acc': 0.8036}]\n",
- "valid: 100%|██████████| 39/39 [00:02<00:00, 15.03it/s, loss=0.319, metrics={'acc': 0.8583}]\n"
+ "epoch 1: 100%|██████████| 153/153 [00:20<00:00, 7.62it/s, loss=0.4, metrics={'acc': 0.8061}] \n",
+ "valid: 100%|██████████| 39/39 [00:01<00:00, 30.53it/s, loss=0.324, metrics={'acc': 0.8551}]\n"
]
}
],
diff --git a/examples/11_Extracting_Embeddings.ipynb b/examples/11_Extracting_Embeddings.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..5504f84dfb8f6c254e2d847cbf2afd7d02c43d67
--- /dev/null
+++ b/examples/11_Extracting_Embeddings.ipynb
@@ -0,0 +1,513 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/javier/.pyenv/versions/3.7.7/lib/python3.7/importlib/_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject\n",
+ " return f(*args, **kwds)\n"
+ ]
+ }
+ ],
+ "source": [
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import torch\n",
+ "\n",
+ "from pytorch_widedeep.preprocessing import TabPreprocessor\n",
+ "from pytorch_widedeep.training import Trainer\n",
+ "from pytorch_widedeep.models import FTTransformer, WideDeep\n",
+ "from pytorch_widedeep.metrics import Accuracy\n",
+ "from pytorch_widedeep import Tab2Vec"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "