提交 3e963e78 编写于 作者: J jrzaurin

Adapted the examples to the small code changes. Replace add_ with add in...

Adapted the examples to the small code changes. Replace add_ with add in WideDeep to avoid annoying warnings. Replace output_dim with pred_dim in Wide for consistentcy
上级 e4f6ad2d
...@@ -8,11 +8,11 @@ ...@@ -8,11 +8,11 @@
"\n", "\n",
"The 5 main components of a `WideDeep` model are:\n", "The 5 main components of a `WideDeep` model are:\n",
"\n", "\n",
"1. `Wide (Class)`\n", "1. `Wide`\n",
"2. `DeepDense (Class)`\n", "2. `DeepDense`\n",
"3. `DeepText (Class)`\n", "3. `DeepText`\n",
"4. `DeepImage (Class)`\n", "4. `DeepImage`\n",
"5. `deephead (WideDeep Class parameter)`\n", "5. `deephead`\n",
"\n", "\n",
"The first 4 of them will be collected and combined by the `WideDeep` collector class, while the 5th one can be optionally added to the `WideDeep` model through its corresponding parameters: `deephead` or alternatively `head_layers`, `head_dropout` and `head_batchnorm`" "The first 4 of them will be collected and combined by the `WideDeep` collector class, while the 5th one can be optionally added to the `WideDeep` model through its corresponding parameters: `deephead` or alternatively `head_layers`, `head_dropout` and `head_batchnorm`"
] ]
...@@ -170,11 +170,11 @@ ...@@ -170,11 +170,11 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"tensor([[ 1.9317, -0.0000, 1.3663, -0.3984, -0.0000, -0.0000, -0.0000, -1.2662],\n", "tensor([[-0.0000, -1.0061, -0.0000, -0.9828, -0.0000, -0.0000, -0.9944, -1.0133],\n",
" [ 0.0000, -1.5337, -0.0000, 0.0726, -0.4231, 3.9977, -0.0000, -0.0000],\n", " [-0.0000, -0.9996, 0.0000, -1.0374, 0.0000, -0.0000, -1.0313, -0.0000],\n",
" [-0.0000, -1.5839, 3.2978, -1.7084, -1.0877, -0.9574, 0.0000, -0.0000],\n", " [-0.8576, -1.0017, -0.0000, -0.9881, -0.0000, 0.0000, -0.0000, -0.0000],\n",
" [-0.0000, 1.6664, -1.6006, 0.0000, -0.0000, -0.9844, -0.0000, -0.0521],\n", " [ 3.9816, 0.0000, 0.0000, 0.0000, 3.7309, 1.1728, 0.0000, -1.1160],\n",
" [ 2.4249, 0.0000, -0.0000, -0.0000, 0.0000, -0.0000, 2.6460, 0.0000]],\n", " [-1.1339, -0.9925, -0.0000, -0.0000, -0.0000, 0.0000, -0.9638, 0.0000]],\n",
" grad_fn=<MulBackward0>)" " grad_fn=<MulBackward0>)"
] ]
}, },
...@@ -484,10 +484,10 @@ ...@@ -484,10 +484,10 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"tensor([[ 8.4865e-02, -3.4401e-03, -9.1973e-04, 3.4269e-01, 3.2816e-02,\n", "tensor([[-1.4630e-04, -6.1540e-04, -2.4541e-04, 2.7543e-01, 1.2993e-01,\n",
" 1.9682e-02, -8.0740e-04, 9.4898e-03],\n", " -1.6553e-03, 6.7002e-02, 2.3974e-01],\n",
" [ 1.5473e-01, -6.2664e-03, -9.3413e-05, 3.8768e-01, -1.9963e-03,\n", " [-9.9619e-04, -1.9412e-03, 1.2113e-01, 1.0122e-01, 2.9080e-01,\n",
" 1.1729e-01, -2.7111e-03, 1.8670e-01]], grad_fn=<LeakyReluBackward1>)" " -2.0852e-03, -1.8016e-04, 2.7996e-02]], grad_fn=<LeakyReluBackward1>)"
] ]
}, },
"execution_count": 18, "execution_count": 18,
...@@ -517,7 +517,7 @@ ...@@ -517,7 +517,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 25, "execution_count": 20,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -563,7 +563,7 @@ ...@@ -563,7 +563,7 @@
")" ")"
] ]
}, },
"execution_count": 25, "execution_count": 20,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -587,7 +587,7 @@ ...@@ -587,7 +587,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 20, "execution_count": 21,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -620,7 +620,7 @@ ...@@ -620,7 +620,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.6.5" "version": "3.7.7"
} }
}, },
"nbformat": 4, "nbformat": 4,
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -19,14 +19,14 @@ ...@@ -19,14 +19,14 @@
"import pandas as pd\n", "import pandas as pd\n",
"import torch\n", "import torch\n",
"\n", "\n",
"from pytorch_widedeep.preprocessing import WidePreprocessor, DeepPreprocessor\n", "from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor\n",
"from pytorch_widedeep.models import Wide, DeepDense, WideDeep\n", "from pytorch_widedeep.models import Wide, DeepDense, WideDeep\n",
"from pytorch_widedeep.metrics import BinaryAccuracy" "from pytorch_widedeep.metrics import BinaryAccuracy"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 6,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -69,7 +69,7 @@ ...@@ -69,7 +69,7 @@
" </thead>\n", " </thead>\n",
" <tbody>\n", " <tbody>\n",
" <tr>\n", " <tr>\n",
" <td>0</td>\n", " <th>0</th>\n",
" <td>25</td>\n", " <td>25</td>\n",
" <td>Private</td>\n", " <td>Private</td>\n",
" <td>226802</td>\n", " <td>226802</td>\n",
...@@ -87,7 +87,7 @@ ...@@ -87,7 +87,7 @@
" <td>&lt;=50K</td>\n", " <td>&lt;=50K</td>\n",
" </tr>\n", " </tr>\n",
" <tr>\n", " <tr>\n",
" <td>1</td>\n", " <th>1</th>\n",
" <td>38</td>\n", " <td>38</td>\n",
" <td>Private</td>\n", " <td>Private</td>\n",
" <td>89814</td>\n", " <td>89814</td>\n",
...@@ -105,7 +105,7 @@ ...@@ -105,7 +105,7 @@
" <td>&lt;=50K</td>\n", " <td>&lt;=50K</td>\n",
" </tr>\n", " </tr>\n",
" <tr>\n", " <tr>\n",
" <td>2</td>\n", " <th>2</th>\n",
" <td>28</td>\n", " <td>28</td>\n",
" <td>Local-gov</td>\n", " <td>Local-gov</td>\n",
" <td>336951</td>\n", " <td>336951</td>\n",
...@@ -123,7 +123,7 @@ ...@@ -123,7 +123,7 @@
" <td>&gt;50K</td>\n", " <td>&gt;50K</td>\n",
" </tr>\n", " </tr>\n",
" <tr>\n", " <tr>\n",
" <td>3</td>\n", " <th>3</th>\n",
" <td>44</td>\n", " <td>44</td>\n",
" <td>Private</td>\n", " <td>Private</td>\n",
" <td>160323</td>\n", " <td>160323</td>\n",
...@@ -141,7 +141,7 @@ ...@@ -141,7 +141,7 @@
" <td>&gt;50K</td>\n", " <td>&gt;50K</td>\n",
" </tr>\n", " </tr>\n",
" <tr>\n", " <tr>\n",
" <td>4</td>\n", " <th>4</th>\n",
" <td>18</td>\n", " <td>18</td>\n",
" <td>?</td>\n", " <td>?</td>\n",
" <td>103497</td>\n", " <td>103497</td>\n",
...@@ -185,7 +185,7 @@ ...@@ -185,7 +185,7 @@
"4 30 United-States <=50K " "4 30 United-States <=50K "
] ]
}, },
"execution_count": 2, "execution_count": 6,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -197,7 +197,7 @@ ...@@ -197,7 +197,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 7,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -240,7 +240,7 @@ ...@@ -240,7 +240,7 @@
" </thead>\n", " </thead>\n",
" <tbody>\n", " <tbody>\n",
" <tr>\n", " <tr>\n",
" <td>0</td>\n", " <th>0</th>\n",
" <td>25</td>\n", " <td>25</td>\n",
" <td>Private</td>\n", " <td>Private</td>\n",
" <td>226802</td>\n", " <td>226802</td>\n",
...@@ -258,7 +258,7 @@ ...@@ -258,7 +258,7 @@
" <td>0</td>\n", " <td>0</td>\n",
" </tr>\n", " </tr>\n",
" <tr>\n", " <tr>\n",
" <td>1</td>\n", " <th>1</th>\n",
" <td>38</td>\n", " <td>38</td>\n",
" <td>Private</td>\n", " <td>Private</td>\n",
" <td>89814</td>\n", " <td>89814</td>\n",
...@@ -276,7 +276,7 @@ ...@@ -276,7 +276,7 @@
" <td>0</td>\n", " <td>0</td>\n",
" </tr>\n", " </tr>\n",
" <tr>\n", " <tr>\n",
" <td>2</td>\n", " <th>2</th>\n",
" <td>28</td>\n", " <td>28</td>\n",
" <td>Local-gov</td>\n", " <td>Local-gov</td>\n",
" <td>336951</td>\n", " <td>336951</td>\n",
...@@ -294,7 +294,7 @@ ...@@ -294,7 +294,7 @@
" <td>1</td>\n", " <td>1</td>\n",
" </tr>\n", " </tr>\n",
" <tr>\n", " <tr>\n",
" <td>3</td>\n", " <th>3</th>\n",
" <td>44</td>\n", " <td>44</td>\n",
" <td>Private</td>\n", " <td>Private</td>\n",
" <td>160323</td>\n", " <td>160323</td>\n",
...@@ -312,7 +312,7 @@ ...@@ -312,7 +312,7 @@
" <td>1</td>\n", " <td>1</td>\n",
" </tr>\n", " </tr>\n",
" <tr>\n", " <tr>\n",
" <td>4</td>\n", " <th>4</th>\n",
" <td>18</td>\n", " <td>18</td>\n",
" <td>?</td>\n", " <td>?</td>\n",
" <td>103497</td>\n", " <td>103497</td>\n",
...@@ -356,7 +356,7 @@ ...@@ -356,7 +356,7 @@
"4 30 United-States 0 " "4 30 United-States 0 "
] ]
}, },
"execution_count": 3, "execution_count": 7,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -381,7 +381,7 @@ ...@@ -381,7 +381,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 8,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -394,7 +394,7 @@ ...@@ -394,7 +394,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 9,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -406,13 +406,13 @@ ...@@ -406,13 +406,13 @@
"X_wide = preprocess_wide.fit_transform(df)\n", "X_wide = preprocess_wide.fit_transform(df)\n",
"\n", "\n",
"# DEEP\n", "# DEEP\n",
"preprocess_deep = DeepPreprocessor(embed_cols=cat_embed_cols, continuous_cols=continuous_cols)\n", "preprocess_deep = DensePreprocessor(embed_cols=cat_embed_cols, continuous_cols=continuous_cols)\n",
"X_deep = preprocess_deep.fit_transform(df)" "X_deep = preprocess_deep.fit_transform(df)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 10,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -437,7 +437,7 @@ ...@@ -437,7 +437,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 11,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -475,11 +475,11 @@ ...@@ -475,11 +475,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 14,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"wide = Wide(wide_dim=X_wide.shape[1], output_dim=1)\n", "wide = Wide(wide_dim=X_wide.shape[1], pred_dim=1)\n",
"deepdense = DeepDense(hidden_layers=[64,32], \n", "deepdense = DeepDense(hidden_layers=[64,32], \n",
" deep_column_idx=preprocess_deep.deep_column_idx,\n", " deep_column_idx=preprocess_deep.deep_column_idx,\n",
" embed_input=preprocess_deep.embeddings_input,\n", " embed_input=preprocess_deep.embeddings_input,\n",
...@@ -489,7 +489,7 @@ ...@@ -489,7 +489,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 15,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -502,11 +502,11 @@ ...@@ -502,11 +502,11 @@
" (deepdense): Sequential(\n", " (deepdense): Sequential(\n",
" (0): DeepDense(\n", " (0): DeepDense(\n",
" (embed_layers): ModuleDict(\n", " (embed_layers): ModuleDict(\n",
" (emb_layer_education): Embedding(16, 16)\n", " (emb_layer_education): Embedding(17, 16)\n",
" (emb_layer_native_country): Embedding(42, 16)\n", " (emb_layer_native_country): Embedding(43, 16)\n",
" (emb_layer_occupation): Embedding(15, 16)\n", " (emb_layer_occupation): Embedding(16, 16)\n",
" (emb_layer_relationship): Embedding(6, 8)\n", " (emb_layer_relationship): Embedding(7, 8)\n",
" (emb_layer_workclass): Embedding(9, 16)\n", " (emb_layer_workclass): Embedding(10, 16)\n",
" )\n", " )\n",
" (embed_dropout): Dropout(p=0.0, inplace=False)\n", " (embed_dropout): Dropout(p=0.0, inplace=False)\n",
" (dense): Sequential(\n", " (dense): Sequential(\n",
...@@ -527,7 +527,7 @@ ...@@ -527,7 +527,7 @@
")" ")"
] ]
}, },
"execution_count": 9, "execution_count": 15,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -560,7 +560,7 @@ ...@@ -560,7 +560,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 16,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -569,7 +569,7 @@ ...@@ -569,7 +569,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 17,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -591,16 +591,16 @@ ...@@ -591,16 +591,16 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 153/153 [00:02<00:00, 56.52it/s, loss=0.412, metrics={'acc': 0.7993}]\n", "epoch 1: 100%|██████████| 153/153 [00:02<00:00, 64.79it/s, loss=0.435, metrics={'acc': 0.7901}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 123.12it/s, loss=0.352, metrics={'acc': 0.8071}]\n", "valid: 100%|██████████| 39/39 [00:00<00:00, 124.97it/s, loss=0.358, metrics={'acc': 0.799}]\n",
"epoch 2: 100%|██████████| 153/153 [00:02<00:00, 59.55it/s, loss=0.351, metrics={'acc': 0.8351}]\n", "epoch 2: 100%|██████████| 153/153 [00:02<00:00, 71.36it/s, loss=0.352, metrics={'acc': 0.8352}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 121.98it/s, loss=0.346, metrics={'acc': 0.8359}]\n", "valid: 100%|██████████| 39/39 [00:00<00:00, 124.33it/s, loss=0.349, metrics={'acc': 0.8358}]\n",
"epoch 3: 100%|██████████| 153/153 [00:02<00:00, 59.82it/s, loss=0.346, metrics={'acc': 0.8377}]\n", "epoch 3: 100%|██████████| 153/153 [00:02<00:00, 72.24it/s, loss=0.345, metrics={'acc': 0.8383}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 121.88it/s, loss=0.344, metrics={'acc': 0.8384}]\n", "valid: 100%|██████████| 39/39 [00:00<00:00, 121.07it/s, loss=0.345, metrics={'acc': 0.8389}]\n",
"epoch 4: 100%|██████████| 153/153 [00:02<00:00, 58.97it/s, loss=0.342, metrics={'acc': 0.8392}]\n", "epoch 4: 100%|██████████| 153/153 [00:02<00:00, 70.39it/s, loss=0.341, metrics={'acc': 0.8404}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 122.20it/s, loss=0.342, metrics={'acc': 0.84}] \n", "valid: 100%|██████████| 39/39 [00:00<00:00, 123.29it/s, loss=0.343, metrics={'acc': 0.8406}]\n",
"epoch 5: 100%|██████████| 153/153 [00:02<00:00, 58.28it/s, loss=0.34, metrics={'acc': 0.8406}] \n", "epoch 5: 100%|██████████| 153/153 [00:02<00:00, 71.14it/s, loss=0.339, metrics={'acc': 0.8423}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 116.57it/s, loss=0.341, metrics={'acc': 0.8413}]\n" "valid: 100%|██████████| 39/39 [00:00<00:00, 121.12it/s, loss=0.342, metrics={'acc': 0.8426}]\n"
] ]
} }
], ],
...@@ -632,7 +632,7 @@ ...@@ -632,7 +632,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.6.5" "version": "3.7.7"
} }
}, },
"nbformat": 4, "nbformat": 4,
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
"import os\n", "import os\n",
"import torch\n", "import torch\n",
"\n", "\n",
"from pytorch_widedeep.preprocessing import WidePreprocessor, DeepPreprocessor, TextPreprocessor, ImagePreprocessor\n", "from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor, TextPreprocessor, ImagePreprocessor\n",
"from pytorch_widedeep.models import Wide, DeepDense, DeepText, DeepImage, WideDeep\n", "from pytorch_widedeep.models import Wide, DeepDense, DeepText, DeepImage, WideDeep\n",
"from pytorch_widedeep.initializers import *\n", "from pytorch_widedeep.initializers import *\n",
"from pytorch_widedeep.callbacks import *\n", "from pytorch_widedeep.callbacks import *\n",
...@@ -162,7 +162,7 @@ ...@@ -162,7 +162,7 @@
" </thead>\n", " </thead>\n",
" <tbody>\n", " <tbody>\n",
" <tr>\n", " <tr>\n",
" <td>0</td>\n", " <th>0</th>\n",
" <td>13913.jpg</td>\n", " <td>13913.jpg</td>\n",
" <td>54730</td>\n", " <td>54730</td>\n",
" <td>My bright double bedroom with a large window has a relaxed feeling! It comfortably fits one or t...</td>\n", " <td>My bright double bedroom with a large window has a relaxed feeling! It comfortably fits one or t...</td>\n",
...@@ -266,7 +266,7 @@ ...@@ -266,7 +266,7 @@
" <td>12.00</td>\n", " <td>12.00</td>\n",
" </tr>\n", " </tr>\n",
" <tr>\n", " <tr>\n",
" <td>1</td>\n", " <th>1</th>\n",
" <td>15400.jpg</td>\n", " <td>15400.jpg</td>\n",
" <td>60302</td>\n", " <td>60302</td>\n",
" <td>Lots of windows and light. St Luke's Gardens are at the end of the block, and the river not too...</td>\n", " <td>Lots of windows and light. St Luke's Gardens are at the end of the block, and the river not too...</td>\n",
...@@ -370,7 +370,7 @@ ...@@ -370,7 +370,7 @@
" <td>109.50</td>\n", " <td>109.50</td>\n",
" </tr>\n", " </tr>\n",
" <tr>\n", " <tr>\n",
" <td>2</td>\n", " <th>2</th>\n",
" <td>17402.jpg</td>\n", " <td>17402.jpg</td>\n",
" <td>67564</td>\n", " <td>67564</td>\n",
" <td>Open from June 2018 after a 3-year break, we are delighted to be welcoming guests again to this ...</td>\n", " <td>Open from June 2018 after a 3-year break, we are delighted to be welcoming guests again to this ...</td>\n",
...@@ -474,7 +474,7 @@ ...@@ -474,7 +474,7 @@
" <td>149.65</td>\n", " <td>149.65</td>\n",
" </tr>\n", " </tr>\n",
" <tr>\n", " <tr>\n",
" <td>3</td>\n", " <th>3</th>\n",
" <td>24328.jpg</td>\n", " <td>24328.jpg</td>\n",
" <td>41759</td>\n", " <td>41759</td>\n",
" <td>Artist house, bright high ceiling rooms, private parking and a communal garden in a conservation...</td>\n", " <td>Artist house, bright high ceiling rooms, private parking and a communal garden in a conservation...</td>\n",
...@@ -578,7 +578,7 @@ ...@@ -578,7 +578,7 @@
" <td>215.60</td>\n", " <td>215.60</td>\n",
" </tr>\n", " </tr>\n",
" <tr>\n", " <tr>\n",
" <td>4</td>\n", " <th>4</th>\n",
" <td>25023.jpg</td>\n", " <td>25023.jpg</td>\n",
" <td>102813</td>\n", " <td>102813</td>\n",
" <td>Large, all comforts, 2-bed flat; first floor; lift; pretty communal gardens + off-street parking...</td>\n", " <td>Large, all comforts, 2-bed flat; first floor; lift; pretty communal gardens + off-street parking...</td>\n",
...@@ -1016,20 +1016,20 @@ ...@@ -1016,20 +1016,20 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"deep_preprocessor = DeepPreprocessor(embed_cols=cat_embed_cols, continuous_cols=continuous_cols)\n", "deep_preprocessor = DensePreprocessor(embed_cols=cat_embed_cols, continuous_cols=continuous_cols)\n",
"X_deep = deep_preprocessor.fit_transform(df)" "X_deep = deep_preprocessor.fit_transform(df)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 7,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"The vocabulary contains 6400 tokens\n", "The vocabulary contains 2192 tokens\n",
"Indexing word vectors...\n", "Indexing word vectors...\n",
"Loaded 400000 word vectors\n", "Loaded 400000 word vectors\n",
"Preparing embeddings matrix...\n", "Preparing embeddings matrix...\n",
...@@ -1044,7 +1044,7 @@ ...@@ -1044,7 +1044,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 8,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -1058,7 +1058,7 @@ ...@@ -1058,7 +1058,7 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
" 9%|▊ | 87/1001 [00:00<00:02, 428.89it/s]" " 4%|▍ | 42/1001 [00:00<00:02, 411.81it/s]"
] ]
}, },
{ {
...@@ -1072,7 +1072,7 @@ ...@@ -1072,7 +1072,7 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"100%|██████████| 1001/1001 [00:02<00:00, 426.92it/s]\n" "100%|██████████| 1001/1001 [00:02<00:00, 402.73it/s]\n"
] ]
}, },
{ {
...@@ -1097,12 +1097,12 @@ ...@@ -1097,12 +1097,12 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 9,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Linear model\n", "# Linear model\n",
"wide = Wide(wide_dim=X_wide.shape[1], output_dim=1)\n", "wide = Wide(wide_dim=X_wide.shape[1], pred_dim=1)\n",
"# DeepDense: 2 Dense layers\n", "# DeepDense: 2 Dense layers\n",
"deepdense = DeepDense(hidden_layers=[128,64], dropout=[0.5, 0.5], \n", "deepdense = DeepDense(hidden_layers=[128,64], dropout=[0.5, 0.5], \n",
" deep_column_idx=deep_preprocessor.deep_column_idx,\n", " deep_column_idx=deep_preprocessor.deep_column_idx,\n",
...@@ -1125,7 +1125,7 @@ ...@@ -1125,7 +1125,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 10,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -1141,7 +1141,7 @@ ...@@ -1141,7 +1141,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": 11,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -1150,7 +1150,7 @@ ...@@ -1150,7 +1150,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": 12,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -1172,8 +1172,8 @@ ...@@ -1172,8 +1172,8 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 25/25 [02:06<00:00, 5.05s/it, loss=118]\n", "epoch 1: 100%|██████████| 25/25 [01:13<00:00, 2.93s/it, loss=135]\n",
"valid: 100%|██████████| 7/7 [00:14<00:00, 2.01s/it, loss=226]\n" "valid: 100%|██████████| 7/7 [00:14<00:00, 2.10s/it, loss=124] \n"
] ]
} }
], ],
...@@ -1186,18 +1186,18 @@ ...@@ -1186,18 +1186,18 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### Regression with varying parameters and a FC-Head receiving the full deep side\n", "### Regression with varying parameters and a fully connected head (FC-Head) receiving the full deep side\n",
"\n", "\n",
"This would be the second architecture shown in the README file" "This would be the second architecture shown in the README file"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": 13,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"wide = Wide(wide_dim=X_wide.shape[1], output_dim=1)\n", "wide = Wide(wide_dim=X_wide.shape[1], pred_dim=1)\n",
"deepdense = DeepDense(hidden_layers=[128,64], dropout=[0.5, 0.5], \n", "deepdense = DeepDense(hidden_layers=[128,64], dropout=[0.5, 0.5], \n",
" deep_column_idx=deep_preprocessor.deep_column_idx,\n", " deep_column_idx=deep_preprocessor.deep_column_idx,\n",
" embed_input=deep_preprocessor.embeddings_input,\n", " embed_input=deep_preprocessor.embeddings_input,\n",
...@@ -1217,7 +1217,7 @@ ...@@ -1217,7 +1217,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": 14,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -1233,7 +1233,7 @@ ...@@ -1233,7 +1233,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 18, "execution_count": 15,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -1245,15 +1245,15 @@ ...@@ -1245,15 +1245,15 @@
" )\n", " )\n",
" (deepdense): DeepDense(\n", " (deepdense): DeepDense(\n",
" (embed_layers): ModuleDict(\n", " (embed_layers): ModuleDict(\n",
" (emb_layer_accommodates_catg): Embedding(3, 16)\n", " (emb_layer_accommodates_catg): Embedding(4, 16)\n",
" (emb_layer_bathrooms_catg): Embedding(3, 16)\n", " (emb_layer_bathrooms_catg): Embedding(4, 16)\n",
" (emb_layer_bedrooms_catg): Embedding(4, 16)\n", " (emb_layer_bedrooms_catg): Embedding(5, 16)\n",
" (emb_layer_beds_catg): Embedding(4, 16)\n", " (emb_layer_beds_catg): Embedding(5, 16)\n",
" (emb_layer_cancellation_policy): Embedding(5, 16)\n", " (emb_layer_cancellation_policy): Embedding(6, 16)\n",
" (emb_layer_guests_included_catg): Embedding(3, 16)\n", " (emb_layer_guests_included_catg): Embedding(4, 16)\n",
" (emb_layer_host_listings_count_catg): Embedding(4, 16)\n", " (emb_layer_host_listings_count_catg): Embedding(5, 16)\n",
" (emb_layer_minimum_nights_catg): Embedding(3, 16)\n", " (emb_layer_minimum_nights_catg): Embedding(4, 16)\n",
" (emb_layer_neighbourhood_cleansed): Embedding(32, 64)\n", " (emb_layer_neighbourhood_cleansed): Embedding(33, 64)\n",
" )\n", " )\n",
" (embed_dropout): Dropout(p=0.0, inplace=False)\n", " (embed_dropout): Dropout(p=0.0, inplace=False)\n",
" (dense): Sequential(\n", " (dense): Sequential(\n",
...@@ -1386,7 +1386,7 @@ ...@@ -1386,7 +1386,7 @@
")" ")"
] ]
}, },
"execution_count": 18, "execution_count": 15,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -1406,7 +1406,7 @@ ...@@ -1406,7 +1406,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 19, "execution_count": 16,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -1420,7 +1420,7 @@ ...@@ -1420,7 +1420,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 20, "execution_count": 17,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -1433,7 +1433,7 @@ ...@@ -1433,7 +1433,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 21, "execution_count": 18,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -1446,7 +1446,7 @@ ...@@ -1446,7 +1446,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 22, "execution_count": 19,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -1472,18 +1472,9 @@ ...@@ -1472,18 +1472,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 23, "execution_count": 20,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/javier/pytorch-widedeep/pytorch_widedeep/initializers.py:31: UserWarning: No initializer found for deephead\n",
" warnings.warn(\"No initializer found for {}\".format(name))\n"
]
}
],
"source": [ "source": [
"model.compile(method='regression', initializers=initializers, optimizers=optimizers,\n", "model.compile(method='regression', initializers=initializers, optimizers=optimizers,\n",
" lr_schedulers=schedulers, callbacks=callbacks, transforms=transforms)" " lr_schedulers=schedulers, callbacks=callbacks, transforms=transforms)"
...@@ -1491,7 +1482,7 @@ ...@@ -1491,7 +1482,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 24, "execution_count": 21,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -1503,15 +1494,15 @@ ...@@ -1503,15 +1494,15 @@
" )\n", " )\n",
" (deepdense): DeepDense(\n", " (deepdense): DeepDense(\n",
" (embed_layers): ModuleDict(\n", " (embed_layers): ModuleDict(\n",
" (emb_layer_accommodates_catg): Embedding(3, 16)\n", " (emb_layer_accommodates_catg): Embedding(4, 16)\n",
" (emb_layer_bathrooms_catg): Embedding(3, 16)\n", " (emb_layer_bathrooms_catg): Embedding(4, 16)\n",
" (emb_layer_bedrooms_catg): Embedding(4, 16)\n", " (emb_layer_bedrooms_catg): Embedding(5, 16)\n",
" (emb_layer_beds_catg): Embedding(4, 16)\n", " (emb_layer_beds_catg): Embedding(5, 16)\n",
" (emb_layer_cancellation_policy): Embedding(5, 16)\n", " (emb_layer_cancellation_policy): Embedding(6, 16)\n",
" (emb_layer_guests_included_catg): Embedding(3, 16)\n", " (emb_layer_guests_included_catg): Embedding(4, 16)\n",
" (emb_layer_host_listings_count_catg): Embedding(4, 16)\n", " (emb_layer_host_listings_count_catg): Embedding(5, 16)\n",
" (emb_layer_minimum_nights_catg): Embedding(3, 16)\n", " (emb_layer_minimum_nights_catg): Embedding(4, 16)\n",
" (emb_layer_neighbourhood_cleansed): Embedding(32, 64)\n", " (emb_layer_neighbourhood_cleansed): Embedding(33, 64)\n",
" )\n", " )\n",
" (embed_dropout): Dropout(p=0.0, inplace=False)\n", " (embed_dropout): Dropout(p=0.0, inplace=False)\n",
" (dense): Sequential(\n", " (dense): Sequential(\n",
...@@ -1644,7 +1635,7 @@ ...@@ -1644,7 +1635,7 @@
")" ")"
] ]
}, },
"execution_count": 24, "execution_count": 21,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -1655,7 +1646,7 @@ ...@@ -1655,7 +1646,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 25, "execution_count": 22,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -1677,8 +1668,8 @@ ...@@ -1677,8 +1668,8 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"epoch 1: 100%|██████████| 25/25 [02:04<00:00, 4.97s/it, loss=127]\n", "epoch 1: 100%|██████████| 25/25 [02:02<00:00, 4.88s/it, loss=128]\n",
"valid: 100%|██████████| 7/7 [00:14<00:00, 2.02s/it, loss=94] \n" "valid: 100%|██████████| 7/7 [00:14<00:00, 2.09s/it, loss=94.5]\n"
] ]
} }
], ],
...@@ -1696,7 +1687,7 @@ ...@@ -1696,7 +1687,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 26, "execution_count": 23,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
...@@ -1721,7 +1712,7 @@ ...@@ -1721,7 +1712,7 @@
" 'lr_deephead_0': [0.001, 0.001]}" " 'lr_deephead_0': [0.001, 0.001]}"
] ]
}, },
"execution_count": 26, "execution_count": 23,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
...@@ -1747,7 +1738,7 @@ ...@@ -1747,7 +1738,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.6.5" "version": "3.7.7"
} }
}, },
"nbformat": 4, "nbformat": 4,
......
...@@ -4,7 +4,7 @@ import pandas as pd ...@@ -4,7 +4,7 @@ import pandas as pd
from pytorch_widedeep.models import Wide, WideDeep, DeepDense from pytorch_widedeep.models import Wide, WideDeep, DeepDense
from pytorch_widedeep.metrics import CategoricalAccuracy from pytorch_widedeep.metrics import CategoricalAccuracy
from pytorch_widedeep.preprocessing import DeepPreprocessor, WidePreprocessor from pytorch_widedeep.preprocessing import DensePreprocessor, WidePreprocessor
use_cuda = torch.cuda.is_available() use_cuda = torch.cuda.is_available()
...@@ -35,7 +35,7 @@ if __name__ == "__main__": ...@@ -35,7 +35,7 @@ if __name__ == "__main__":
prepare_wide = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols) prepare_wide = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)
X_wide = prepare_wide.fit_transform(df) X_wide = prepare_wide.fit_transform(df)
prepare_deep = DeepPreprocessor( prepare_deep = DensePreprocessor(
embed_cols=cat_embed_cols, continuous_cols=continuous_cols embed_cols=cat_embed_cols, continuous_cols=continuous_cols
) )
X_deep = prepare_deep.fit_transform(df) X_deep = prepare_deep.fit_transform(df)
...@@ -47,7 +47,7 @@ if __name__ == "__main__": ...@@ -47,7 +47,7 @@ if __name__ == "__main__":
embed_input=prepare_deep.embeddings_input, embed_input=prepare_deep.embeddings_input,
continuous_cols=continuous_cols, continuous_cols=continuous_cols,
) )
model = WideDeep(wide=wide, deepdense=deepdense, output_dim=3) model = WideDeep(wide=wide, deepdense=deepdense, pred_dim=3)
model.compile(method="multiclass", metrics=[CategoricalAccuracy]) model.compile(method="multiclass", metrics=[CategoricalAccuracy])
model.fit( model.fit(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册