diff --git a/examples/notebooks/19_wide_and_deep_for_recsys_pt1.ipynb b/examples/notebooks/19_wide_and_deep_for_recsys_pt1.ipynb index 0d7434373c9601f0bcac026b2bec0bb5a4425fc6..b6a57e6dcd07dd90e3b4142246e89a57026e476b 100644 --- a/examples/notebooks/19_wide_and_deep_for_recsys_pt1.ipynb +++ b/examples/notebooks/19_wide_and_deep_for_recsys_pt1.ipynb @@ -23,7 +23,9 @@ "import warnings\n", "\n", "import pandas as pd\n", - "from sklearn.model_selection import train_test_split" + "from sklearn.model_selection import train_test_split\n", + "\n", + "from pytorch_widedeep.datasets import load_movielens100k" ] }, { @@ -43,44 +45,31 @@ "metadata": {}, "outputs": [], "source": [ - "raw_data_path = Path(\"~/ml_projects/wide_deep_learning_for_recsys/ml-100k\")\n", - "\n", "save_path = Path(\"prepared_data\")\n", "if not save_path.exists():\n", " save_path.mkdir(parents=True, exist_ok=True)" ] }, { - "cell_type": "markdown", - "id": "929a9712", + "cell_type": "code", + "execution_count": 4, + "id": "5de7a941", "metadata": {}, + "outputs": [], "source": [ - "Let's first start by loading the interactions, user and item data" + "data, users, items = load_movielens100k(as_frame=True)" ] }, { "cell_type": "code", - "execution_count": 4, - "id": "38de36ff", + "execution_count": 5, + "id": "7a288aee", "metadata": {}, "outputs": [], "source": [ - "# Load the Ratings/Interaction (triplets (user, item, rating) plus timestamp)\n", - "data = pd.read_csv(raw_data_path / \"u.data\", sep=\"\\t\", header=None)\n", - "data.columns = [\"user_id\", \"movie_id\", \"rating\", \"timestamp\"]\n", - "\n", - "# Load the User features\n", - "users = pd.read_csv(raw_data_path / \"u.user\", sep=\"|\", encoding=\"latin-1\", header=None)\n", - "users.columns = [\"user_id\", \"age\", \"gender\", \"occupation\", \"zip_code\"]\n", - "\n", - "# Load the Item features\n", - "items = pd.read_csv(raw_data_path / \"u.item\", sep=\"|\", encoding=\"latin-1\", header=None)\n", - "items.columns = [\n", - " \"movie_id\",\n", - " \"movie_title\",\n", - " \"release_date\",\n", - " \"video_release_date\",\n", - " \"IMDb_URL\",\n", + "# Alternatively, as specified in the docs: 'The last 19 fields are the genres' so:\n", + "# list_of_genres = items.columns.tolist()[-19:]\n", + "list_of_genres = [\n", " \"unknown\",\n", " \"Action\",\n", " \"Adventure\",\n", @@ -103,9 +92,17 @@ "]" ] }, + { + "cell_type": "markdown", + "id": "929a9712", + "metadata": {}, + "source": [ + "Let's first start by loading the interactions, user and item data" + ] + }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "id": "f4c09273", "metadata": {}, "outputs": [ @@ -185,7 +182,7 @@ "4 166 346 1 886397596" ] }, - "execution_count": 5, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -196,7 +193,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "id": "18c3faa0", "metadata": {}, "outputs": [ @@ -282,7 +279,7 @@ "4 5 33 F other 15213" ] }, - "execution_count": 6, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -293,7 +290,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "id": "1dbad7b1", "metadata": {}, "outputs": [ @@ -499,55 +496,13 @@ "[5 rows x 24 columns]" ] }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "items.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "7b1ce069", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['unknown',\n", - " 'Action',\n", - " 'Adventure',\n", - " 'Animation',\n", - " \"Children's\",\n", - " 'Comedy',\n", - " 'Crime',\n", - " 'Documentary',\n", - " 'Drama',\n", - " 'Fantasy',\n", - " 'Film-Noir',\n", - " 'Horror',\n", - " 'Musical',\n", - " 'Mystery',\n", - " 'Romance',\n", - " 'Sci-Fi',\n", - " 'Thriller',\n", - " 'War',\n", - " 'Western']" - ] - }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "list_of_genres = pd.read_csv(\n", - " raw_data_path / \"u.genre\", sep=\"|\", header=None, usecols=[0]\n", - ")[0].tolist()\n", - "list_of_genres" + "items.head()" ] }, { @@ -1899,16 +1854,7 @@ "execution_count": 29, "id": "68555183", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/javierrodriguezzaurin/Projects/pytorch-widedeep/pytorch_widedeep/preprocessing/tab_preprocessor.py:309: UserWarning: Continuous columns will not be normalised\n", - " warnings.warn(\"Continuous columns will not be normalised\")\n" - ] - } - ], + "outputs": [], "source": [ "X_train_tab = tab_preprocessor.fit_transform(X_train.fillna(0))\n", "X_test_tab = tab_preprocessor.transform(X_test.fillna(0))" @@ -2291,16 +2237,16 @@ "name": "stderr", "output_type": "stream", "text": [ - "epoch 1: 100%|██████████████████████████████| 149/149 [00:16<00:00, 9.08it/s, loss=6.66]\n", - "valid: 100%|██████████████████████████████████| 38/38 [00:01<00:00, 23.53it/s, loss=6.61]\n", - "epoch 2: 100%|██████████████████████████████| 149/149 [00:16<00:00, 9.11it/s, loss=5.99]\n", - "valid: 100%|██████████████████████████████████| 38/38 [00:01<00:00, 23.07it/s, loss=6.55]\n", - "epoch 3: 100%|██████████████████████████████| 149/149 [00:16<00:00, 9.11it/s, loss=5.67]\n", - "valid: 100%|██████████████████████████████████| 38/38 [00:01<00:00, 22.89it/s, loss=6.55]\n", - "epoch 4: 100%|██████████████████████████████| 149/149 [00:16<00:00, 8.81it/s, loss=5.43]\n", - "valid: 100%|██████████████████████████████████| 38/38 [00:01<00:00, 21.43it/s, loss=6.57]\n", - "epoch 5: 100%|██████████████████████████████| 149/149 [00:16<00:00, 8.79it/s, loss=5.24]\n", - "valid: 100%|███████████████████████████████████| 38/38 [00:01<00:00, 22.39it/s, loss=6.6]\n" + "epoch 1: 100%|█████████████████████████████████████████████████████████████████████████| 149/149 [00:16<00:00, 8.82it/s, loss=6.66]\n", + "valid: 100%|█████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 21.14it/s, loss=6.61]\n", + "epoch 2: 100%|█████████████████████████████████████████████████████████████████████████| 149/149 [00:17<00:00, 8.53it/s, loss=5.98]\n", + "valid: 100%|█████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 23.20it/s, loss=6.53]\n", + "epoch 3: 100%|█████████████████████████████████████████████████████████████████████████| 149/149 [00:17<00:00, 8.61it/s, loss=5.66]\n", + "valid: 100%|█████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 23.16it/s, loss=6.54]\n", + "epoch 4: 100%|█████████████████████████████████████████████████████████████████████████| 149/149 [00:17<00:00, 8.76it/s, loss=5.43]\n", + "valid: 100%|█████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 22.03it/s, loss=6.56]\n", + "epoch 5: 100%|█████████████████████████████████████████████████████████████████████████| 149/149 [00:17<00:00, 8.28it/s, loss=5.23]\n", + "valid: 100%|█████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 22.60it/s, loss=6.59]\n" ] } ],