{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Simple Binary Classification with defaults\n", "\n", "In this notebook we will use the Adult Census dataset. Download the data from [here](https://www.kaggle.com/wenruliu/adult-income-dataset/downloads/adult.csv/2)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import torch\n", "\n", "from pytorch_widedeep.preprocessing import WidePreprocessor, DeepPreprocessor\n", "from pytorch_widedeep.models import Wide, DeepDense, WideDeep\n", "from pytorch_widedeep.metrics import BinaryAccuracy" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ageworkclassfnlwgteducationeducational-nummarital-statusoccupationrelationshipracegendercapital-gaincapital-losshours-per-weeknative-countryincome
025Private22680211th7Never-marriedMachine-op-inspctOwn-childBlackMale0040United-States<=50K
138Private89814HS-grad9Married-civ-spouseFarming-fishingHusbandWhiteMale0050United-States<=50K
228Local-gov336951Assoc-acdm12Married-civ-spouseProtective-servHusbandWhiteMale0040United-States>50K
344Private160323Some-college10Married-civ-spouseMachine-op-inspctHusbandBlackMale7688040United-States>50K
418?103497Some-college10Never-married?Own-childWhiteFemale0030United-States<=50K
\n", "
" ], "text/plain": [ " age workclass fnlwgt education educational-num marital-status \\\n", "0 25 Private 226802 11th 7 Never-married \n", "1 38 Private 89814 HS-grad 9 Married-civ-spouse \n", "2 28 Local-gov 336951 Assoc-acdm 12 Married-civ-spouse \n", "3 44 Private 160323 Some-college 10 Married-civ-spouse \n", "4 18 ? 103497 Some-college 10 Never-married \n", "\n", " occupation relationship race gender capital-gain capital-loss \\\n", "0 Machine-op-inspct Own-child Black Male 0 0 \n", "1 Farming-fishing Husband White Male 0 0 \n", "2 Protective-serv Husband White Male 0 0 \n", "3 Machine-op-inspct Husband Black Male 7688 0 \n", "4 ? Own-child White Female 0 0 \n", "\n", " hours-per-week native-country income \n", "0 40 United-States <=50K \n", "1 50 United-States <=50K \n", "2 40 United-States >50K \n", "3 40 United-States >50K \n", "4 30 United-States <=50K " ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv('data/adult/adult.csv.zip')\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ageworkclassfnlwgteducationeducational_nummarital_statusoccupationrelationshipracegendercapital_gaincapital_losshours_per_weeknative_countryincome_label
025Private22680211th7Never-marriedMachine-op-inspctOwn-childBlackMale0040United-States0
138Private89814HS-grad9Married-civ-spouseFarming-fishingHusbandWhiteMale0050United-States0
228Local-gov336951Assoc-acdm12Married-civ-spouseProtective-servHusbandWhiteMale0040United-States1
344Private160323Some-college10Married-civ-spouseMachine-op-inspctHusbandBlackMale7688040United-States1
418?103497Some-college10Never-married?Own-childWhiteFemale0030United-States0
\n", "
" ], "text/plain": [ " age workclass fnlwgt education educational_num marital_status \\\n", "0 25 Private 226802 11th 7 Never-married \n", "1 38 Private 89814 HS-grad 9 Married-civ-spouse \n", "2 28 Local-gov 336951 Assoc-acdm 12 Married-civ-spouse \n", "3 44 Private 160323 Some-college 10 Married-civ-spouse \n", "4 18 ? 103497 Some-college 10 Never-married \n", "\n", " occupation relationship race gender capital_gain capital_loss \\\n", "0 Machine-op-inspct Own-child Black Male 0 0 \n", "1 Farming-fishing Husband White Male 0 0 \n", "2 Protective-serv Husband White Male 0 0 \n", "3 Machine-op-inspct Husband Black Male 7688 0 \n", "4 ? Own-child White Female 0 0 \n", "\n", " hours_per_week native_country income_label \n", "0 40 United-States 0 \n", "1 50 United-States 0 \n", "2 40 United-States 1 \n", "3 40 United-States 1 \n", "4 30 United-States 0 " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# For convenience, we'll replace '-' with '_'\n", "df.columns = [c.replace(\"-\", \"_\") for c in df.columns]\n", "# binary target\n", "df['income_label'] = (df[\"income\"].apply(lambda x: \">50K\" in x)).astype(int)\n", "df.drop('income', axis=1, inplace=True)\n", "df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Preparing the data\n", "\n", "Have a look to notebooks one and two if you want to get a good understanding of the next few lines of code (although there is no need to use the package)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "wide_cols = ['education', 'relationship','workclass','occupation','native_country','gender']\n", "crossed_cols = [('education', 'occupation'), ('native_country', 'occupation')]\n", "cat_embed_cols = [('education',16), ('relationship',8), ('workclass',16),\n", " ('occupation',16),('native_country',16)]\n", "continuous_cols = [\"age\",\"hours_per_week\"]\n", "target_col = 'income_label'" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# TARGET\n", "target = df[target_col].values\n", "\n", "# WIDE\n", "preprocess_wide = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)\n", "X_wide = preprocess_wide.fit_transform(df)\n", "\n", "# DEEP\n", "preprocess_deep = DeepPreprocessor(embed_cols=cat_embed_cols, continuous_cols=continuous_cols)\n", "X_deep = preprocess_deep.fit_transform(df)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0. 1. 0. ... 0. 0. 0.]\n", " [0. 0. 0. ... 0. 0. 0.]\n", " [0. 0. 0. ... 0. 0. 0.]\n", " ...\n", " [0. 0. 0. ... 0. 0. 0.]\n", " [0. 0. 0. ... 0. 0. 0.]\n", " [0. 0. 0. ... 0. 0. 0.]]\n", "(48842, 796)\n" ] } ], "source": [ "print(X_wide)\n", "print(X_wide.shape)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[ 0. 0. 0. ... 0. -0.99512893\n", " -0.03408696]\n", " [ 1. 1. 0. ... 0. -0.04694151\n", " 0.77292975]\n", " [ 2. 1. 1. ... 0. -0.77631645\n", " -0.03408696]\n", " ...\n", " [ 1. 3. 0. ... 0. 1.41180837\n", " -0.03408696]\n", " [ 1. 0. 0. ... 0. -1.21394141\n", " -1.64812038]\n", " [ 1. 4. 6. ... 0. 0.97418341\n", " -0.03408696]]\n", "(48842, 7)\n" ] } ], "source": [ "print(X_deep)\n", "print(X_deep.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Defining the model" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "wide = Wide(wide_dim=X_wide.shape[1], output_dim=1)\n", "deepdense = DeepDense(hidden_layers=[64,32], \n", " deep_column_idx=preprocess_deep.deep_column_idx,\n", " embed_input=preprocess_deep.embeddings_input,\n", " continuous_cols=continuous_cols)\n", "model = WideDeep(wide=wide, deepdense=deepdense)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "WideDeep(\n", " (wide): Wide(\n", " (wide_linear): Linear(in_features=796, out_features=1, bias=True)\n", " )\n", " (deepdense): Sequential(\n", " (0): DeepDense(\n", " (embed_layers): ModuleDict(\n", " (emb_layer_education): Embedding(16, 16)\n", " (emb_layer_native_country): Embedding(42, 16)\n", " (emb_layer_occupation): Embedding(15, 16)\n", " (emb_layer_relationship): Embedding(6, 8)\n", " (emb_layer_workclass): Embedding(9, 16)\n", " )\n", " (embed_dropout): Dropout(p=0.0, inplace=False)\n", " (dense): Sequential(\n", " (dense_layer_0): Sequential(\n", " (0): Linear(in_features=74, out_features=64, bias=True)\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " (2): Dropout(p=0.0, inplace=False)\n", " )\n", " (dense_layer_1): Sequential(\n", " (0): Linear(in_features=64, out_features=32, bias=True)\n", " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", " (2): Dropout(p=0.0, inplace=False)\n", " )\n", " )\n", " )\n", " (1): Linear(in_features=32, out_features=1, bias=True)\n", " )\n", ")" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, the model is not particularly complex. In mathematical terms (Eq 3 in the [original paper](https://arxiv.org/pdf/1606.07792.pdf)): \n", "\n", "$$\n", "pred = \\sigma(W^{T}_{wide}[x, \\phi(x)] + W^{T}_{deep}a_{deep}^{(l_f)} + b) \n", "$$ \n", "\n", "\n", "The architecture above will output the 1st and the second term in the parenthesis. `WideDeep` will then add them and apply an activation function (`sigmoid` in this case). For more details, please refer to the paper." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compiling and Running/Fitting\n", "Once the model is built, we just need to compile it and run it" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "model.compile(method='binary', metrics=[BinaryAccuracy])" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\r", " 0%| | 0/153 [00:00