提交 3e963e78 编写于 作者: J jrzaurin

Adapted the examples to the small code changes. Replace add_ with add in...

Adapted the examples to the small code changes. Replace add_ with add in WideDeep to avoid annoying warnings. Replace output_dim with pred_dim in Wide for consistentcy
上级 e4f6ad2d
......@@ -8,11 +8,11 @@
"\n",
"The 5 main components of a `WideDeep` model are:\n",
"\n",
"1. `Wide (Class)`\n",
"2. `DeepDense (Class)`\n",
"3. `DeepText (Class)`\n",
"4. `DeepImage (Class)`\n",
"5. `deephead (WideDeep Class parameter)`\n",
"1. `Wide`\n",
"2. `DeepDense`\n",
"3. `DeepText`\n",
"4. `DeepImage`\n",
"5. `deephead`\n",
"\n",
"The first 4 of them will be collected and combined by the `WideDeep` collector class, while the 5th one can be optionally added to the `WideDeep` model through its corresponding parameters: `deephead` or alternatively `head_layers`, `head_dropout` and `head_batchnorm`"
]
......@@ -170,11 +170,11 @@
{
"data": {
"text/plain": [
"tensor([[ 1.9317, -0.0000, 1.3663, -0.3984, -0.0000, -0.0000, -0.0000, -1.2662],\n",
" [ 0.0000, -1.5337, -0.0000, 0.0726, -0.4231, 3.9977, -0.0000, -0.0000],\n",
" [-0.0000, -1.5839, 3.2978, -1.7084, -1.0877, -0.9574, 0.0000, -0.0000],\n",
" [-0.0000, 1.6664, -1.6006, 0.0000, -0.0000, -0.9844, -0.0000, -0.0521],\n",
" [ 2.4249, 0.0000, -0.0000, -0.0000, 0.0000, -0.0000, 2.6460, 0.0000]],\n",
"tensor([[-0.0000, -1.0061, -0.0000, -0.9828, -0.0000, -0.0000, -0.9944, -1.0133],\n",
" [-0.0000, -0.9996, 0.0000, -1.0374, 0.0000, -0.0000, -1.0313, -0.0000],\n",
" [-0.8576, -1.0017, -0.0000, -0.9881, -0.0000, 0.0000, -0.0000, -0.0000],\n",
" [ 3.9816, 0.0000, 0.0000, 0.0000, 3.7309, 1.1728, 0.0000, -1.1160],\n",
" [-1.1339, -0.9925, -0.0000, -0.0000, -0.0000, 0.0000, -0.9638, 0.0000]],\n",
" grad_fn=<MulBackward0>)"
]
},
......@@ -484,10 +484,10 @@
{
"data": {
"text/plain": [
"tensor([[ 8.4865e-02, -3.4401e-03, -9.1973e-04, 3.4269e-01, 3.2816e-02,\n",
" 1.9682e-02, -8.0740e-04, 9.4898e-03],\n",
" [ 1.5473e-01, -6.2664e-03, -9.3413e-05, 3.8768e-01, -1.9963e-03,\n",
" 1.1729e-01, -2.7111e-03, 1.8670e-01]], grad_fn=<LeakyReluBackward1>)"
"tensor([[-1.4630e-04, -6.1540e-04, -2.4541e-04, 2.7543e-01, 1.2993e-01,\n",
" -1.6553e-03, 6.7002e-02, 2.3974e-01],\n",
" [-9.9619e-04, -1.9412e-03, 1.2113e-01, 1.0122e-01, 2.9080e-01,\n",
" -2.0852e-03, -1.8016e-04, 2.7996e-02]], grad_fn=<LeakyReluBackward1>)"
]
},
"execution_count": 18,
......@@ -517,7 +517,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 20,
"metadata": {},
"outputs": [
{
......@@ -563,7 +563,7 @@
")"
]
},
"execution_count": 25,
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
......@@ -587,7 +587,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
......@@ -620,7 +620,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
"version": "3.7.7"
}
},
"nbformat": 4,
......
......@@ -11,7 +11,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
......@@ -19,14 +19,14 @@
"import pandas as pd\n",
"import torch\n",
"\n",
"from pytorch_widedeep.preprocessing import WidePreprocessor, DeepPreprocessor\n",
"from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor\n",
"from pytorch_widedeep.models import Wide, DeepDense, WideDeep\n",
"from pytorch_widedeep.metrics import BinaryAccuracy"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 6,
"metadata": {},
"outputs": [
{
......@@ -69,7 +69,7 @@
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <th>0</th>\n",
" <td>25</td>\n",
" <td>Private</td>\n",
" <td>226802</td>\n",
......@@ -87,7 +87,7 @@
" <td>&lt;=50K</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <th>1</th>\n",
" <td>38</td>\n",
" <td>Private</td>\n",
" <td>89814</td>\n",
......@@ -105,7 +105,7 @@
" <td>&lt;=50K</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <th>2</th>\n",
" <td>28</td>\n",
" <td>Local-gov</td>\n",
" <td>336951</td>\n",
......@@ -123,7 +123,7 @@
" <td>&gt;50K</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <th>3</th>\n",
" <td>44</td>\n",
" <td>Private</td>\n",
" <td>160323</td>\n",
......@@ -141,7 +141,7 @@
" <td>&gt;50K</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <th>4</th>\n",
" <td>18</td>\n",
" <td>?</td>\n",
" <td>103497</td>\n",
......@@ -185,7 +185,7 @@
"4 30 United-States <=50K "
]
},
"execution_count": 2,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
......@@ -197,7 +197,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 7,
"metadata": {},
"outputs": [
{
......@@ -240,7 +240,7 @@
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <th>0</th>\n",
" <td>25</td>\n",
" <td>Private</td>\n",
" <td>226802</td>\n",
......@@ -258,7 +258,7 @@
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <th>1</th>\n",
" <td>38</td>\n",
" <td>Private</td>\n",
" <td>89814</td>\n",
......@@ -276,7 +276,7 @@
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <th>2</th>\n",
" <td>28</td>\n",
" <td>Local-gov</td>\n",
" <td>336951</td>\n",
......@@ -294,7 +294,7 @@
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <th>3</th>\n",
" <td>44</td>\n",
" <td>Private</td>\n",
" <td>160323</td>\n",
......@@ -312,7 +312,7 @@
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <th>4</th>\n",
" <td>18</td>\n",
" <td>?</td>\n",
" <td>103497</td>\n",
......@@ -356,7 +356,7 @@
"4 30 United-States 0 "
]
},
"execution_count": 3,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
......@@ -381,7 +381,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
......@@ -394,7 +394,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
......@@ -406,13 +406,13 @@
"X_wide = preprocess_wide.fit_transform(df)\n",
"\n",
"# DEEP\n",
"preprocess_deep = DeepPreprocessor(embed_cols=cat_embed_cols, continuous_cols=continuous_cols)\n",
"preprocess_deep = DensePreprocessor(embed_cols=cat_embed_cols, continuous_cols=continuous_cols)\n",
"X_deep = preprocess_deep.fit_transform(df)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 10,
"metadata": {},
"outputs": [
{
......@@ -437,7 +437,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 11,
"metadata": {},
"outputs": [
{
......@@ -475,11 +475,11 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"wide = Wide(wide_dim=X_wide.shape[1], output_dim=1)\n",
"wide = Wide(wide_dim=X_wide.shape[1], pred_dim=1)\n",
"deepdense = DeepDense(hidden_layers=[64,32], \n",
" deep_column_idx=preprocess_deep.deep_column_idx,\n",
" embed_input=preprocess_deep.embeddings_input,\n",
......@@ -489,7 +489,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 15,
"metadata": {},
"outputs": [
{
......@@ -502,11 +502,11 @@
" (deepdense): Sequential(\n",
" (0): DeepDense(\n",
" (embed_layers): ModuleDict(\n",
" (emb_layer_education): Embedding(16, 16)\n",
" (emb_layer_native_country): Embedding(42, 16)\n",
" (emb_layer_occupation): Embedding(15, 16)\n",
" (emb_layer_relationship): Embedding(6, 8)\n",
" (emb_layer_workclass): Embedding(9, 16)\n",
" (emb_layer_education): Embedding(17, 16)\n",
" (emb_layer_native_country): Embedding(43, 16)\n",
" (emb_layer_occupation): Embedding(16, 16)\n",
" (emb_layer_relationship): Embedding(7, 8)\n",
" (emb_layer_workclass): Embedding(10, 16)\n",
" )\n",
" (embed_dropout): Dropout(p=0.0, inplace=False)\n",
" (dense): Sequential(\n",
......@@ -527,7 +527,7 @@
")"
]
},
"execution_count": 9,
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
......@@ -560,7 +560,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
......@@ -569,7 +569,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 17,
"metadata": {},
"outputs": [
{
......@@ -591,16 +591,16 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 153/153 [00:02<00:00, 56.52it/s, loss=0.412, metrics={'acc': 0.7993}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 123.12it/s, loss=0.352, metrics={'acc': 0.8071}]\n",
"epoch 2: 100%|██████████| 153/153 [00:02<00:00, 59.55it/s, loss=0.351, metrics={'acc': 0.8351}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 121.98it/s, loss=0.346, metrics={'acc': 0.8359}]\n",
"epoch 3: 100%|██████████| 153/153 [00:02<00:00, 59.82it/s, loss=0.346, metrics={'acc': 0.8377}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 121.88it/s, loss=0.344, metrics={'acc': 0.8384}]\n",
"epoch 4: 100%|██████████| 153/153 [00:02<00:00, 58.97it/s, loss=0.342, metrics={'acc': 0.8392}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 122.20it/s, loss=0.342, metrics={'acc': 0.84}] \n",
"epoch 5: 100%|██████████| 153/153 [00:02<00:00, 58.28it/s, loss=0.34, metrics={'acc': 0.8406}] \n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 116.57it/s, loss=0.341, metrics={'acc': 0.8413}]\n"
"epoch 1: 100%|██████████| 153/153 [00:02<00:00, 64.79it/s, loss=0.435, metrics={'acc': 0.7901}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 124.97it/s, loss=0.358, metrics={'acc': 0.799}]\n",
"epoch 2: 100%|██████████| 153/153 [00:02<00:00, 71.36it/s, loss=0.352, metrics={'acc': 0.8352}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 124.33it/s, loss=0.349, metrics={'acc': 0.8358}]\n",
"epoch 3: 100%|██████████| 153/153 [00:02<00:00, 72.24it/s, loss=0.345, metrics={'acc': 0.8383}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 121.07it/s, loss=0.345, metrics={'acc': 0.8389}]\n",
"epoch 4: 100%|██████████| 153/153 [00:02<00:00, 70.39it/s, loss=0.341, metrics={'acc': 0.8404}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 123.29it/s, loss=0.343, metrics={'acc': 0.8406}]\n",
"epoch 5: 100%|██████████| 153/153 [00:02<00:00, 71.14it/s, loss=0.339, metrics={'acc': 0.8423}]\n",
"valid: 100%|██████████| 39/39 [00:00<00:00, 121.12it/s, loss=0.342, metrics={'acc': 0.8426}]\n"
]
}
],
......@@ -632,7 +632,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
"version": "3.7.7"
}
},
"nbformat": 4,
......
......@@ -24,7 +24,7 @@
"import os\n",
"import torch\n",
"\n",
"from pytorch_widedeep.preprocessing import WidePreprocessor, DeepPreprocessor, TextPreprocessor, ImagePreprocessor\n",
"from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor, TextPreprocessor, ImagePreprocessor\n",
"from pytorch_widedeep.models import Wide, DeepDense, DeepText, DeepImage, WideDeep\n",
"from pytorch_widedeep.initializers import *\n",
"from pytorch_widedeep.callbacks import *\n",
......@@ -162,7 +162,7 @@
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <th>0</th>\n",
" <td>13913.jpg</td>\n",
" <td>54730</td>\n",
" <td>My bright double bedroom with a large window has a relaxed feeling! It comfortably fits one or t...</td>\n",
......@@ -266,7 +266,7 @@
" <td>12.00</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <th>1</th>\n",
" <td>15400.jpg</td>\n",
" <td>60302</td>\n",
" <td>Lots of windows and light. St Luke's Gardens are at the end of the block, and the river not too...</td>\n",
......@@ -370,7 +370,7 @@
" <td>109.50</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <th>2</th>\n",
" <td>17402.jpg</td>\n",
" <td>67564</td>\n",
" <td>Open from June 2018 after a 3-year break, we are delighted to be welcoming guests again to this ...</td>\n",
......@@ -474,7 +474,7 @@
" <td>149.65</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <th>3</th>\n",
" <td>24328.jpg</td>\n",
" <td>41759</td>\n",
" <td>Artist house, bright high ceiling rooms, private parking and a communal garden in a conservation...</td>\n",
......@@ -578,7 +578,7 @@
" <td>215.60</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <th>4</th>\n",
" <td>25023.jpg</td>\n",
" <td>102813</td>\n",
" <td>Large, all comforts, 2-bed flat; first floor; lift; pretty communal gardens + off-street parking...</td>\n",
......@@ -1016,20 +1016,20 @@
"metadata": {},
"outputs": [],
"source": [
"deep_preprocessor = DeepPreprocessor(embed_cols=cat_embed_cols, continuous_cols=continuous_cols)\n",
"deep_preprocessor = DensePreprocessor(embed_cols=cat_embed_cols, continuous_cols=continuous_cols)\n",
"X_deep = deep_preprocessor.fit_transform(df)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The vocabulary contains 6400 tokens\n",
"The vocabulary contains 2192 tokens\n",
"Indexing word vectors...\n",
"Loaded 400000 word vectors\n",
"Preparing embeddings matrix...\n",
......@@ -1044,7 +1044,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 8,
"metadata": {},
"outputs": [
{
......@@ -1058,7 +1058,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
" 9%|▊ | 87/1001 [00:00<00:02, 428.89it/s]"
" 4%|▍ | 42/1001 [00:00<00:02, 411.81it/s]"
]
},
{
......@@ -1072,7 +1072,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1001/1001 [00:02<00:00, 426.92it/s]\n"
"100%|██████████| 1001/1001 [00:02<00:00, 402.73it/s]\n"
]
},
{
......@@ -1097,12 +1097,12 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# Linear model\n",
"wide = Wide(wide_dim=X_wide.shape[1], output_dim=1)\n",
"wide = Wide(wide_dim=X_wide.shape[1], pred_dim=1)\n",
"# DeepDense: 2 Dense layers\n",
"deepdense = DeepDense(hidden_layers=[128,64], dropout=[0.5, 0.5], \n",
" deep_column_idx=deep_preprocessor.deep_column_idx,\n",
......@@ -1125,7 +1125,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
......@@ -1141,7 +1141,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
......@@ -1150,7 +1150,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 12,
"metadata": {},
"outputs": [
{
......@@ -1172,8 +1172,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 25/25 [02:06<00:00, 5.05s/it, loss=118]\n",
"valid: 100%|██████████| 7/7 [00:14<00:00, 2.01s/it, loss=226]\n"
"epoch 1: 100%|██████████| 25/25 [01:13<00:00, 2.93s/it, loss=135]\n",
"valid: 100%|██████████| 7/7 [00:14<00:00, 2.10s/it, loss=124] \n"
]
}
],
......@@ -1186,18 +1186,18 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Regression with varying parameters and a FC-Head receiving the full deep side\n",
"### Regression with varying parameters and a fully connected head (FC-Head) receiving the full deep side\n",
"\n",
"This would be the second architecture shown in the README file"
]
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"wide = Wide(wide_dim=X_wide.shape[1], output_dim=1)\n",
"wide = Wide(wide_dim=X_wide.shape[1], pred_dim=1)\n",
"deepdense = DeepDense(hidden_layers=[128,64], dropout=[0.5, 0.5], \n",
" deep_column_idx=deep_preprocessor.deep_column_idx,\n",
" embed_input=deep_preprocessor.embeddings_input,\n",
......@@ -1217,7 +1217,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
......@@ -1233,7 +1233,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 15,
"metadata": {},
"outputs": [
{
......@@ -1245,15 +1245,15 @@
" )\n",
" (deepdense): DeepDense(\n",
" (embed_layers): ModuleDict(\n",
" (emb_layer_accommodates_catg): Embedding(3, 16)\n",
" (emb_layer_bathrooms_catg): Embedding(3, 16)\n",
" (emb_layer_bedrooms_catg): Embedding(4, 16)\n",
" (emb_layer_beds_catg): Embedding(4, 16)\n",
" (emb_layer_cancellation_policy): Embedding(5, 16)\n",
" (emb_layer_guests_included_catg): Embedding(3, 16)\n",
" (emb_layer_host_listings_count_catg): Embedding(4, 16)\n",
" (emb_layer_minimum_nights_catg): Embedding(3, 16)\n",
" (emb_layer_neighbourhood_cleansed): Embedding(32, 64)\n",
" (emb_layer_accommodates_catg): Embedding(4, 16)\n",
" (emb_layer_bathrooms_catg): Embedding(4, 16)\n",
" (emb_layer_bedrooms_catg): Embedding(5, 16)\n",
" (emb_layer_beds_catg): Embedding(5, 16)\n",
" (emb_layer_cancellation_policy): Embedding(6, 16)\n",
" (emb_layer_guests_included_catg): Embedding(4, 16)\n",
" (emb_layer_host_listings_count_catg): Embedding(5, 16)\n",
" (emb_layer_minimum_nights_catg): Embedding(4, 16)\n",
" (emb_layer_neighbourhood_cleansed): Embedding(33, 64)\n",
" )\n",
" (embed_dropout): Dropout(p=0.0, inplace=False)\n",
" (dense): Sequential(\n",
......@@ -1386,7 +1386,7 @@
")"
]
},
"execution_count": 18,
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
......@@ -1406,7 +1406,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
......@@ -1420,7 +1420,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
......@@ -1433,7 +1433,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
......@@ -1446,7 +1446,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
......@@ -1472,18 +1472,9 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/javier/pytorch-widedeep/pytorch_widedeep/initializers.py:31: UserWarning: No initializer found for deephead\n",
" warnings.warn(\"No initializer found for {}\".format(name))\n"
]
}
],
"outputs": [],
"source": [
"model.compile(method='regression', initializers=initializers, optimizers=optimizers,\n",
" lr_schedulers=schedulers, callbacks=callbacks, transforms=transforms)"
......@@ -1491,7 +1482,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 21,
"metadata": {},
"outputs": [
{
......@@ -1503,15 +1494,15 @@
" )\n",
" (deepdense): DeepDense(\n",
" (embed_layers): ModuleDict(\n",
" (emb_layer_accommodates_catg): Embedding(3, 16)\n",
" (emb_layer_bathrooms_catg): Embedding(3, 16)\n",
" (emb_layer_bedrooms_catg): Embedding(4, 16)\n",
" (emb_layer_beds_catg): Embedding(4, 16)\n",
" (emb_layer_cancellation_policy): Embedding(5, 16)\n",
" (emb_layer_guests_included_catg): Embedding(3, 16)\n",
" (emb_layer_host_listings_count_catg): Embedding(4, 16)\n",
" (emb_layer_minimum_nights_catg): Embedding(3, 16)\n",
" (emb_layer_neighbourhood_cleansed): Embedding(32, 64)\n",
" (emb_layer_accommodates_catg): Embedding(4, 16)\n",
" (emb_layer_bathrooms_catg): Embedding(4, 16)\n",
" (emb_layer_bedrooms_catg): Embedding(5, 16)\n",
" (emb_layer_beds_catg): Embedding(5, 16)\n",
" (emb_layer_cancellation_policy): Embedding(6, 16)\n",
" (emb_layer_guests_included_catg): Embedding(4, 16)\n",
" (emb_layer_host_listings_count_catg): Embedding(5, 16)\n",
" (emb_layer_minimum_nights_catg): Embedding(4, 16)\n",
" (emb_layer_neighbourhood_cleansed): Embedding(33, 64)\n",
" )\n",
" (embed_dropout): Dropout(p=0.0, inplace=False)\n",
" (dense): Sequential(\n",
......@@ -1644,7 +1635,7 @@
")"
]
},
"execution_count": 24,
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
......@@ -1655,7 +1646,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 22,
"metadata": {},
"outputs": [
{
......@@ -1677,8 +1668,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 25/25 [02:04<00:00, 4.97s/it, loss=127]\n",
"valid: 100%|██████████| 7/7 [00:14<00:00, 2.02s/it, loss=94] \n"
"epoch 1: 100%|██████████| 25/25 [02:02<00:00, 4.88s/it, loss=128]\n",
"valid: 100%|██████████| 7/7 [00:14<00:00, 2.09s/it, loss=94.5]\n"
]
}
],
......@@ -1696,7 +1687,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 23,
"metadata": {},
"outputs": [
{
......@@ -1721,7 +1712,7 @@
" 'lr_deephead_0': [0.001, 0.001]}"
]
},
"execution_count": 26,
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
......@@ -1747,7 +1738,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
"version": "3.7.7"
}
},
"nbformat": 4,
......
......@@ -4,7 +4,7 @@ import pandas as pd
from pytorch_widedeep.models import Wide, WideDeep, DeepDense
from pytorch_widedeep.metrics import CategoricalAccuracy
from pytorch_widedeep.preprocessing import DeepPreprocessor, WidePreprocessor
from pytorch_widedeep.preprocessing import DensePreprocessor, WidePreprocessor
use_cuda = torch.cuda.is_available()
......@@ -35,7 +35,7 @@ if __name__ == "__main__":
prepare_wide = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)
X_wide = prepare_wide.fit_transform(df)
prepare_deep = DeepPreprocessor(
prepare_deep = DensePreprocessor(
embed_cols=cat_embed_cols, continuous_cols=continuous_cols
)
X_deep = prepare_deep.fit_transform(df)
......@@ -47,7 +47,7 @@ if __name__ == "__main__":
embed_input=prepare_deep.embeddings_input,
continuous_cols=continuous_cols,
)
model = WideDeep(wide=wide, deepdense=deepdense, output_dim=3)
model = WideDeep(wide=wide, deepdense=deepdense, pred_dim=3)
model.compile(method="multiclass", metrics=[CategoricalAccuracy])
model.fit(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册