未验证 提交 2fe4b49a 编写于 作者: J Javier 提交者: GitHub

Merge pull request #30 from jrzaurin/fix_image_format

Fix image format 
<p align="center">
<img width="450" src="docs/figures/widedeep_logo.png">
<img width="300" src="docs/figures/widedeep_logo.png">
</p>
[![Build Status](https://travis-ci.org/jrzaurin/pytorch-widedeep.svg?branch=master)](https://travis-ci.org/jrzaurin/pytorch-widedeep)
......@@ -9,11 +9,7 @@
[![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://github.com/jrzaurin/pytorch-widedeep/graphs/commit-activity)
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/jrzaurin/pytorch-widedeep/issues)
[![codecov](https://codecov.io/gh/jrzaurin/pytorch-widedeep/branch/master/graph/badge.svg)](https://codecov.io/gh/jrzaurin/pytorch-widedeep)
Platform | Version Support
---------|:---------------
OSX | [![Python 3.6 3.7](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)](https://www.python.org/)
Linux | [![Python 3.6 3.7 3.8](https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8-blue.svg)](https://www.python.org/)
[![Python 3.6 3.7 3.8](https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8-blue.svg)](https://www.python.org/)
# pytorch-widedeep
......@@ -88,15 +84,23 @@ as:
<img width="300" src="docs/figures/architecture_2_math.png">
</p>
When using `pytorch-widedeep`, the assumption is that the so called `Wide` and
`deep dense` (this can be either `DeepDense` or `DeepDenseResnet`. See the
documentation and examples folder for more details) components in the figures
are **always** present, while `DeepText text` and `DeepImage` are optional.
Note that each individual component, `wide`, `deepdense` (either `DeepDense`
or `DeepDenseResnet`), `deeptext` and `deepimage`, can be used independently
and in isolation. For example, one could use only `wide`, which is in simply a
linear model.
On the other hand, while I recommend using the `Wide` and `DeepDense` (or
`DeepDenseResnet`) classes in `pytorch-widedeep` to build the `wide` and
`deepdense` component, it is very likely that users will want to use their own
models in the case of the `deeptext` and `deepimage` components. That is
perfectly possible as long as the the custom models have an attribute called
`output_dim` with the size of the last layer of activations, so that
`WideDeep` can be constructed
`pytorch-widedeep` includes standard text (stack of LSTMs) and image
(pre-trained ResNets or stack of CNNs) models. However, the user can use any
custom model as long as it has an attribute called `output_dim` with the size
of the last layer of activations, so that `WideDeep` can be constructed. See
the examples folder or the docs for more information.
(pre-trained ResNets or stack of CNNs) models.
See the examples folder or the docs for more information.
### Installation
......@@ -124,6 +128,28 @@ cd pytorch-widedeep
pip install -e .
```
**Important note for Mac users**: at the time of writing (Dec-2020) the latest
`torch` release is `1.7`. This release has some
[issues](https://stackoverflow.com/questions/64772335/pytorch-w-parallelnative-cpp206)
when running on Mac and the data-loaders will not run in parallel. In
addition, since `python 3.8`, [the `multiprocessing` library start method
changed from `'fork'` to
`'spawn'`](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods).
This also affects the data-loaders (for any `torch` version) and they will not
run in parallel. Therefore, for Mac users I recommend using `python 3.6` or
`3.7` and `torch <= 1.6` (with the corresponding, consistent version of
`torchvision`, e.g. `0.7.0` for `torch 1.6`). I do not want to force this
versioning in the `setup.py` file since I expect that all these issues are
fixed in the future. Therefore, after installing `pytorch-widedeep` via pip or
directly from github, downgrade `torch` and `torchvision` manually:
```bash
pip install pytorch-widedeep
pip install torch==1.6.0 torchvision==0.7.0
```
None of these issues affect Linux users.
### Quick start
Binary classification with the [adult
......
0.4.6
\ No newline at end of file
0.4.7
\ No newline at end of file
docs/figures/widedeep_logo.png

72.8 KB | W: | H:

docs/figures/widedeep_logo.png

38.0 KB | W: | H:

docs/figures/widedeep_logo.png
docs/figures/widedeep_logo.png
docs/figures/widedeep_logo.png
docs/figures/widedeep_logo.png
  • 2-up
  • Swipe
  • Onion skin
......@@ -130,7 +130,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"if we simply numerically encode (label encode or `le`) the values, starting from 1 (we will save 0 for padding, i.e. unseen values)"
"if we simply numerically encode (label encode or `le`) the values:"
]
},
{
......@@ -146,7 +146,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"now, let's see if the two implementations are equivalent"
"Note that in the functioning implementation of the package we start from 1, saving 0 for padding, i.e. unseen values. \n",
"\n",
"Now, let's see if the two implementations are equivalent"
]
},
{
......@@ -261,7 +263,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that even though the input dim is 10, the Embedding layer has 11 weights. This is because we save 0 for padding, which is used for unseen values during the encoding process"
"Note that even though the input dim is 10, the Embedding layer has 11 weights. Again, this is because we save 0 for padding, which is used for unseen values during the encoding process"
]
},
{
......
......@@ -591,16 +591,16 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 611/611 [00:05<00:00, 115.33it/s, loss=0.743, metrics={'acc': 0.6205, 'prec': 0.2817}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 168.06it/s, loss=0.545, metrics={'acc': 0.6452, 'prec': 0.3014}]\n",
"epoch 2: 100%|██████████| 611/611 [00:04<00:00, 122.57it/s, loss=0.486, metrics={'acc': 0.7765, 'prec': 0.5517}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 158.84it/s, loss=0.44, metrics={'acc': 0.783, 'prec': 0.573}] \n",
"epoch 3: 100%|██████████| 611/611 [00:04<00:00, 124.89it/s, loss=0.419, metrics={'acc': 0.8129, 'prec': 0.6753}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 158.10it/s, loss=0.402, metrics={'acc': 0.815, 'prec': 0.6816}] \n",
"epoch 4: 100%|██████████| 611/611 [00:04<00:00, 126.35it/s, loss=0.393, metrics={'acc': 0.8228, 'prec': 0.7047}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 160.72it/s, loss=0.385, metrics={'acc': 0.8233, 'prec': 0.7024}]\n",
"epoch 5: 100%|██████████| 611/611 [00:04<00:00, 124.33it/s, loss=0.38, metrics={'acc': 0.826, 'prec': 0.702}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 163.43it/s, loss=0.376, metrics={'acc': 0.8264, 'prec': 0.7}] \n"
"epoch 1: 100%|██████████| 611/611 [00:06<00:00, 101.71it/s, loss=0.448, metrics={'acc': 0.792, 'prec': 0.5728}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 171.00it/s, loss=0.366, metrics={'acc': 0.7991, 'prec': 0.5907}]\n",
"epoch 2: 100%|██████████| 611/611 [00:06<00:00, 101.69it/s, loss=0.361, metrics={'acc': 0.8324, 'prec': 0.6817}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 169.36it/s, loss=0.357, metrics={'acc': 0.8328, 'prec': 0.6807}]\n",
"epoch 3: 100%|██████████| 611/611 [00:05<00:00, 102.65it/s, loss=0.352, metrics={'acc': 0.8366, 'prec': 0.691}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 171.49it/s, loss=0.352, metrics={'acc': 0.8361, 'prec': 0.6867}]\n",
"epoch 4: 100%|██████████| 611/611 [00:06<00:00, 101.52it/s, loss=0.347, metrics={'acc': 0.8389, 'prec': 0.6956}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 163.49it/s, loss=0.349, metrics={'acc': 0.8383, 'prec': 0.6906}]\n",
"epoch 5: 100%|██████████| 611/611 [00:07<00:00, 84.91it/s, loss=0.343, metrics={'acc': 0.8405, 'prec': 0.6987}] \n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 142.83it/s, loss=0.347, metrics={'acc': 0.8399, 'prec': 0.6946}]\n"
]
}
],
......@@ -664,22 +664,88 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 611/611 [00:05<00:00, 108.62it/s, loss=0.894, metrics={'acc': 0.5182, 'prec': 0.2037}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 154.44it/s, loss=0.604, metrics={'acc': 0.5542, 'prec': 0.2135}]\n",
"epoch 2: 100%|██████████| 611/611 [00:05<00:00, 106.49it/s, loss=0.51, metrics={'acc': 0.751, 'prec': 0.4614}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 157.79it/s, loss=0.452, metrics={'acc': 0.7581, 'prec': 0.4898}]\n",
"epoch 3: 100%|██████████| 611/611 [00:05<00:00, 106.66it/s, loss=0.425, metrics={'acc': 0.8031, 'prec': 0.6618}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 160.73it/s, loss=0.405, metrics={'acc': 0.806, 'prec': 0.6686}] \n",
"epoch 4: 100%|██████████| 611/611 [00:05<00:00, 106.58it/s, loss=0.394, metrics={'acc': 0.8185, 'prec': 0.6966}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 155.55it/s, loss=0.385, metrics={'acc': 0.8196, 'prec': 0.6994}]\n",
"epoch 5: 100%|██████████| 611/611 [00:05<00:00, 107.28it/s, loss=0.38, metrics={'acc': 0.8236, 'prec': 0.7004}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 155.37it/s, loss=0.375, metrics={'acc': 0.8244, 'prec': 0.7017}]\n"
"epoch 1: 100%|██████████| 611/611 [00:07<00:00, 77.46it/s, loss=0.387, metrics={'acc': 0.8192, 'prec': 0.6576}]\n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 147.78it/s, loss=0.36, metrics={'acc': 0.8216, 'prec': 0.6617}] \n",
"epoch 2: 100%|██████████| 611/611 [00:08<00:00, 74.99it/s, loss=0.358, metrics={'acc': 0.8313, 'prec': 0.6836}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 158.26it/s, loss=0.355, metrics={'acc': 0.8321, 'prec': 0.6848}]\n",
"epoch 3: 100%|██████████| 611/611 [00:08<00:00, 76.28it/s, loss=0.351, metrics={'acc': 0.8345, 'prec': 0.6889}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 154.84it/s, loss=0.354, metrics={'acc': 0.8347, 'prec': 0.6887}]\n",
"epoch 4: 100%|██████████| 611/611 [00:07<00:00, 76.71it/s, loss=0.346, metrics={'acc': 0.8374, 'prec': 0.6946}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 157.80it/s, loss=0.353, metrics={'acc': 0.8369, 'prec': 0.6935}]\n",
"epoch 5: 100%|██████████| 611/611 [00:08<00:00, 73.25it/s, loss=0.343, metrics={'acc': 0.8386, 'prec': 0.6966}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 157.05it/s, loss=0.352, metrics={'acc': 0.8382, 'prec': 0.6961}]\n"
]
}
],
"source": [
"model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=5, batch_size=64, val_split=0.2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Also mentioning that one could build a model with the individual components independently. For example, a model comprised only by the `wide` component would be simply a linear model. This could be attained by just:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"model = WideDeep(wide=wide)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"model.compile(method='binary', metrics=[Accuracy, Precision])"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 0%| | 0/611 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 611/611 [00:03<00:00, 188.59it/s, loss=0.482, metrics={'acc': 0.771, 'prec': 0.5633}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 236.13it/s, loss=0.423, metrics={'acc': 0.7747, 'prec': 0.5819}]\n",
"epoch 2: 100%|██████████| 611/611 [00:03<00:00, 190.62it/s, loss=0.399, metrics={'acc': 0.8131, 'prec': 0.686}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 221.47it/s, loss=0.387, metrics={'acc': 0.8138, 'prec': 0.6879}]\n",
"epoch 3: 100%|██████████| 611/611 [00:03<00:00, 190.28it/s, loss=0.378, metrics={'acc': 0.8267, 'prec': 0.7149}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 241.12it/s, loss=0.374, metrics={'acc': 0.8255, 'prec': 0.7128}]\n",
"epoch 4: 100%|██████████| 611/611 [00:03<00:00, 183.27it/s, loss=0.37, metrics={'acc': 0.8304, 'prec': 0.7073}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 227.46it/s, loss=0.369, metrics={'acc': 0.8294, 'prec': 0.7061}]\n",
"epoch 5: 100%|██████████| 611/611 [00:03<00:00, 184.28it/s, loss=0.366, metrics={'acc': 0.8315, 'prec': 0.7006}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 239.87it/s, loss=0.366, metrics={'acc': 0.8303, 'prec': 0.6999}]\n"
]
}
],
"source": [
"model.fit(X_wide=X_wide, target=target, n_epochs=5, batch_size=64, val_split=0.2)"
]
}
],
"metadata": {
......
......@@ -4,11 +4,7 @@
[![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://github.com/jrzaurin/pytorch-widedeep/graphs/commit-activity)
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/jrzaurin/pytorch-widedeep/issues)
[![codecov](https://codecov.io/gh/jrzaurin/pytorch-widedeep/branch/master/graph/badge.svg)](https://codecov.io/gh/jrzaurin/pytorch-widedeep)
Platform | Version Support
---------|:---------------
OSX | [![Python 3.6 3.7](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)](https://www.python.org/)
Linux | [![Python 3.6 3.7 3.8](https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8-blue.svg)](https://www.python.org/)
[![Python 3.6 3.7 3.8](https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8-blue.svg)](https://www.python.org/)
# pytorch-widedeep
......@@ -57,6 +53,28 @@ cd pytorch-widedeep
pip install -e .
```
**Important note for Mac users**: at the time of writing (Dec-2020) the latest
`torch` release is `1.7`. This release has some
[issues](https://stackoverflow.com/questions/64772335/pytorch-w-parallelnative-cpp206)
when running on Mac and the data-loaders will not run in parallel. In
addition, since `python 3.8`, [the `multiprocessing` library start method
changed from `'fork'` to
`'spawn'`](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods).
This also affects the data-loaders (for any `torch` version) and they will not
run in parallel. Therefore, for Mac users I recommend using `python 3.6` or
`3.7` and `torch <= 1.6` (with the corresponding, consistent version of
`torchvision`, e.g. `0.7.0` for `torch 1.6`). I do not want to force this
versioning in the `setup.py` file since I expect that all these issues are
fixed in the future. Therefore, after installing `pytorch-widedeep` via pip or
directly from github, downgrade `torch` and `torchvision` manually:
```bash
pip install pytorch-widedeep
pip install torch==1.6.0 torchvision==0.7.0
```
None of these issues affect Linux users.
### Quick start
Binary classification with the [adult
......
......@@ -27,11 +27,11 @@ class WideDeepDataset(Dataset):
def __init__(
self,
X_wide: np.ndarray,
X_deep: np.ndarray,
target: Optional[np.ndarray] = None,
X_wide: Optional[np.ndarray] = None,
X_deep: Optional[np.ndarray] = None,
X_text: Optional[np.ndarray] = None,
X_img: Optional[np.ndarray] = None,
target: Optional[np.ndarray] = None,
transforms: Optional[Any] = None,
):
......@@ -48,10 +48,12 @@ class WideDeepDataset(Dataset):
self.transforms_names = []
self.Y = target
def __getitem__(self, idx: int):
# X_wide and X_deep are assumed to be *always* present
X = Bunch(wide=self.X_wide[idx])
X.deepdense = self.X_deep[idx]
def __getitem__(self, idx: int): # noqa: C901
X = Bunch()
if self.X_wide is not None:
X.wide = self.X_wide[idx]
if self.X_deep is not None:
X.deepdense = self.X_deep[idx]
if self.X_text is not None:
X.deeptext = self.X_text[idx]
if self.X_img is not None:
......@@ -68,6 +70,8 @@ class WideDeepDataset(Dataset):
# then we need to replicate what Tensor() does -> transpose axis
# and normalize if necessary
if not self.transforms or "ToTensor" not in self.transforms_names:
if xdi.ndim == 2:
xdi = xdi[:, :, None]
xdi = xdi.transpose(2, 0, 1)
if "int" in str(xdi.dtype):
xdi = (xdi / xdi.max()).astype("float32")
......@@ -87,4 +91,11 @@ class WideDeepDataset(Dataset):
return X
def __len__(self):
return len(self.X_deep)
if self.X_wide is not None:
return len(self.X_wide)
if self.X_deep is not None:
return len(self.X_deep)
if self.X_text is not None:
return len(self.X_text)
if self.X_img is not None:
return len(self.X_img)
__version__ = "0.4.6"
__version__ = "0.4.7"
......@@ -33,9 +33,10 @@ extras["docs"] = [
]
extras["quality"] = [
"black",
"isort @ git+git://github.com/timothycrosley/isort.git@e63ae06ec7d70b06df9e528357650281a3d3ec22#egg=isort",
"isort",
"flake8",
]
extras["all"] = extras["test"] + extras["docs"] + extras["quality"]
# main setup kw args
setup_kwargs = {
......@@ -62,7 +63,7 @@ setup_kwargs = {
"torch",
"torchvision",
],
"extra_requires": extras,
"extras_require": extras,
"python_requires": ">=3.6.0",
"classifiers": [
dev_status[majorminor],
......
......@@ -55,7 +55,7 @@ def test_history_callback(deepcomponent, component_name):
def test_deephead_and_head_layers():
deephead = nn.Sequential(nn.Linear(32, 16), nn.Linear(16, 8))
with pytest.warns(UserWarning):
with pytest.raises(ValueError):
model = WideDeep( # noqa: F841
wide=wide, deepdense=deepdense, head_layers=[16, 8], deephead=deephead
)
......
......@@ -2,6 +2,7 @@ import string
import numpy as np
import pytest
from torch import nn
from torchvision.transforms import ToTensor, Normalize
from sklearn.model_selection import train_test_split
......@@ -67,11 +68,16 @@ std = [0.225, 0.224, 0.229] # BGR
transforms1 = [ToTensor, Normalize(mean=mean, std=std)]
transforms2 = [Normalize(mean=mean, std=std)]
deephead_ds = nn.Sequential(nn.Linear(16, 8), nn.Linear(8, 4))
deephead_dt = nn.Sequential(nn.Linear(64, 8), nn.Linear(8, 4))
deephead_di = nn.Sequential(nn.Linear(512, 8), nn.Linear(8, 4))
##############################################################################
# #############################################################################
# Test many possible scenarios of data inputs I can think off. Surely users
# will input something unexpected
##############################################################################
# #############################################################################
@pytest.mark.parametrize(
"X_wide, X_deep, X_text, X_img, X_train, X_val, target, val_split, transforms, nepoch, null",
[
......@@ -266,3 +272,141 @@ def test_widedeep_inputs(
model.history.epoch[0] == nepoch
and model.history._history["train_loss"] is not null
)
@pytest.mark.parametrize(
"X_wide, X_deep, X_text, X_img, X_train, X_val, target",
[
(
X_wide,
X_deep,
X_text,
X_img,
None,
{
"X_wide": X_wide_val,
"X_deep": X_deep_val,
"X_text": X_text_val,
"X_img": X_img_val,
"target": y_val,
},
target,
),
],
)
def test_xtrain_xval_assertion(
X_wide,
X_deep,
X_text,
X_img,
X_train,
X_val,
target,
):
model = WideDeep(
wide=wide, deepdense=deepdense, deeptext=deeptext, deepimage=deepimage
)
model.compile(method="binary", verbose=0)
with pytest.raises(AssertionError):
model.fit(
X_wide=X_wide,
X_deep=X_deep,
X_text=X_text,
X_img=X_img,
X_train=X_train,
X_val=X_val,
target=target,
batch_size=16,
)
@pytest.mark.parametrize(
"wide, deepdense, deeptext, deepimage, X_wide, X_deep, X_text, X_img, target",
[
(wide, None, None, None, X_wide, None, None, None, target),
(None, deepdense, None, None, None, X_deep, None, None, target),
(None, None, deeptext, None, None, None, X_text, None, target),
(None, None, None, deepimage, None, None, None, X_img, target),
],
)
def test_individual_inputs(
wide, deepdense, deeptext, deepimage, X_wide, X_deep, X_text, X_img, target
):
model = WideDeep(
wide=wide, deepdense=deepdense, deeptext=deeptext, deepimage=deepimage
)
model.compile(method="binary", verbose=0)
model.fit(
X_wide=X_wide,
X_deep=X_deep,
X_text=X_text,
X_img=X_img,
target=target,
batch_size=16,
)
# check it has run succesfully
assert len(model.history._history) == 1
###############################################################################
#  test deephead is not None and individual components
###############################################################################
@pytest.mark.parametrize(
"deepdense, deeptext, deepimage, X_deep, X_text, X_img, deephead, target",
[
(deepdense, None, None, X_deep, None, None, deephead_ds, target),
(None, deeptext, None, None, X_text, None, deephead_dt, target),
(None, None, deepimage, None, None, X_img, deephead_di, target),
],
)
def test_deephead_individual_components(
deepdense, deeptext, deepimage, X_deep, X_text, X_img, deephead, target
):
model = WideDeep(
deepdense=deepdense, deeptext=deeptext, deepimage=deepimage, deephead=deephead
) # noqa: F841
model.compile(method="binary", verbose=0)
model.fit(
X_wide=X_wide,
X_deep=X_deep,
X_text=X_text,
X_img=X_img,
target=target,
batch_size=16,
)
# check it has run succesfully
assert len(model.history._history) == 1
###############################################################################
#  test deephead is None and head_layers is not None and individual components
###############################################################################
@pytest.mark.parametrize(
"deepdense, deeptext, deepimage, X_deep, X_text, X_img, target",
[
(deepdense, None, None, X_deep, None, None, target),
(None, deeptext, None, None, X_text, None, target),
(None, None, deepimage, None, None, X_img, target),
],
)
def test_head_layers_individual_components(
deepdense, deeptext, deepimage, X_deep, X_text, X_img, target
):
model = WideDeep(
deepdense=deepdense, deeptext=deeptext, deepimage=deepimage, head_layers=[8, 4]
) # noqa: F841
model.compile(method="binary", verbose=0)
model.fit(
X_wide=X_wide,
X_deep=X_deep,
X_text=X_text,
X_img=X_img,
target=target,
batch_size=16,
)
# check it has run succesfully
assert len(model.history._history) == 1
import string
import numpy as np
import torch
import pytest
from sklearn.model_selection import train_test_split
from pytorch_widedeep.models import (
Wide,
DeepText,
WideDeep,
DeepDense,
DeepImage,
)
from pytorch_widedeep.metrics import Accuracy, Precision
from pytorch_widedeep.callbacks import EarlyStopping
# Wide array
X_wide = np.random.choice(50, (32, 10))
# Deep Array
colnames = list(string.ascii_lowercase)[:10]
embed_cols = [np.random.choice(np.arange(5), 32) for _ in range(5)]
embed_input = [(u, i, j) for u, i, j in zip(colnames[:5], [5] * 5, [16] * 5)]
cont_cols = [np.random.rand(32) for _ in range(5)]
X_deep = np.vstack(embed_cols + cont_cols).transpose()
#  Text Array
padded_sequences = np.random.choice(np.arange(1, 100), (32, 48))
X_text = np.hstack((np.repeat(np.array([[0, 0]]), 32, axis=0), padded_sequences))
vocab_size = 100
#  Image Array
X_img = np.random.choice(256, (32, 224, 224, 3))
X_img_norm = X_img / 255.0
# Target
target = np.random.choice(2, 32)
target_multi = np.random.choice(3, 32)
# train/validation split
(
X_wide_tr,
X_wide_val,
X_deep_tr,
X_deep_val,
X_text_tr,
X_text_val,
X_img_tr,
X_img_val,
y_train,
y_val,
) = train_test_split(X_wide, X_deep, X_text, X_img, target)
# build model components
wide = Wide(np.unique(X_wide).shape[0], 1)
deepdense = DeepDense(
hidden_layers=[32, 16],
dropout=[0.5, 0.5],
deep_column_idx={k: v for v, k in enumerate(colnames)},
embed_input=embed_input,
continuous_cols=colnames[-5:],
)
deeptext = DeepText(vocab_size=vocab_size, embed_dim=32, padding_idx=0)
deepimage = DeepImage(pretrained=True)
###############################################################################
#  test consistecy between optimizers and lr_schedulers format
###############################################################################
def test_optimizer_scheduler_format():
model = WideDeep(deepdense=deepdense)
optimizers = {"deepdense": torch.optim.Adam(model.deepdense.parameters(), lr=0.01)}
schedulers = torch.optim.lr_scheduler.StepLR(optimizers["deepdense"], step_size=3)
with pytest.raises(ValueError):
model.compile(
method="binary",
optimizers=optimizers,
lr_schedulers=schedulers,
)
###############################################################################
#  test that callbacks are properly initialised internally
###############################################################################
def test_non_instantiated_callbacks():
model = WideDeep(wide=wide, deepdense=deepdense)
callbacks = [EarlyStopping]
model.compile(method="binary", callbacks=callbacks)
assert model.callbacks[1].__class__.__name__ == "EarlyStopping"
###############################################################################
#  test that multiple metrics are properly constructed internally
###############################################################################
def test_multiple_metrics():
model = WideDeep(wide=wide, deepdense=deepdense)
metrics = [Accuracy, Precision]
model.compile(method="binary", metrics=metrics)
assert (
model.metric._metrics[0].__class__.__name__ == "Accuracy"
and model.metric._metrics[1].__class__.__name__ == "Precision"
)
###############################################################################
#  test the train step with metrics runs well for a binary prediction
###############################################################################
def test_basic_run_with_metrics_binary():
model = WideDeep(wide=wide, deepdense=deepdense)
model.compile(method="binary", metrics=[Accuracy], verbose=False)
model.fit(
X_wide=X_wide,
X_deep=X_deep,
target=target,
n_epochs=1,
batch_size=16,
val_split=0.2,
)
assert (
"train_loss" in model.history._history.keys()
and "train_acc" in model.history._history.keys()
)
###############################################################################
#  test the train step with metrics runs well for a muticlass prediction
###############################################################################
def test_basic_run_with_metrics_multiclass():
wide = Wide(np.unique(X_wide).shape[0], 3)
deepdense = DeepDense(
hidden_layers=[32, 16],
dropout=[0.5, 0.5],
deep_column_idx={k: v for v, k in enumerate(colnames)},
embed_input=embed_input,
continuous_cols=colnames[-5:],
)
model = WideDeep(wide=wide, deepdense=deepdense, pred_dim=3)
model.compile(method="multiclass", metrics=[Accuracy], verbose=False)
model.fit(
X_wide=X_wide,
X_deep=X_deep,
target=target_multi,
n_epochs=1,
batch_size=16,
val_split=0.2,
)
assert (
"train_loss" in model.history._history.keys()
and "train_acc" in model.history._history.keys()
)
###############################################################################
#  test predict method for individual components
###############################################################################
@pytest.mark.parametrize(
"wide, deepdense, deeptext, deepimage, X_wide, X_deep, X_text, X_img, target",
[
(wide, None, None, None, X_wide, None, None, None, target),
(None, deepdense, None, None, None, X_deep, None, None, target),
(None, None, deeptext, None, None, None, X_text, None, target),
(None, None, None, deepimage, None, None, None, X_img, target),
],
)
def test_predict_with_individual_component(
wide, deepdense, deeptext, deepimage, X_wide, X_deep, X_text, X_img, target
):
model = WideDeep(
wide=wide, deepdense=deepdense, deeptext=deeptext, deepimage=deepimage
)
model.compile(method="binary", verbose=0)
model.fit(
X_wide=X_wide,
X_deep=X_deep,
X_text=X_text,
X_img=X_img,
target=target,
batch_size=16,
)
# simply checking that runs and produces outputs
preds = model.predict(X_wide=X_wide, X_deep=X_deep, X_text=X_text, X_img=X_img)
assert preds.shape[0] == 32 and "train_loss" in model.history._history
......@@ -161,7 +161,7 @@ def test_warm_all(model, modelname, loader, n_epochs, max_lr):
has_run = True
try:
warmer.warm_all(model, modelname, loader, n_epochs, max_lr)
except:
except Exception:
has_run = False
assert has_run
......@@ -182,6 +182,6 @@ def test_warm_gradual(model, modelname, loader, max_lr, layers, routine):
has_run = True
try:
warmer.warm_gradual(model, modelname, loader, max_lr, layers, routine)
except:
except Exception:
has_run = False
assert has_run
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册