README.md 13.2 KB
Newer Older
J
Javier 已提交
1

J
jrzaurin 已提交
2
<p align="center">
3
  <img width="300" src="docs/figures/widedeep_logo.png">
J
jrzaurin 已提交
4 5
</p>

6 7
[![PyPI version](https://badge.fury.io/py/pytorch-widedeep.svg)](https://pypi.org/project/pytorch-widedeep/)
[![Python 3.7 3.8 3.9](https://img.shields.io/badge/python-3.7%20%7C%203.8%20%7C%203.9-blue.svg)](https://pypi.org/project/pytorch-widedeep/)
8
[![Build Status](https://github.com/jrzaurin/pytorch-widedeep/actions/workflows/build.yml/badge.svg)](https://github.com/jrzaurin/pytorch-widedeep/actions)
J
jrzaurin 已提交
9
[![Documentation Status](https://readthedocs.org/projects/pytorch-widedeep/badge/?version=latest)](https://pytorch-widedeep.readthedocs.io/en/latest/?badge=latest)
10 11
[![codecov](https://codecov.io/gh/jrzaurin/pytorch-widedeep/branch/master/graph/badge.svg)](https://codecov.io/gh/jrzaurin/pytorch-widedeep)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
12
[![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://github.com/jrzaurin/pytorch-widedeep/graphs/commit-activity)
13
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/jrzaurin/pytorch-widedeep/issues)
14
[![Slack](https://img.shields.io/badge/slack-chat-green.svg?logo=slack)](https://join.slack.com/t/pytorch-widedeep/shared_invite/zt-soss7stf-iXpVuLeKZz8lGTnxxtHtTw)
15

J
jrzaurin 已提交
16 17
# pytorch-widedeep

18 19
A flexible package to use Deep Learning with tabular data, text and images
using wide and deep models.
J
jrzaurin 已提交
20

21
**Documentation:** [https://pytorch-widedeep.readthedocs.io](https://pytorch-widedeep.readthedocs.io/en/latest/index.html)
J
jrzaurin 已提交
22

23 24 25
**Companion posts and tutorials:** [infinitoml](https://jrzaurin.github.io/infinitoml/)

**Experiments and comparisson with `LightGBM`**: [TabularDL vs LightGBM](https://github.com/jrzaurin/tabulardl-benchmark)
J
jrzaurin 已提交
26

27 28 29 30 31 32 33
The content of this document is organized as follows:

1. [introduction](#introduction)
2. [The deeptabular component](#the-deeptabular-component)
3. [installation](#installation)
4. [quick start (tl;dr)](#quick-start)

J
jrzaurin 已提交
34 35
### Introduction

36
``pytorch-widedeep`` is based on Google's [Wide and Deep Algorithm](https://arxiv.org/abs/1606.07792)
J
jrzaurin 已提交
37

J
jrzaurin 已提交
38 39
In general terms, `pytorch-widedeep` is a package to use deep learning with
tabular data. In particular, is intended to facilitate the combination of text
J
jrzaurin 已提交
40
and images with corresponding tabular data using wide and deep models. With
41 42 43
that in mind there are a number of architectures that can be implemented with
just a few lines of code. The main components of those architectures are shown
in the Figure below:
J
jrzaurin 已提交
44 45 46


<p align="center">
47
  <img width="750" src="docs/figures/widedeep_arch.png">
J
jrzaurin 已提交
48 49
</p>

50
The dashed boxes in the figure represent optional, overall components, and the
51 52 53 54 55 56 57
dashed lines/arrows indicate the corresponding connections, depending on
whether or not certain components are present. For example, the dashed,
blue-lines indicate that the ``deeptabular``, ``deeptext`` and ``deepimage``
components are connected directly to the output neuron or neurons (depending
on whether we are performing a binary classification or regression, or a
multi-class classification) if the optional ``deephead`` is not present.
Finally, the components within the faded-pink rectangle are concatenated.
58 59 60 61

Note that it is not possible to illustrate the number of possible
architectures and components available in ``pytorch-widedeep`` in one Figure.
Therefore, for more details on possible architectures (and more) please, see
62 63 64
the
[documentation]((https://pytorch-widedeep.readthedocs.io/en/latest/index.html)),
or the Examples folders and the notebooks there.
J
jrzaurin 已提交
65

66
In math terms, and following the notation in the
67 68
[paper](https://arxiv.org/abs/1606.07792), the expression for the architecture
without a ``deephead`` component can be formulated as:
J
jrzaurin 已提交
69 70

<p align="center">
71
  <img width="500" src="docs/figures/architecture_1_math.png">
J
jrzaurin 已提交
72 73
</p>

74 75 76

Where *'W'* are the weight matrices applied to the wide model and to the final
activations of the deep models, *'a'* are these final activations, and
J
jrzaurin 已提交
77 78 79 80 81 82 83
&phi;(x) are the cross product transformations of the original features *'x'*.
In case you are wondering what are *"cross product transformations"*, here is
a quote taken directly from the paper: *"For binary features, a cross-product
transformation (e.g., “AND(gender=female, language=en)”) is 1 if and only if
the constituent features (“gender=female” and “language=en”) are all 1, and 0
otherwise".*

84

85 86
While if there is a ``deephead`` component, the previous expression turns
into:
87 88 89 90 91

<p align="center">
  <img width="300" src="docs/figures/architecture_2_math.png">
</p>

J
jrzaurin 已提交
92 93 94 95 96 97 98 99 100 101
I recommend using the ``wide`` and ``deeptabular`` models in
``pytorch-widedeep``. However it is very likely that users will want to use
their own models for the ``deeptext`` and ``deepimage`` components. That is
perfectly possible as long as the the custom models have an attribute called
``output_dim`` with the size of the last layer of activations, so that
``WideDeep`` can be constructed. Again, examples on how to use custom
components can be found in the Examples folder. Just in case
``pytorch-widedeep`` includes standard text (stack of LSTMs) and image
(pre-trained ResNets or stack of CNNs) models.

102 103
### The ``deeptabular`` component

104 105
It is important to emphasize that **each individual component, `wide`,
`deeptabular`, `deeptext` and `deepimage`, can be used independently** and in
J
jrzaurin 已提交
106 107
isolation. For example, one could use only `wide`, which is in simply a
linear model. In fact, one of the most interesting functionalities
J
jrzaurin 已提交
108 109 110 111
in``pytorch-widedeep`` would be the use of the ``deeptabular`` component on
its own, i.e. what one might normally refer as Deep Learning for Tabular
Data. Currently, ``pytorch-widedeep`` offers the following different models
for that component:
112 113


J
jrzaurin 已提交
114 115 116
1. **TabMlp**: a simple MLP that receives embeddings representing the
categorical features, concatenated with the continuous features.
2. **TabResnet**: similar to the previous model but the embeddings are
117
passed through a series of ResNet blocks built with dense layers.
J
jrzaurin 已提交
118
3. **TabNet**: details on TabNet can be found in
119
[TabNet: Attentive Interpretable Tabular Learning](https://arxiv.org/abs/1908.07442)
J
Added  
jrzaurin 已提交
120

J
jrzaurin 已提交
121
And the ``Tabformer`` family, i.e. Transformers for Tabular data:
122

J
jrzaurin 已提交
123 124 125 126 127
4. **TabTransformer**: details on the TabTransformer can be found in
[TabTransformer: Tabular Data Modeling Using Contextual Embeddings](https://arxiv.org/pdf/2012.06678.pdf).
5. **SAINT**: Details on SAINT can be found in
[SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training](https://arxiv.org/abs/2106.01342).
6. **FT-Transformer**: details on the FT-Transformer can be found in
J
jrzaurin 已提交
128
[Revisiting Deep Learning Models for Tabular Data](https://arxiv.org/abs/2106.11959).
129
7. **TabFastFormer**: adaptation of the FastFormer for tabular data. Details
J
jrzaurin 已提交
130 131 132 133 134
on the Fasformer can be found in
[FastFormers: Highly Efficient Transformer Models for Natural Language Understanding](https://arxiv.org/abs/2010.13382)
8. **TabPerceiver**: adaptation of the Perceiver for tabular data. Details on
the Perceiver can be found in
[Perceiver: General Perception with Iterative Attention](https://arxiv.org/abs/2103.03206)
J
jrzaurin 已提交
135

J
jrzaurin 已提交
136 137 138
Note that while there are scientific publications for the TabTransformer,
SAINT and FT-Transformer, the TabFasfFormer and TabPerceiver are our own
adaptation of those algorithms for tabular data.
139

J
jrzaurin 已提交
140
For details on these models and their options please see the examples in the
141 142
Examples folder and the documentation.

143
###  Installation
J
jrzaurin 已提交
144

145 146 147 148 149 150 151
Install using pip:

```bash
pip install pytorch-widedeep
```

Or install directly from github
J
jrzaurin 已提交
152

J
jrzaurin 已提交
153
```bash
J
jrzaurin 已提交
154 155 156
pip install git+https://github.com/jrzaurin/pytorch-widedeep.git
```

J
jrzaurin 已提交
157
#### Developer Install
J
jrzaurin 已提交
158 159

```bash
J
jrzaurin 已提交
160
# Clone the repository
J
jrzaurin 已提交
161
git clone https://github.com/jrzaurin/pytorch-widedeep
J
jrzaurin 已提交
162 163
cd pytorch-widedeep

J
jrzaurin 已提交
164
# Install in dev mode
J
jrzaurin 已提交
165 166 167
pip install -e .
```

J
jrzaurin 已提交
168 169 170 171 172
**Important note for Mac users**: at the time of writing the latest `torch`
release is `1.9`. Some past [issues](https://stackoverflow.com/questions/64772335/pytorch-w-parallelnative-cpp206)
when running on Mac, present in previous versions, persist on this release
and the data-loaders will not run in parallel. In addition, since `python
3.8`, [the `multiprocessing` library start method changed from `'fork'` to`'spawn'`](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods).
173
This also affects the data-loaders (for any `torch` version) and they will
174 175
not run in parallel. Therefore, for Mac users I recommend using `python 3.7`
and `torch <= 1.6` (with the corresponding, consistent
J
jrzaurin 已提交
176 177 178 179 180
version of `torchvision`, e.g. `0.7.0` for `torch 1.6`). I do not want to
force this versioning in the `setup.py` file since I expect that all these
issues are fixed in the future. Therefore, after installing
`pytorch-widedeep` via pip or directly from github, downgrade `torch` and
`torchvision` manually:
181 182 183 184 185 186 187 188

```bash
pip install pytorch-widedeep
pip install torch==1.6.0 torchvision==0.7.0
```

None of these issues affect Linux users.

J
jrzaurin 已提交
189 190 191
### Quick start

Binary classification with the [adult
192
dataset]([adult](https://www.kaggle.com/wenruliu/adult-income-dataset))
193
using `Wide` and `DeepDense` and defaults settings.
J
jrzaurin 已提交
194

195 196
Building a wide (linear) and deep model with ``pytorch-widedeep``:

J
jrzaurin 已提交
197
```python
198

J
jrzaurin 已提交
199
import pandas as pd
200
import numpy as np
201
import torch
202 203
from sklearn.model_selection import train_test_split

204 205 206
from pytorch_widedeep import Trainer
from pytorch_widedeep.preprocessing import WidePreprocessor, TabPreprocessor
from pytorch_widedeep.models import Wide, TabMlp, WideDeep
207
from pytorch_widedeep.metrics import Accuracy
J
jrzaurin 已提交
208

209 210 211
# the following 4 lines are not directly related to ``pytorch-widedeep``. I
# assume you have downloaded the dataset and place it in a dir called
# data/adult/
212 213 214 215
df = pd.read_csv("data/adult/adult.csv.zip")
df["income_label"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
df.drop("income", axis=1, inplace=True)
df_train, df_test = train_test_split(df, test_size=0.2, stratify=df.income_label)
J
jrzaurin 已提交
216 217

# prepare wide, crossed, embedding and continuous columns
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234
wide_cols = [
    "education",
    "relationship",
    "workclass",
    "occupation",
    "native-country",
    "gender",
]
cross_cols = [("education", "occupation"), ("native-country", "occupation")]
embed_cols = [
    ("education", 16),
    ("workclass", 16),
    ("occupation", 16),
    ("native-country", 32),
]
cont_cols = ["age", "hours-per-week"]
target_col = "income_label"
J
jrzaurin 已提交
235 236

# target
237
target = df_train[target_col].values
J
jrzaurin 已提交
238 239

# wide
240 241
wide_preprocessor = WidePreprocessor(wide_cols=wide_cols, crossed_cols=cross_cols)
X_wide = wide_preprocessor.fit_transform(df_train)
242
wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)
J
jrzaurin 已提交
243

244 245 246 247 248 249 250
# deeptabular
tab_preprocessor = TabPreprocessor(embed_cols=embed_cols, continuous_cols=cont_cols)
X_tab = tab_preprocessor.fit_transform(df_train)
deeptabular = TabMlp(
    mlp_hidden_dims=[64, 32],
    column_idx=tab_preprocessor.column_idx,
    embed_input=tab_preprocessor.embeddings_input,
251 252
    continuous_cols=cont_cols,
)
253 254 255 256 257 258 259

# wide and deep
model = WideDeep(wide=wide, deeptabular=deeptabular)

# train the model
trainer = Trainer(model, objective="binary", metrics=[Accuracy])
trainer.fit(
260
    X_wide=X_wide,
261
    X_tab=X_tab,
262 263 264 265 266 267 268
    target=target,
    n_epochs=5,
    batch_size=256,
    val_split=0.1,
)

# predict
269 270 271 272
X_wide_te = wide_preprocessor.transform(df_test)
X_tab_te = tab_preprocessor.transform(df_test)
preds = trainer.predict(X_wide=X_wide_te, X_tab=X_tab_te)

273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
# Save and load

# Option 1: this will also save training history and lr history if the
# LRHistory callback is used
trainer.save(path="model_weights", save_state_dict=True)

# Option 2: save as any other torch model
torch.save(model.state_dict(), "model_weights/wd_model.pt")

# From here in advance, Option 1 or 2 are the same. I assume the user has
# prepared the data and defined the new model components:
# 1. Build the model
model_new = WideDeep(wide=wide, deeptabular=deeptabular)
model_new.load_state_dict(torch.load("model_weights/wd_model.pt"))

# 2. Instantiate the trainer
trainer_new = Trainer(
    model_new,
    objective="binary",
)

# 3. Either start the fit or directly predict
preds = trainer_new.predict(X_wide=X_wide, X_tab=X_tab)
J
jrzaurin 已提交
296 297
```

298 299 300
Of course, one can do **much more**. See the Examples folder, the
documentation or the companion posts for a better understanding of the content
of the package and its functionalities.
J
jrzaurin 已提交
301 302 303 304

### Testing

```
J
jrzaurin 已提交
305
pytest tests
306 307
```

P
Pavol Mulinka 已提交
308 309 310 311
### How to Contribute

Check [CONTRIBUTING](https://github.com/jrzaurin/pytorch-widedeep/CONTRIBUTING.MD) page.

312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
### Acknowledgments

This library takes from a series of other libraries, so I think it is just
fair to mention them here in the README (specific mentions are also included
in the code).

The `Callbacks` and `Initializers` structure and code is inspired by the
[`torchsample`](https://github.com/ncullen93/torchsample) library, which in
itself partially inspired by [`Keras`](https://keras.io/).

The `TextProcessor` class in this library uses the
[`fastai`](https://docs.fast.ai/text.transform.html#BaseTokenizer.tokenizer)'s
`Tokenizer` and `Vocab`. The code at `utils.fastai_transforms` is a minor
adaptation of their code so it functions within this library. To my experience
their `Tokenizer` is the best in class.

The `ImageProcessor` class in this library uses code from the fantastic [Deep
Learning for Computer
Vision](https://www.pyimagesearch.com/deep-learning-computer-vision-python-book/)
(DL4CV) book by Adrian Rosebrock.