提交 801c597a 编写于 作者: J Javier

Adjusted the notebooks to show how one can use the 'load_movielens100k' function in the library

上级 d30203a0
......@@ -23,7 +23,9 @@
"import warnings\n",
"\n",
"import pandas as pd\n",
"from sklearn.model_selection import train_test_split"
"from sklearn.model_selection import train_test_split\n",
"\n",
"from pytorch_widedeep.datasets import load_movielens100k"
]
},
{
......@@ -43,44 +45,31 @@
"metadata": {},
"outputs": [],
"source": [
"raw_data_path = Path(\"~/ml_projects/wide_deep_learning_for_recsys/ml-100k\")\n",
"\n",
"save_path = Path(\"prepared_data\")\n",
"if not save_path.exists():\n",
" save_path.mkdir(parents=True, exist_ok=True)"
]
},
{
"cell_type": "markdown",
"id": "929a9712",
"cell_type": "code",
"execution_count": 4,
"id": "5de7a941",
"metadata": {},
"outputs": [],
"source": [
"Let's first start by loading the interactions, user and item data"
"data, users, items = load_movielens100k(as_frame=True)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "38de36ff",
"execution_count": 5,
"id": "7a288aee",
"metadata": {},
"outputs": [],
"source": [
"# Load the Ratings/Interaction (triplets (user, item, rating) plus timestamp)\n",
"data = pd.read_csv(raw_data_path / \"u.data\", sep=\"\\t\", header=None)\n",
"data.columns = [\"user_id\", \"movie_id\", \"rating\", \"timestamp\"]\n",
"\n",
"# Load the User features\n",
"users = pd.read_csv(raw_data_path / \"u.user\", sep=\"|\", encoding=\"latin-1\", header=None)\n",
"users.columns = [\"user_id\", \"age\", \"gender\", \"occupation\", \"zip_code\"]\n",
"\n",
"# Load the Item features\n",
"items = pd.read_csv(raw_data_path / \"u.item\", sep=\"|\", encoding=\"latin-1\", header=None)\n",
"items.columns = [\n",
" \"movie_id\",\n",
" \"movie_title\",\n",
" \"release_date\",\n",
" \"video_release_date\",\n",
" \"IMDb_URL\",\n",
"# Alternatively, as specified in the docs: 'The last 19 fields are the genres' so:\n",
"# list_of_genres = items.columns.tolist()[-19:]\n",
"list_of_genres = [\n",
" \"unknown\",\n",
" \"Action\",\n",
" \"Adventure\",\n",
......@@ -103,9 +92,17 @@
"]"
]
},
{
"cell_type": "markdown",
"id": "929a9712",
"metadata": {},
"source": [
"Let's first start by loading the interactions, user and item data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"id": "f4c09273",
"metadata": {},
"outputs": [
......@@ -185,7 +182,7 @@
"4 166 346 1 886397596"
]
},
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
......@@ -196,7 +193,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"id": "18c3faa0",
"metadata": {},
"outputs": [
......@@ -282,7 +279,7 @@
"4 5 33 F other 15213"
]
},
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
......@@ -293,7 +290,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"id": "1dbad7b1",
"metadata": {},
"outputs": [
......@@ -499,55 +496,13 @@
"[5 rows x 24 columns]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"items.head()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "7b1ce069",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['unknown',\n",
" 'Action',\n",
" 'Adventure',\n",
" 'Animation',\n",
" \"Children's\",\n",
" 'Comedy',\n",
" 'Crime',\n",
" 'Documentary',\n",
" 'Drama',\n",
" 'Fantasy',\n",
" 'Film-Noir',\n",
" 'Horror',\n",
" 'Musical',\n",
" 'Mystery',\n",
" 'Romance',\n",
" 'Sci-Fi',\n",
" 'Thriller',\n",
" 'War',\n",
" 'Western']"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list_of_genres = pd.read_csv(\n",
" raw_data_path / \"u.genre\", sep=\"|\", header=None, usecols=[0]\n",
")[0].tolist()\n",
"list_of_genres"
"items.head()"
]
},
{
......@@ -1899,16 +1854,7 @@
"execution_count": 29,
"id": "68555183",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/javierrodriguezzaurin/Projects/pytorch-widedeep/pytorch_widedeep/preprocessing/tab_preprocessor.py:309: UserWarning: Continuous columns will not be normalised\n",
" warnings.warn(\"Continuous columns will not be normalised\")\n"
]
}
],
"outputs": [],
"source": [
"X_train_tab = tab_preprocessor.fit_transform(X_train.fillna(0))\n",
"X_test_tab = tab_preprocessor.transform(X_test.fillna(0))"
......@@ -2291,16 +2237,16 @@
"name": "stderr",
"output_type": "stream",
"text": [
"epoch 1: 100%|██████████████████████████████| 149/149 [00:16<00:00, 9.08it/s, loss=6.66]\n",
"valid: 100%|██████████████████████████████████| 38/38 [00:01<00:00, 23.53it/s, loss=6.61]\n",
"epoch 2: 100%|██████████████████████████████| 149/149 [00:16<00:00, 9.11it/s, loss=5.99]\n",
"valid: 100%|██████████████████████████████████| 38/38 [00:01<00:00, 23.07it/s, loss=6.55]\n",
"epoch 3: 100%|██████████████████████████████| 149/149 [00:16<00:00, 9.11it/s, loss=5.67]\n",
"valid: 100%|██████████████████████████████████| 38/38 [00:01<00:00, 22.89it/s, loss=6.55]\n",
"epoch 4: 100%|██████████████████████████████| 149/149 [00:16<00:00, 8.81it/s, loss=5.43]\n",
"valid: 100%|██████████████████████████████████| 38/38 [00:01<00:00, 21.43it/s, loss=6.57]\n",
"epoch 5: 100%|██████████████████████████████| 149/149 [00:16<00:00, 8.79it/s, loss=5.24]\n",
"valid: 100%|███████████████████████████████████| 38/38 [00:01<00:00, 22.39it/s, loss=6.6]\n"
"epoch 1: 100%|█████████████████████████████████████████████████████████████████████████| 149/149 [00:16<00:00, 8.82it/s, loss=6.66]\n",
"valid: 100%|█████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 21.14it/s, loss=6.61]\n",
"epoch 2: 100%|█████████████████████████████████████████████████████████████████████████| 149/149 [00:17<00:00, 8.53it/s, loss=5.98]\n",
"valid: 100%|█████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 23.20it/s, loss=6.53]\n",
"epoch 3: 100%|█████████████████████████████████████████████████████████████████████████| 149/149 [00:17<00:00, 8.61it/s, loss=5.66]\n",
"valid: 100%|█████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 23.16it/s, loss=6.54]\n",
"epoch 4: 100%|█████████████████████████████████████████████████████████████████████████| 149/149 [00:17<00:00, 8.76it/s, loss=5.43]\n",
"valid: 100%|█████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 22.03it/s, loss=6.56]\n",
"epoch 5: 100%|█████████████████████████████████████████████████████████████████████████| 149/149 [00:17<00:00, 8.28it/s, loss=5.23]\n",
"valid: 100%|█████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 22.60it/s, loss=6.59]\n"
]
}
],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册