README.md 5.9 KB
Newer Older
J
Javier 已提交
1

J
jrzaurin 已提交
2 3 4 5 6 7
<p align="center">
  <img width="250" src="docs/figures/widedeep_logo.png">
</p>

# pytorch-widedeep

8
A flexible package to combine tabular data with text and images using wide and
J
jrzaurin 已提交
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
deep models.

### Introduction

`pytorch-widedeep` is based on Tensorflow's Wide and Deep Algorithm. Details of
the original algorithm can be found
[here](https://www.tensorflow.org/tutorials/wide_and_deep) and the nice
research paper can be found [here](https://arxiv.org/abs/1606.07792).

`pytorch-widedeep` is a package intended to facilitate the combination of text
and images with corresponding tabular data using wide and deep models. With
that in mind there are two architectures that can be implemented with just a
few lines of code.

### Architectures

**Architecture 1**:

<p align="center">
  <img width="600" src="docs/figures/architecture_1.png">
</p>

Architecture 1 combines the `Wide`, one-hot encoded features (a linear model)
with the outputs from the `DeepDense`, `DeepText` and `DeepImage` components
33 34 35 36
connected to a final output neuron or neurons, depending on whether we are
performing a binary classification or regression, or a multi-class
classification. The components within the faded-pink rectangles are
concatenated.
J
jrzaurin 已提交
37

38
In math terms, and following the notation in the [paper](https://arxiv.org/abs/1606.07792), Architecture 1 can be formulated as:
J
jrzaurin 已提交
39 40

<p align="center">
41
  <img width="500" src="docs/figures/architecture_1_math.png">
J
jrzaurin 已提交
42 43
</p>

44 45 46 47 48 49 50

Where *'W'* are the weight matrices applied to the wide model and to the final
activations of the deep models, *'a'* are these final activations, and
&phi;(x) are the cross product transformations of the original features *'x'*

**Architecture 2**

J
jrzaurin 已提交
51 52 53 54 55
<p align="center">
  <img width="600" src="docs/figures/architecture_2.png">
</p>

Architecture 2 combines the `Wide` one-hot encoded features (a linear model)
56 57
with the Deep components of the model connected to the output neuron(s), after
the different Deep components have been themselves combined through a FC-Head
J
jrzaurin 已提交
58 59
(that I refer as `deephead`).

60 61 62 63 64 65 66 67
In math terms, and following the notation in the
[paper](https://arxiv.org/abs/1606.07792), Architecture 2 can be formulated
as:

<p align="center">
  <img width="300" src="docs/figures/architecture_2_math.png">
</p>

J
jrzaurin 已提交
68 69 70
When using `pytorch-widedeep`, the assumption is that the so called `Wide` and
`DeepDense` components in the figures are **always** present, while `DeepText`
and `DeepImage` are optional. `pytorch-widedeep` includes some standard text
71
(stack of LSTMs) and image (pre-trained ResNets or stack of CNNs) models.
J
jrzaurin 已提交
72 73 74 75 76 77 78 79 80 81 82 83
However, the user can use any custom model as long as it has an attribute
called `output_dim` with the size of the last layer of activations, so that
WideDeep can be constructed. See the examples folder for more information.


### Installation
Install directly from github

```
pip install git+https://github.com/jrzaurin/pytorch-widedeep.git
```

84 85 86 87 88
Note that the `Pytorch` installation formula with `pip` is different than that
of `conda`. The later installs the cuda toolkit. Therefore, if you are using
`conda` and have already installed `torch` and `torvision`, I recommend
cloning the directory, removing the `torch` and `torchvision` dependencies
from the `setup.py` file and then `pip install .`:
J
jrzaurin 已提交
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105

```
# Clone the repository
git clone git@github.com:jrzaurin/pytorch-widedeep.git
cd pytorch-widedeep

# remove torch and torchvision dependencies from setup.py and the run:
pip install .

# or dev mode
pip install -e .
```

### Quick start

Binary classification with the [adult
dataset]([adult](https://www.kaggle.com/wenruliu/adult-income-dataset/downloads/adult.csv/2))
106
using `Wide` and `DeepDense` and defaults settings.
J
jrzaurin 已提交
107 108 109 110 111 112

```python
from pytorch_widedeep.preprocessing import WidePreprocessor, DeepPreprocessor
from pytorch_widedeep.models import Wide, DeepDense, WideDeep
from pytorch_widedeep.metrics import BinaryAccuracy

113
# these next 3 lines are not directly related to pytorch-widedeep
J
jrzaurin 已提交
114 115 116 117 118
df = pd.read_csv('data/adult/adult.csv.zip')
df['income_label'] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
df.drop('income', axis=1, inplace=True)

# prepare wide, crossed, embedding and continuous columns
119 120 121 122 123
wide_cols  = ['education', 'relationship', 'workclass', 'occupation','native-country', 'gender']
cross_cols = [('education', 'occupation'), ('native-country', 'occupation')]
embed_cols = [('education',16), ('workclass',16), ('occupation',16),('native-country',32)]
cont_cols  = ["age", "hours-per-week"]
target_col = 'income-label'
J
jrzaurin 已提交
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147

# target
target = df[target_col].values

# wide
preprocess_wide = WidePreprocessor(wide_cols=wide_cols, crossed_cols=cross_cols)
X_wide = preprocess_wide.fit_transform(df)
wide = Wide(wide_dim=X_wide.shape[1], output_dim=1)

# deepdense
preprocess_deep = DeepPreprocessor(embed_cols=embed_cols, continuous_cols=cont_cols)
X_deep = preprocess_deep.fit_transform(df)
deepdense = DeepDense(hidden_layers=[64,32],
                      deep_column_idx=preprocess_deep.deep_column_idx,
                      embed_input=preprocess_deep.embeddings_input,
                      continuous_cols=cont_cols)

# build, compile, fit and predict
model = WideDeep(wide=wide, deepdense=deepdense)
model.compile(method='binary', metrics=[BinaryAccuracy])
model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=5, batch_size=256, val_split=0.2)
model.predict(X_wide=X_wide_te, X_deep=X_deep_te)
```

148 149 150 151 152 153 154
Of course, one can do much more, such as using different initializations,
optimizers or learning rate schedulers for each component of the overall
model. Adding FC-Heads to the Text and Image components, etc. See the examples
folder for a better understanding of the content of the package and its
functionalities. In the likely case github does not render the notebooks, or
it renders them missing some parts (e.g. colorful headlines) they are saved as
markdown files in the docs folder.
J
jrzaurin 已提交
155 156 157 158 159 160 161 162 163

### Testing

```
cd test
pytest --ignore=test_data_utils/test_du_deep_image.py
cd test_data_utils
pytest test_du_deep_image.py
```