{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Custom components\n", "\n", "As I mentioned earlier in the example notebooks, and also in the `README`, it is possible to customise almost every component in `pytorch-widedeep`.\n", "\n", "Let's now go through a couple of simple example to illustrate how that could be done. \n", "\n", "First let's load and process the data \"as usual\", let's start with a regression and the [airbnb](http://insideairbnb.com/get-the-data.html) dataset." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import os\n", "import torch\n", "\n", "from pytorch_widedeep import Trainer\n", "from pytorch_widedeep.preprocessing import WidePreprocessor, TabPreprocessor, TextPreprocessor, ImagePreprocessor\n", "from pytorch_widedeep.models import Wide, TabMlp, TabResnet, DeepText, DeepImage, WideDeep\n", "from pytorch_widedeep.losses import RMSELoss\n", "from pytorch_widedeep.initializers import *\n", "from pytorch_widedeep.callbacks import *" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idhost_iddescriptionhost_listings_counthost_identity_verifiedneighbourhood_cleansedlatitudelongitudeis_location_exactproperty_typeroom_typeaccommodatesbathroomsbedroomsbedsguests_includedminimum_nightsinstant_bookablecancellation_policyhas_house_ruleshost_genderaccommodates_catgguests_included_catgminimum_nights_catghost_listings_count_catgbathrooms_catgbedrooms_catgbeds_catgamenity_24-hour_check-inamenity__toiletamenity_accessible-height_bedamenity_accessible-height_toiletamenity_air_conditioningamenity_air_purifieramenity_alfresco_bathtubamenity_amazon_echoamenity_baby_bathamenity_baby_monitoramenity_babysitter_recommendationsamenity_balconyamenity_bath_towelamenity_bathroom_essentialsamenity_bathtubamenity_bathtub_with_bath_chairamenity_bbq_grillamenity_beach_essentialsamenity_beach_viewamenity_beachfrontamenity_bed_linensamenity_bedroom_comforts...amenity_roll-in_showeramenity_room-darkening_shadesamenity_safety_cardamenity_saunaamenity_self_check-inamenity_shampooamenity_shared_gymamenity_shared_hot_tubamenity_shared_poolamenity_shower_chairamenity_single_level_homeamenity_ski-in_ski-outamenity_smart_lockamenity_smart_tvamenity_smoke_detectoramenity_smoking_allowedamenity_soaking_tubamenity_sound_systemamenity_stair_gatesamenity_stand_alone_steam_showeramenity_standing_valetamenity_steam_ovenamenity_stoveamenity_suitable_for_eventsamenity_sun_loungersamenity_table_corner_guardsamenity_tennis_courtamenity_terraceamenity_toilet_paperamenity_touchless_faucetsamenity_tvamenity_walk-in_showeramenity_warming_draweramenity_washeramenity_washer_dryeramenity_waterfrontamenity_well-lit_path_to_entranceamenity_wheelchair_accessibleamenity_wide_clearance_to_showeramenity_wide_doorway_to_guest_bathroomamenity_wide_entranceamenity_wide_entrance_for_guestsamenity_wide_entrywayamenity_wide_hallwaysamenity_wifiamenity_window_guardsamenity_wine_coolersecurity_depositextra_peopleyield
013913.jpg54730My bright double bedroom with a large window has a relaxed feeling! It comfortably fits one or t...4.0fIslington51.56802-0.11121tapartmentprivate_room21.01.00.011fmoderate1female21131100011000000100011000010...11000100000000110000001000000010010000011000100100.015.012.00
115400.jpg60302Lots of windows and light. St Luke's Gardens are at the end of the block, and the river not too...1.0tKensington and Chelsea51.48796-0.16898tapartmententire_home/apt21.01.01.023fstrict_14_with_grace_period1female22311111000100000000000000000...00000100000000100000000000000010010000000000100150.00.0109.50
217402.jpg67564Open from June 2018 after a 3-year break, we are delighted to be welcoming guests again to this ...19.0tWestminster51.52098-0.14002tapartmententire_home/apt62.03.03.043tstrict_14_with_grace_period1female33332330000000000000000000010...00001100000000100000001000000010010000000000100350.010.0149.65
324328.jpg41759Artist house, bright high ceiling rooms, private parking and a communal garden in a conservation...2.0tWandsworth51.47298-0.16376totherentire_home/apt21.51.01.0230fmoderate1male22322111000000000000000000000...00001100000000100000000000000010010000000000100250.00.0215.60
425023.jpg102813Large, all comforts, 2-bed flat; first floor; lift; pretty communal gardens + off-street parking...1.0fWandsworth51.44687-0.21874tapartmententire_home/apt41.02.02.024fmoderate1female32311220000000000000000000000...00000000000000100000000000000010010000000000100250.011.079.35
\n", "

5 rows × 223 columns

\n", "
" ], "text/plain": [ " id host_id \\\n", "0 13913.jpg 54730 \n", "1 15400.jpg 60302 \n", "2 17402.jpg 67564 \n", "3 24328.jpg 41759 \n", "4 25023.jpg 102813 \n", "\n", " description \\\n", "0 My bright double bedroom with a large window has a relaxed feeling! It comfortably fits one or t... \n", "1 Lots of windows and light. St Luke's Gardens are at the end of the block, and the river not too... \n", "2 Open from June 2018 after a 3-year break, we are delighted to be welcoming guests again to this ... \n", "3 Artist house, bright high ceiling rooms, private parking and a communal garden in a conservation... \n", "4 Large, all comforts, 2-bed flat; first floor; lift; pretty communal gardens + off-street parking... \n", "\n", " host_listings_count host_identity_verified neighbourhood_cleansed \\\n", "0 4.0 f Islington \n", "1 1.0 t Kensington and Chelsea \n", "2 19.0 t Westminster \n", "3 2.0 t Wandsworth \n", "4 1.0 f Wandsworth \n", "\n", " latitude longitude is_location_exact property_type room_type \\\n", "0 51.56802 -0.11121 t apartment private_room \n", "1 51.48796 -0.16898 t apartment entire_home/apt \n", "2 51.52098 -0.14002 t apartment entire_home/apt \n", "3 51.47298 -0.16376 t other entire_home/apt \n", "4 51.44687 -0.21874 t apartment entire_home/apt \n", "\n", " accommodates bathrooms bedrooms beds guests_included minimum_nights \\\n", "0 2 1.0 1.0 0.0 1 1 \n", "1 2 1.0 1.0 1.0 2 3 \n", "2 6 2.0 3.0 3.0 4 3 \n", "3 2 1.5 1.0 1.0 2 30 \n", "4 4 1.0 2.0 2.0 2 4 \n", "\n", " instant_bookable cancellation_policy has_house_rules host_gender \\\n", "0 f moderate 1 female \n", "1 f strict_14_with_grace_period 1 female \n", "2 t strict_14_with_grace_period 1 female \n", "3 f moderate 1 male \n", "4 f moderate 1 female \n", "\n", " accommodates_catg guests_included_catg minimum_nights_catg \\\n", "0 2 1 1 \n", "1 2 2 3 \n", "2 3 3 3 \n", "3 2 2 3 \n", "4 3 2 3 \n", "\n", " host_listings_count_catg bathrooms_catg bedrooms_catg beds_catg \\\n", "0 3 1 1 0 \n", "1 1 1 1 1 \n", "2 3 2 3 3 \n", "3 2 2 1 1 \n", "4 1 1 2 2 \n", "\n", " amenity_24-hour_check-in amenity__toilet amenity_accessible-height_bed \\\n", "0 0 0 1 \n", "1 1 0 0 \n", "2 0 0 0 \n", "3 1 0 0 \n", "4 0 0 0 \n", "\n", " amenity_accessible-height_toilet amenity_air_conditioning \\\n", "0 1 0 \n", "1 0 1 \n", "2 0 0 \n", "3 0 0 \n", "4 0 0 \n", "\n", " amenity_air_purifier amenity_alfresco_bathtub amenity_amazon_echo \\\n", "0 0 0 0 \n", "1 0 0 0 \n", "2 0 0 0 \n", "3 0 0 0 \n", "4 0 0 0 \n", "\n", " amenity_baby_bath amenity_baby_monitor \\\n", "0 0 0 \n", "1 0 0 \n", "2 0 0 \n", "3 0 0 \n", "4 0 0 \n", "\n", " amenity_babysitter_recommendations amenity_balcony amenity_bath_towel \\\n", "0 1 0 0 \n", "1 0 0 0 \n", "2 0 0 0 \n", "3 0 0 0 \n", "4 0 0 0 \n", "\n", " amenity_bathroom_essentials amenity_bathtub \\\n", "0 0 1 \n", "1 0 0 \n", "2 0 0 \n", "3 0 0 \n", "4 0 0 \n", "\n", " amenity_bathtub_with_bath_chair amenity_bbq_grill \\\n", "0 1 0 \n", "1 0 0 \n", "2 0 0 \n", "3 0 0 \n", "4 0 0 \n", "\n", " amenity_beach_essentials amenity_beach_view amenity_beachfront \\\n", "0 0 0 0 \n", "1 0 0 0 \n", "2 0 0 0 \n", "3 0 0 0 \n", "4 0 0 0 \n", "\n", " amenity_bed_linens amenity_bedroom_comforts ... amenity_roll-in_shower \\\n", "0 1 0 ... 1 \n", "1 0 0 ... 0 \n", "2 1 0 ... 0 \n", "3 0 0 ... 0 \n", "4 0 0 ... 0 \n", "\n", " amenity_room-darkening_shades amenity_safety_card amenity_sauna \\\n", "0 1 0 0 \n", "1 0 0 0 \n", "2 0 0 0 \n", "3 0 0 0 \n", "4 0 0 0 \n", "\n", " amenity_self_check-in amenity_shampoo amenity_shared_gym \\\n", "0 0 1 0 \n", "1 0 1 0 \n", "2 1 1 0 \n", "3 1 1 0 \n", "4 0 0 0 \n", "\n", " amenity_shared_hot_tub amenity_shared_pool amenity_shower_chair \\\n", "0 0 0 0 \n", "1 0 0 0 \n", "2 0 0 0 \n", "3 0 0 0 \n", "4 0 0 0 \n", "\n", " amenity_single_level_home amenity_ski-in_ski-out amenity_smart_lock \\\n", "0 0 0 0 \n", "1 0 0 0 \n", "2 0 0 0 \n", "3 0 0 0 \n", "4 0 0 0 \n", "\n", " amenity_smart_tv amenity_smoke_detector amenity_smoking_allowed \\\n", "0 0 1 1 \n", "1 0 1 0 \n", "2 0 1 0 \n", "3 0 1 0 \n", "4 0 1 0 \n", "\n", " amenity_soaking_tub amenity_sound_system amenity_stair_gates \\\n", "0 0 0 0 \n", "1 0 0 0 \n", "2 0 0 0 \n", "3 0 0 0 \n", "4 0 0 0 \n", "\n", " amenity_stand_alone_steam_shower amenity_standing_valet \\\n", "0 0 0 \n", "1 0 0 \n", "2 0 0 \n", "3 0 0 \n", "4 0 0 \n", "\n", " amenity_steam_oven amenity_stove amenity_suitable_for_events \\\n", "0 0 1 0 \n", "1 0 0 0 \n", "2 0 1 0 \n", "3 0 0 0 \n", "4 0 0 0 \n", "\n", " amenity_sun_loungers amenity_table_corner_guards amenity_tennis_court \\\n", "0 0 0 0 \n", "1 0 0 0 \n", "2 0 0 0 \n", "3 0 0 0 \n", "4 0 0 0 \n", "\n", " amenity_terrace amenity_toilet_paper amenity_touchless_faucets \\\n", "0 0 0 0 \n", "1 0 0 0 \n", "2 0 0 0 \n", "3 0 0 0 \n", "4 0 0 0 \n", "\n", " amenity_tv amenity_walk-in_shower amenity_warming_drawer amenity_washer \\\n", "0 1 0 0 1 \n", "1 1 0 0 1 \n", "2 1 0 0 1 \n", "3 1 0 0 1 \n", "4 1 0 0 1 \n", "\n", " amenity_washer_dryer amenity_waterfront \\\n", "0 0 0 \n", "1 0 0 \n", "2 0 0 \n", "3 0 0 \n", "4 0 0 \n", "\n", " amenity_well-lit_path_to_entrance amenity_wheelchair_accessible \\\n", "0 0 0 \n", "1 0 0 \n", "2 0 0 \n", "3 0 0 \n", "4 0 0 \n", "\n", " amenity_wide_clearance_to_shower amenity_wide_doorway_to_guest_bathroom \\\n", "0 0 1 \n", "1 0 0 \n", "2 0 0 \n", "3 0 0 \n", "4 0 0 \n", "\n", " amenity_wide_entrance amenity_wide_entrance_for_guests \\\n", "0 1 0 \n", "1 0 0 \n", "2 0 0 \n", "3 0 0 \n", "4 0 0 \n", "\n", " amenity_wide_entryway amenity_wide_hallways amenity_wifi \\\n", "0 0 0 1 \n", "1 0 0 1 \n", "2 0 0 1 \n", "3 0 0 1 \n", "4 0 0 1 \n", "\n", " amenity_window_guards amenity_wine_cooler security_deposit extra_people \\\n", "0 0 0 100.0 15.0 \n", "1 0 0 150.0 0.0 \n", "2 0 0 350.0 10.0 \n", "3 0 0 250.0 0.0 \n", "4 0 0 250.0 11.0 \n", "\n", " yield \n", "0 12.00 \n", "1 109.50 \n", "2 149.65 \n", "3 215.60 \n", "4 79.35 \n", "\n", "[5 rows x 223 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv('data/airbnb/airbnb_sample.csv')\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# There are a number of columns that are already binary. Therefore, no need to one hot encode them\n", "crossed_cols = [('property_type', 'room_type')]\n", "already_dummies = [c for c in df.columns if 'amenity' in c] + ['has_house_rules']\n", "wide_cols = ['is_location_exact', 'property_type', 'room_type', 'host_gender',\n", "'instant_bookable'] + already_dummies\n", "cat_embed_cols = [(c, 16) for c in df.columns if 'catg' in c] + \\\n", " [('neighbourhood_cleansed', 64), ('cancellation_policy', 16)]\n", "continuous_cols = ['latitude', 'longitude', 'security_deposit', 'extra_people']\n", "# it does not make sense to standarised Latitude and Longitude\n", "already_standard = ['latitude', 'longitude']\n", "# text and image colnames\n", "text_col = 'description'\n", "img_col = 'id'\n", "# path to pretrained word embeddings and the images\n", "word_vectors_path = 'data/glove.6B/glove.6B.100d.txt'\n", "img_path = 'data/airbnb/property_picture'\n", "# target\n", "target_col = 'yield'" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "target = df[target_col].values" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The vocabulary contains 2192 tokens\n", "Indexing word vectors...\n", "Loaded 400000 word vectors\n", "Preparing embeddings matrix...\n", "2175 words in the vocabulary had data/glove.6B/glove.6B.100d.txt vectors and appear more than 5 times\n", "Reading Images from data/airbnb/property_picture\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 4%|▍ | 42/1001 [00:00<00:02, 414.31it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Resizing\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 1001/1001 [00:02<00:00, 387.88it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Computing normalisation metrics\n" ] } ], "source": [ "wide_preprocessor = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)\n", "X_wide = wide_preprocessor.fit_transform(df)\n", "\n", "tab_preprocessor = TabPreprocessor(embed_cols=cat_embed_cols, continuous_cols=continuous_cols)\n", "X_tab = tab_preprocessor.fit_transform(df)\n", "\n", "text_preprocessor = TextPreprocessor(word_vectors_path=word_vectors_path, text_col=text_col)\n", "X_text = text_preprocessor.fit_transform(df)\n", "\n", "image_processor = ImagePreprocessor(img_col = img_col, img_path = img_path)\n", "X_images = image_processor.fit_transform(df)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we are ready to build a wide and deep model. Three of the four components we will use are included in this package, and they will be combined with a custom `deeptext` component. Then the fit process will run with a custom loss function.\n", "\n", "Let's have a look" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Linear model\n", "wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)\n", "\n", "# DeepDense: 2 Dense layers\n", "deeptabular = TabMlp(\n", " column_idx = tab_preprocessor.column_idx,\n", " mlp_hidden_dims=[128,64],\n", " mlp_dropout = 0.1,\n", " mlp_batchnorm = True,\n", " embed_input=tab_preprocessor.embeddings_input,\n", " embed_dropout = 0.1,\n", " continuous_cols = continuous_cols,\n", " batchnorm_cont = True\n", ")\n", " \n", "# Pretrained Resnet 18 (default is all but last 2 conv blocks frozen) plus a FC-Head 512->256->128\n", "deepimage = DeepImage(pretrained=True, head_hidden_dims=[512, 256, 128])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Custom `deeptext`\n", "\n", "Standard Pytorch model" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "class MyDeepText(nn.Module):\n", " def __init__(self, vocab_size, padding_idx=1, embed_dim=100, hidden_dim=64):\n", " super(MyDeepText, self).__init__()\n", "\n", " # word/token embeddings\n", " self.word_embed = nn.Embedding(\n", " vocab_size, embed_dim, padding_idx=padding_idx\n", " )\n", "\n", " # stack of RNNs\n", " self.rnn = nn.GRU(\n", " embed_dim,\n", " hidden_dim,\n", " num_layers=2,\n", " bidirectional=True,\n", " batch_first=True,\n", " )\n", "\n", " # Remember, this must be defined. If not WideDeep will through an error\n", " self.output_dim = hidden_dim * 2\n", "\n", " def forward(self, X):\n", " embed = self.word_embed(X.long())\n", " o, h = self.rnn(embed)\n", " return torch.cat((h[-2], h[-1]), dim=1)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "mydeeptext = MyDeepText(vocab_size=len(text_preprocessor.vocab.itos))" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "model = WideDeep(wide=wide, deeptabular=deeptabular, deeptext=mydeeptext, deepimage=deepimage)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Custom loss function\n", "\n", "Loss functions must simply inherit pytorch's `nn.Module`. For example, let's say we want to use `RMSE` (note that this is already available in the package, but I will pass it here as a custom loss for illustration purposes)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "class RMSELoss(nn.Module):\n", " def __init__(self):\n", " \"\"\"root mean squared error\"\"\"\n", " super().__init__()\n", " self.mse = nn.MSELoss()\n", "\n", " def forward(self, input: Tensor, target: Tensor) -> Tensor:\n", " return torch.sqrt(self.mse(input, target))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and now we just instantiate the ``Trainer`` as usual. Needless to say, but this runs with 1000 random observations, so loss and metric values are meaningless. This is just an example" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "trainer = Trainer(model, objective='regression', custom_loss_function=RMSELoss())" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "epoch 1: 100%|██████████| 25/25 [02:13<00:00, 5.33s/it, loss=118]\n", "valid: 100%|██████████| 7/7 [00:15<00:00, 2.23s/it, loss=101] \n" ] } ], "source": [ "trainer.fit(X_wide=X_wide, X_tab=X_tab, X_text=X_text, X_img=X_images,\n", " target=target, n_epochs=1, batch_size=32, val_split=0.2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In addition to model components and loss functions, we can also use custom callbacks or custom metrics. The former need to be of type `Callback` and the latter need to be of type `Metric`. See:\n", "\n", "```python\n", "pytorch-widedeep.callbacks\n", "```\n", "and \n", "\n", "```python\n", "pytorch-widedeep.metrics\n", "```\n", "\n", "For this example let me use the adult dataset. Again, we first prepare the data as usual" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ageworkclassfnlwgteducationeducational-nummarital-statusoccupationrelationshipracegendercapital-gaincapital-losshours-per-weeknative-countryincome
025Private22680211th7Never-marriedMachine-op-inspctOwn-childBlackMale0040United-States<=50K
138Private89814HS-grad9Married-civ-spouseFarming-fishingHusbandWhiteMale0050United-States<=50K
228Local-gov336951Assoc-acdm12Married-civ-spouseProtective-servHusbandWhiteMale0040United-States>50K
344Private160323Some-college10Married-civ-spouseMachine-op-inspctHusbandBlackMale7688040United-States>50K
418?103497Some-college10Never-married?Own-childWhiteFemale0030United-States<=50K
\n", "
" ], "text/plain": [ " age workclass fnlwgt education educational-num marital-status \\\n", "0 25 Private 226802 11th 7 Never-married \n", "1 38 Private 89814 HS-grad 9 Married-civ-spouse \n", "2 28 Local-gov 336951 Assoc-acdm 12 Married-civ-spouse \n", "3 44 Private 160323 Some-college 10 Married-civ-spouse \n", "4 18 ? 103497 Some-college 10 Never-married \n", "\n", " occupation relationship race gender capital-gain capital-loss \\\n", "0 Machine-op-inspct Own-child Black Male 0 0 \n", "1 Farming-fishing Husband White Male 0 0 \n", "2 Protective-serv Husband White Male 0 0 \n", "3 Machine-op-inspct Husband Black Male 7688 0 \n", "4 ? Own-child White Female 0 0 \n", "\n", " hours-per-week native-country income \n", "0 40 United-States <=50K \n", "1 50 United-States <=50K \n", "2 40 United-States >50K \n", "3 40 United-States >50K \n", "4 30 United-States <=50K " ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv('data/adult/adult.csv.zip')\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ageworkclassfnlwgteducationeducational_nummarital_statusoccupationrelationshipracegendercapital_gaincapital_losshours_per_weeknative_countryincome_label
025Private22680211th7Never-marriedMachine-op-inspctOwn-childBlackMale0040United-States0
138Private89814HS-grad9Married-civ-spouseFarming-fishingHusbandWhiteMale0050United-States0
228Local-gov336951Assoc-acdm12Married-civ-spouseProtective-servHusbandWhiteMale0040United-States1
344Private160323Some-college10Married-civ-spouseMachine-op-inspctHusbandBlackMale7688040United-States1
418?103497Some-college10Never-married?Own-childWhiteFemale0030United-States0
\n", "
" ], "text/plain": [ " age workclass fnlwgt education educational_num marital_status \\\n", "0 25 Private 226802 11th 7 Never-married \n", "1 38 Private 89814 HS-grad 9 Married-civ-spouse \n", "2 28 Local-gov 336951 Assoc-acdm 12 Married-civ-spouse \n", "3 44 Private 160323 Some-college 10 Married-civ-spouse \n", "4 18 ? 103497 Some-college 10 Never-married \n", "\n", " occupation relationship race gender capital_gain capital_loss \\\n", "0 Machine-op-inspct Own-child Black Male 0 0 \n", "1 Farming-fishing Husband White Male 0 0 \n", "2 Protective-serv Husband White Male 0 0 \n", "3 Machine-op-inspct Husband Black Male 7688 0 \n", "4 ? Own-child White Female 0 0 \n", "\n", " hours_per_week native_country income_label \n", "0 40 United-States 0 \n", "1 50 United-States 0 \n", "2 40 United-States 1 \n", "3 40 United-States 1 \n", "4 30 United-States 0 " ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# For convenience, we'll replace '-' with '_'\n", "df.columns = [c.replace(\"-\", \"_\") for c in df.columns]\n", "# binary target\n", "df['income_label'] = (df[\"income\"].apply(lambda x: \">50K\" in x)).astype(int)\n", "df.drop('income', axis=1, inplace=True)\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "wide_cols = ['education', 'relationship','workclass','occupation','native_country','gender']\n", "crossed_cols = [('education', 'occupation'), ('native_country', 'occupation')]\n", "cat_embed_cols = [('education',16), ('relationship',8), ('workclass',16), ('occupation',16),('native_country',16)]\n", "continuous_cols = [\"age\",\"hours_per_week\"]\n", "target_col = 'income_label'" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# TARGET\n", "target = df[target_col].values\n", "\n", "# wide\n", "wide_preprocessor = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)\n", "X_wide = wide_preprocessor.fit_transform(df)\n", "\n", "# deeptabular\n", "tab_preprocessor = TabPreprocessor(embed_cols=cat_embed_cols, continuous_cols=continuous_cols)\n", "X_tab = tab_preprocessor.fit_transform(df)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)\n", "deeptabular = TabMlp(mlp_hidden_dims=[64,32], \n", " column_idx=tab_preprocessor.column_idx,\n", " embed_input=tab_preprocessor.embeddings_input,\n", " continuous_cols=continuous_cols)\n", "model = WideDeep(wide=wide, deeptabular=deeptabular)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Custom metric\n", "\n", "Let's say we want to use our own accuracy metric (again, this is already available in the package, but I will pass it here as a custom loss for illustration purposes). \n", "\n", "This could be done as:" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "from pytorch_widedeep.metrics import Metric" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "class Accuracy(Metric):\n", " def __init__(self, top_k: int = 1):\n", " super(Accuracy, self).__init__()\n", "\n", " self.top_k = top_k\n", " self.correct_count = 0\n", " self.total_count = 0\n", "\n", " # metric name needs to be defined\n", " self._name = \"acc\"\n", "\n", " def reset(self):\n", " self.correct_count = 0\n", " self.total_count = 0\n", "\n", " def __call__(self, y_pred: Tensor, y_true: Tensor) -> float:\n", " num_classes = y_pred.size(1)\n", "\n", " if num_classes == 1:\n", " y_pred = y_pred.round()\n", " y_true = y_true\n", " elif num_classes > 1:\n", " y_pred = y_pred.topk(self.top_k, 1)[1]\n", " y_true = y_true.view(-1, 1).expand_as(y_pred)\n", "\n", " self.correct_count += y_pred.eq(y_true).sum().item()\n", " self.total_count += len(y_pred)\n", " accuracy = float(self.correct_count) / float(self.total_count)\n", " return accuracy" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Custom Callback\n", "\n", "Let's code a callback that records the current epoch at the beginning and the end of each epoch (silly, but you know, this is just an example)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "# have a look to the class\n", "from pytorch_widedeep.callbacks import Callback" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "class SillyCallback(Callback):\n", " def on_train_begin(self, logs = None):\n", " # recordings will be the trainer object attributes\n", " self.trainer.silly_callback = {}\n", "\n", " self.trainer.silly_callback['beginning'] = []\n", " self.trainer.silly_callback['end'] = []\n", "\n", " def on_epoch_begin(self, epoch, logs=None):\n", " self.trainer.silly_callback['beginning'].append(epoch+1)\n", "\n", " def on_epoch_end(self, epoch, logs=None):\n", " self.trainer.silly_callback['end'].append(epoch+1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and now, as usual:" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "trainer = Trainer(model, objective='binary', metrics=[Accuracy], callbacks=[SillyCallback])" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "epoch 1: 100%|██████████| 611/611 [00:06<00:00, 92.66it/s, loss=0.397, metrics={'acc': 0.8112}]\n", "valid: 100%|██████████| 153/153 [00:00<00:00, 163.83it/s, loss=0.364, metrics={'acc': 0.8154}]\n", "epoch 2: 100%|██████████| 611/611 [00:06<00:00, 93.55it/s, loss=0.363, metrics={'acc': 0.8289}]\n", "valid: 100%|██████████| 153/153 [00:00<00:00, 167.03it/s, loss=0.356, metrics={'acc': 0.8304}]\n", "epoch 3: 100%|██████████| 611/611 [00:06<00:00, 93.64it/s, loss=0.357, metrics={'acc': 0.8325}]\n", "valid: 100%|██████████| 153/153 [00:00<00:00, 164.14it/s, loss=0.35, metrics={'acc': 0.834}] \n", "epoch 4: 100%|██████████| 611/611 [00:06<00:00, 92.59it/s, loss=0.352, metrics={'acc': 0.8347}]\n", "valid: 100%|██████████| 153/153 [00:00<00:00, 171.96it/s, loss=0.349, metrics={'acc': 0.8359}]\n", "epoch 5: 100%|██████████| 611/611 [00:06<00:00, 93.63it/s, loss=0.348, metrics={'acc': 0.8361}]\n", "valid: 100%|██████████| 153/153 [00:00<00:00, 162.69it/s, loss=0.347, metrics={'acc': 0.8372}]\n" ] } ], "source": [ "trainer.fit(X_wide=X_wide, X_tab=X_tab, target=target, n_epochs=5, batch_size=64, val_split=0.2)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'beginning': [1, 2, 3, 4, 5], 'end': [1, 2, 3, 4, 5]}" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.silly_callback" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.7" } }, "nbformat": 4, "nbformat_minor": 4 }