{ "cells": [ { "cell_type": "markdown", "id": "d298e185", "metadata": {}, "source": [ "The goal of this, and the companion (part 2) notebooks is to illustrate how one could use this library in the context of recommendation systems. In particular, this notebook and the scripts at the `wide_deep_for_recsys` dir are a response to this [issue](https://github.com/jrzaurin/pytorch-widedeep/issues/133). Therefore, we will use the [Kaggle notebook](https://www.kaggle.com/code/matanivanov/wide-deep-learning-for-recsys-with-pytorch) referred in that issue here.\n", "\n", "In order to keep the length of the notebook tractable, we will split this exercise in 2. In this first notebook we will prepare the [data](https://www.kaggle.com/datasets/prajitdatta/movielens-100k-dataset) in almost the exact same way as it is done in the Kaggle notebook and also show how one could use `pytorch-widedeep` to build a model almost identical to the one in that notebook. \n", "\n", "In a second notebook, we will show how one could use this library to implement other models, still following the same problem formulation." ] }, { "cell_type": "code", "execution_count": 1, "id": "ebd9980d", "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "import warnings\n", "\n", "import pandas as pd\n", "from sklearn.model_selection import train_test_split\n", "\n", "from pytorch_widedeep.datasets import load_movielens100k" ] }, { "cell_type": "code", "execution_count": 2, "id": "7cd76bce", "metadata": {}, "outputs": [], "source": [ "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "0aed611e", "metadata": {}, "outputs": [], "source": [ "save_path = Path(\"prepared_data\")\n", "if not save_path.exists():\n", " save_path.mkdir(parents=True, exist_ok=True)" ] }, { "cell_type": "code", "execution_count": 4, "id": "5de7a941", "metadata": {}, "outputs": [], "source": [ "data, users, items = load_movielens100k(as_frame=True)" ] }, { "cell_type": "code", "execution_count": 5, "id": "7a288aee", "metadata": {}, "outputs": [], "source": [ "# Alternatively, as specified in the docs: 'The last 19 fields are the genres' so:\n", "# list_of_genres = items.columns.tolist()[-19:]\n", "list_of_genres = [\n", " \"unknown\",\n", " \"Action\",\n", " \"Adventure\",\n", " \"Animation\",\n", " \"Children's\",\n", " \"Comedy\",\n", " \"Crime\",\n", " \"Documentary\",\n", " \"Drama\",\n", " \"Fantasy\",\n", " \"Film-Noir\",\n", " \"Horror\",\n", " \"Musical\",\n", " \"Mystery\",\n", " \"Romance\",\n", " \"Sci-Fi\",\n", " \"Thriller\",\n", " \"War\",\n", " \"Western\",\n", "]" ] }, { "cell_type": "markdown", "id": "929a9712", "metadata": {}, "source": [ "Let's first start by loading the interactions, user and item data" ] }, { "cell_type": "code", "execution_count": 6, "id": "f4c09273", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
user_idmovie_idratingtimestamp
01962423881250949
11863023891717742
2223771878887116
3244512880606923
41663461886397596
\n", "
" ], "text/plain": [ " user_id movie_id rating timestamp\n", "0 196 242 3 881250949\n", "1 186 302 3 891717742\n", "2 22 377 1 878887116\n", "3 244 51 2 880606923\n", "4 166 346 1 886397596" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data.head()" ] }, { "cell_type": "code", "execution_count": 7, "id": "18c3faa0", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
user_idagegenderoccupationzip_code
0124Mtechnician85711
1253Fother94043
2323Mwriter32067
3424Mtechnician43537
4533Fother15213
\n", "
" ], "text/plain": [ " user_id age gender occupation zip_code\n", "0 1 24 M technician 85711\n", "1 2 53 F other 94043\n", "2 3 23 M writer 32067\n", "3 4 24 M technician 43537\n", "4 5 33 F other 15213" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "users.head()" ] }, { "cell_type": "code", "execution_count": 8, "id": "1dbad7b1", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
movie_idmovie_titlerelease_datevideo_release_dateIMDb_URLunknownActionAdventureAnimationChildren's...FantasyFilm-NoirHorrorMusicalMysteryRomanceSci-FiThrillerWarWestern
01Toy Story (1995)01-Jan-1995NaNhttp://us.imdb.com/M/title-exact?Toy%20Story%2...00011...0000000000
12GoldenEye (1995)01-Jan-1995NaNhttp://us.imdb.com/M/title-exact?GoldenEye%20(...01100...0000000100
23Four Rooms (1995)01-Jan-1995NaNhttp://us.imdb.com/M/title-exact?Four%20Rooms%...00000...0000000100
34Get Shorty (1995)01-Jan-1995NaNhttp://us.imdb.com/M/title-exact?Get%20Shorty%...01000...0000000000
45Copycat (1995)01-Jan-1995NaNhttp://us.imdb.com/M/title-exact?Copycat%20(1995)00000...0000000100
\n", "

5 rows × 24 columns

\n", "
" ], "text/plain": [ " movie_id movie_title release_date video_release_date \\\n", "0 1 Toy Story (1995) 01-Jan-1995 NaN \n", "1 2 GoldenEye (1995) 01-Jan-1995 NaN \n", "2 3 Four Rooms (1995) 01-Jan-1995 NaN \n", "3 4 Get Shorty (1995) 01-Jan-1995 NaN \n", "4 5 Copycat (1995) 01-Jan-1995 NaN \n", "\n", " IMDb_URL unknown Action \\\n", "0 http://us.imdb.com/M/title-exact?Toy%20Story%2... 0 0 \n", "1 http://us.imdb.com/M/title-exact?GoldenEye%20(... 0 1 \n", "2 http://us.imdb.com/M/title-exact?Four%20Rooms%... 0 0 \n", "3 http://us.imdb.com/M/title-exact?Get%20Shorty%... 0 1 \n", "4 http://us.imdb.com/M/title-exact?Copycat%20(1995) 0 0 \n", "\n", " Adventure Animation Children's ... Fantasy Film-Noir Horror Musical \\\n", "0 0 1 1 ... 0 0 0 0 \n", "1 1 0 0 ... 0 0 0 0 \n", "2 0 0 0 ... 0 0 0 0 \n", "3 0 0 0 ... 0 0 0 0 \n", "4 0 0 0 ... 0 0 0 0 \n", "\n", " Mystery Romance Sci-Fi Thriller War Western \n", "0 0 0 0 0 0 0 \n", "1 0 0 0 1 0 0 \n", "2 0 0 0 1 0 0 \n", "3 0 0 0 0 0 0 \n", "4 0 0 0 1 0 0 \n", "\n", "[5 rows x 24 columns]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "items.head()" ] }, { "cell_type": "code", "execution_count": 9, "id": "3cb7bbc5", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
user_idmovie_idratingtimestampnum_watched
0116858749654781
1117258749654782
2116558749655183
3115648749655564
4119658749656775
\n", "
" ], "text/plain": [ " user_id movie_id rating timestamp num_watched\n", "0 1 168 5 874965478 1\n", "1 1 172 5 874965478 2\n", "2 1 165 5 874965518 3\n", "3 1 156 4 874965556 4\n", "4 1 196 5 874965677 5" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# adding a column with the number of movies watched per user\n", "dataset = data.sort_values([\"user_id\", \"timestamp\"]).reset_index(drop=True)\n", "dataset[\"one\"] = 1\n", "dataset[\"num_watched\"] = dataset.groupby(\"user_id\")[\"one\"].cumsum()\n", "dataset.drop(\"one\", axis=1, inplace=True)\n", "dataset.head()" ] }, { "cell_type": "code", "execution_count": 10, "id": "cf7c5da2", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
user_idmovie_idratingtimestampnum_watchedmean_rate
01168587496547815.00
11172587496547825.00
21165587496551835.00
31156487496555644.75
41196587496567754.80
\n", "
" ], "text/plain": [ " user_id movie_id rating timestamp num_watched mean_rate\n", "0 1 168 5 874965478 1 5.00\n", "1 1 172 5 874965478 2 5.00\n", "2 1 165 5 874965518 3 5.00\n", "3 1 156 4 874965556 4 4.75\n", "4 1 196 5 874965677 5 4.80" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# adding a column with the mean rating at a point in time per user\n", "dataset[\"mean_rate\"] = (\n", " dataset.groupby(\"user_id\")[\"rating\"].cumsum() / dataset[\"num_watched\"]\n", ")\n", "dataset.head()" ] }, { "cell_type": "markdown", "id": "29d1c399", "metadata": {}, "source": [ "### Problem formulation\n", "\n", "In this particular exercise the problem is formulated as predicting the next movie that will be watched (in consequence the last interactions will be discarded)" ] }, { "cell_type": "code", "execution_count": 11, "id": "0e9d1315", "metadata": {}, "outputs": [], "source": [ "dataset[\"target\"] = dataset.groupby(\"user_id\")[\"movie_id\"].shift(-1)" ] }, { "cell_type": "markdown", "id": "b38bba10", "metadata": {}, "source": [ "Following the same processing used by the author in the before-mentioned Kaggle notebook, we build sequences of previous movies watched" ] }, { "cell_type": "code", "execution_count": 12, "id": "f001f2b4", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
user_idmovie_idratingtimestampnum_watchedmean_ratetargetprev_movies
01168587496547815.00172.0[168]
11172587496547825.00165.0[168, 172]
21165587496551835.00156.0[168, 172, 165]
31156487496555644.75196.0[168, 172, 165, 156]
41196587496567754.80166.0[168, 172, 165, 156, 196]
\n", "
" ], "text/plain": [ " user_id movie_id rating timestamp num_watched mean_rate target \\\n", "0 1 168 5 874965478 1 5.00 172.0 \n", "1 1 172 5 874965478 2 5.00 165.0 \n", "2 1 165 5 874965518 3 5.00 156.0 \n", "3 1 156 4 874965556 4 4.75 196.0 \n", "4 1 196 5 874965677 5 4.80 166.0 \n", "\n", " prev_movies \n", "0 [168] \n", "1 [168, 172] \n", "2 [168, 172, 165] \n", "3 [168, 172, 165, 156] \n", "4 [168, 172, 165, 156, 196] " ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Here the author builds the sequences\n", "dataset[\"prev_movies\"] = dataset[\"movie_id\"].apply(lambda x: str(x))\n", "dataset[\"prev_movies\"] = (\n", " dataset.groupby(\"user_id\")[\"prev_movies\"]\n", " .apply(lambda x: (x + \" \").cumsum().str.strip())\n", " .reset_index(drop=True)\n", ")\n", "dataset[\"prev_movies\"] = dataset[\"prev_movies\"].apply(lambda x: x.split())\n", "dataset.head()" ] }, { "cell_type": "markdown", "id": "a024b9c4", "metadata": {}, "source": [ "And now we add a `genre_rate` as the mean of all movies rated for a given genre per user\n" ] }, { "cell_type": "code", "execution_count": 13, "id": "5782f0c9", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
user_idmovie_idratingtimestampnum_watchedmean_ratetargetprev_moviesunknownAction...Fantasy_rateFilm-Noir_rateHorror_rateMusical_rateMystery_rateRomance_rateSci-Fi_rateThriller_rateWar_rateWestern_rate
01168587496547815.00172.0[168]0.00.000000...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
11172587496547825.00165.0[168, 172]0.00.500000...NaNNaNNaNNaNNaN5.05.0NaN5.0NaN
21165587496551835.00156.0[168, 172, 165]0.00.333333...NaNNaNNaNNaNNaN5.05.0NaN5.0NaN
31156487496555644.75196.0[168, 172, 165, 156]0.00.250000...NaNNaNNaNNaNNaN5.05.04.05.0NaN
41196587496567754.80166.0[168, 172, 165, 156, 196]0.00.200000...NaNNaNNaNNaNNaN5.05.04.05.0NaN
\n", "

5 rows × 46 columns

\n", "
" ], "text/plain": [ " user_id movie_id rating timestamp num_watched mean_rate target \\\n", "0 1 168 5 874965478 1 5.00 172.0 \n", "1 1 172 5 874965478 2 5.00 165.0 \n", "2 1 165 5 874965518 3 5.00 156.0 \n", "3 1 156 4 874965556 4 4.75 196.0 \n", "4 1 196 5 874965677 5 4.80 166.0 \n", "\n", " prev_movies unknown Action ... Fantasy_rate \\\n", "0 [168] 0.0 0.000000 ... NaN \n", "1 [168, 172] 0.0 0.500000 ... NaN \n", "2 [168, 172, 165] 0.0 0.333333 ... NaN \n", "3 [168, 172, 165, 156] 0.0 0.250000 ... NaN \n", "4 [168, 172, 165, 156, 196] 0.0 0.200000 ... NaN \n", "\n", " Film-Noir_rate Horror_rate Musical_rate Mystery_rate Romance_rate \\\n", "0 NaN NaN NaN NaN NaN \n", "1 NaN NaN NaN NaN 5.0 \n", "2 NaN NaN NaN NaN 5.0 \n", "3 NaN NaN NaN NaN 5.0 \n", "4 NaN NaN NaN NaN 5.0 \n", "\n", " Sci-Fi_rate Thriller_rate War_rate Western_rate \n", "0 NaN NaN NaN NaN \n", "1 5.0 NaN 5.0 NaN \n", "2 5.0 NaN 5.0 NaN \n", "3 5.0 4.0 5.0 NaN \n", "4 5.0 4.0 5.0 NaN \n", "\n", "[5 rows x 46 columns]" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset = dataset.merge(items[[\"movie_id\"] + list_of_genres], on=\"movie_id\", how=\"left\")\n", "for genre in list_of_genres:\n", " dataset[f\"{genre}_rate\"] = dataset[genre] * dataset[\"rating\"]\n", " dataset[genre] = dataset.groupby(\"user_id\")[genre].cumsum()\n", " dataset[f\"{genre}_rate\"] = (\n", " dataset.groupby(\"user_id\")[f\"{genre}_rate\"].cumsum() / dataset[genre]\n", " )\n", "dataset[list_of_genres] = dataset[list_of_genres].apply(\n", " lambda x: x / dataset[\"num_watched\"]\n", ")\n", "dataset.head()" ] }, { "cell_type": "markdown", "id": "7029510d", "metadata": {}, "source": [ "Adding user features" ] }, { "cell_type": "code", "execution_count": 14, "id": "df698ec8", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
user_idmovie_idratingtimestampnum_watchedmean_ratetargetprev_moviesunknownAction...Mystery_rateRomance_rateSci-Fi_rateThriller_rateWar_rateWestern_rateagegenderoccupationzip_code
01168587496547815.00172.0[168]0.00.000000...NaNNaNNaNNaNNaNNaN24Mtechnician85711
11172587496547825.00165.0[168, 172]0.00.500000...NaN5.05.0NaN5.0NaN24Mtechnician85711
21165587496551835.00156.0[168, 172, 165]0.00.333333...NaN5.05.0NaN5.0NaN24Mtechnician85711
31156487496555644.75196.0[168, 172, 165, 156]0.00.250000...NaN5.05.04.05.0NaN24Mtechnician85711
41196587496567754.80166.0[168, 172, 165, 156, 196]0.00.200000...NaN5.05.04.05.0NaN24Mtechnician85711
\n", "

5 rows × 50 columns

\n", "
" ], "text/plain": [ " user_id movie_id rating timestamp num_watched mean_rate target \\\n", "0 1 168 5 874965478 1 5.00 172.0 \n", "1 1 172 5 874965478 2 5.00 165.0 \n", "2 1 165 5 874965518 3 5.00 156.0 \n", "3 1 156 4 874965556 4 4.75 196.0 \n", "4 1 196 5 874965677 5 4.80 166.0 \n", "\n", " prev_movies unknown Action ... Mystery_rate \\\n", "0 [168] 0.0 0.000000 ... NaN \n", "1 [168, 172] 0.0 0.500000 ... NaN \n", "2 [168, 172, 165] 0.0 0.333333 ... NaN \n", "3 [168, 172, 165, 156] 0.0 0.250000 ... NaN \n", "4 [168, 172, 165, 156, 196] 0.0 0.200000 ... NaN \n", "\n", " Romance_rate Sci-Fi_rate Thriller_rate War_rate Western_rate age \\\n", "0 NaN NaN NaN NaN NaN 24 \n", "1 5.0 5.0 NaN 5.0 NaN 24 \n", "2 5.0 5.0 NaN 5.0 NaN 24 \n", "3 5.0 5.0 4.0 5.0 NaN 24 \n", "4 5.0 5.0 4.0 5.0 NaN 24 \n", "\n", " gender occupation zip_code \n", "0 M technician 85711 \n", "1 M technician 85711 \n", "2 M technician 85711 \n", "3 M technician 85711 \n", "4 M technician 85711 \n", "\n", "[5 rows x 50 columns]" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset = dataset.merge(users, on=\"user_id\", how=\"left\")\n", "dataset.head()" ] }, { "cell_type": "markdown", "id": "ee62d77e", "metadata": {}, "source": [ "Again, we use the same settings as those in the Kaggle notebook, but `COLD_START_TRESH` is pretty aggressive" ] }, { "cell_type": "code", "execution_count": 15, "id": "8060cf59", "metadata": {}, "outputs": [], "source": [ "COLD_START_TRESH = 5\n", "\n", "filtred_data = dataset[\n", " (dataset[\"num_watched\"] >= COLD_START_TRESH) & ~(dataset[\"target\"].isna())\n", "].sort_values(\"timestamp\")\n", "train_data, _test_data = train_test_split(filtred_data, test_size=0.2, shuffle=False)\n", "valid_data, test_data = train_test_split(_test_data, test_size=0.5, shuffle=False)" ] }, { "cell_type": "code", "execution_count": 16, "id": "b1beb347", "metadata": {}, "outputs": [], "source": [ "cols_to_drop = [\n", " # \"rating\",\n", " \"timestamp\",\n", " \"num_watched\",\n", "]\n", "\n", "df_train = train_data.drop(cols_to_drop, axis=1)\n", "df_valid = valid_data.drop(cols_to_drop, axis=1)\n", "df_test = test_data.drop(cols_to_drop, axis=1)\n", "\n", "df_train.to_pickle(save_path / \"df_train.pkl\")\n", "df_valid.to_pickle(save_path / \"df_valid.pkl\")\n", "df_test.to_pickle(save_path / \"df_test.pkl\")" ] }, { "cell_type": "markdown", "id": "5bf71a82", "metadata": {}, "source": [ "Let's now build a model that is nearly identical to the one use in the[ Kaggle notebook](https://www.kaggle.com/code/matanivanov/wide-deep-learning-for-recsys-with-pytorch)" ] }, { "cell_type": "code", "execution_count": 17, "id": "6aa2e3f2", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import torch\n", "from torch import nn\n", "from scipy.sparse import coo_matrix\n", "\n", "from pytorch_widedeep import Trainer\n", "from pytorch_widedeep.models import TabMlp, BasicRNN, WideDeep\n", "from pytorch_widedeep.preprocessing import TabPreprocessor" ] }, { "cell_type": "code", "execution_count": 18, "id": "42b0d88f", "metadata": {}, "outputs": [], "source": [ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "\n", "save_path = Path(\"prepared_data\")\n", "\n", "PAD_IDX = 0" ] }, { "cell_type": "markdown", "id": "be204fe8", "metadata": {}, "source": [ "Let's use some of the functions the author of the kaggle's notebook uses to prepare the data" ] }, { "cell_type": "code", "execution_count": 19, "id": "206eb90e", "metadata": {}, "outputs": [], "source": [ "def get_coo_indexes(lil):\n", " rows = []\n", " cols = []\n", " for i, el in enumerate(lil):\n", " if type(el) != list:\n", " el = [el]\n", " for j in el:\n", " rows.append(i)\n", " cols.append(j)\n", " return rows, cols\n", "\n", "\n", "def get_sparse_features(series, shape):\n", " coo_indexes = get_coo_indexes(series.tolist())\n", " sparse_df = coo_matrix(\n", " (np.ones(len(coo_indexes[0])), (coo_indexes[0], coo_indexes[1])), shape=shape\n", " )\n", " return sparse_df\n", "\n", "\n", "def sparse_to_idx(data, pad_idx=-1):\n", " indexes = data.nonzero()\n", " indexes_df = pd.DataFrame()\n", " indexes_df[\"rows\"] = indexes[0]\n", " indexes_df[\"cols\"] = indexes[1]\n", " mdf = indexes_df.groupby(\"rows\").apply(lambda x: x[\"cols\"].tolist())\n", " max_len = mdf.apply(lambda x: len(x)).max()\n", " return mdf.apply(lambda x: pd.Series(x + [pad_idx] * (max_len - len(x)))).values" ] }, { "cell_type": "markdown", "id": "7ca8dd42", "metadata": {}, "source": [ "For the time being, we will not use a validation set for hyperparameter optimization, and we will simply concatenate the validation and the test set in one test set. I simply splitted the data into train/valid/test in case the reader wants to actually do hyperparameter optimization (and because I know in the future I will).\n", "\n", "There is also another caveat worth mentioning, related to the indexing of the movies. To build the matrices of movies watched, we use the entire dataset. A more realistic (and correct) approach would be to use ONLY the movies that appear in the training set and consider `unknown` or `unseen` those in the testing set that have not been seen during training. Nonetheless, this will not affect the purposes of this notebook, which is to illustrate how one could use `pytorch-widedeep` to build a recommendation algorithm. However, if one wanted to explore the performance of different algorithms in a \"proper\" way, these \"details\" need to be accounted for." ] }, { "cell_type": "code", "execution_count": 20, "id": "39f778bc", "metadata": {}, "outputs": [], "source": [ "df_test = pd.concat([df_valid, df_test], ignore_index=True)" ] }, { "cell_type": "code", "execution_count": 21, "id": "ab7483c3", "metadata": {}, "outputs": [], "source": [ "id_cols = [\"user_id\", \"movie_id\"]\n", "max_movie_index = max(df_train.movie_id.max(), df_test.movie_id.max())" ] }, { "cell_type": "code", "execution_count": 22, "id": "3d17bd3d", "metadata": {}, "outputs": [], "source": [ "X_train = df_train.drop(id_cols + [\"rating\", \"prev_movies\", \"target\"], axis=1)\n", "y_train = np.array(df_train.target.values, dtype=\"int64\")\n", "train_movies_watched = get_sparse_features(\n", " df_train[\"prev_movies\"], (len(df_train), max_movie_index + 1)\n", ")\n", "\n", "X_test = df_test.drop(id_cols + [\"rating\", \"prev_movies\", \"target\"], axis=1)\n", "y_test = np.array(df_test.target.values, dtype=\"int64\")\n", "test_movies_watched = get_sparse_features(\n", " df_test[\"prev_movies\"], (len(df_test), max_movie_index + 1)\n", ")" ] }, { "cell_type": "markdown", "id": "511e95ed", "metadata": {}, "source": [ "let's have a look to the information in each dataset" ] }, { "cell_type": "code", "execution_count": 23, "id": "dd9e5ef3", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
mean_rateunknownActionAdventureAnimationChildren'sComedyCrimeDocumentaryDrama...Mystery_rateRomance_rateSci-Fi_rateThriller_rateWar_rateWestern_rateagegenderoccupationzip_code
254234.0000000.00.4000000.2000000.00.00.4000000.00.00.200000...NaN4.04.04.0000004.0NaN21Mstudent48823
254254.0000000.00.2857140.1428570.00.00.4285710.00.00.285714...NaN4.04.04.0000004.0NaN21Mstudent48823
254244.0000000.00.3333330.1666670.00.00.3333330.00.00.333333...NaN4.04.04.0000004.0NaN21Mstudent48823
254263.8750000.00.2500000.1250000.00.00.3750000.00.00.250000...NaN4.04.03.6666674.0NaN21Mstudent48823
254273.8888890.00.2222220.1111110.00.00.3333330.00.00.333333...NaN4.04.03.6666674.0NaN21Mstudent48823
\n", "

5 rows × 43 columns

\n", "
" ], "text/plain": [ " mean_rate unknown Action Adventure Animation Children's \\\n", "25423 4.000000 0.0 0.400000 0.200000 0.0 0.0 \n", "25425 4.000000 0.0 0.285714 0.142857 0.0 0.0 \n", "25424 4.000000 0.0 0.333333 0.166667 0.0 0.0 \n", "25426 3.875000 0.0 0.250000 0.125000 0.0 0.0 \n", "25427 3.888889 0.0 0.222222 0.111111 0.0 0.0 \n", "\n", " Comedy Crime Documentary Drama ... Mystery_rate \\\n", "25423 0.400000 0.0 0.0 0.200000 ... NaN \n", "25425 0.428571 0.0 0.0 0.285714 ... NaN \n", "25424 0.333333 0.0 0.0 0.333333 ... NaN \n", "25426 0.375000 0.0 0.0 0.250000 ... NaN \n", "25427 0.333333 0.0 0.0 0.333333 ... NaN \n", "\n", " Romance_rate Sci-Fi_rate Thriller_rate War_rate Western_rate age \\\n", "25423 4.0 4.0 4.000000 4.0 NaN 21 \n", "25425 4.0 4.0 4.000000 4.0 NaN 21 \n", "25424 4.0 4.0 4.000000 4.0 NaN 21 \n", "25426 4.0 4.0 3.666667 4.0 NaN 21 \n", "25427 4.0 4.0 3.666667 4.0 NaN 21 \n", "\n", " gender occupation zip_code \n", "25423 M student 48823 \n", "25425 M student 48823 \n", "25424 M student 48823 \n", "25426 M student 48823 \n", "25427 M student 48823 \n", "\n", "[5 rows x 43 columns]" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train.head()" ] }, { "cell_type": "code", "execution_count": 24, "id": "840e59a2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([772, 288, 108, ..., 183, 432, 509])" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_train" ] }, { "cell_type": "code", "execution_count": 25, "id": "516d2fd5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<76228x1683 sparse matrix of type ''\n", "\twith 7957390 stored elements in COOrdinate format>" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_movies_watched" ] }, { "cell_type": "code", "execution_count": 26, "id": "a4cba74d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['173', '185', '255', '286', '298']" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sorted(df_train.prev_movies.tolist()[0])" ] }, { "cell_type": "code", "execution_count": 27, "id": "a4f11af4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([0, 0, 0, 0, 0]), array([173, 185, 255, 286, 298]))" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.where(train_movies_watched.todense()[0])" ] }, { "cell_type": "markdown", "id": "2d7dd7bc", "metadata": {}, "source": [ "And from now on is when the specifics related to this library start to appear. The only component that is going to be a bit different is the so-called tabular component, referred as `continuous` in the notebook. \n", "\n", "In the case of `pytorch-widedeep` we have the `TabPreprocessor` that allows for a lot of flexibility as to how we would like to process the tabular component of this Wide and Deep model. In other words, here our tabular component is a bit more elaborated than that in the notebook, just a bit...\n" ] }, { "cell_type": "code", "execution_count": 28, "id": "733ea2a5", "metadata": {}, "outputs": [], "source": [ "cat_cols = [\"gender\", \"occupation\", \"zip_code\"]\n", "cont_cols = [c for c in X_train if c not in cat_cols]\n", "tab_preprocessor = TabPreprocessor(\n", " cat_embed_cols=cat_cols,\n", " continuous_cols=cont_cols,\n", ")" ] }, { "cell_type": "code", "execution_count": 29, "id": "68555183", "metadata": {}, "outputs": [], "source": [ "X_train_tab = tab_preprocessor.fit_transform(X_train.fillna(0))\n", "X_test_tab = tab_preprocessor.transform(X_test.fillna(0))" ] }, { "cell_type": "markdown", "id": "a00da28c", "metadata": {}, "source": [ "Now, in the notebook, the author moves the sparse matrices to sparse tensors and then turns them into dense tensors. In reality, this is not neccessary, one could feed sparse tensors to `nn.Linear` layers in pytorch. Nonetheless, this is not the most efficient implementation and is the reason why in our library the wide, linear component is implemented as an embedding layer. \n", "\n", "Nonetheless, to reproduce the notebook the best we can and because currently the `Wide` model in `pytorch-widedeep` is not designed to receive sparse tensors (we might consider implementing this functionality), we will turn the sparse COO matrices into dense arrays. We will then code a fairly simple, custom `Wide` component." ] }, { "cell_type": "code", "execution_count": 30, "id": "20903dd2", "metadata": {}, "outputs": [], "source": [ "X_train_wide = np.array(train_movies_watched.todense())\n", "X_test_wide = np.array(test_movies_watched.todense())" ] }, { "cell_type": "markdown", "id": "377e7f90", "metadata": {}, "source": [ "Finally, the author of the notebook uses a simple `Embedding` layer to encode the sequences of movies watched, the `prev_movies` columns. In my opinion, there is an element of information redundancy here. This is because the wide and text components have implicitely the same information, but in different form. Moreover, both of the models used for these two components ignore the sequential element in the data. Nonetheless, we want to reproduce the Kaggle notebook as close as possible, AND as one can explore later (by simply performing simple ablation studies), the wide component seems to carry most of the predictive power." ] }, { "cell_type": "code", "execution_count": 31, "id": "c52fd52c", "metadata": {}, "outputs": [], "source": [ "X_train_text = sparse_to_idx(train_movies_watched, pad_idx=PAD_IDX)\n", "X_test_text = sparse_to_idx(test_movies_watched, pad_idx=PAD_IDX)" ] }, { "cell_type": "markdown", "id": "1ca8b84d", "metadata": {}, "source": [ "Let's now build the models" ] }, { "cell_type": "code", "execution_count": 32, "id": "44bc73d4", "metadata": {}, "outputs": [], "source": [ "class Wide(nn.Module):\n", " def __init__(self, input_dim: int, pred_dim: int):\n", " super().__init__()\n", "\n", " self.input_dim = input_dim\n", " self.pred_dim = pred_dim\n", "\n", " # When I coded the library I never though that someone would want to code\n", " # their own wide component. However, if you do, the wide component must have\n", " # a 'wide_linear' attribute. In other words, the linear layer must be\n", " # called 'wide_linear'\n", " self.wide_linear = nn.Linear(input_dim, pred_dim)\n", "\n", " def forward(self, X):\n", " out = self.wide_linear(X.type(torch.float32))\n", " return out\n", "\n", "\n", "wide = Wide(X_train_wide.shape[1], max_movie_index + 1)" ] }, { "cell_type": "code", "execution_count": 33, "id": "6f66130d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Wide(\n", " (wide_linear): Linear(in_features=1683, out_features=1683, bias=True)\n", ")" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "wide" ] }, { "cell_type": "code", "execution_count": 34, "id": "25592d30", "metadata": {}, "outputs": [], "source": [ "class SimpleEmbed(nn.Module):\n", " def __init__(self, vocab_size: int, embed_dim: int, pad_idx: int):\n", " super().__init__()\n", "\n", " self.vocab_size = vocab_size\n", " self.embed_dim = embed_dim\n", " self.pad_idx = pad_idx\n", "\n", " # The sequences of movies watched are simply embedded in the Kaggle\n", " # notebook. No RNN, Transformer or any model is used\n", " self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)\n", "\n", " def forward(self, X):\n", " embed = self.embed(X)\n", " embed_mean = torch.mean(embed, dim=1)\n", " return embed_mean\n", "\n", " @property\n", " def output_dim(self) -> int:\n", " # All deep components in a custom 'pytorch-widedeep' model must have\n", " # an output_dim property\n", " return self.embed_dim\n", "\n", "\n", "# In the notebook the author uses simply embeddings\n", "simple_embed = SimpleEmbed(max_movie_index + 1, 16, 0)" ] }, { "cell_type": "code", "execution_count": 35, "id": "492f12c5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "SimpleEmbed(\n", " (embed): Embedding(1683, 16, padding_idx=0)\n", ")" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "simple_embed" ] }, { "cell_type": "markdown", "id": "fe9f137a", "metadata": {}, "source": [ "Maybe one would like to use an RNN to account for the sequence nature of the problem. If that was the case it would be as easy as: " ] }, { "cell_type": "code", "execution_count": 36, "id": "0c3f17b2", "metadata": {}, "outputs": [], "source": [ "basic_rnn = BasicRNN(\n", " vocab_size=max_movie_index + 1,\n", " embed_dim=16,\n", " hidden_dim=32,\n", " n_layers=2,\n", " rnn_type=\"gru\",\n", ")" ] }, { "cell_type": "markdown", "id": "e410d5d9", "metadata": {}, "source": [ "And finally, the tabular component, which is the notebook is simply a stak of linear + Rely layers. In our case we have an embedding layer before the linear layers to encode categorial and numerical cols" ] }, { "cell_type": "code", "execution_count": 37, "id": "ca721555", "metadata": {}, "outputs": [], "source": [ "tab_mlp = TabMlp(\n", " column_idx=tab_preprocessor.column_idx,\n", " cat_embed_input=tab_preprocessor.cat_embed_input,\n", " continuous_cols=tab_preprocessor.continuous_cols,\n", " cont_norm_layer=None,\n", " mlp_hidden_dims=[1024, 512, 256],\n", " mlp_activation=\"relu\",\n", ")" ] }, { "cell_type": "code", "execution_count": 38, "id": "25c25e3a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "TabMlp(\n", " (cat_and_cont_embed): DiffSizeCatAndContEmbeddings(\n", " (cat_embed): DiffSizeCatEmbeddings(\n", " (embed_layers): ModuleDict(\n", " (emb_layer_gender): Embedding(3, 2, padding_idx=0)\n", " (emb_layer_occupation): Embedding(22, 9, padding_idx=0)\n", " (emb_layer_zip_code): Embedding(648, 60, padding_idx=0)\n", " )\n", " (embedding_dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (cont_norm): Identity()\n", " )\n", " (encoder): MLP(\n", " (mlp): Sequential(\n", " (dense_layer_0): Sequential(\n", " (0): Dropout(p=0.1, inplace=False)\n", " (1): Linear(in_features=111, out_features=1024, bias=True)\n", " (2): ReLU(inplace=True)\n", " )\n", " (dense_layer_1): Sequential(\n", " (0): Dropout(p=0.1, inplace=False)\n", " (1): Linear(in_features=1024, out_features=512, bias=True)\n", " (2): ReLU(inplace=True)\n", " )\n", " (dense_layer_2): Sequential(\n", " (0): Dropout(p=0.1, inplace=False)\n", " (1): Linear(in_features=512, out_features=256, bias=True)\n", " (2): ReLU(inplace=True)\n", " )\n", " )\n", " )\n", ")" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tab_mlp" ] }, { "cell_type": "markdown", "id": "b68c5bc9", "metadata": {}, "source": [ "Finally, we simply wrap up all models with the `WideDeep` 'collector' class and we are ready to train. " ] }, { "cell_type": "code", "execution_count": 39, "id": "4c6acc08", "metadata": {}, "outputs": [], "source": [ "wide_deep_model = WideDeep(\n", " wide=wide, deeptabular=tab_mlp, deeptext=simple_embed, pred_dim=max_movie_index + 1\n", ")" ] }, { "cell_type": "code", "execution_count": 40, "id": "bc8970f7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "WideDeep(\n", " (wide): Wide(\n", " (wide_linear): Linear(in_features=1683, out_features=1683, bias=True)\n", " )\n", " (deeptabular): Sequential(\n", " (0): TabMlp(\n", " (cat_and_cont_embed): DiffSizeCatAndContEmbeddings(\n", " (cat_embed): DiffSizeCatEmbeddings(\n", " (embed_layers): ModuleDict(\n", " (emb_layer_gender): Embedding(3, 2, padding_idx=0)\n", " (emb_layer_occupation): Embedding(22, 9, padding_idx=0)\n", " (emb_layer_zip_code): Embedding(648, 60, padding_idx=0)\n", " )\n", " (embedding_dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (cont_norm): Identity()\n", " )\n", " (encoder): MLP(\n", " (mlp): Sequential(\n", " (dense_layer_0): Sequential(\n", " (0): Dropout(p=0.1, inplace=False)\n", " (1): Linear(in_features=111, out_features=1024, bias=True)\n", " (2): ReLU(inplace=True)\n", " )\n", " (dense_layer_1): Sequential(\n", " (0): Dropout(p=0.1, inplace=False)\n", " (1): Linear(in_features=1024, out_features=512, bias=True)\n", " (2): ReLU(inplace=True)\n", " )\n", " (dense_layer_2): Sequential(\n", " (0): Dropout(p=0.1, inplace=False)\n", " (1): Linear(in_features=512, out_features=256, bias=True)\n", " (2): ReLU(inplace=True)\n", " )\n", " )\n", " )\n", " )\n", " (1): Linear(in_features=256, out_features=1683, bias=True)\n", " )\n", " (deeptext): Sequential(\n", " (0): SimpleEmbed(\n", " (embed): Embedding(1683, 16, padding_idx=0)\n", " )\n", " (1): Linear(in_features=16, out_features=1683, bias=True)\n", " )\n", ")" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "wide_deep_model" ] }, { "cell_type": "markdown", "id": "e08d41ed", "metadata": {}, "source": [ "Note that the main difference between this wide and deep model and the Wide and Deep model in the Kaggle notebook is that in that notebook, the author concatenates the embedings and the tabular features, then passes this concatenation through a stack of linear + Relu layers with a final output dim of 256. Then concatenates this output with the binary features and connects this concatenation with the final linear layer (so the final weights are of dim (batch_size, 256 + 1683)). Our implementation follows the notation of the original paper and instead of concatenating the tabular, text and wide components and then connect them to the output neurons, we first compute their output, and then add it (see here: https://arxiv.org/pdf/1606.07792.pdf, their Eq 3). Note that this is effectively the same, with the caveat that while in one case one initialises a big weight matrix \"at once\", in our implementation we initialise different matrices for different components. Anyway, let's give it a go." ] }, { "cell_type": "code", "execution_count": 41, "id": "538a34de", "metadata": {}, "outputs": [], "source": [ "trainer = Trainer(\n", " model=wide_deep_model,\n", " objective=\"multiclass\",\n", " custom_loss_function=nn.CrossEntropyLoss(ignore_index=PAD_IDX),\n", " optimizers=torch.optim.Adam(wide_deep_model.parameters(), lr=1e-3),\n", ")" ] }, { "cell_type": "code", "execution_count": 42, "id": "77c02ed5", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "epoch 1: 100%|█████████████████████████████████████████████████████████████████████████| 149/149 [00:16<00:00, 8.82it/s, loss=6.66]\n", "valid: 100%|█████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 21.14it/s, loss=6.61]\n", "epoch 2: 100%|█████████████████████████████████████████████████████████████████████████| 149/149 [00:17<00:00, 8.53it/s, loss=5.98]\n", "valid: 100%|█████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 23.20it/s, loss=6.53]\n", "epoch 3: 100%|█████████████████████████████████████████████████████████████████████████| 149/149 [00:17<00:00, 8.61it/s, loss=5.66]\n", "valid: 100%|█████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 23.16it/s, loss=6.54]\n", "epoch 4: 100%|█████████████████████████████████████████████████████████████████████████| 149/149 [00:17<00:00, 8.76it/s, loss=5.43]\n", "valid: 100%|█████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 22.03it/s, loss=6.56]\n", "epoch 5: 100%|█████████████████████████████████████████████████████████████████████████| 149/149 [00:17<00:00, 8.28it/s, loss=5.23]\n", "valid: 100%|█████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 22.60it/s, loss=6.59]\n" ] } ], "source": [ "trainer.fit(\n", " X_train={\n", " \"X_wide\": X_train_wide,\n", " \"X_tab\": X_train_tab,\n", " \"X_text\": X_train_text,\n", " \"target\": y_train,\n", " },\n", " X_val={\n", " \"X_wide\": X_test_wide,\n", " \"X_tab\": X_test_tab,\n", " \"X_text\": X_test_text,\n", " \"target\": y_test,\n", " },\n", " n_epochs=5,\n", " batch_size=512,\n", " shuffle=False,\n", ")" ] }, { "cell_type": "markdown", "id": "a8f9aec7", "metadata": {}, "source": [ "Now one could continue to the 'compare' metrics section of the Kaggle notebook. However, for the purposes of illustrating how one could use `pytorch-widedeep` to build recommendation algorithms we consider this notebook completed and move onto part 2 " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.15" } }, "nbformat": 4, "nbformat_minor": 5 }