提交 74954645 编写于 作者: J jrzaurin

notebooks consistent with README and scripts

上级 66354228
因为 它太大了无法显示 source diff 。你可以改为 查看blob
因为 它太大了无法显示 source diff 。你可以改为 查看blob
......@@ -32,7 +32,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
......@@ -499,7 +499,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
......@@ -541,22 +541,21 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Reading Images from data/airbnb/property_picture\n",
"BGR to RGB\n"
"Reading Images from data/airbnb/property_picture\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 1%| | 41/4999 [00:00<00:12, 406.51it/s]"
" 1%| | 41/5000 [00:00<00:12, 405.03it/s]"
]
},
{
......@@ -570,18 +569,18 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 4999/4999 [00:12<00:00, 389.79it/s]\n"
"100%|██████████| 5000/5000 [00:12<00:00, 384.31it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Our vocabulary contains 12434 words\n",
"Our vocabulary contains 12433 words\n",
"Indexing word vectors...\n",
"Loaded 400000 word vectors\n",
"Preparing embeddings matrix...\n",
"6827 words in our vocabulary had glove vectors and appear more than the min frequency\n",
"6776 words in our vocabulary had glove vectors and appear more than the min frequency\n",
"Wide and Deep airbnb data preparation completed.\n"
]
}
......@@ -619,7 +618,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 56,
"metadata": {},
"outputs": [],
"source": [
......@@ -640,7 +639,7 @@
" vocab_size = len(wd_dataset_airbnb['vocab'].itos),\n",
" embedding_dim = wd_dataset_airbnb['word_embeddings_matrix'].shape[1],\n",
" hidden_dim = 64,\n",
" n_layers = 3,\n",
" n_layers = 2,\n",
" rnn_dropout = 0.5,\n",
" spatial_dropout = 0.1,\n",
" padding_idx = 1,\n",
......@@ -665,7 +664,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
......@@ -675,7 +674,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 58,
"metadata": {},
"outputs": [
{
......@@ -687,13 +686,13 @@
" )\n",
" (deep_dense): DeepDense(\n",
" (emb_layer_neighbourhood_cleansed): Embedding(33, 64)\n",
" (emb_layer_beds_catg): Embedding(4, 16)\n",
" (emb_layer_host_listings_count_catg): Embedding(4, 16)\n",
" (emb_layer_bathrooms_catg): Embedding(3, 16)\n",
" (emb_layer_host_listings_count_catg): Embedding(4, 16)\n",
" (emb_layer_minimum_nights_catg): Embedding(3, 16)\n",
" (emb_layer_beds_catg): Embedding(4, 16)\n",
" (emb_layer_bedrooms_catg): Embedding(4, 16)\n",
" (emb_layer_accommodates_catg): Embedding(3, 16)\n",
" (emb_layer_guests_included_catg): Embedding(3, 16)\n",
" (emb_layer_accommodates_catg): Embedding(3, 16)\n",
" (dense): Sequential(\n",
" (dense_layer_0): Sequential(\n",
" (0): Linear(in_features=180, out_features=64, bias=True)\n",
......@@ -710,8 +709,8 @@
" )\n",
" (deep_text): DeepText(\n",
" (embedding_dropout): Dropout2d(p=0.1)\n",
" (embedding): Embedding(7165, 300, padding_idx=1)\n",
" (rnn): GRU(300, 64, num_layers=3, batch_first=True, dropout=0.5)\n",
" (embedding): Embedding(7093, 300, padding_idx=1)\n",
" (rnn): GRU(300, 64, num_layers=2, batch_first=True, dropout=0.5)\n",
" (dtlinear): Linear(in_features=64, out_features=1, bias=True)\n",
" )\n",
" (deep_img): DeepImage(\n",
......@@ -859,7 +858,7 @@
")"
]
},
"execution_count": 9,
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
......@@ -877,7 +876,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
......@@ -901,7 +900,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
......@@ -910,7 +909,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
......@@ -919,12 +918,15 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"mean=[0.485, 0.456, 0.406] #RGB\n",
"std=[0.229, 0.224, 0.225] # RGB\n",
"# cv2 reads bgr\n",
"# mean=[0.485, 0.456, 0.406] #RGB\n",
"# std=[0.229, 0.224, 0.225] #RGB\n",
"mean=[0.406, 0.456, 0.485] #RGB\n",
"std=[0.225, 0.224, 0.229] #RGB\n",
"transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean=mean, std=std)\n",
......@@ -933,32 +935,32 @@
"valid_set = WideDeepLoader(wd_dataset_airbnb['valid'], transform, mode='train')\n",
"test_set = WideDeepLoader(wd_dataset_airbnb['test'], transform, mode='test')\n",
"train_loader = torch.utils.data.DataLoader(dataset=train_set,\n",
" batch_size=128,shuffle=True)\n",
" batch_size=64,shuffle=True)\n",
"valid_loader = torch.utils.data.DataLoader(dataset=valid_set,\n",
" batch_size=128,shuffle=True)\n",
" batch_size=64,shuffle=True)\n",
"test_loader = torch.utils.data.DataLoader(dataset=test_set,\n",
" batch_size=32,shuffle=False)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 63,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 24/24 [00:21<00:00, 1.24it/s, loss=128]\n",
"valid: 100%|██████████| 8/8 [00:03<00:00, 2.11it/s, loss=1.3e+3] \n",
"epoch 2: 100%|██████████| 24/24 [00:21<00:00, 1.25it/s, loss=110]\n",
"valid: 100%|██████████| 8/8 [00:03<00:00, 2.07it/s, loss=98.8]\n",
"epoch 3: 100%|██████████| 24/24 [00:21<00:00, 1.25it/s, loss=106]\n",
"valid: 100%|██████████| 8/8 [00:03<00:00, 2.08it/s, loss=98.3]\n",
"epoch 4: 100%|██████████| 24/24 [00:21<00:00, 1.25it/s, loss=106]\n",
"valid: 100%|██████████| 8/8 [00:04<00:00, 2.03it/s, loss=97.6]\n",
"epoch 5: 100%|██████████| 24/24 [00:21<00:00, 1.26it/s, loss=104]\n",
"valid: 100%|██████████| 8/8 [00:04<00:00, 2.05it/s, loss=98.2]\n"
"epoch 1: 100%|██████████| 47/47 [00:22<00:00, 2.05it/s, loss=118]\n",
"valid: 100%|██████████| 16/16 [00:04<00:00, 4.06it/s, loss=117]\n",
"epoch 2: 100%|██████████| 47/47 [00:22<00:00, 2.14it/s, loss=105]\n",
"valid: 100%|██████████| 16/16 [00:04<00:00, 4.20it/s, loss=158]\n",
"epoch 3: 100%|██████████| 47/47 [00:22<00:00, 2.15it/s, loss=99.2]\n",
"valid: 100%|██████████| 16/16 [00:03<00:00, 4.27it/s, loss=99] \n",
"epoch 4: 100%|██████████| 47/47 [00:22<00:00, 2.13it/s, loss=97] \n",
"valid: 100%|██████████| 16/16 [00:04<00:00, 4.06it/s, loss=99.6]\n",
"epoch 5: 100%|██████████| 47/47 [00:22<00:00, 2.15it/s, loss=94.7]\n",
"valid: 100%|██████████| 16/16 [00:04<00:00, 4.24it/s, loss=99] \n"
]
}
],
......@@ -968,21 +970,21 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 64,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"predict: 100%|██████████| 32/32 [00:04<00:00, 6.40it/s]"
"predict: 100%|██████████| 32/32 [00:04<00:00, 6.73it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"103.62761727684484\n"
"105.88571074249793\n"
]
},
{
......@@ -1008,7 +1010,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 65,
"metadata": {},
"outputs": [],
"source": [
......@@ -1017,48 +1019,48 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 66,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'Hackney': array([-1.974714, 0.863323, -1.459528, 1.365437, ..., -0.14018 , -0.524041, -0.316058, -0.410313], dtype=float32),\n",
" 'Camden': array([ 0.410322, -0.724736, -0.56029 , 0.079272, ..., -0.220795, -1.234073, -2.003421, 0.949352], dtype=float32),\n",
" 'Hillingdon': array([-0.853097, 0.266621, 0.08445 , 1.554609, ..., -0.139359, -0.419641, 0.74903 , -0.73454 ], dtype=float32),\n",
" 'Southwark': array([-0.025719, 0.1034 , -0.930781, 0.259913, ..., -0.762355, 0.91882 , -0.158588, -0.399257], dtype=float32),\n",
" 'Islington': array([-1.344581, 0.027291, -0.482429, -0.05447 , ..., 0.00804 , 1.058191, 1.389844, 0.976658], dtype=float32),\n",
" 'Lambeth': array([-0.503692, 1.997271, 0.204253, 2.161032, ..., -1.5392 , -0.274035, 1.052611, -1.117137], dtype=float32),\n",
" 'Westminster': array([-0.160366, -0.897001, -1.555466, -0.108475, ..., 0.095278, -1.265424, 0.48901 , 1.249886], dtype=float32),\n",
" 'Kensington and Chelsea': array([-0.941537, -0.79041 , -0.163071, 2.514469, ..., -1.033578, -2.369078, 0.659031, -1.707059], dtype=float32),\n",
" 'Ealing': array([-0.175527, 0.796092, -0.495381, -0.922058, ..., 2.760556, 1.751993, 0.105334, -0.786012], dtype=float32),\n",
" 'Newham': array([-1.281402, 0.662634, 1.103911, -0.32145 , ..., -0.99564 , -0.359498, -0.36606 , -0.431747], dtype=float32),\n",
" 'Tower Hamlets': array([ 0.271887, -0.468121, 0.431422, 0.281864, ..., 1.437819, -0.144674, -0.62363 , 1.09578 ], dtype=float32),\n",
" 'Waltham Forest': array([ 0.666032, -0.498509, 1.038276, -0.642088, ..., 1.354977, -0.626648, 0.817839, -0.277589], dtype=float32),\n",
" 'Wandsworth': array([ 0.07317 , 1.094801, 1.50066 , -1.284966, ..., -0.905864, -0.143956, 0.490856, -1.177169], dtype=float32),\n",
" 'Bromley': array([-0.486731, 1.064391, -0.700393, 1.08445 , ..., 1.852425, -2.306012, 0.341955, 1.354509], dtype=float32),\n",
" 'Merton': array([ 0.910861, 0.590595, 1.780486, 0.106199, ..., -0.178996, 1.284232, -1.418589, -0.559087], dtype=float32),\n",
" 'Harrow': array([-0.586278, -1.161813, -0.196221, -1.276113, ..., -0.647752, -0.788815, 0.564165, -1.408766], dtype=float32),\n",
" 'Hammersmith and Fulham': array([ 0.126797, 1.058752, -0.480162, -0.246861, ..., -1.77949 , -0.367484, -0.493584, 1.963025], dtype=float32),\n",
" 'Croydon': array([ 0.831145, 1.084541, 0.18434 , -1.17936 , ..., -0.644541, 1.287308, -0.437237, -0.036385], dtype=float32),\n",
" 'Kingston upon Thames': array([-1.138847, 0.353119, 0.453814, -1.537953, ..., 0.433265, -1.492994, 0.49617 , 0.32435 ], dtype=float32),\n",
" 'Greenwich': array([-0.375921, -0.006278, -1.202842, 0.710625, ..., -1.025979, 1.733119, 0.30344 , -0.333253], dtype=float32),\n",
" 'Barnet': array([ 0.112118, 0.253998, -0.083625, 0.991614, ..., -0.534774, -0.836971, -0.546797, -0.074188], dtype=float32),\n",
" 'Sutton': array([-0.581444, 1.245627, -0.268801, 0.220981, ..., -0.484765, 0.235929, -0.339714, 0.383085], dtype=float32),\n",
" 'Brent': array([ 0.684018, 1.122321, -0.578836, 1.606611, ..., 0.772141, 1.470037, 0.896998, -0.527985], dtype=float32),\n",
" 'Hounslow': array([ 0.186708, -0.859478, 1.421476, -0.367777, ..., 0.641241, -1.193496, -0.028203, -0.5517 ], dtype=float32),\n",
" 'Haringey': array([-0.520568, 1.791861, -1.58973 , -0.191609, ..., -0.763657, 0.877533, -0.751659, -0.779809], dtype=float32),\n",
" 'Lewisham': array([ 0.94821 , 1.358798, 2.511026, -0.085026, ..., -0.201893, 0.56487 , -1.297939, -0.988157], dtype=float32),\n",
" 'Enfield': array([ 0.839919, 3.486744, 0.209266, 0.486838, ..., -0.255266, 1.530212, 0.036456, -1.841724], dtype=float32),\n",
" 'Redbridge': array([ 0.562378, 0.117497, -0.163867, -0.418552, ..., -1.083247, 0.736799, 0.942412, -1.175167], dtype=float32),\n",
" 'Richmond upon Thames': array([-0.454952, 1.581438, 2.934833, 0.976836, ..., -1.138685, -0.491665, -1.153073, 0.511827], dtype=float32),\n",
" 'Barking and Dagenham': array([-1.24154 , -0.185623, 0.239697, 1.636263, ..., 1.715875, -1.810889, 0.612478, 0.533097], dtype=float32),\n",
" 'Bexley': array([ 0.593676, 1.990228, 0.449346, -1.309624, ..., 0.541702, 1.06037 , -0.971935, -1.45975 ], dtype=float32),\n",
" 'City of London': array([ 0.210624, 0.413433, -0.161755, 0.762585, ..., 2.038467, 0.73799 , 0.154868, -0.416979], dtype=float32),\n",
" 'Havering': array([ 0.082899, -0.406886, 0.492994, 0.116197, ..., 0.465383, -1.56611 , -1.431871, -0.940736], dtype=float32)}"
"{'Hammersmith and Fulham': array([-1.851022, 0.703244, 0.34214 , -1.215795, ..., -0.218133, -0.107084, -0.050649, -1.276854], dtype=float32),\n",
" 'Barnet': array([-1.739752, -0.935879, 0.335114, -1.109356, ..., -2.491122, -0.632074, 2.492794, 1.231859], dtype=float32),\n",
" 'Ealing': array([-0.89164 , 0.005264, -2.254685, 0.572432, ..., -1.358266, -1.876953, 1.084197, 0.825844], dtype=float32),\n",
" 'Greenwich': array([-0.686692, -0.623536, 1.663162, 1.130035, ..., -0.359084, -0.609575, 0.304735, -1.042224], dtype=float32),\n",
" 'Lambeth': array([-0.047345, -0.697568, 0.793924, -0.18951 , ..., 0.230893, -0.170741, -0.592736, -0.755723], dtype=float32),\n",
" 'Lewisham': array([-0.302935, 1.052123, 0.883626, 0.127071, ..., -0.047294, -0.667769, 1.237696, 1.278981], dtype=float32),\n",
" 'Richmond upon Thames': array([-1.108736, 0.175303, -1.596437, -0.13958 , ..., 0.557685, 0.076416, -0.171436, 1.561785], dtype=float32),\n",
" 'Wandsworth': array([-0.133121, -1.265229, -0.536881, -0.235154, ..., 0.39987 , -0.759289, 0.188098, -1.317402], dtype=float32),\n",
" 'Camden': array([ 1.093842, -0.654579, 0.45953 , -1.833444, ..., 0.678493, -0.840447, -0.144676, 0.803519], dtype=float32),\n",
" 'Southwark': array([ 0.632342, -1.031606, -1.9757 , 1.434942, ..., 0.477039, -0.716686, -1.574186, 0.361259], dtype=float32),\n",
" 'Westminster': array([-0.006769, 1.014129, 0.38176 , 1.087195, ..., 0.619644, 0.145372, -1.583134, -0.53737 ], dtype=float32),\n",
" 'Newham': array([-0.940997, 0.449212, 0.006719, -0.971067, ..., -1.480452, 1.291778, -2.473881, -0.788751], dtype=float32),\n",
" 'Tower Hamlets': array([-0.257651, -1.176663, -0.254655, 0.915376, ..., 0.678123, -1.044624, 0.056251, -2.595061], dtype=float32),\n",
" 'Hackney': array([-0.032634, -1.021449, 0.060701, -0.706772, ..., 1.819459, -0.264873, -0.062177, 0.125787], dtype=float32),\n",
" 'Merton': array([-0.380734, -0.538981, -0.415401, 1.023716, ..., 1.238271, 1.291368, 0.297376, -0.182454], dtype=float32),\n",
" 'Haringey': array([-0.39299 , -0.500936, -1.375325, 0.125624, ..., 1.305327, 0.661907, 0.926493, -2.139146], dtype=float32),\n",
" 'Islington': array([ 0.407774, -0.393032, 0.543974, 0.567474, ..., -0.864082, -0.724516, -0.102573, -1.721182], dtype=float32),\n",
" 'Havering': array([ 0.012448, 0.703619, 1.464297, 0.258891, ..., -0.021862, -0.373643, -0.002513, -0.207162], dtype=float32),\n",
" 'Brent': array([-1.178367, 0.598762, -0.947129, -0.834452, ..., -0.257813, 0.121773, -1.224157, -0.314848], dtype=float32),\n",
" 'Kensington and Chelsea': array([-0.056346, -0.968234, 0.959274, -1.280915, ..., 0.104357, 0.123029, 0.263767, -1.308004], dtype=float32),\n",
" 'Croydon': array([-0.697593, -0.6681 , 0.333914, 1.499083, ..., -0.066999, -0.861174, 1.097988, 0.951798], dtype=float32),\n",
" 'Hounslow': array([-1.200491, 1.092508, -1.106972, -0.012594, ..., -0.200799, -1.20688 , -0.161942, 1.381369], dtype=float32),\n",
" 'Hillingdon': array([ 0.560095, -1.294285, 1.093274, -0.846406, ..., -1.553291, 0.536398, -0.276216, 1.071183], dtype=float32),\n",
" 'Enfield': array([ 0.907168, -0.619706, 1.104722, 0.555476, ..., -1.789324, -1.150976, -0.895986, 0.604396], dtype=float32),\n",
" 'Waltham Forest': array([-1.09497 , 0.34694 , -2.238035, 0.603189, ..., -0.058768, -1.1416 , -0.313438, 0.146215], dtype=float32),\n",
" 'Harrow': array([-1.697243, -0.534554, -0.166178, 0.088522, ..., -0.103859, -1.85962 , -0.387001, -0.54297 ], dtype=float32),\n",
" 'Redbridge': array([-0.439203, -1.083957, -0.588426, -0.761668, ..., 0.055353, 1.498796, 0.979085, -0.972337], dtype=float32),\n",
" 'Bromley': array([-2.125255, -0.03386 , 1.223974, 0.211232, ..., 0.094852, 0.507099, -0.205213, 1.35158 ], dtype=float32),\n",
" 'Sutton': array([ 0.3636 , -0.62571 , 0.074913, -0.408316, ..., 1.613977, 1.136851, -2.397302, -0.385847], dtype=float32),\n",
" 'City of London': array([ 1.146366, 0.461671, -1.337096, 0.136036, ..., 0.683757, -0.658017, 0.520523, 0.580623], dtype=float32),\n",
" 'Barking and Dagenham': array([-2.43159 , -0.656072, -1.209697, 1.669525, ..., 0.124255, 2.291806, -0.741579, -1.083829], dtype=float32),\n",
" 'Kingston upon Thames': array([-1.791893, -0.873715, 0.819117, -1.091105, ..., -0.759607, -0.174701, 1.495903, 0.518327], dtype=float32),\n",
" 'Bexley': array([-0.761755, -0.174487, -1.790552, -0.710304, ..., 1.084789, 1.210242, 0.164761, 0.111679], dtype=float32)}"
]
},
"execution_count": 17,
"execution_count": 66,
"metadata": {},
"output_type": "execute_result"
}
......@@ -1085,7 +1087,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 67,
"metadata": {},
"outputs": [],
"source": [
......@@ -1095,7 +1097,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 68,
"metadata": {},
"outputs": [],
"source": [
......@@ -1127,22 +1129,21 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 69,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Reading Images from data/airbnb/property_picture\n",
"BGR to RGB\n"
"Reading Images from data/airbnb/property_picture\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 1%| | 41/4998 [00:00<00:12, 402.31it/s]"
" 1%| | 41/5000 [00:00<00:12, 402.21it/s]"
]
},
{
......@@ -1156,18 +1157,18 @@
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 4998/4998 [00:12<00:00, 392.41it/s]\n"
"100%|██████████| 5000/5000 [00:12<00:00, 387.15it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Our vocabulary contains 12433 words\n",
"Our vocabulary contains 12675 words\n",
"Indexing word vectors...\n",
"Loaded 400000 word vectors\n",
"Preparing embeddings matrix...\n",
"6869 words in our vocabulary had glove vectors and appear more than the min frequency\n",
"6786 words in our vocabulary had glove vectors and appear more than the min frequency\n",
"Wide and Deep airbnb data preparation completed.\n"
]
}
......@@ -1198,689 +1199,16 @@
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[[[115, 153, 158],\n",
" [139, 182, 189],\n",
" [152, 195, 201],\n",
" [157, 205, 199],\n",
" ...,\n",
" [128, 87, 61],\n",
" [124, 81, 54],\n",
" [132, 85, 50],\n",
" [140, 93, 58]],\n",
"\n",
" [[143, 179, 182],\n",
" [149, 189, 193],\n",
" [149, 189, 194],\n",
" [149, 194, 190],\n",
" ...,\n",
" [180, 127, 91],\n",
" [166, 114, 78],\n",
" [163, 113, 72],\n",
" [161, 111, 71]],\n",
"\n",
" [[138, 170, 169],\n",
" [165, 201, 201],\n",
" [142, 178, 179],\n",
" [147, 186, 187],\n",
" ...,\n",
" [168, 104, 57],\n",
" [148, 87, 42],\n",
" [141, 86, 38],\n",
" [138, 86, 41]],\n",
"\n",
" [[145, 174, 170],\n",
" [170, 202, 198],\n",
" [143, 175, 173],\n",
" [141, 174, 177],\n",
" ...,\n",
" [140, 75, 26],\n",
" [133, 72, 24],\n",
" [137, 81, 29],\n",
" [135, 83, 34]],\n",
"\n",
" ...,\n",
"\n",
" [[ 82, 73, 71],\n",
" [ 83, 77, 74],\n",
" [ 81, 75, 73],\n",
" [ 64, 50, 46],\n",
" ...,\n",
" [ 76, 67, 68],\n",
" [ 74, 66, 67],\n",
" [ 74, 68, 68],\n",
" [ 74, 65, 66]],\n",
"\n",
" [[ 84, 75, 73],\n",
" [ 86, 80, 77],\n",
" [ 74, 67, 66],\n",
" [ 58, 46, 43],\n",
" ...,\n",
" [ 76, 67, 68],\n",
" [ 73, 65, 66],\n",
" [ 73, 67, 67],\n",
" [ 74, 66, 67]],\n",
"\n",
" [[ 85, 77, 74],\n",
" [ 85, 79, 76],\n",
" [ 64, 58, 56],\n",
" [ 58, 47, 46],\n",
" ...,\n",
" [ 77, 68, 69],\n",
" [ 76, 68, 69],\n",
" [ 73, 67, 67],\n",
" [ 72, 64, 64]],\n",
"\n",
" [[ 85, 76, 74],\n",
" [ 81, 75, 72],\n",
" [ 66, 60, 59],\n",
" [ 69, 60, 60],\n",
" ...,\n",
" [ 78, 69, 70],\n",
" [ 77, 69, 70],\n",
" [ 74, 68, 68],\n",
" [ 75, 66, 67]]],\n",
"\n",
"\n",
" [[[197, 168, 128],\n",
" [196, 168, 128],\n",
" [200, 173, 136],\n",
" [209, 187, 154],\n",
" ...,\n",
" [180, 150, 112],\n",
" [180, 150, 112],\n",
" [180, 150, 112],\n",
" [179, 149, 111]],\n",
"\n",
" [[208, 179, 137],\n",
" [203, 174, 134],\n",
" [196, 167, 129],\n",
" [193, 168, 132],\n",
" ...,\n",
" [180, 150, 112],\n",
" [180, 150, 112],\n",
" [180, 150, 112],\n",
" [179, 149, 111]],\n",
"\n",
" [[210, 179, 134],\n",
" [210, 180, 138],\n",
" [210, 180, 140],\n",
" [206, 177, 136],\n",
" ...,\n",
" [180, 150, 112],\n",
" [180, 150, 112],\n",
" [180, 150, 112],\n",
" [179, 149, 111]],\n",
"\n",
" [[211, 181, 132],\n",
" [212, 181, 134],\n",
" [212, 180, 135],\n",
" [215, 183, 136],\n",
" ...,\n",
" [180, 150, 112],\n",
" [180, 150, 112],\n",
" [180, 150, 112],\n",
" [179, 149, 111]],\n",
"\n",
" ...,\n",
"\n",
" [[183, 209, 244],\n",
" [181, 207, 242],\n",
" [180, 205, 240],\n",
" [179, 202, 236],\n",
" ...,\n",
" [175, 112, 58],\n",
" [174, 111, 57],\n",
" [173, 110, 56],\n",
" [172, 109, 55]],\n",
"\n",
" [[180, 205, 240],\n",
" [177, 202, 237],\n",
" [173, 198, 232],\n",
" [169, 192, 226],\n",
" ...,\n",
" [174, 111, 57],\n",
" [174, 111, 57],\n",
" [175, 112, 58],\n",
" [174, 111, 57]],\n",
"\n",
" [[176, 199, 233],\n",
" [172, 195, 229],\n",
" [168, 191, 225],\n",
" [168, 191, 223],\n",
" ...,\n",
" [175, 112, 58],\n",
" [174, 111, 57],\n",
" [173, 110, 56],\n",
" [172, 109, 55]],\n",
"\n",
" [[170, 191, 226],\n",
" [167, 188, 222],\n",
" [169, 189, 222],\n",
" [169, 192, 224],\n",
" ...,\n",
" [177, 114, 60],\n",
" [175, 112, 58],\n",
" [174, 111, 57],\n",
" [172, 109, 55]]],\n",
"\n",
"\n",
" [[[197, 201, 204],\n",
" [198, 202, 205],\n",
" [198, 202, 205],\n",
" [201, 202, 206],\n",
" ...,\n",
" [189, 190, 192],\n",
" [188, 189, 191],\n",
" [187, 188, 190],\n",
" [186, 187, 189]],\n",
"\n",
" [[196, 200, 203],\n",
" [197, 201, 204],\n",
" [198, 202, 205],\n",
" [201, 202, 206],\n",
" ...,\n",
" [189, 190, 192],\n",
" [189, 190, 192],\n",
" [188, 189, 191],\n",
" [187, 188, 190]],\n",
"\n",
" [[196, 199, 202],\n",
" [197, 200, 203],\n",
" [198, 200, 204],\n",
" [200, 201, 205],\n",
" ...,\n",
" [190, 191, 193],\n",
" [190, 191, 193],\n",
" [189, 190, 192],\n",
" [188, 189, 191]],\n",
"\n",
" [[197, 198, 202],\n",
" [198, 199, 203],\n",
" [198, 199, 203],\n",
" [199, 200, 204],\n",
" ...,\n",
" [191, 192, 194],\n",
" [190, 191, 193],\n",
" [189, 190, 192],\n",
" [189, 190, 192]],\n",
"\n",
" ...,\n",
"\n",
" [[152, 143, 98],\n",
" [152, 143, 100],\n",
" [158, 145, 103],\n",
" [150, 137, 93],\n",
" ...,\n",
" [179, 178, 180],\n",
" [175, 172, 171],\n",
" [175, 165, 161],\n",
" [171, 161, 155]],\n",
"\n",
" [[151, 142, 97],\n",
" [150, 141, 98],\n",
" [156, 144, 102],\n",
" [157, 147, 107],\n",
" ...,\n",
" [195, 206, 217],\n",
" [188, 195, 206],\n",
" [186, 186, 195],\n",
" [203, 201, 208]],\n",
"\n",
" [[158, 149, 103],\n",
" [159, 151, 105],\n",
" [171, 159, 116],\n",
" [167, 159, 120],\n",
" ...,\n",
" [188, 197, 212],\n",
" [184, 191, 207],\n",
" [175, 178, 196],\n",
" [187, 189, 205]],\n",
"\n",
" [[170, 161, 115],\n",
" [172, 164, 118],\n",
" [175, 163, 119],\n",
" [174, 163, 122],\n",
" ...,\n",
" [188, 194, 209],\n",
" [187, 192, 208],\n",
" [183, 185, 203],\n",
" [174, 177, 194]]],\n",
"\n",
"\n",
" [[[ 87, 95, 107],\n",
" [ 84, 92, 103],\n",
" [ 83, 90, 101],\n",
" [ 80, 87, 93],\n",
" ...,\n",
" [106, 96, 87],\n",
" [106, 96, 87],\n",
" [106, 97, 88],\n",
" [105, 96, 87]],\n",
"\n",
" [[ 89, 97, 110],\n",
" [ 86, 94, 106],\n",
" [ 84, 92, 103],\n",
" [ 82, 89, 95],\n",
" ...,\n",
" [106, 96, 87],\n",
" [106, 96, 87],\n",
" [106, 97, 88],\n",
" [105, 96, 87]],\n",
"\n",
" [[ 92, 101, 114],\n",
" [ 89, 97, 110],\n",
" [ 87, 94, 107],\n",
" [ 87, 92, 98],\n",
" ...,\n",
" [106, 96, 87],\n",
" [106, 96, 87],\n",
" [106, 97, 88],\n",
" [105, 96, 87]],\n",
"\n",
" [[ 96, 103, 120],\n",
" [ 91, 99, 115],\n",
" [ 89, 96, 110],\n",
" [ 90, 95, 101],\n",
" ...,\n",
" [106, 96, 87],\n",
" [106, 96, 87],\n",
" [106, 97, 88],\n",
" [105, 96, 87]],\n",
"\n",
" ...,\n",
"\n",
" [[ 65, 62, 57],\n",
" [ 65, 62, 57],\n",
" [ 65, 62, 57],\n",
" [ 65, 62, 57],\n",
" ...,\n",
" [ 53, 39, 38],\n",
" [ 51, 37, 36],\n",
" [ 50, 36, 35],\n",
" [ 51, 37, 36]],\n",
"\n",
" [[ 65, 62, 57],\n",
" [ 65, 62, 57],\n",
" [ 65, 62, 57],\n",
" [ 65, 62, 57],\n",
" ...,\n",
" [ 49, 35, 34],\n",
" [ 47, 33, 32],\n",
" [ 45, 31, 30],\n",
" [ 49, 35, 34]],\n",
"\n",
" [[ 65, 62, 57],\n",
" [ 65, 62, 57],\n",
" [ 65, 62, 57],\n",
" [ 65, 62, 57],\n",
" ...,\n",
" [ 41, 27, 26],\n",
" [ 43, 29, 28],\n",
" [ 46, 32, 31],\n",
" [ 46, 32, 31]],\n",
"\n",
" [[ 66, 63, 58],\n",
" [ 65, 62, 57],\n",
" [ 64, 61, 56],\n",
" [ 65, 62, 57],\n",
" ...,\n",
" [ 35, 21, 20],\n",
" [ 44, 30, 29],\n",
" [ 50, 36, 35],\n",
" [ 44, 30, 29]]],\n",
"\n",
"\n",
" ...,\n",
"\n",
"\n",
" [[[180, 179, 175],\n",
" [180, 179, 175],\n",
" [180, 179, 175],\n",
" [182, 181, 177],\n",
" ...,\n",
" [198, 197, 192],\n",
" [195, 194, 189],\n",
" [190, 189, 184],\n",
" [195, 194, 189]],\n",
"\n",
" [[180, 179, 175],\n",
" [180, 179, 175],\n",
" [180, 179, 175],\n",
" [182, 181, 177],\n",
" ...,\n",
" [198, 197, 192],\n",
" [193, 192, 187],\n",
" [190, 189, 184],\n",
" [195, 194, 189]],\n",
"\n",
" [[180, 179, 175],\n",
" [180, 179, 175],\n",
" [180, 179, 175],\n",
" [182, 181, 177],\n",
" ...,\n",
" [197, 196, 191],\n",
" [192, 191, 186],\n",
" [190, 189, 184],\n",
" [195, 194, 189]],\n",
"\n",
" [[180, 179, 175],\n",
" [180, 179, 175],\n",
" [180, 179, 175],\n",
" [182, 181, 177],\n",
" ...,\n",
" [196, 195, 190],\n",
" [193, 192, 187],\n",
" [190, 189, 184],\n",
" [195, 194, 189]],\n",
"\n",
" ...,\n",
"\n",
" [[ 55, 47, 44],\n",
" [ 83, 75, 72],\n",
" [100, 92, 89],\n",
" [ 72, 64, 62],\n",
" ...,\n",
" [ 76, 36, 44],\n",
" [ 77, 34, 41],\n",
" [ 79, 33, 38],\n",
" [ 85, 44, 48]],\n",
"\n",
" [[ 74, 64, 62],\n",
" [ 65, 55, 53],\n",
" [ 73, 64, 62],\n",
" [ 86, 78, 76],\n",
" ...,\n",
" [ 77, 37, 45],\n",
" [ 89, 47, 55],\n",
" [ 79, 36, 44],\n",
" [ 71, 33, 39]],\n",
"\n",
" [[ 61, 51, 49],\n",
" [ 80, 70, 68],\n",
" [ 67, 57, 55],\n",
" [ 64, 56, 54],\n",
" ...,\n",
" [ 85, 45, 53],\n",
" [ 86, 46, 54],\n",
" [ 79, 39, 48],\n",
" [ 61, 27, 33]],\n",
"\n",
" [[ 81, 71, 69],\n",
" [ 73, 63, 61],\n",
" [ 63, 53, 51],\n",
" [ 62, 54, 52],\n",
" ...,\n",
" [ 77, 37, 45],\n",
" [ 78, 38, 46],\n",
" [ 85, 48, 55],\n",
" [ 46, 14, 19]]],\n",
"\n",
"\n",
" [[[224, 224, 232],\n",
" [224, 224, 232],\n",
" [224, 224, 232],\n",
" [225, 225, 235],\n",
" ...,\n",
" [215, 208, 215],\n",
" [216, 208, 215],\n",
" [217, 207, 215],\n",
" [217, 207, 215]],\n",
"\n",
" [[225, 225, 233],\n",
" [224, 224, 232],\n",
" [223, 223, 231],\n",
" [223, 223, 233],\n",
" ...,\n",
" [215, 208, 215],\n",
" [215, 207, 215],\n",
" [217, 207, 215],\n",
" [217, 207, 215]],\n",
"\n",
" [[227, 227, 235],\n",
" [225, 225, 233],\n",
" [224, 224, 232],\n",
" [223, 223, 233],\n",
" ...,\n",
" [214, 207, 214],\n",
" [215, 206, 214],\n",
" [216, 206, 214],\n",
" [216, 206, 214]],\n",
"\n",
" [[228, 228, 236],\n",
" [228, 228, 236],\n",
" [228, 228, 236],\n",
" [227, 227, 237],\n",
" ...,\n",
" [213, 206, 213],\n",
" [214, 206, 213],\n",
" [215, 205, 213],\n",
" [215, 205, 213]],\n",
"\n",
" ...,\n",
"\n",
" [[120, 114, 121],\n",
" [123, 116, 123],\n",
" [125, 118, 125],\n",
" [124, 117, 124],\n",
" ...,\n",
" [ 83, 48, 15],\n",
" [ 76, 44, 15],\n",
" [ 57, 32, 14],\n",
" [127, 110, 106]],\n",
"\n",
" [[125, 119, 125],\n",
" [124, 117, 124],\n",
" [125, 118, 125],\n",
" [129, 122, 129],\n",
" ...,\n",
" [ 80, 44, 8],\n",
" [ 74, 42, 10],\n",
" [ 58, 32, 13],\n",
" [128, 109, 105]],\n",
"\n",
" [[130, 124, 130],\n",
" [129, 122, 129],\n",
" [131, 124, 131],\n",
" [131, 124, 131],\n",
" ...,\n",
" [ 81, 46, 7],\n",
" [ 77, 43, 10],\n",
" [ 59, 32, 13],\n",
" [130, 109, 105]],\n",
"\n",
" [[145, 139, 145],\n",
" [132, 125, 132],\n",
" [132, 125, 132],\n",
" [129, 122, 129],\n",
" ...,\n",
" [ 95, 58, 17],\n",
" [ 85, 50, 15],\n",
" [ 66, 35, 15],\n",
" [133, 110, 106]]],\n",
"\n",
"\n",
" [[[214, 203, 197],\n",
" [216, 205, 199],\n",
" [216, 205, 199],\n",
" [216, 205, 199],\n",
" ...,\n",
" [200, 186, 183],\n",
" [200, 186, 183],\n",
" [200, 186, 182],\n",
" [199, 186, 180]],\n",
"\n",
" [[214, 203, 197],\n",
" [216, 205, 199],\n",
" [216, 205, 199],\n",
" [216, 205, 199],\n",
" ...,\n",
" [200, 186, 183],\n",
" [200, 186, 183],\n",
" [200, 186, 182],\n",
" [199, 186, 180]],\n",
"\n",
" [[214, 203, 197],\n",
" [216, 205, 199],\n",
" [216, 205, 199],\n",
" [216, 205, 199],\n",
" ...,\n",
" [200, 186, 183],\n",
" [200, 186, 183],\n",
" [200, 186, 182],\n",
" [199, 186, 180]],\n",
"\n",
" [[214, 203, 197],\n",
" [216, 205, 199],\n",
" [216, 205, 199],\n",
" [216, 205, 199],\n",
" ...,\n",
" [200, 186, 183],\n",
" [200, 186, 183],\n",
" [200, 186, 182],\n",
" [199, 186, 180]],\n",
"\n",
" ...,\n",
"\n",
" [[198, 153, 111],\n",
" [196, 151, 109],\n",
" [199, 152, 110],\n",
" [199, 152, 108],\n",
" ...,\n",
" [224, 229, 235],\n",
" [224, 229, 235],\n",
" [224, 229, 235],\n",
" [224, 227, 234]],\n",
"\n",
" [[199, 154, 112],\n",
" [196, 151, 108],\n",
" [200, 153, 109],\n",
" [197, 150, 106],\n",
" ...,\n",
" [224, 229, 235],\n",
" [224, 229, 235],\n",
" [224, 229, 235],\n",
" [224, 227, 234]],\n",
"\n",
" [[199, 154, 111],\n",
" [197, 152, 108],\n",
" [198, 151, 107],\n",
" [191, 144, 99],\n",
" ...,\n",
" [224, 229, 235],\n",
" [224, 229, 235],\n",
" [224, 229, 235],\n",
" [224, 227, 234]],\n",
"\n",
" [[197, 153, 108],\n",
" [197, 152, 107],\n",
" [197, 150, 105],\n",
" [196, 149, 103],\n",
" ...,\n",
" [224, 229, 235],\n",
" [224, 229, 235],\n",
" [224, 229, 235],\n",
" [224, 227, 234]]],\n",
"\n",
"\n",
" [[[232, 234, 233],\n",
" [232, 234, 233],\n",
" [232, 234, 233],\n",
" [231, 233, 230],\n",
" ...,\n",
" [164, 158, 142],\n",
" [163, 157, 141],\n",
" [163, 157, 141],\n",
" [163, 157, 141]],\n",
"\n",
" [[232, 234, 233],\n",
" [232, 234, 233],\n",
" [232, 234, 233],\n",
" [231, 233, 230],\n",
" ...,\n",
" [164, 158, 142],\n",
" [163, 157, 141],\n",
" [163, 157, 141],\n",
" [163, 157, 141]],\n",
"\n",
" [[232, 234, 233],\n",
" [232, 234, 233],\n",
" [232, 234, 233],\n",
" [231, 233, 230],\n",
" ...,\n",
" [164, 158, 142],\n",
" [163, 157, 141],\n",
" [163, 157, 141],\n",
" [163, 157, 141]],\n",
"\n",
" [[232, 234, 233],\n",
" [232, 234, 233],\n",
" [232, 234, 233],\n",
" [231, 233, 230],\n",
" ...,\n",
" [164, 158, 142],\n",
" [163, 157, 141],\n",
" [163, 157, 141],\n",
" [163, 157, 141]],\n",
"\n",
" ...,\n",
"\n",
" [[113, 104, 99],\n",
" [117, 108, 103],\n",
" [116, 107, 101],\n",
" [109, 105, 94],\n",
" ...,\n",
" [124, 85, 53],\n",
" [121, 84, 51],\n",
" [111, 81, 46],\n",
" [ 94, 69, 41]],\n",
"\n",
" [[114, 105, 100],\n",
" [114, 105, 100],\n",
" [110, 101, 96],\n",
" [105, 101, 92],\n",
" ...,\n",
" [119, 81, 43],\n",
" [114, 77, 39],\n",
" [111, 79, 42],\n",
" [ 94, 68, 38]],\n",
"\n",
" [[114, 105, 100],\n",
" [110, 101, 96],\n",
" [107, 99, 93],\n",
" [167, 164, 158],\n",
" ...,\n",
" [125, 88, 44],\n",
" [116, 79, 36],\n",
" [115, 82, 43],\n",
" [ 99, 72, 39]],\n",
"\n",
" [[111, 102, 97],\n",
" [107, 98, 93],\n",
" [165, 157, 152],\n",
" [213, 211, 207],\n",
" ...,\n",
" [128, 91, 46],\n",
" [130, 93, 48],\n",
" [124, 90, 48],\n",
" [111, 83, 48]]]], dtype=uint8)"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 70,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# Let's drop the image dataset, this time \"only\" with Wide, Deep_Dense and Deep_Text\n",
"wd_dataset_airbnb['train'].pop('deep_img')\n",
"wd_dataset_airbnb['valid'].pop('deep_img')\n",
"wd_dataset_airbnb['test'].pop('deep_img')"
"del wd_dataset_airbnb['train']['deep_img']\n",
"del wd_dataset_airbnb['valid']['deep_img']\n",
"del wd_dataset_airbnb['test']['deep_img']"
]
},
{
......@@ -1892,7 +1220,7 @@
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": 71,
"metadata": {},
"outputs": [],
"source": [
......@@ -1931,7 +1259,7 @@
},
{
"cell_type": "code",
"execution_count": 37,
"execution_count": 72,
"metadata": {},
"outputs": [],
"source": [
......@@ -1942,7 +1270,7 @@
},
{
"cell_type": "code",
"execution_count": 38,
"execution_count": 73,
"metadata": {},
"outputs": [
{
......@@ -1955,11 +1283,11 @@
" (deep_dense): DeepDense(\n",
" (emb_layer_neighbourhood_cleansed): Embedding(33, 64)\n",
" (emb_layer_bathrooms_catg): Embedding(3, 16)\n",
" (emb_layer_host_listings_count_catg): Embedding(4, 16)\n",
" (emb_layer_minimum_nights_catg): Embedding(3, 16)\n",
" (emb_layer_beds_catg): Embedding(4, 16)\n",
" (emb_layer_bedrooms_catg): Embedding(4, 16)\n",
" (emb_layer_guests_included_catg): Embedding(3, 16)\n",
" (emb_layer_beds_catg): Embedding(4, 16)\n",
" (emb_layer_minimum_nights_catg): Embedding(3, 16)\n",
" (emb_layer_host_listings_count_catg): Embedding(4, 16)\n",
" (emb_layer_accommodates_catg): Embedding(3, 16)\n",
" (dense): Sequential(\n",
" (dense_layer_0): Sequential(\n",
......@@ -1977,14 +1305,14 @@
" )\n",
" (deep_text): DeepText(\n",
" (embedding_dropout): Dropout2d(p=0.1)\n",
" (embedding): Embedding(7208, 300, padding_idx=1)\n",
" (embedding): Embedding(7097, 300, padding_idx=1)\n",
" (rnn): GRU(300, 64, num_layers=3, batch_first=True, dropout=0.5)\n",
" (dtlinear): Linear(in_features=64, out_features=3, bias=True)\n",
" )\n",
")"
]
},
"execution_count": 38,
"execution_count": 73,
"metadata": {},
"output_type": "execute_result"
}
......@@ -2002,7 +1330,7 @@
},
{
"cell_type": "code",
"execution_count": 39,
"execution_count": 74,
"metadata": {},
"outputs": [],
"source": [
......@@ -2012,7 +1340,7 @@
},
{
"cell_type": "code",
"execution_count": 40,
"execution_count": 75,
"metadata": {},
"outputs": [],
"source": [
......@@ -2021,7 +1349,7 @@
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": 76,
"metadata": {},
"outputs": [],
"source": [
......@@ -2030,7 +1358,7 @@
},
{
"cell_type": "code",
"execution_count": 42,
"execution_count": 77,
"metadata": {},
"outputs": [],
"source": [
......@@ -2047,7 +1375,7 @@
},
{
"cell_type": "code",
"execution_count": 43,
"execution_count": 78,
"metadata": {},
"outputs": [
{
......@@ -2056,7 +1384,7 @@
"['wide', 'deep_dense', 'deep_text', 'target']"
]
},
"execution_count": 43,
"execution_count": 78,
"metadata": {},
"output_type": "execute_result"
}
......@@ -2067,23 +1395,23 @@
},
{
"cell_type": "code",
"execution_count": 44,
"execution_count": 79,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████| 24/24 [00:03<00:00, 7.01it/s, acc=0.512, loss=1.01]\n",
"valid: 100%|██████████| 8/8 [00:00<00:00, 38.83it/s, acc=0.552, loss=0.986]\n",
"epoch 2: 100%|██████████| 24/24 [00:03<00:00, 7.47it/s, acc=0.577, loss=0.96] \n",
"valid: 100%|██████████| 8/8 [00:00<00:00, 40.44it/s, acc=0.557, loss=0.967]\n",
"epoch 3: 100%|██████████| 24/24 [00:03<00:00, 7.74it/s, acc=0.61, loss=0.929] \n",
"valid: 100%|██████████| 8/8 [00:00<00:00, 41.64it/s, acc=0.567, loss=0.961]\n",
"epoch 4: 100%|██████████| 24/24 [00:03<00:00, 8.20it/s, acc=0.635, loss=0.911]\n",
"valid: 100%|██████████| 8/8 [00:00<00:00, 42.00it/s, acc=0.562, loss=0.973]\n",
"epoch 5: 100%|██████████| 24/24 [00:02<00:00, 8.62it/s, acc=0.646, loss=0.894]\n",
"valid: 100%|██████████| 8/8 [00:00<00:00, 45.36it/s, acc=0.536, loss=0.981]\n"
"epoch 1: 100%|██████████| 24/24 [00:03<00:00, 7.05it/s, acc=0.547, loss=0.985]\n",
"valid: 100%|██████████| 8/8 [00:00<00:00, 37.03it/s, acc=0.588, loss=0.949]\n",
"epoch 2: 100%|██████████| 24/24 [00:03<00:00, 7.96it/s, acc=0.601, loss=0.936]\n",
"valid: 100%|██████████| 8/8 [00:00<00:00, 41.34it/s, acc=0.58, loss=0.951] \n",
"epoch 3: 100%|██████████| 24/24 [00:02<00:00, 8.47it/s, acc=0.634, loss=0.904]\n",
"valid: 100%|██████████| 8/8 [00:00<00:00, 42.72it/s, acc=0.597, loss=0.936]\n",
"epoch 4: 100%|██████████| 24/24 [00:02<00:00, 8.47it/s, acc=0.663, loss=0.887]\n",
"valid: 100%|██████████| 8/8 [00:00<00:00, 43.66it/s, acc=0.581, loss=0.945]\n",
"epoch 5: 100%|██████████| 24/24 [00:02<00:00, 8.63it/s, acc=0.67, loss=0.876] \n",
"valid: 100%|██████████| 8/8 [00:00<00:00, 43.78it/s, acc=0.576, loss=0.943]\n"
]
}
],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册