提交 1ac0cd59 编写于 作者: H Hyo-kyun Park 提交者: hyokyun-park

clear all ouput from .ipynb

上级 d490ae12
......@@ -13,183 +13,9 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>workclass</th>\n",
" <th>fnlwgt</th>\n",
" <th>education</th>\n",
" <th>education_num</th>\n",
" <th>marital_status</th>\n",
" <th>occupation</th>\n",
" <th>relationship</th>\n",
" <th>race</th>\n",
" <th>gender</th>\n",
" <th>capital_gain</th>\n",
" <th>capital_loss</th>\n",
" <th>hours_per_week</th>\n",
" <th>native_country</th>\n",
" <th>income_bracket</th>\n",
" <th>income_label</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>39</td>\n",
" <td>State-gov</td>\n",
" <td>77516</td>\n",
" <td>Bachelors</td>\n",
" <td>13</td>\n",
" <td>Never-married</td>\n",
" <td>Adm-clerical</td>\n",
" <td>Not-in-family</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>2174</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>50</td>\n",
" <td>Self-emp-not-inc</td>\n",
" <td>83311</td>\n",
" <td>Bachelors</td>\n",
" <td>13</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Exec-managerial</td>\n",
" <td>Husband</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>13</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>38</td>\n",
" <td>Private</td>\n",
" <td>215646</td>\n",
" <td>HS-grad</td>\n",
" <td>9</td>\n",
" <td>Divorced</td>\n",
" <td>Handlers-cleaners</td>\n",
" <td>Not-in-family</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>53</td>\n",
" <td>Private</td>\n",
" <td>234721</td>\n",
" <td>11th</td>\n",
" <td>7</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Handlers-cleaners</td>\n",
" <td>Husband</td>\n",
" <td>Black</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>28</td>\n",
" <td>Private</td>\n",
" <td>338409</td>\n",
" <td>Bachelors</td>\n",
" <td>13</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Prof-specialty</td>\n",
" <td>Wife</td>\n",
" <td>Black</td>\n",
" <td>Female</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>Cuba</td>\n",
" <td>&lt;=50K</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" age workclass fnlwgt education education_num \\\n",
"0 39 State-gov 77516 Bachelors 13 \n",
"1 50 Self-emp-not-inc 83311 Bachelors 13 \n",
"2 38 Private 215646 HS-grad 9 \n",
"3 53 Private 234721 11th 7 \n",
"4 28 Private 338409 Bachelors 13 \n",
"\n",
" marital_status occupation relationship race gender \\\n",
"0 Never-married Adm-clerical Not-in-family White Male \n",
"1 Married-civ-spouse Exec-managerial Husband White Male \n",
"2 Divorced Handlers-cleaners Not-in-family White Male \n",
"3 Married-civ-spouse Handlers-cleaners Husband Black Male \n",
"4 Married-civ-spouse Prof-specialty Wife Black Female \n",
"\n",
" capital_gain capital_loss hours_per_week native_country income_bracket \\\n",
"0 2174 0 40 United-States <=50K \n",
"1 0 0 13 United-States <=50K \n",
"2 0 0 40 United-States <=50K \n",
"3 0 0 40 United-States <=50K \n",
"4 0 0 40 Cuba <=50K \n",
"\n",
" income_label \n",
"0 0 \n",
"1 0 \n",
"2 0 \n",
"3 0 \n",
"4 0 "
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"from __future__ import print_function\n",
"import pandas as pd\n",
......@@ -218,7 +44,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -243,7 +69,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -270,7 +96,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -298,25 +124,9 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0 Bachelors-Adm-clerical\n",
"1 Bachelors-Exec-managerial\n",
"2 HS-grad-Handlers-cleaners\n",
"3 11th-Handlers-cleaners\n",
"4 Bachelors-Prof-specialty\n",
"Name: education_occupation, dtype: object"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"df_tmp['education_occupation'].head()"
]
......@@ -339,7 +149,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -393,7 +203,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -425,7 +235,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -450,176 +260,36 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train_dataset(wide=array([[46, 50, 0, ..., 0, 0, 0],\n",
" [32, 45, 1, ..., 0, 0, 0],\n",
" [30, 30, 0, ..., 0, 0, 0],\n",
" ...,\n",
" [40, 40, 0, ..., 0, 0, 0],\n",
" [45, 37, 1, ..., 0, 0, 0],\n",
" [40, 45, 1, ..., 0, 0, 0]]), deep=array([[ 3. , 1. , 6. , ..., 0. ,\n",
" 0.53655844, 0.77292975],\n",
" [ 0. , 0. , 2. , ..., 0. ,\n",
" -0.48456647, 0.36942139],\n",
" [ 1. , 4. , 2. , ..., 0. ,\n",
" -0.63044146, -0.84110367],\n",
" ...,\n",
" [ 1. , 0. , 2. , ..., 0. ,\n",
" 0.09893348, -0.03408696],\n",
" [ 0. , 1. , 2. , ..., 0. ,\n",
" 0.46362095, -0.27619198],\n",
" [ 0. , 1. , 2. , ..., 0. ,\n",
" 0.09893348, 0.36942139]]), labels=array([1, 0, 0, ..., 0, 0, 0]))\n"
]
}
],
"outputs": [],
"source": [
"print(wd_dataset['train_dataset'])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[('education', 16, 10), ('relationship', 6, 8), ('native_country', 42, 12), ('workclass', 9, 10), ('occupation', 15, 10)]\n"
]
}
],
"outputs": [],
"source": [
"print(wd_dataset['embeddings_input'])"
]
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'education': 0, 'relationship': 1, 'workclass': 2, 'occupation': 3, 'native_country': 4, 'age': 5, 'hours_per_week': 6}\n"
]
}
],
"outputs": [],
"source": [
"print(wd_dataset['deep_column_idx'])"
]
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'education': {'Bachelors': 0,\n",
" 'HS-grad': 1,\n",
" '11th': 2,\n",
" 'Masters': 3,\n",
" '9th': 4,\n",
" 'Some-college': 5,\n",
" 'Assoc-acdm': 6,\n",
" 'Assoc-voc': 7,\n",
" '7th-8th': 8,\n",
" 'Doctorate': 9,\n",
" 'Prof-school': 10,\n",
" '5th-6th': 11,\n",
" '10th': 12,\n",
" '1st-4th': 13,\n",
" 'Preschool': 14,\n",
" '12th': 15},\n",
" 'relationship': {'Not-in-family': 0,\n",
" 'Husband': 1,\n",
" 'Wife': 2,\n",
" 'Own-child': 3,\n",
" 'Unmarried': 4,\n",
" 'Other-relative': 5},\n",
" 'native_country': {'United-States': 0,\n",
" 'Cuba': 1,\n",
" 'Jamaica': 2,\n",
" 'India': 3,\n",
" '?': 4,\n",
" 'Mexico': 5,\n",
" 'South': 6,\n",
" 'Puerto-Rico': 7,\n",
" 'Honduras': 8,\n",
" 'England': 9,\n",
" 'Canada': 10,\n",
" 'Germany': 11,\n",
" 'Iran': 12,\n",
" 'Philippines': 13,\n",
" 'Italy': 14,\n",
" 'Poland': 15,\n",
" 'Columbia': 16,\n",
" 'Cambodia': 17,\n",
" 'Thailand': 18,\n",
" 'Ecuador': 19,\n",
" 'Laos': 20,\n",
" 'Taiwan': 21,\n",
" 'Haiti': 22,\n",
" 'Portugal': 23,\n",
" 'Dominican-Republic': 24,\n",
" 'El-Salvador': 25,\n",
" 'France': 26,\n",
" 'Guatemala': 27,\n",
" 'China': 28,\n",
" 'Japan': 29,\n",
" 'Yugoslavia': 30,\n",
" 'Peru': 31,\n",
" 'Outlying-US(Guam-USVI-etc)': 32,\n",
" 'Scotland': 33,\n",
" 'Trinadad&Tobago': 34,\n",
" 'Greece': 35,\n",
" 'Nicaragua': 36,\n",
" 'Vietnam': 37,\n",
" 'Hong': 38,\n",
" 'Ireland': 39,\n",
" 'Hungary': 40,\n",
" 'Holand-Netherlands': 41},\n",
" 'workclass': {'State-gov': 0,\n",
" 'Self-emp-not-inc': 1,\n",
" 'Private': 2,\n",
" 'Federal-gov': 3,\n",
" 'Local-gov': 4,\n",
" '?': 5,\n",
" 'Self-emp-inc': 6,\n",
" 'Without-pay': 7,\n",
" 'Never-worked': 8},\n",
" 'occupation': {'Adm-clerical': 0,\n",
" 'Exec-managerial': 1,\n",
" 'Handlers-cleaners': 2,\n",
" 'Prof-specialty': 3,\n",
" 'Other-service': 4,\n",
" 'Sales': 5,\n",
" 'Craft-repair': 6,\n",
" 'Transport-moving': 7,\n",
" 'Farming-fishing': 8,\n",
" 'Machine-op-inspct': 9,\n",
" 'Tech-support': 10,\n",
" '?': 11,\n",
" 'Protective-serv': 12,\n",
" 'Armed-Forces': 13,\n",
" 'Priv-house-serv': 14}}"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"wd_dataset['encoding_dict']"
]
......@@ -633,7 +303,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......
因为 它太大了无法显示 source diff 。你可以改为 查看blob
......@@ -22,183 +22,9 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>workclass</th>\n",
" <th>fnlwgt</th>\n",
" <th>education</th>\n",
" <th>education_num</th>\n",
" <th>marital_status</th>\n",
" <th>occupation</th>\n",
" <th>relationship</th>\n",
" <th>race</th>\n",
" <th>gender</th>\n",
" <th>capital_gain</th>\n",
" <th>capital_loss</th>\n",
" <th>hours_per_week</th>\n",
" <th>native_country</th>\n",
" <th>income_bracket</th>\n",
" <th>income_label</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>39</td>\n",
" <td>State-gov</td>\n",
" <td>77516</td>\n",
" <td>Bachelors</td>\n",
" <td>13</td>\n",
" <td>Never-married</td>\n",
" <td>Adm-clerical</td>\n",
" <td>Not-in-family</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>2174</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>50</td>\n",
" <td>Self-emp-not-inc</td>\n",
" <td>83311</td>\n",
" <td>Bachelors</td>\n",
" <td>13</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Exec-managerial</td>\n",
" <td>Husband</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>13</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>38</td>\n",
" <td>Private</td>\n",
" <td>215646</td>\n",
" <td>HS-grad</td>\n",
" <td>9</td>\n",
" <td>Divorced</td>\n",
" <td>Handlers-cleaners</td>\n",
" <td>Not-in-family</td>\n",
" <td>White</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>53</td>\n",
" <td>Private</td>\n",
" <td>234721</td>\n",
" <td>11th</td>\n",
" <td>7</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Handlers-cleaners</td>\n",
" <td>Husband</td>\n",
" <td>Black</td>\n",
" <td>Male</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>United-States</td>\n",
" <td>&lt;=50K</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>28</td>\n",
" <td>Private</td>\n",
" <td>338409</td>\n",
" <td>Bachelors</td>\n",
" <td>13</td>\n",
" <td>Married-civ-spouse</td>\n",
" <td>Prof-specialty</td>\n",
" <td>Wife</td>\n",
" <td>Black</td>\n",
" <td>Female</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>40</td>\n",
" <td>Cuba</td>\n",
" <td>&lt;=50K</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" age workclass fnlwgt education education_num \\\n",
"0 39 State-gov 77516 Bachelors 13 \n",
"1 50 Self-emp-not-inc 83311 Bachelors 13 \n",
"2 38 Private 215646 HS-grad 9 \n",
"3 53 Private 234721 11th 7 \n",
"4 28 Private 338409 Bachelors 13 \n",
"\n",
" marital_status occupation relationship race gender \\\n",
"0 Never-married Adm-clerical Not-in-family White Male \n",
"1 Married-civ-spouse Exec-managerial Husband White Male \n",
"2 Divorced Handlers-cleaners Not-in-family White Male \n",
"3 Married-civ-spouse Handlers-cleaners Husband Black Male \n",
"4 Married-civ-spouse Prof-specialty Wife Black Female \n",
"\n",
" capital_gain capital_loss hours_per_week native_country income_bracket \\\n",
"0 2174 0 40 United-States <=50K \n",
"1 0 0 13 United-States <=50K \n",
"2 0 0 40 United-States <=50K \n",
"3 0 0 40 United-States <=50K \n",
"4 0 0 40 Cuba <=50K \n",
"\n",
" income_label \n",
"0 0 \n",
"1 0 \n",
"2 0 \n",
"3 0 \n",
"4 0 "
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"from __future__ import print_function\n",
"import pandas as pd\n",
......@@ -225,7 +51,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -252,7 +78,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -271,7 +97,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
......@@ -301,26 +127,9 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WideDeep(\n",
" (emb_layer_native_country): Embedding(42, 10)\n",
" (emb_layer_relationship): Embedding(6, 8)\n",
" (emb_layer_occupation): Embedding(15, 10)\n",
" (emb_layer_education): Embedding(16, 10)\n",
" (emb_layer_workclass): Embedding(9, 10)\n",
" (linear_1): Linear(in_features=50, out_features=100, bias=True)\n",
" (linear_2): Linear(in_features=100, out_features=50, bias=True)\n",
" (output): Linear(in_features=848, out_features=1, bias=True)\n",
")\n"
]
}
],
"outputs": [],
"source": [
"print(model)"
]
......@@ -334,27 +143,9 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1 of 10, Loss: 0.136, accuracy: 0.8246\n",
"Epoch 2 of 10, Loss: 0.106, accuracy: 0.8392\n",
"Epoch 3 of 10, Loss: 0.513, accuracy: 0.8421\n",
"Epoch 4 of 10, Loss: 0.345, accuracy: 0.8414\n",
"Epoch 5 of 10, Loss: 0.29, accuracy: 0.843\n",
"Epoch 6 of 10, Loss: 0.227, accuracy: 0.8443\n",
"Epoch 7 of 10, Loss: 0.426, accuracy: 0.845\n",
"Epoch 8 of 10, Loss: 0.183, accuracy: 0.8454\n",
"Epoch 9 of 10, Loss: 0.322, accuracy: 0.8461\n",
"Epoch 10 of 10, Loss: 0.246, accuracy: 0.8469\n",
"0.8382583771241384\n"
]
}
],
"outputs": [],
"source": [
"train_dataset = wd_dataset['train_dataset']\n",
"test_dataset = wd_dataset['test_dataset']\n",
......@@ -376,67 +167,9 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'Bachelors': array([-1.1927266 , 0.13337217, 0.751513 , -0.3854133 , -1.512503 ,\n",
" 0.43075648, 0.03185017, 0.2740599 , -1.3502986 , -0.51524764],\n",
" dtype=float32),\n",
" 'HS-grad': array([ 0.01510752, -0.41036212, -1.2737428 , -0.03190449, 0.30465913,\n",
" -0.4891645 , -0.35087353, 1.7667191 , 0.90333945, -0.42637545],\n",
" dtype=float32),\n",
" '11th': array([-1.3361819 , -1.0304003 , -0.7671982 , 1.1118906 , 0.6290409 ,\n",
" 0.09973534, -0.41261104, -0.79101914, 1.2672484 , 0.7189385 ],\n",
" dtype=float32),\n",
" 'Masters': array([ 0.5837133 , -1.3451334 , 0.9863935 , 0.35932744, -0.13541682,\n",
" 0.34770364, -0.8982047 , 0.4550249 , -1.326133 , -0.08214497],\n",
" dtype=float32),\n",
" '9th': array([ 0.00944321, -0.2883264 , 1.1186845 , 0.16699162, 0.20891678,\n",
" -2.222243 , 0.90257394, -2.499814 , 0.32215422, -0.02830464],\n",
" dtype=float32),\n",
" 'Some-college': array([ 0.11737815, -0.9354352 , -1.6950701 , -0.3879866 , -0.34800476,\n",
" 0.65498114, -1.0632497 , 1.2390918 , -1.3980893 , -1.5068939 ],\n",
" dtype=float32),\n",
" 'Assoc-acdm': array([-0.3248521 , 0.67525595, -0.7607256 , -1.688361 , -0.01024881,\n",
" -0.17185631, -1.5726321 , -0.33589116, -0.6568722 , -0.83356154],\n",
" dtype=float32),\n",
" 'Assoc-voc': array([ 0.11711311, 1.6658193 , 0.2525636 , -1.7053522 , 0.11374688,\n",
" 0.69635576, 0.39209226, 0.55386406, 1.4460421 , -0.4076955 ],\n",
" dtype=float32),\n",
" '7th-8th': array([ 0.8109543 , -0.9696295 , -1.1880634 , -2.673678 , 1.387889 ,\n",
" 0.03207216, 0.28635803, 0.32005164, -0.14126171, -0.12705447],\n",
" dtype=float32),\n",
" 'Doctorate': array([ 2.5456786 , 0.9495662 , -0.65327275, 0.63417935, -1.4665067 ,\n",
" -1.0520831 , -0.8822009 , 1.7168643 , 1.3397688 , 1.0705113 ],\n",
" dtype=float32),\n",
" 'Prof-school': array([-0.3236308 , 0.36975744, 0.79298687, -0.24033554, 0.8012961 ,\n",
" -0.38213903, 0.20259416, -0.30737472, -2.190927 , 0.47054496],\n",
" dtype=float32),\n",
" '5th-6th': array([-0.71498626, -1.3042029 , 0.04956457, -0.20074964, 0.85997975,\n",
" 2.4887364 , 0.9329344 , -0.33221987, -0.37141427, 1.9041626 ],\n",
" dtype=float32),\n",
" '10th': array([-0.5197668 , -1.2800047 , 1.5472891 , -1.141539 , 0.00724531,\n",
" 1.3354197 , 1.4840577 , 0.9995618 , -0.03808165, -1.1237134 ],\n",
" dtype=float32),\n",
" '1st-4th': array([ 1.1701114 , -1.0981313 , -1.5367142 , 0.16519445, -0.0972092 ,\n",
" -0.3711076 , 0.9954778 , -0.94091356, -0.75837976, -1.9332327 ],\n",
" dtype=float32),\n",
" 'Preschool': array([-0.7313262 , -0.56184304, 0.30143896, 0.8417214 , 2.0694172 ,\n",
" -1.2695692 , 1.461705 , -0.6897159 , 1.6769298 , 0.55851436],\n",
" dtype=float32),\n",
" '12th': array([-1.2723062 , -0.27862272, -0.3878713 , -0.9044023 , -0.00804312,\n",
" -1.1498355 , 1.0327121 , 0.29477796, 0.2951289 , 0.96019965],\n",
" dtype=float32)}"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"model.get_embeddings('education')"
]
......@@ -457,28 +190,9 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WideDeep(\n",
" (emb_layer_native_country): Embedding(42, 10)\n",
" (emb_layer_relationship): Embedding(6, 10)\n",
" (emb_layer_occupation): Embedding(15, 10)\n",
" (emb_layer_education): Embedding(16, 10)\n",
" (emb_layer_workclass): Embedding(9, 10)\n",
" (linear_1): Linear(in_features=51, out_features=100, bias=True)\n",
" (linear_1_drop): Dropout(p=0.5)\n",
" (linear_2): Linear(in_features=100, out_features=50, bias=True)\n",
" (linear_2_drop): Dropout(p=0.2)\n",
" (output): Linear(in_features=847, out_features=3, bias=True)\n",
")\n"
]
}
],
"outputs": [],
"source": [
"# Let's define age groups\n",
"age_groups = [0, 25, 50, 90]\n",
......@@ -513,34 +227,9 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1 of 10, Loss: 1.131, accuracy: 0.6735\n",
"Epoch 2 of 10, Loss: 0.846, accuracy: 0.6843\n",
"Epoch 3 of 10, Loss: 0.902, accuracy: 0.686\n",
"Epoch 4 of 10, Loss: 0.806, accuracy: 0.691\n",
"Epoch 5 of 10, Loss: 1.015, accuracy: 0.6931\n",
"Epoch 6 of 10, Loss: 0.77, accuracy: 0.694\n",
"Epoch 7 of 10, Loss: 0.868, accuracy: 0.6962\n",
"Epoch 8 of 10, Loss: 0.808, accuracy: 0.6973\n",
"Epoch 9 of 10, Loss: 0.977, accuracy: 0.6972\n",
"Epoch 10 of 10, Loss: 0.851, accuracy: 0.6968\n",
"\n",
" [[9.9808323e-01 1.9167198e-03 1.2708337e-07]\n",
" [1.8705309e-12 1.0000000e+00 1.0048575e-09]\n",
" [2.1682714e-08 9.9999905e-01 9.1604261e-07]\n",
" ...\n",
" [1.0082698e-03 9.6010476e-01 3.8887005e-02]\n",
" [2.4448596e-07 9.9994826e-01 5.1442345e-05]\n",
" [6.8863249e-01 3.0600473e-01 5.3628702e-03]]\n"
]
}
],
"outputs": [],
"source": [
"train_dataset = wd_dataset['train_dataset']\n",
"model.fit(dataset=train_dataset, n_epochs=10, batch_size=64)\n",
......@@ -553,20 +242,9 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
" 0.7318796365288384\n",
"\n",
" 0.7006756295639118\n"
]
}
],
"outputs": [],
"source": [
"from sklearn.metrics import f1_score, accuracy_score\n",
"\n",
......@@ -586,28 +264,9 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WideDeep(\n",
" (emb_layer_native_country): Embedding(42, 10)\n",
" (emb_layer_relationship): Embedding(6, 8)\n",
" (emb_layer_occupation): Embedding(15, 10)\n",
" (emb_layer_education): Embedding(16, 10)\n",
" (emb_layer_workclass): Embedding(9, 10)\n",
" (linear_1): Linear(in_features=49, out_features=100, bias=True)\n",
" (linear_1_drop): Dropout(p=0.5)\n",
" (linear_2): Linear(in_features=100, out_features=50, bias=True)\n",
" (linear_2_drop): Dropout(p=0.2)\n",
" (output): Linear(in_features=847, out_features=1, bias=True)\n",
")\n"
]
}
],
"outputs": [],
"source": [
"# Set the experiment\n",
"wide_cols = ['hours_per_week','education', 'relationship','workclass',\n",
......@@ -636,28 +295,9 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1 of 10, Loss: 78.643\n",
"Epoch 2 of 10, Loss: 144.536\n",
"Epoch 3 of 10, Loss: 53.688\n",
"Epoch 4 of 10, Loss: 177.116\n",
"Epoch 5 of 10, Loss: 198.454\n",
"Epoch 6 of 10, Loss: 90.156\n",
"Epoch 7 of 10, Loss: 44.655\n",
"Epoch 8 of 10, Loss: 205.163\n",
"Epoch 9 of 10, Loss: 246.263\n",
"Epoch 10 of 10, Loss: 102.745\n",
"\n",
" RMSE: 11.113094884527147\n"
]
}
],
"outputs": [],
"source": [
"train_dataset = wd_dataset['train_dataset']\n",
"model.fit(dataset=train_dataset, n_epochs=10, batch_size=64)\n",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册