README.md 6.9 KB
Newer Older
J
Javier 已提交
1

J
jrzaurin 已提交
2
<p align="center">
J
jrzaurin 已提交
3
  <img width="450" src="docs/figures/widedeep_logo.png">
J
jrzaurin 已提交
4 5 6 7
</p>

# pytorch-widedeep

8
A flexible package to combine tabular data with text and images using wide and
J
jrzaurin 已提交
9 10 11 12
deep models.

### Introduction

J
jrzaurin 已提交
13
`pytorch-widedeep` is based on Google's Wide and Deep Algorithm. Details of
J
jrzaurin 已提交
14
the original algorithm can be found
J
jrzaurin 已提交
15
[here](https://www.tensorflow.org/tutorials/wide_and_deep), and the nice
J
jrzaurin 已提交
16 17
research paper can be found [here](https://arxiv.org/abs/1606.07792).

J
jrzaurin 已提交
18 19
In general terms, `pytorch-widedeep` is a package to use deep learning with
tabular data. In particular, is intended to facilitate the combination of text
J
jrzaurin 已提交
20 21 22 23 24 25 26 27 28 29 30 31
and images with corresponding tabular data using wide and deep models. With
that in mind there are two architectures that can be implemented with just a
few lines of code.

### Architectures

**Architecture 1**:

<p align="center">
  <img width="600" src="docs/figures/architecture_1.png">
</p>

J
jrzaurin 已提交
32 33 34 35 36
Architecture 1 combines the `Wide`, one-hot encoded features with the outputs
from the `DeepDense`, `DeepText` and `DeepImage` components connected to a
final output neuron or neurons, depending on whether we are performing a
binary classification or regression, or a multi-class classification. The
components within the faded-pink rectangles are concatenated.
J
jrzaurin 已提交
37

38
In math terms, and following the notation in the [paper](https://arxiv.org/abs/1606.07792), Architecture 1 can be formulated as:
J
jrzaurin 已提交
39 40

<p align="center">
41
  <img width="500" src="docs/figures/architecture_1_math.png">
J
jrzaurin 已提交
42 43
</p>

44 45 46

Where *'W'* are the weight matrices applied to the wide model and to the final
activations of the deep models, *'a'* are these final activations, and
J
jrzaurin 已提交
47 48 49 50 51 52 53
&phi;(x) are the cross product transformations of the original features *'x'*.
In case you are wondering what are *"cross product transformations"*, here is
a quote taken directly from the paper: *"For binary features, a cross-product
transformation (e.g., “AND(gender=female, language=en)”) is 1 if and only if
the constituent features (“gender=female” and “language=en”) are all 1, and 0
otherwise".*

54 55 56

**Architecture 2**

J
jrzaurin 已提交
57 58 59 60
<p align="center">
  <img width="600" src="docs/figures/architecture_2.png">
</p>

J
jrzaurin 已提交
61 62 63 64
Architecture 2 combines the `Wide` one-hot encoded features with the Deep
components of the model connected to the output neuron(s), after the different
Deep components have been themselves combined through a FC-Head (that I refer
as `deephead`).
J
jrzaurin 已提交
65

66 67 68 69 70 71 72 73
In math terms, and following the notation in the
[paper](https://arxiv.org/abs/1606.07792), Architecture 2 can be formulated
as:

<p align="center">
  <img width="300" src="docs/figures/architecture_2_math.png">
</p>

J
jrzaurin 已提交
74 75
When using `pytorch-widedeep`, the assumption is that the so called `Wide` and
`DeepDense` components in the figures are **always** present, while `DeepText`
J
jrzaurin 已提交
76 77 78 79 80
and `DeepImage` are optional. `pytorch-widedeep` includes standard text (stack
of LSTMs) and image (pre-trained ResNets or stack of CNNs) models. However,
the user can use any custom model as long as it has an attribute called
`output_dim` with the size of the last layer of activations, so that
`WideDeep` can be constructed. See the examples folder for more information.
J
jrzaurin 已提交
81 82 83


### Installation
J
jrzaurin 已提交
84 85 86

#### OSX and Ubuntu (with and without CUDA)

J
jrzaurin 已提交
87 88
Install directly from github

J
jrzaurin 已提交
89
```bash
J
jrzaurin 已提交
90 91 92
pip install git+https://github.com/jrzaurin/pytorch-widedeep.git
```

J
jrzaurin 已提交
93
#### If you are using Conda with CUDA
J
jrzaurin 已提交
94

J
jrzaurin 已提交
95 96 97 98 99 100 101 102
The `Pytorch` installation command with `pip` is different than that of
`conda` with CUDA. The later installs the CUDA toolkit, see
[here](https://pytorch.org/). Therefore, if you are using `conda` and have
already installed `torch` and `torvision`, or do not want to use `pip`, I
recommend cloning the directory, removing the `torch` and `torchvision`
dependencies from the `setup.py` file and then `pip install .`:

```bash
J
jrzaurin 已提交
103
# Clone the repository
J
jrzaurin 已提交
104
git clone https://github.com/jrzaurin/pytorch-widedeep
J
jrzaurin 已提交
105 106 107 108 109 110 111 112 113
cd pytorch-widedeep

# remove torch and torchvision dependencies from setup.py and the run:
pip install .

# or dev mode
pip install -e .
```

J
jrzaurin 已提交
114 115 116 117 118
Note that installing `pytorch-widedeep` directly from github would still work
(moreover if you do not have CUDA), but the CUDA toolkit is recommended for a
more efficient installation/performance.


J
jrzaurin 已提交
119 120 121
### Examples

There are 4 main notebooks in the `examples` folder plus some additional
J
jrzaurin 已提交
122 123 124 125
files. These notebooks cover most of the utilities of this package and can
also act as documentation. In the likely case that github does not render the
notebooks, or it renders them missing some parts, they are saved as markdown
files in the docs folder.
J
jrzaurin 已提交
126

J
jrzaurin 已提交
127 128 129 130
### Quick start

Binary classification with the [adult
dataset]([adult](https://www.kaggle.com/wenruliu/adult-income-dataset/downloads/adult.csv/2))
131
using `Wide` and `DeepDense` and defaults settings.
J
jrzaurin 已提交
132 133

```python
J
jrzaurin 已提交
134
import pandas as pd
J
jrzaurin 已提交
135 136 137 138
from pytorch_widedeep.preprocessing import WidePreprocessor, DeepPreprocessor
from pytorch_widedeep.models import Wide, DeepDense, WideDeep
from pytorch_widedeep.metrics import BinaryAccuracy

J
jrzaurin 已提交
139 140
# these next 3 lines are not directly related to pytorch-widedeep. I assume
# you have downloaded the dataset and place it in a dir called data/adult/
J
jrzaurin 已提交
141 142 143 144 145
df = pd.read_csv('data/adult/adult.csv.zip')
df['income_label'] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
df.drop('income', axis=1, inplace=True)

# prepare wide, crossed, embedding and continuous columns
146 147 148 149
wide_cols  = ['education', 'relationship', 'workclass', 'occupation','native-country', 'gender']
cross_cols = [('education', 'occupation'), ('native-country', 'occupation')]
embed_cols = [('education',16), ('workclass',16), ('occupation',16),('native-country',32)]
cont_cols  = ["age", "hours-per-week"]
J
jrzaurin 已提交
150
target_col = 'income_label'
J
jrzaurin 已提交
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174

# target
target = df[target_col].values

# wide
preprocess_wide = WidePreprocessor(wide_cols=wide_cols, crossed_cols=cross_cols)
X_wide = preprocess_wide.fit_transform(df)
wide = Wide(wide_dim=X_wide.shape[1], output_dim=1)

# deepdense
preprocess_deep = DeepPreprocessor(embed_cols=embed_cols, continuous_cols=cont_cols)
X_deep = preprocess_deep.fit_transform(df)
deepdense = DeepDense(hidden_layers=[64,32],
                      deep_column_idx=preprocess_deep.deep_column_idx,
                      embed_input=preprocess_deep.embeddings_input,
                      continuous_cols=cont_cols)

# build, compile, fit and predict
model = WideDeep(wide=wide, deepdense=deepdense)
model.compile(method='binary', metrics=[BinaryAccuracy])
model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=5, batch_size=256, val_split=0.2)
model.predict(X_wide=X_wide_te, X_deep=X_deep_te)
```

175 176
Of course, one can do much more, such as using different initializations,
optimizers or learning rate schedulers for each component of the overall
J
jrzaurin 已提交
177 178 179
model. Adding FC-Heads to the Text and Image components, etc. See the
`examples` folder for a better understanding of the content of the package and
its functionalities.
J
jrzaurin 已提交
180 181 182 183 184 185 186 187 188

### Testing

```
cd test
pytest --ignore=test_data_utils/test_du_deep_image.py
cd test_data_utils
pytest test_du_deep_image.py
```