README.md 14.4 KB
Newer Older
J
Javier 已提交
1

J
jrzaurin 已提交
2
<p align="center">
3
  <img width="300" src="docs/figures/widedeep_logo.png">
J
jrzaurin 已提交
4 5
</p>

6
[![PyPI version](https://badge.fury.io/py/pytorch-widedeep.svg)](https://pypi.org/project/pytorch-widedeep/)
J
Javier 已提交
7
[![Python 3.7 3.8 3.9 3.10](https://img.shields.io/badge/python-3.7%20%7C%203.8%20%7C%203.9%20%7C%203.10-blue.svg)](https://pypi.org/project/pytorch-widedeep/)
8
[![Build Status](https://github.com/jrzaurin/pytorch-widedeep/actions/workflows/build.yml/badge.svg)](https://github.com/jrzaurin/pytorch-widedeep/actions)
J
jrzaurin 已提交
9
[![Documentation Status](https://readthedocs.org/projects/pytorch-widedeep/badge/?version=latest)](https://pytorch-widedeep.readthedocs.io/en/latest/?badge=latest)
10 11
[![codecov](https://codecov.io/gh/jrzaurin/pytorch-widedeep/branch/master/graph/badge.svg)](https://codecov.io/gh/jrzaurin/pytorch-widedeep)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
12
[![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://github.com/jrzaurin/pytorch-widedeep/graphs/commit-activity)
13
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/jrzaurin/pytorch-widedeep/issues)
14
[![Slack](https://img.shields.io/badge/slack-chat-green.svg?logo=slack)](https://join.slack.com/t/pytorch-widedeep/shared_invite/zt-soss7stf-iXpVuLeKZz8lGTnxxtHtTw)
15
[![DOI](https://joss.theoj.org/papers/10.21105/joss.05027/status.svg)](https://doi.org/10.21105/joss.05027)
16

J
jrzaurin 已提交
17 18
# pytorch-widedeep

J
jrzaurin 已提交
19 20
A flexible package for multimodal-deep-learning to combine tabular data with
text and images using Wide and Deep models in Pytorch
J
jrzaurin 已提交
21

22
**Documentation:** [https://pytorch-widedeep.readthedocs.io](https://pytorch-widedeep.readthedocs.io/en/latest/index.html)
J
jrzaurin 已提交
23

24 25
**Companion posts and tutorials:** [infinitoml](https://jrzaurin.github.io/infinitoml/)

J
jrzaurin 已提交
26
**Experiments and comparison with `LightGBM`**: [TabularDL vs LightGBM](https://github.com/jrzaurin/tabulardl-benchmark)
J
jrzaurin 已提交
27

28 29
**Slack**: if you want to contribute or just want to chat with us, join [slack](https://join.slack.com/t/pytorch-widedeep/shared_invite/zt-soss7stf-iXpVuLeKZz8lGTnxxtHtTw)

30 31
The content of this document is organized as follows:

32 33 34 35 36 37 38 39 40
- [pytorch-widedeep](#pytorch-widedeep)
    - [Introduction](#introduction)
    - [The ``deeptabular`` component](#the-deeptabular-component)
    - [Installation](#installation)
      - [Developer Install](#developer-install)
    - [Quick start](#quick-start)
    - [Testing](#testing)
    - [How to Contribute](#how-to-contribute)
    - [Acknowledgments](#acknowledgments)
P
Pavol Mulinka 已提交
41
    - [License](#license)
42 43 44
    - [Cite](#cite)
      - [BibTex](#bibtex)
      - [APA](#apa)
45

J
jrzaurin 已提交
46 47
### Introduction

J
jrzaurin 已提交
48 49
``pytorch-widedeep`` is based on Google's [Wide and Deep Algorithm](https://arxiv.org/abs/1606.07792),
adjusted for multi-modal datasets
J
jrzaurin 已提交
50

J
jrzaurin 已提交
51 52
In general terms, `pytorch-widedeep` is a package to use deep learning with
tabular data. In particular, is intended to facilitate the combination of text
J
jrzaurin 已提交
53
and images with corresponding tabular data using wide and deep models. With
54 55 56
that in mind there are a number of architectures that can be implemented with
just a few lines of code. The main components of those architectures are shown
in the Figure below:
J
jrzaurin 已提交
57 58 59


<p align="center">
60
  <img width="750" src="docs/figures/widedeep_arch.png">
J
jrzaurin 已提交
61 62
</p>

63
The dashed boxes in the figure represent optional, overall components, and the
64 65 66 67 68 69 70
dashed lines/arrows indicate the corresponding connections, depending on
whether or not certain components are present. For example, the dashed,
blue-lines indicate that the ``deeptabular``, ``deeptext`` and ``deepimage``
components are connected directly to the output neuron or neurons (depending
on whether we are performing a binary classification or regression, or a
multi-class classification) if the optional ``deephead`` is not present.
Finally, the components within the faded-pink rectangle are concatenated.
71 72 73 74

Note that it is not possible to illustrate the number of possible
architectures and components available in ``pytorch-widedeep`` in one Figure.
Therefore, for more details on possible architectures (and more) please, see
75 76 77
the
[documentation]((https://pytorch-widedeep.readthedocs.io/en/latest/index.html)),
or the Examples folders and the notebooks there.
J
jrzaurin 已提交
78

79
In math terms, and following the notation in the
80 81
[paper](https://arxiv.org/abs/1606.07792), the expression for the architecture
without a ``deephead`` component can be formulated as:
J
jrzaurin 已提交
82 83

<p align="center">
84
  <img width="500" src="docs/figures/architecture_1_math.png">
J
jrzaurin 已提交
85 86
</p>

87

88
Where &sigma; is the sigmoid function, *'W'* are the weight matrices applied to the wide model and to the final
89
activations of the deep models, *'a'* are these final activations,
90 91
&phi;(x) are the cross product transformations of the original features *'x'*, and
, and *'b'* is the bias term.
J
jrzaurin 已提交
92 93 94 95 96 97
In case you are wondering what are *"cross product transformations"*, here is
a quote taken directly from the paper: *"For binary features, a cross-product
transformation (e.g., “AND(gender=female, language=en)”) is 1 if and only if
the constituent features (“gender=female” and “language=en”) are all 1, and 0
otherwise".*

98

99 100
While if there is a ``deephead`` component, the previous expression turns
into:
101 102 103 104 105

<p align="center">
  <img width="300" src="docs/figures/architecture_2_math.png">
</p>

J
jrzaurin 已提交
106 107
It is perfectly possible to use custom models (and not necessarily those in
the library) as long as the the custom models have an attribute called
J
jrzaurin 已提交
108
``output_dim`` with the size of the last layer of activations, so that
J
jrzaurin 已提交
109 110
``WideDeep`` can be constructed. Examples on how to use custom components can
be found in the Examples folder.
J
jrzaurin 已提交
111

112 113
### The ``deeptabular`` component

114 115
It is important to emphasize that **each individual component, `wide`,
`deeptabular`, `deeptext` and `deepimage`, can be used independently** and in
J
jrzaurin 已提交
116 117
isolation. For example, one could use only `wide`, which is in simply a
linear model. In fact, one of the most interesting functionalities
J
jrzaurin 已提交
118 119 120 121
in``pytorch-widedeep`` would be the use of the ``deeptabular`` component on
its own, i.e. what one might normally refer as Deep Learning for Tabular
Data. Currently, ``pytorch-widedeep`` offers the following different models
for that component:
122

J
jrzaurin 已提交
123 124
0. **Wide**: a simple linear model where the nonlinearities are captured via
cross-product transformations, as explained before.
J
jrzaurin 已提交
125
1. **TabMlp**: a simple MLP that receives embeddings representing the
J
jrzaurin 已提交
126 127
categorical features, concatenated with the continuous features, which can
also be embedded.
J
jrzaurin 已提交
128
2. **TabResnet**: similar to the previous model but the embeddings are
129
passed through a series of ResNet blocks built with dense layers.
J
jrzaurin 已提交
130
3. **TabNet**: details on TabNet can be found in
131
[TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/abs/1908.07442)
J
Added  
jrzaurin 已提交
132

133 134 135 136 137 138 139
Two simpler attention based models that we call:

4. **ContextAttentionMLP**: MLP with at attention mechanism "on top" that is based on
    [Hierarchical Attention Networks for Document Classification](https://www.cs.cmu.edu/~./hovy/papers/16HLT-hierarchical-attention-networks.pd)
5. **SelfAttentionMLP**: MLP with an attention mechanism that is a simplified
    version of a transformer block that we refer as "query-key self-attention".

J
jrzaurin 已提交
140
The ``Tabformer`` family, i.e. Transformers for Tabular data:
141

142
6. **TabTransformer**: details on the TabTransformer can be found in
J
jrzaurin 已提交
143
[TabTransformer: Tabular Data Modeling Using Contextual Embeddings](https://arxiv.org/pdf/2012.06678.pdf).
144
7. **SAINT**: Details on SAINT can be found in
J
jrzaurin 已提交
145
[SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training](https://arxiv.org/abs/2106.01342).
146
8. **FT-Transformer**: details on the FT-Transformer can be found in
J
jrzaurin 已提交
147
[Revisiting Deep Learning Models for Tabular Data](https://arxiv.org/abs/2106.11959).
148
9. **TabFastFormer**: adaptation of the FastFormer for tabular data. Details
J
jrzaurin 已提交
149 150
on the Fasformer can be found in
[FastFormers: Highly Efficient Transformer Models for Natural Language Understanding](https://arxiv.org/abs/2010.13382)
151
10. **TabPerceiver**: adaptation of the Perceiver for tabular data. Details on
J
jrzaurin 已提交
152 153
the Perceiver can be found in
[Perceiver: General Perception with Iterative Attention](https://arxiv.org/abs/2103.03206)
J
jrzaurin 已提交
154

J
jrzaurin 已提交
155 156 157
And probabilistic DL models for tabular data based on
[Weight Uncertainty in Neural Networks](https://arxiv.org/abs/1505.05424):

158 159
11. **BayesianWide**: Probabilistic adaptation of the `Wide` model.
12. **BayesianTabMlp**: Probabilistic adaptation of the `TabMlp` model
J
jrzaurin 已提交
160

J
jrzaurin 已提交
161 162 163
Note that while there are scientific publications for the TabTransformer,
SAINT and FT-Transformer, the TabFasfFormer and TabPerceiver are our own
adaptation of those algorithms for tabular data.
164

165 166 167 168 169 170
In addition, Self-Supervised pre-training can be used for all `deeptabular`
models, with the exception of the `TabPerceiver`. Self-Supervised
pre-training can be used via two methods or routines which we refer as:
encoder-decoder method and constrastive-denoising method. Please, see the
documentation and the examples for details on this functionality, and all
other options in the library.
171

172
###  Installation
J
jrzaurin 已提交
173

174 175 176 177 178 179 180
Install using pip:

```bash
pip install pytorch-widedeep
```

Or install directly from github
J
jrzaurin 已提交
181

J
jrzaurin 已提交
182
```bash
J
jrzaurin 已提交
183 184 185
pip install git+https://github.com/jrzaurin/pytorch-widedeep.git
```

J
jrzaurin 已提交
186
#### Developer Install
J
jrzaurin 已提交
187 188

```bash
J
jrzaurin 已提交
189
# Clone the repository
J
jrzaurin 已提交
190
git clone https://github.com/jrzaurin/pytorch-widedeep
J
jrzaurin 已提交
191 192
cd pytorch-widedeep

J
jrzaurin 已提交
193
# Install in dev mode
J
jrzaurin 已提交
194 195 196 197 198 199
pip install -e .
```

### Quick start

Binary classification with the [adult
200
dataset]([adult](https://www.kaggle.com/wenruliu/adult-income-dataset))
201
using `Wide` and `DeepDense` and defaults settings.
J
jrzaurin 已提交
202

203 204
Building a wide (linear) and deep model with ``pytorch-widedeep``:

J
jrzaurin 已提交
205
```python
206
import numpy as np
207
import torch
208 209
from sklearn.model_selection import train_test_split

210 211 212
from pytorch_widedeep import Trainer
from pytorch_widedeep.preprocessing import WidePreprocessor, TabPreprocessor
from pytorch_widedeep.models import Wide, TabMlp, WideDeep
213
from pytorch_widedeep.metrics import Accuracy
J
jrzaurin 已提交
214 215
from pytorch_widedeep.datasets import load_adult

J
jrzaurin 已提交
216

J
jrzaurin 已提交
217
df = load_adult(as_frame=True)
218 219 220
df["income_label"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
df.drop("income", axis=1, inplace=True)
df_train, df_test = train_test_split(df, test_size=0.2, stratify=df.income_label)
J
jrzaurin 已提交
221

J
jrzaurin 已提交
222
# Define the 'column set up'
223 224 225 226 227 228 229 230
wide_cols = [
    "education",
    "relationship",
    "workclass",
    "occupation",
    "native-country",
    "gender",
]
J
jrzaurin 已提交
231
crossed_cols = [("education", "occupation"), ("native-country", "occupation")]
J
jrzaurin 已提交
232

J
jrzaurin 已提交
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
cat_embed_cols = [
    "workclass",
    "education",
    "marital-status",
    "occupation",
    "relationship",
    "race",
    "gender",
    "capital-gain",
    "capital-loss",
    "native-country",
]
continuous_cols = ["age", "hours-per-week"]
target = "income_label"
target = df_train[target].values
J
jrzaurin 已提交
248

J
jrzaurin 已提交
249 250
# prepare the data
wide_preprocessor = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)
251
X_wide = wide_preprocessor.fit_transform(df_train)
J
jrzaurin 已提交
252

J
jrzaurin 已提交
253 254 255
tab_preprocessor = TabPreprocessor(
    cat_embed_cols=cat_embed_cols, continuous_cols=continuous_cols  # type: ignore[arg-type]
)
256
X_tab = tab_preprocessor.fit_transform(df_train)
J
jrzaurin 已提交
257 258 259 260

# build the model
wide = Wide(input_dim=np.unique(X_wide).shape[0], pred_dim=1)
tab_mlp = TabMlp(
261
    column_idx=tab_preprocessor.column_idx,
J
jrzaurin 已提交
262 263
    cat_embed_input=tab_preprocessor.cat_embed_input,
    continuous_cols=continuous_cols,
264
)
J
jrzaurin 已提交
265
model = WideDeep(wide=wide, deeptabular=tab_mlp)
266

J
jrzaurin 已提交
267
# train and validate
268 269
trainer = Trainer(model, objective="binary", metrics=[Accuracy])
trainer.fit(
270
    X_wide=X_wide,
271
    X_tab=X_tab,
272 273 274 275 276
    target=target,
    n_epochs=5,
    batch_size=256,
)

J
jrzaurin 已提交
277
# predict on test
278 279 280 281
X_wide_te = wide_preprocessor.transform(df_test)
X_tab_te = tab_preprocessor.transform(df_test)
preds = trainer.predict(X_wide=X_wide_te, X_tab=X_tab_te)

282 283 284 285 286 287 288 289 290 291 292 293
# Save and load

# Option 1: this will also save training history and lr history if the
# LRHistory callback is used
trainer.save(path="model_weights", save_state_dict=True)

# Option 2: save as any other torch model
torch.save(model.state_dict(), "model_weights/wd_model.pt")

# From here in advance, Option 1 or 2 are the same. I assume the user has
# prepared the data and defined the new model components:
# 1. Build the model
J
jrzaurin 已提交
294
model_new = WideDeep(wide=wide, deeptabular=tab_mlp)
295 296 297
model_new.load_state_dict(torch.load("model_weights/wd_model.pt"))

# 2. Instantiate the trainer
J
jrzaurin 已提交
298
trainer_new = Trainer(model_new, objective="binary")
299 300 301

# 3. Either start the fit or directly predict
preds = trainer_new.predict(X_wide=X_wide, X_tab=X_tab)
J
jrzaurin 已提交
302 303
```

304 305 306
Of course, one can do **much more**. See the Examples folder, the
documentation or the companion posts for a better understanding of the content
of the package and its functionalities.
J
jrzaurin 已提交
307 308 309 310

### Testing

```
J
jrzaurin 已提交
311
pytest tests
312 313
```

P
Pavol Mulinka 已提交
314 315
### How to Contribute

316
Check [CONTRIBUTING](https://github.com/jrzaurin/pytorch-widedeep/blob/master/CONTRIBUTING.MD) page.
P
Pavol Mulinka 已提交
317

318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336
### Acknowledgments

This library takes from a series of other libraries, so I think it is just
fair to mention them here in the README (specific mentions are also included
in the code).

The `Callbacks` and `Initializers` structure and code is inspired by the
[`torchsample`](https://github.com/ncullen93/torchsample) library, which in
itself partially inspired by [`Keras`](https://keras.io/).

The `TextProcessor` class in this library uses the
[`fastai`](https://docs.fast.ai/text.transform.html#BaseTokenizer.tokenizer)'s
`Tokenizer` and `Vocab`. The code at `utils.fastai_transforms` is a minor
adaptation of their code so it functions within this library. To my experience
their `Tokenizer` is the best in class.

The `ImageProcessor` class in this library uses code from the fantastic [Deep
Learning for Computer
Vision](https://www.pyimagesearch.com/deep-learning-computer-vision-python-book/)
P
Pavol Mulinka 已提交
337 338 339 340 341 342 343
(DL4CV) book by Adrian Rosebrock.

### License

This work is dual-licensed under Apache 2.0 and MIT (or any later version).
You can choose between one of them if you use this work.

344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371
`SPDX-License-Identifier: Apache-2.0 AND MIT`

### Cite

#### BibTex

```
@article{Zaurin_pytorch-widedeep_A_flexible_2023,
author = {Zaurin, Javier Rodriguez and Mulinka, Pavol},
doi = {10.21105/joss.05027},
journal = {Journal of Open Source Software},
month = jun,
number = {86},
pages = {5027},
title = {{pytorch-widedeep: A flexible package for multimodal deep learning}},
url = {https://joss.theoj.org/papers/10.21105/joss.05027},
volume = {8},
year = {2023}
}
```

#### APA

```
Zaurin, J. R., & Mulinka, P. (2023). pytorch-widedeep: A flexible package for
multimodal deep learning. Journal of Open Source Software, 8(86), 5027.
https://doi.org/10.21105/joss.05027
```