{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Simple Binary Classification with defaults\n",
"\n",
"In this notebook we will use the Adult Census dataset. Download the data from [here](https://www.kaggle.com/wenruliu/adult-income-dataset/downloads/adult.csv/2)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"\n",
"from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor\n",
"from pytorch_widedeep.models import Wide, DeepDense, DeepDenseResnet, WideDeep\n",
"from pytorch_widedeep.metrics import Accuracy, Precision"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" age | \n",
" workclass | \n",
" fnlwgt | \n",
" education | \n",
" educational-num | \n",
" marital-status | \n",
" occupation | \n",
" relationship | \n",
" race | \n",
" gender | \n",
" capital-gain | \n",
" capital-loss | \n",
" hours-per-week | \n",
" native-country | \n",
" income | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 25 | \n",
" Private | \n",
" 226802 | \n",
" 11th | \n",
" 7 | \n",
" Never-married | \n",
" Machine-op-inspct | \n",
" Own-child | \n",
" Black | \n",
" Male | \n",
" 0 | \n",
" 0 | \n",
" 40 | \n",
" United-States | \n",
" <=50K | \n",
"
\n",
" \n",
" 1 | \n",
" 38 | \n",
" Private | \n",
" 89814 | \n",
" HS-grad | \n",
" 9 | \n",
" Married-civ-spouse | \n",
" Farming-fishing | \n",
" Husband | \n",
" White | \n",
" Male | \n",
" 0 | \n",
" 0 | \n",
" 50 | \n",
" United-States | \n",
" <=50K | \n",
"
\n",
" \n",
" 2 | \n",
" 28 | \n",
" Local-gov | \n",
" 336951 | \n",
" Assoc-acdm | \n",
" 12 | \n",
" Married-civ-spouse | \n",
" Protective-serv | \n",
" Husband | \n",
" White | \n",
" Male | \n",
" 0 | \n",
" 0 | \n",
" 40 | \n",
" United-States | \n",
" >50K | \n",
"
\n",
" \n",
" 3 | \n",
" 44 | \n",
" Private | \n",
" 160323 | \n",
" Some-college | \n",
" 10 | \n",
" Married-civ-spouse | \n",
" Machine-op-inspct | \n",
" Husband | \n",
" Black | \n",
" Male | \n",
" 7688 | \n",
" 0 | \n",
" 40 | \n",
" United-States | \n",
" >50K | \n",
"
\n",
" \n",
" 4 | \n",
" 18 | \n",
" ? | \n",
" 103497 | \n",
" Some-college | \n",
" 10 | \n",
" Never-married | \n",
" ? | \n",
" Own-child | \n",
" White | \n",
" Female | \n",
" 0 | \n",
" 0 | \n",
" 30 | \n",
" United-States | \n",
" <=50K | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" age workclass fnlwgt education educational-num marital-status \\\n",
"0 25 Private 226802 11th 7 Never-married \n",
"1 38 Private 89814 HS-grad 9 Married-civ-spouse \n",
"2 28 Local-gov 336951 Assoc-acdm 12 Married-civ-spouse \n",
"3 44 Private 160323 Some-college 10 Married-civ-spouse \n",
"4 18 ? 103497 Some-college 10 Never-married \n",
"\n",
" occupation relationship race gender capital-gain capital-loss \\\n",
"0 Machine-op-inspct Own-child Black Male 0 0 \n",
"1 Farming-fishing Husband White Male 0 0 \n",
"2 Protective-serv Husband White Male 0 0 \n",
"3 Machine-op-inspct Husband Black Male 7688 0 \n",
"4 ? Own-child White Female 0 0 \n",
"\n",
" hours-per-week native-country income \n",
"0 40 United-States <=50K \n",
"1 50 United-States <=50K \n",
"2 40 United-States >50K \n",
"3 40 United-States >50K \n",
"4 30 United-States <=50K "
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.read_csv('data/adult/adult.csv.zip')\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" age | \n",
" workclass | \n",
" fnlwgt | \n",
" education | \n",
" educational_num | \n",
" marital_status | \n",
" occupation | \n",
" relationship | \n",
" race | \n",
" gender | \n",
" capital_gain | \n",
" capital_loss | \n",
" hours_per_week | \n",
" native_country | \n",
" income_label | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 25 | \n",
" Private | \n",
" 226802 | \n",
" 11th | \n",
" 7 | \n",
" Never-married | \n",
" Machine-op-inspct | \n",
" Own-child | \n",
" Black | \n",
" Male | \n",
" 0 | \n",
" 0 | \n",
" 40 | \n",
" United-States | \n",
" 0 | \n",
"
\n",
" \n",
" 1 | \n",
" 38 | \n",
" Private | \n",
" 89814 | \n",
" HS-grad | \n",
" 9 | \n",
" Married-civ-spouse | \n",
" Farming-fishing | \n",
" Husband | \n",
" White | \n",
" Male | \n",
" 0 | \n",
" 0 | \n",
" 50 | \n",
" United-States | \n",
" 0 | \n",
"
\n",
" \n",
" 2 | \n",
" 28 | \n",
" Local-gov | \n",
" 336951 | \n",
" Assoc-acdm | \n",
" 12 | \n",
" Married-civ-spouse | \n",
" Protective-serv | \n",
" Husband | \n",
" White | \n",
" Male | \n",
" 0 | \n",
" 0 | \n",
" 40 | \n",
" United-States | \n",
" 1 | \n",
"
\n",
" \n",
" 3 | \n",
" 44 | \n",
" Private | \n",
" 160323 | \n",
" Some-college | \n",
" 10 | \n",
" Married-civ-spouse | \n",
" Machine-op-inspct | \n",
" Husband | \n",
" Black | \n",
" Male | \n",
" 7688 | \n",
" 0 | \n",
" 40 | \n",
" United-States | \n",
" 1 | \n",
"
\n",
" \n",
" 4 | \n",
" 18 | \n",
" ? | \n",
" 103497 | \n",
" Some-college | \n",
" 10 | \n",
" Never-married | \n",
" ? | \n",
" Own-child | \n",
" White | \n",
" Female | \n",
" 0 | \n",
" 0 | \n",
" 30 | \n",
" United-States | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" age workclass fnlwgt education educational_num marital_status \\\n",
"0 25 Private 226802 11th 7 Never-married \n",
"1 38 Private 89814 HS-grad 9 Married-civ-spouse \n",
"2 28 Local-gov 336951 Assoc-acdm 12 Married-civ-spouse \n",
"3 44 Private 160323 Some-college 10 Married-civ-spouse \n",
"4 18 ? 103497 Some-college 10 Never-married \n",
"\n",
" occupation relationship race gender capital_gain capital_loss \\\n",
"0 Machine-op-inspct Own-child Black Male 0 0 \n",
"1 Farming-fishing Husband White Male 0 0 \n",
"2 Protective-serv Husband White Male 0 0 \n",
"3 Machine-op-inspct Husband Black Male 7688 0 \n",
"4 ? Own-child White Female 0 0 \n",
"\n",
" hours_per_week native_country income_label \n",
"0 40 United-States 0 \n",
"1 50 United-States 0 \n",
"2 40 United-States 1 \n",
"3 40 United-States 1 \n",
"4 30 United-States 0 "
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# For convenience, we'll replace '-' with '_'\n",
"df.columns = [c.replace(\"-\", \"_\") for c in df.columns]\n",
"# binary target\n",
"df['income_label'] = (df[\"income\"].apply(lambda x: \">50K\" in x)).astype(int)\n",
"df.drop('income', axis=1, inplace=True)\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Preparing the data\n",
"\n",
"Have a look to notebooks one and two if you want to get a good understanding of the next few lines of code (although there is no need to use the package)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"wide_cols = ['education', 'relationship','workclass','occupation','native_country','gender']\n",
"crossed_cols = [('education', 'occupation'), ('native_country', 'occupation')]\n",
"cat_embed_cols = [('education',16), ('relationship',8), ('workclass',16), ('occupation',16),('native_country',16)]\n",
"continuous_cols = [\"age\",\"hours_per_week\"]\n",
"target_col = 'income_label'"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# TARGET\n",
"target = df[target_col].values\n",
"\n",
"# WIDE\n",
"preprocess_wide = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)\n",
"X_wide = preprocess_wide.fit_transform(df)\n",
"\n",
"# DEEP\n",
"preprocess_deep = DensePreprocessor(embed_cols=cat_embed_cols, continuous_cols=continuous_cols)\n",
"X_deep = preprocess_deep.fit_transform(df)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 1 17 23 ... 89 91 316]\n",
" [ 2 18 23 ... 89 92 317]\n",
" [ 3 18 24 ... 89 93 318]\n",
" ...\n",
" [ 2 20 23 ... 90 103 323]\n",
" [ 2 17 23 ... 89 103 323]\n",
" [ 2 21 29 ... 90 115 324]]\n",
"(48842, 8)\n"
]
}
],
"source": [
"print(X_wide)\n",
"print(X_wide.shape)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0. 0. 0. ... 0. -0.99512893\n",
" -0.03408696]\n",
" [ 1. 1. 0. ... 0. -0.04694151\n",
" 0.77292975]\n",
" [ 2. 1. 1. ... 0. -0.77631645\n",
" -0.03408696]\n",
" ...\n",
" [ 1. 3. 0. ... 0. 1.41180837\n",
" -0.03408696]\n",
" [ 1. 0. 0. ... 0. -1.21394141\n",
" -1.64812038]\n",
" [ 1. 4. 6. ... 0. 0.97418341\n",
" -0.03408696]]\n",
"(48842, 7)\n"
]
}
],
"source": [
"print(X_deep)\n",
"print(X_deep.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Defining the model"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)\n",
"deepdense = DeepDense(hidden_layers=[64,32], \n",
" column_idx=preprocess_deep.column_idx,\n",
" embed_input=preprocess_deep.embeddings_input,\n",
" continuous_cols=continuous_cols)\n",
"model = WideDeep(wide=wide, deepdense=deepdense)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"WideDeep(\n",
" (wide): Wide(\n",
" (wide_linear): Embedding(797, 1, padding_idx=0)\n",
" )\n",
" (deepdense): Sequential(\n",
" (0): DeepDense(\n",
" (embed_layers): ModuleDict(\n",
" (emb_layer_education): Embedding(17, 16)\n",
" (emb_layer_native_country): Embedding(43, 16)\n",
" (emb_layer_occupation): Embedding(16, 16)\n",
" (emb_layer_relationship): Embedding(7, 8)\n",
" (emb_layer_workclass): Embedding(10, 16)\n",
" )\n",
" (embed_dropout): Dropout(p=0.0, inplace=False)\n",
" (dense): Sequential(\n",
" (dense_layer_0): Sequential(\n",
" (0): Linear(in_features=74, out_features=64, bias=True)\n",
" (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
" (2): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (dense_layer_1): Sequential(\n",
" (0): Linear(in_features=64, out_features=32, bias=True)\n",
" (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
" (2): Dropout(p=0.0, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (1): Linear(in_features=32, out_features=1, bias=True)\n",
" )\n",
")"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As you can see, the model is not particularly complex. In mathematical terms (Eq 3 in the [original paper](https://arxiv.org/pdf/1606.07792.pdf)): \n",
"\n",
"$$\n",
"pred = \\sigma(W^{T}_{wide}[x, \\phi(x)] + W^{T}_{deep}a_{deep}^{(l_f)} + b) \n",
"$$ \n",
"\n",
"\n",
"The architecture above will output the 1st and the second term in the parenthesis. `WideDeep` will then add them and apply an activation function (`sigmoid` in this case). For more details, please refer to the paper."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Compiling and Running/Fitting\n",
"Once the model is built, we just need to compile it and run it"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"model.compile(method='binary', metrics=[Accuracy, Precision])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 0%| | 0/611 [00:00, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 611/611 [00:06<00:00, 101.71it/s, loss=0.448, metrics={'acc': 0.792, 'prec': 0.5728}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 171.00it/s, loss=0.366, metrics={'acc': 0.7991, 'prec': 0.5907}]\n",
"epoch 2: 100%|██████████| 611/611 [00:06<00:00, 101.69it/s, loss=0.361, metrics={'acc': 0.8324, 'prec': 0.6817}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 169.36it/s, loss=0.357, metrics={'acc': 0.8328, 'prec': 0.6807}]\n",
"epoch 3: 100%|██████████| 611/611 [00:05<00:00, 102.65it/s, loss=0.352, metrics={'acc': 0.8366, 'prec': 0.691}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 171.49it/s, loss=0.352, metrics={'acc': 0.8361, 'prec': 0.6867}]\n",
"epoch 4: 100%|██████████| 611/611 [00:06<00:00, 101.52it/s, loss=0.347, metrics={'acc': 0.8389, 'prec': 0.6956}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 163.49it/s, loss=0.349, metrics={'acc': 0.8383, 'prec': 0.6906}]\n",
"epoch 5: 100%|██████████| 611/611 [00:07<00:00, 84.91it/s, loss=0.343, metrics={'acc': 0.8405, 'prec': 0.6987}] \n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 142.83it/s, loss=0.347, metrics={'acc': 0.8399, 'prec': 0.6946}]\n"
]
}
],
"source": [
"model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=5, batch_size=64, val_split=0.2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As you can see, you can run a wide and deep model in just a few lines of code. \n",
"\n",
"Using `DeepDenseResnet` as the `deepdense` component"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)\n",
"deepdense = DeepDenseResnet(blocks=[64,32], \n",
" column_idx=preprocess_deep.column_idx,\n",
" embed_input=preprocess_deep.embeddings_input,\n",
" continuous_cols=continuous_cols)\n",
"model = WideDeep(wide=wide, deepdense=deepdense)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"model.compile(method='binary', metrics=[Accuracy, Precision])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 0%| | 0/611 [00:00, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 611/611 [00:07<00:00, 77.46it/s, loss=0.387, metrics={'acc': 0.8192, 'prec': 0.6576}]\n",
"valid: 100%|██████████| 153/153 [00:01<00:00, 147.78it/s, loss=0.36, metrics={'acc': 0.8216, 'prec': 0.6617}] \n",
"epoch 2: 100%|██████████| 611/611 [00:08<00:00, 74.99it/s, loss=0.358, metrics={'acc': 0.8313, 'prec': 0.6836}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 158.26it/s, loss=0.355, metrics={'acc': 0.8321, 'prec': 0.6848}]\n",
"epoch 3: 100%|██████████| 611/611 [00:08<00:00, 76.28it/s, loss=0.351, metrics={'acc': 0.8345, 'prec': 0.6889}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 154.84it/s, loss=0.354, metrics={'acc': 0.8347, 'prec': 0.6887}]\n",
"epoch 4: 100%|██████████| 611/611 [00:07<00:00, 76.71it/s, loss=0.346, metrics={'acc': 0.8374, 'prec': 0.6946}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 157.80it/s, loss=0.353, metrics={'acc': 0.8369, 'prec': 0.6935}]\n",
"epoch 5: 100%|██████████| 611/611 [00:08<00:00, 73.25it/s, loss=0.343, metrics={'acc': 0.8386, 'prec': 0.6966}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 157.05it/s, loss=0.352, metrics={'acc': 0.8382, 'prec': 0.6961}]\n"
]
}
],
"source": [
"model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=5, batch_size=64, val_split=0.2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Also mentioning that one could build a model with the individual components independently. For example, a model comprised only by the `wide` component would be simply a linear model. This could be attained by just:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"model = WideDeep(wide=wide)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"model.compile(method='binary', metrics=[Accuracy, Precision])"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r",
" 0%| | 0/611 [00:00, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 611/611 [00:03<00:00, 188.59it/s, loss=0.482, metrics={'acc': 0.771, 'prec': 0.5633}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 236.13it/s, loss=0.423, metrics={'acc': 0.7747, 'prec': 0.5819}]\n",
"epoch 2: 100%|██████████| 611/611 [00:03<00:00, 190.62it/s, loss=0.399, metrics={'acc': 0.8131, 'prec': 0.686}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 221.47it/s, loss=0.387, metrics={'acc': 0.8138, 'prec': 0.6879}]\n",
"epoch 3: 100%|██████████| 611/611 [00:03<00:00, 190.28it/s, loss=0.378, metrics={'acc': 0.8267, 'prec': 0.7149}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 241.12it/s, loss=0.374, metrics={'acc': 0.8255, 'prec': 0.7128}]\n",
"epoch 4: 100%|██████████| 611/611 [00:03<00:00, 183.27it/s, loss=0.37, metrics={'acc': 0.8304, 'prec': 0.7073}] \n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 227.46it/s, loss=0.369, metrics={'acc': 0.8294, 'prec': 0.7061}]\n",
"epoch 5: 100%|██████████| 611/611 [00:03<00:00, 184.28it/s, loss=0.366, metrics={'acc': 0.8315, 'prec': 0.7006}]\n",
"valid: 100%|██████████| 153/153 [00:00<00:00, 239.87it/s, loss=0.366, metrics={'acc': 0.8303, 'prec': 0.6999}]\n"
]
}
],
"source": [
"model.fit(X_wide=X_wide, target=target, n_epochs=5, batch_size=64, val_split=0.2)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}