diff --git a/VERSION b/VERSION index afaf360d37fb71bcfa8cc082882f910ac2628bda..c650d5af2e4bb1f21c2f517cdae22fdbf34fde4b 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.0.0 \ No newline at end of file +0.4.8 \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index d30ef90a794db3c6881a5e181185a2c67a50db07..c79914f38c5a6bc7dc8d41059c35bd9cbe418c6e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -26,7 +26,7 @@ sys.path.insert(0, PACKAGEDIR) # -- Project information ----------------------------------------------------- project = "pytorch-widedeep" -copyright = "2020, Javier Rodriguez Zaurin" +copyright = "2021, Javier Rodriguez Zaurin" author = "Javier Rodriguez Zaurin" # # The full version, including alpha/beta/rc tags diff --git a/docs/examples.rst b/docs/examples.rst index 3f6e5b1b18144e35a810332ff3146688a87e5c21..f56a9646016014405802719026fec45cf63dcca7 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -12,3 +12,4 @@ them to address different problems * `Binary Classification with varying parameters `__ * `Regression with Images and Text `__ * `FineTune routines `__ +* `Custom Components `__ diff --git a/docs/index.rst b/docs/index.rst index c374ff06d1298aa5332510da0efe5481a818310f..cde440fedf431fed2c8ef7721b2c5865e9e49dcf 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -53,7 +53,7 @@ within the faded-pink rectangle are concatenated. Note that it is not possible to illustrate the number of possible architectures and components available in ``pytorch-widedeep`` in one Figure. Therefore, for more details on possible architectures (and more) please, read -this documentation, or seethe Examples folders in the repo. +this documentation, or see the Examples folders in the repo. In math terms, and following the notation in the `paper `_, the expression for the architecture diff --git a/docs/installation.rst b/docs/installation.rst index d517a744f493e1218aae10d0972f29424d4afc22..7c89ba623cb7a0517a1a49831a629cc6331540e6 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -40,3 +40,4 @@ Dependencies * torch * torchvision * einops +* wrapt \ No newline at end of file diff --git a/docs/utils/index.rst b/docs/utils/index.rst index 56929d279f96a101e834079eef8a23cb73acfaf8..9579f0a3cb42b195f06cd0d895198ea3a2fb604c 100644 --- a/docs/utils/index.rst +++ b/docs/utils/index.rst @@ -2,10 +2,10 @@ The ``utils`` module ==================== These are a series utilities that might be useful for a number of -preprocessing tasks. All the classes and functions discussed here are -available directly from the ``utils`` module. For example, the -``LabelEncoder`` within the ``deeptabular_utils`` submodule can be imported -as: +preprocessing tasks, even not directly related to ``pytorch-widedeep``. All +the classes and functions discussed here are available directly from the +``utils`` module. For example, the ``LabelEncoder`` within the +``deeptabular_utils`` submodule can be imported as: .. code-block:: python diff --git a/examples/07_Custom_Components.ipynb b/examples/07_Custom_Components.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..9209f047ec9531f6483d2a480f03bea2ad66aceb --- /dev/null +++ b/examples/07_Custom_Components.ipynb @@ -0,0 +1,1788 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Custom components\n", + "\n", + "As I mentioned earlier in the example notebooks, and also in the `README`, it is possible to customise almost every component in `pytorch-widedeep`.\n", + "\n", + "Let's now go through a couple of simple example to illustrate how that could be done. \n", + "\n", + "First let's load and process the data \"as usual\", let's start with a regression and the [airbnb](http://insideairbnb.com/get-the-data.html) dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import os\n", + "import torch\n", + "\n", + "from pytorch_widedeep import Trainer\n", + "from pytorch_widedeep.preprocessing import WidePreprocessor, TabPreprocessor, TextPreprocessor, ImagePreprocessor\n", + "from pytorch_widedeep.models import Wide, TabMlp, TabResnet, DeepText, DeepImage, WideDeep\n", + "from pytorch_widedeep.losses import RMSELoss\n", + "from pytorch_widedeep.initializers import *\n", + "from pytorch_widedeep.callbacks import *" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idhost_iddescriptionhost_listings_counthost_identity_verifiedneighbourhood_cleansedlatitudelongitudeis_location_exactproperty_typeroom_typeaccommodatesbathroomsbedroomsbedsguests_includedminimum_nightsinstant_bookablecancellation_policyhas_house_ruleshost_genderaccommodates_catgguests_included_catgminimum_nights_catghost_listings_count_catgbathrooms_catgbedrooms_catgbeds_catgamenity_24-hour_check-inamenity__toiletamenity_accessible-height_bedamenity_accessible-height_toiletamenity_air_conditioningamenity_air_purifieramenity_alfresco_bathtubamenity_amazon_echoamenity_baby_bathamenity_baby_monitoramenity_babysitter_recommendationsamenity_balconyamenity_bath_towelamenity_bathroom_essentialsamenity_bathtubamenity_bathtub_with_bath_chairamenity_bbq_grillamenity_beach_essentialsamenity_beach_viewamenity_beachfrontamenity_bed_linensamenity_bedroom_comforts...amenity_roll-in_showeramenity_room-darkening_shadesamenity_safety_cardamenity_saunaamenity_self_check-inamenity_shampooamenity_shared_gymamenity_shared_hot_tubamenity_shared_poolamenity_shower_chairamenity_single_level_homeamenity_ski-in_ski-outamenity_smart_lockamenity_smart_tvamenity_smoke_detectoramenity_smoking_allowedamenity_soaking_tubamenity_sound_systemamenity_stair_gatesamenity_stand_alone_steam_showeramenity_standing_valetamenity_steam_ovenamenity_stoveamenity_suitable_for_eventsamenity_sun_loungersamenity_table_corner_guardsamenity_tennis_courtamenity_terraceamenity_toilet_paperamenity_touchless_faucetsamenity_tvamenity_walk-in_showeramenity_warming_draweramenity_washeramenity_washer_dryeramenity_waterfrontamenity_well-lit_path_to_entranceamenity_wheelchair_accessibleamenity_wide_clearance_to_showeramenity_wide_doorway_to_guest_bathroomamenity_wide_entranceamenity_wide_entrance_for_guestsamenity_wide_entrywayamenity_wide_hallwaysamenity_wifiamenity_window_guardsamenity_wine_coolersecurity_depositextra_peopleyield
013913.jpg54730My bright double bedroom with a large window has a relaxed feeling! It comfortably fits one or t...4.0fIslington51.56802-0.11121tapartmentprivate_room21.01.00.011fmoderate1female21131100011000000100011000010...11000100000000110000001000000010010000011000100100.015.012.00
115400.jpg60302Lots of windows and light. St Luke's Gardens are at the end of the block, and the river not too...1.0tKensington and Chelsea51.48796-0.16898tapartmententire_home/apt21.01.01.023fstrict_14_with_grace_period1female22311111000100000000000000000...00000100000000100000000000000010010000000000100150.00.0109.50
217402.jpg67564Open from June 2018 after a 3-year break, we are delighted to be welcoming guests again to this ...19.0tWestminster51.52098-0.14002tapartmententire_home/apt62.03.03.043tstrict_14_with_grace_period1female33332330000000000000000000010...00001100000000100000001000000010010000000000100350.010.0149.65
324328.jpg41759Artist house, bright high ceiling rooms, private parking and a communal garden in a conservation...2.0tWandsworth51.47298-0.16376totherentire_home/apt21.51.01.0230fmoderate1male22322111000000000000000000000...00001100000000100000000000000010010000000000100250.00.0215.60
425023.jpg102813Large, all comforts, 2-bed flat; first floor; lift; pretty communal gardens + off-street parking...1.0fWandsworth51.44687-0.21874tapartmententire_home/apt41.02.02.024fmoderate1female32311220000000000000000000000...00000000000000100000000000000010010000000000100250.011.079.35
\n", + "

5 rows × 223 columns

\n", + "
" + ], + "text/plain": [ + " id host_id \\\n", + "0 13913.jpg 54730 \n", + "1 15400.jpg 60302 \n", + "2 17402.jpg 67564 \n", + "3 24328.jpg 41759 \n", + "4 25023.jpg 102813 \n", + "\n", + " description \\\n", + "0 My bright double bedroom with a large window has a relaxed feeling! It comfortably fits one or t... \n", + "1 Lots of windows and light. St Luke's Gardens are at the end of the block, and the river not too... \n", + "2 Open from June 2018 after a 3-year break, we are delighted to be welcoming guests again to this ... \n", + "3 Artist house, bright high ceiling rooms, private parking and a communal garden in a conservation... \n", + "4 Large, all comforts, 2-bed flat; first floor; lift; pretty communal gardens + off-street parking... \n", + "\n", + " host_listings_count host_identity_verified neighbourhood_cleansed \\\n", + "0 4.0 f Islington \n", + "1 1.0 t Kensington and Chelsea \n", + "2 19.0 t Westminster \n", + "3 2.0 t Wandsworth \n", + "4 1.0 f Wandsworth \n", + "\n", + " latitude longitude is_location_exact property_type room_type \\\n", + "0 51.56802 -0.11121 t apartment private_room \n", + "1 51.48796 -0.16898 t apartment entire_home/apt \n", + "2 51.52098 -0.14002 t apartment entire_home/apt \n", + "3 51.47298 -0.16376 t other entire_home/apt \n", + "4 51.44687 -0.21874 t apartment entire_home/apt \n", + "\n", + " accommodates bathrooms bedrooms beds guests_included minimum_nights \\\n", + "0 2 1.0 1.0 0.0 1 1 \n", + "1 2 1.0 1.0 1.0 2 3 \n", + "2 6 2.0 3.0 3.0 4 3 \n", + "3 2 1.5 1.0 1.0 2 30 \n", + "4 4 1.0 2.0 2.0 2 4 \n", + "\n", + " instant_bookable cancellation_policy has_house_rules host_gender \\\n", + "0 f moderate 1 female \n", + "1 f strict_14_with_grace_period 1 female \n", + "2 t strict_14_with_grace_period 1 female \n", + "3 f moderate 1 male \n", + "4 f moderate 1 female \n", + "\n", + " accommodates_catg guests_included_catg minimum_nights_catg \\\n", + "0 2 1 1 \n", + "1 2 2 3 \n", + "2 3 3 3 \n", + "3 2 2 3 \n", + "4 3 2 3 \n", + "\n", + " host_listings_count_catg bathrooms_catg bedrooms_catg beds_catg \\\n", + "0 3 1 1 0 \n", + "1 1 1 1 1 \n", + "2 3 2 3 3 \n", + "3 2 2 1 1 \n", + "4 1 1 2 2 \n", + "\n", + " amenity_24-hour_check-in amenity__toilet amenity_accessible-height_bed \\\n", + "0 0 0 1 \n", + "1 1 0 0 \n", + "2 0 0 0 \n", + "3 1 0 0 \n", + "4 0 0 0 \n", + "\n", + " amenity_accessible-height_toilet amenity_air_conditioning \\\n", + "0 1 0 \n", + "1 0 1 \n", + "2 0 0 \n", + "3 0 0 \n", + "4 0 0 \n", + "\n", + " amenity_air_purifier amenity_alfresco_bathtub amenity_amazon_echo \\\n", + "0 0 0 0 \n", + "1 0 0 0 \n", + "2 0 0 0 \n", + "3 0 0 0 \n", + "4 0 0 0 \n", + "\n", + " amenity_baby_bath amenity_baby_monitor \\\n", + "0 0 0 \n", + "1 0 0 \n", + "2 0 0 \n", + "3 0 0 \n", + "4 0 0 \n", + "\n", + " amenity_babysitter_recommendations amenity_balcony amenity_bath_towel \\\n", + "0 1 0 0 \n", + "1 0 0 0 \n", + "2 0 0 0 \n", + "3 0 0 0 \n", + "4 0 0 0 \n", + "\n", + " amenity_bathroom_essentials amenity_bathtub \\\n", + "0 0 1 \n", + "1 0 0 \n", + "2 0 0 \n", + "3 0 0 \n", + "4 0 0 \n", + "\n", + " amenity_bathtub_with_bath_chair amenity_bbq_grill \\\n", + "0 1 0 \n", + "1 0 0 \n", + "2 0 0 \n", + "3 0 0 \n", + "4 0 0 \n", + "\n", + " amenity_beach_essentials amenity_beach_view amenity_beachfront \\\n", + "0 0 0 0 \n", + "1 0 0 0 \n", + "2 0 0 0 \n", + "3 0 0 0 \n", + "4 0 0 0 \n", + "\n", + " amenity_bed_linens amenity_bedroom_comforts ... amenity_roll-in_shower \\\n", + "0 1 0 ... 1 \n", + "1 0 0 ... 0 \n", + "2 1 0 ... 0 \n", + "3 0 0 ... 0 \n", + "4 0 0 ... 0 \n", + "\n", + " amenity_room-darkening_shades amenity_safety_card amenity_sauna \\\n", + "0 1 0 0 \n", + "1 0 0 0 \n", + "2 0 0 0 \n", + "3 0 0 0 \n", + "4 0 0 0 \n", + "\n", + " amenity_self_check-in amenity_shampoo amenity_shared_gym \\\n", + "0 0 1 0 \n", + "1 0 1 0 \n", + "2 1 1 0 \n", + "3 1 1 0 \n", + "4 0 0 0 \n", + "\n", + " amenity_shared_hot_tub amenity_shared_pool amenity_shower_chair \\\n", + "0 0 0 0 \n", + "1 0 0 0 \n", + "2 0 0 0 \n", + "3 0 0 0 \n", + "4 0 0 0 \n", + "\n", + " amenity_single_level_home amenity_ski-in_ski-out amenity_smart_lock \\\n", + "0 0 0 0 \n", + "1 0 0 0 \n", + "2 0 0 0 \n", + "3 0 0 0 \n", + "4 0 0 0 \n", + "\n", + " amenity_smart_tv amenity_smoke_detector amenity_smoking_allowed \\\n", + "0 0 1 1 \n", + "1 0 1 0 \n", + "2 0 1 0 \n", + "3 0 1 0 \n", + "4 0 1 0 \n", + "\n", + " amenity_soaking_tub amenity_sound_system amenity_stair_gates \\\n", + "0 0 0 0 \n", + "1 0 0 0 \n", + "2 0 0 0 \n", + "3 0 0 0 \n", + "4 0 0 0 \n", + "\n", + " amenity_stand_alone_steam_shower amenity_standing_valet \\\n", + "0 0 0 \n", + "1 0 0 \n", + "2 0 0 \n", + "3 0 0 \n", + "4 0 0 \n", + "\n", + " amenity_steam_oven amenity_stove amenity_suitable_for_events \\\n", + "0 0 1 0 \n", + "1 0 0 0 \n", + "2 0 1 0 \n", + "3 0 0 0 \n", + "4 0 0 0 \n", + "\n", + " amenity_sun_loungers amenity_table_corner_guards amenity_tennis_court \\\n", + "0 0 0 0 \n", + "1 0 0 0 \n", + "2 0 0 0 \n", + "3 0 0 0 \n", + "4 0 0 0 \n", + "\n", + " amenity_terrace amenity_toilet_paper amenity_touchless_faucets \\\n", + "0 0 0 0 \n", + "1 0 0 0 \n", + "2 0 0 0 \n", + "3 0 0 0 \n", + "4 0 0 0 \n", + "\n", + " amenity_tv amenity_walk-in_shower amenity_warming_drawer amenity_washer \\\n", + "0 1 0 0 1 \n", + "1 1 0 0 1 \n", + "2 1 0 0 1 \n", + "3 1 0 0 1 \n", + "4 1 0 0 1 \n", + "\n", + " amenity_washer_dryer amenity_waterfront \\\n", + "0 0 0 \n", + "1 0 0 \n", + "2 0 0 \n", + "3 0 0 \n", + "4 0 0 \n", + "\n", + " amenity_well-lit_path_to_entrance amenity_wheelchair_accessible \\\n", + "0 0 0 \n", + "1 0 0 \n", + "2 0 0 \n", + "3 0 0 \n", + "4 0 0 \n", + "\n", + " amenity_wide_clearance_to_shower amenity_wide_doorway_to_guest_bathroom \\\n", + "0 0 1 \n", + "1 0 0 \n", + "2 0 0 \n", + "3 0 0 \n", + "4 0 0 \n", + "\n", + " amenity_wide_entrance amenity_wide_entrance_for_guests \\\n", + "0 1 0 \n", + "1 0 0 \n", + "2 0 0 \n", + "3 0 0 \n", + "4 0 0 \n", + "\n", + " amenity_wide_entryway amenity_wide_hallways amenity_wifi \\\n", + "0 0 0 1 \n", + "1 0 0 1 \n", + "2 0 0 1 \n", + "3 0 0 1 \n", + "4 0 0 1 \n", + "\n", + " amenity_window_guards amenity_wine_cooler security_deposit extra_people \\\n", + "0 0 0 100.0 15.0 \n", + "1 0 0 150.0 0.0 \n", + "2 0 0 350.0 10.0 \n", + "3 0 0 250.0 0.0 \n", + "4 0 0 250.0 11.0 \n", + "\n", + " yield \n", + "0 12.00 \n", + "1 109.50 \n", + "2 149.65 \n", + "3 215.60 \n", + "4 79.35 \n", + "\n", + "[5 rows x 223 columns]" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = pd.read_csv('data/airbnb/airbnb_sample.csv')\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# There are a number of columns that are already binary. Therefore, no need to one hot encode them\n", + "crossed_cols = [('property_type', 'room_type')]\n", + "already_dummies = [c for c in df.columns if 'amenity' in c] + ['has_house_rules']\n", + "wide_cols = ['is_location_exact', 'property_type', 'room_type', 'host_gender',\n", + "'instant_bookable'] + already_dummies\n", + "cat_embed_cols = [(c, 16) for c in df.columns if 'catg' in c] + \\\n", + " [('neighbourhood_cleansed', 64), ('cancellation_policy', 16)]\n", + "continuous_cols = ['latitude', 'longitude', 'security_deposit', 'extra_people']\n", + "# it does not make sense to standarised Latitude and Longitude\n", + "already_standard = ['latitude', 'longitude']\n", + "# text and image colnames\n", + "text_col = 'description'\n", + "img_col = 'id'\n", + "# path to pretrained word embeddings and the images\n", + "word_vectors_path = 'data/glove.6B/glove.6B.100d.txt'\n", + "img_path = 'data/airbnb/property_picture'\n", + "# target\n", + "target_col = 'yield'" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "target = df[target_col].values" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The vocabulary contains 2192 tokens\n", + "Indexing word vectors...\n", + "Loaded 400000 word vectors\n", + "Preparing embeddings matrix...\n", + "2175 words in the vocabulary had data/glove.6B/glove.6B.100d.txt vectors and appear more than 5 times\n", + "Reading Images from data/airbnb/property_picture\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 4%|▍ | 42/1001 [00:00<00:02, 414.31it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Resizing\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1001/1001 [00:02<00:00, 387.88it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Computing normalisation metrics\n" + ] + } + ], + "source": [ + "wide_preprocessor = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)\n", + "X_wide = wide_preprocessor.fit_transform(df)\n", + "\n", + "tab_preprocessor = TabPreprocessor(embed_cols=cat_embed_cols, continuous_cols=continuous_cols)\n", + "X_tab = tab_preprocessor.fit_transform(df)\n", + "\n", + "text_preprocessor = TextPreprocessor(word_vectors_path=word_vectors_path, text_col=text_col)\n", + "X_text = text_preprocessor.fit_transform(df)\n", + "\n", + "image_processor = ImagePreprocessor(img_col = img_col, img_path = img_path)\n", + "X_images = image_processor.fit_transform(df)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we are ready to build a wide and deep model. Three of the four components we will use are included in this package, and they will be combined with a custom `deeptext` component. Then the fit process will run with a custom loss function.\n", + "\n", + "Let's have a look" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# Linear model\n", + "wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)\n", + "\n", + "# DeepDense: 2 Dense layers\n", + "deeptabular = TabMlp(\n", + " column_idx = tab_preprocessor.column_idx,\n", + " mlp_hidden_dims=[128,64],\n", + " mlp_dropout = 0.1,\n", + " mlp_batchnorm = True,\n", + " embed_input=tab_preprocessor.embeddings_input,\n", + " embed_dropout = 0.1,\n", + " continuous_cols = continuous_cols,\n", + " batchnorm_cont = True\n", + ")\n", + " \n", + "# Pretrained Resnet 18 (default is all but last 2 conv blocks frozen) plus a FC-Head 512->256->128\n", + "deepimage = DeepImage(pretrained=True, head_hidden_dims=[512, 256, 128])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Custom `deeptext`\n", + "\n", + "Standard Pytorch model" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "class MyDeepText(nn.Module):\n", + " def __init__(self, vocab_size, padding_idx=1, embed_dim=100, hidden_dim=64):\n", + " super(MyDeepText, self).__init__()\n", + "\n", + " # word/token embeddings\n", + " self.word_embed = nn.Embedding(\n", + " vocab_size, embed_dim, padding_idx=padding_idx\n", + " )\n", + "\n", + " # stack of RNNs\n", + " self.rnn = nn.GRU(\n", + " embed_dim,\n", + " hidden_dim,\n", + " num_layers=2,\n", + " bidirectional=True,\n", + " batch_first=True,\n", + " )\n", + "\n", + " # Remember, this must be defined. If not WideDeep will through an error\n", + " self.output_dim = hidden_dim * 2\n", + "\n", + " def forward(self, X):\n", + " embed = self.word_embed(X.long())\n", + " o, h = self.rnn(embed)\n", + " return torch.cat((h[-2], h[-1]), dim=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "mydeeptext = MyDeepText(vocab_size=len(text_preprocessor.vocab.itos))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "model = WideDeep(wide=wide, deeptabular=deeptabular, deeptext=mydeeptext, deepimage=deepimage)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Custom loss function\n", + "\n", + "Loss functions must simply inherit pytorch's `nn.Module`. For example, let's say we want to use `RMSE` (note that this is already available in the package, but I will pass it here as a custom loss for illustration purposes)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "class RMSELoss(nn.Module):\n", + " def __init__(self):\n", + " \"\"\"root mean squared error\"\"\"\n", + " super().__init__()\n", + " self.mse = nn.MSELoss()\n", + "\n", + " def forward(self, input: Tensor, target: Tensor) -> Tensor:\n", + " return torch.sqrt(self.mse(input, target))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "and now we just instantiate the ``Trainer`` as usual. Needless to say, but this runs with 1000 random observations, so loss and metric values are meaningless. This is just an example" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "trainer = Trainer(model, objective='regression', custom_loss_function=RMSELoss())" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "epoch 1: 100%|██████████| 25/25 [02:13<00:00, 5.33s/it, loss=118]\n", + "valid: 100%|██████████| 7/7 [00:15<00:00, 2.23s/it, loss=101] \n" + ] + } + ], + "source": [ + "trainer.fit(X_wide=X_wide, X_tab=X_tab, X_text=X_text, X_img=X_images,\n", + " target=target, n_epochs=1, batch_size=32, val_split=0.2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In addition to model components and loss functions, we can also use custom callbacks or custom metrics. The former need to be of type `Callback` and the latter need to be of type `Metric`. See:\n", + "\n", + "```python\n", + "pytorch-widedeep.callbacks\n", + "```\n", + "and \n", + "\n", + "```python\n", + "pytorch-widedeep.metrics\n", + "```\n", + "\n", + "For this example let me use the adult dataset. Again, we first prepare the data as usual" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageworkclassfnlwgteducationeducational-nummarital-statusoccupationrelationshipracegendercapital-gaincapital-losshours-per-weeknative-countryincome
025Private22680211th7Never-marriedMachine-op-inspctOwn-childBlackMale0040United-States<=50K
138Private89814HS-grad9Married-civ-spouseFarming-fishingHusbandWhiteMale0050United-States<=50K
228Local-gov336951Assoc-acdm12Married-civ-spouseProtective-servHusbandWhiteMale0040United-States>50K
344Private160323Some-college10Married-civ-spouseMachine-op-inspctHusbandBlackMale7688040United-States>50K
418?103497Some-college10Never-married?Own-childWhiteFemale0030United-States<=50K
\n", + "
" + ], + "text/plain": [ + " age workclass fnlwgt education educational-num marital-status \\\n", + "0 25 Private 226802 11th 7 Never-married \n", + "1 38 Private 89814 HS-grad 9 Married-civ-spouse \n", + "2 28 Local-gov 336951 Assoc-acdm 12 Married-civ-spouse \n", + "3 44 Private 160323 Some-college 10 Married-civ-spouse \n", + "4 18 ? 103497 Some-college 10 Never-married \n", + "\n", + " occupation relationship race gender capital-gain capital-loss \\\n", + "0 Machine-op-inspct Own-child Black Male 0 0 \n", + "1 Farming-fishing Husband White Male 0 0 \n", + "2 Protective-serv Husband White Male 0 0 \n", + "3 Machine-op-inspct Husband Black Male 7688 0 \n", + "4 ? Own-child White Female 0 0 \n", + "\n", + " hours-per-week native-country income \n", + "0 40 United-States <=50K \n", + "1 50 United-States <=50K \n", + "2 40 United-States >50K \n", + "3 40 United-States >50K \n", + "4 30 United-States <=50K " + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = pd.read_csv('data/adult/adult.csv.zip')\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ageworkclassfnlwgteducationeducational_nummarital_statusoccupationrelationshipracegendercapital_gaincapital_losshours_per_weeknative_countryincome_label
025Private22680211th7Never-marriedMachine-op-inspctOwn-childBlackMale0040United-States0
138Private89814HS-grad9Married-civ-spouseFarming-fishingHusbandWhiteMale0050United-States0
228Local-gov336951Assoc-acdm12Married-civ-spouseProtective-servHusbandWhiteMale0040United-States1
344Private160323Some-college10Married-civ-spouseMachine-op-inspctHusbandBlackMale7688040United-States1
418?103497Some-college10Never-married?Own-childWhiteFemale0030United-States0
\n", + "
" + ], + "text/plain": [ + " age workclass fnlwgt education educational_num marital_status \\\n", + "0 25 Private 226802 11th 7 Never-married \n", + "1 38 Private 89814 HS-grad 9 Married-civ-spouse \n", + "2 28 Local-gov 336951 Assoc-acdm 12 Married-civ-spouse \n", + "3 44 Private 160323 Some-college 10 Married-civ-spouse \n", + "4 18 ? 103497 Some-college 10 Never-married \n", + "\n", + " occupation relationship race gender capital_gain capital_loss \\\n", + "0 Machine-op-inspct Own-child Black Male 0 0 \n", + "1 Farming-fishing Husband White Male 0 0 \n", + "2 Protective-serv Husband White Male 0 0 \n", + "3 Machine-op-inspct Husband Black Male 7688 0 \n", + "4 ? Own-child White Female 0 0 \n", + "\n", + " hours_per_week native_country income_label \n", + "0 40 United-States 0 \n", + "1 50 United-States 0 \n", + "2 40 United-States 1 \n", + "3 40 United-States 1 \n", + "4 30 United-States 0 " + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# For convenience, we'll replace '-' with '_'\n", + "df.columns = [c.replace(\"-\", \"_\") for c in df.columns]\n", + "# binary target\n", + "df['income_label'] = (df[\"income\"].apply(lambda x: \">50K\" in x)).astype(int)\n", + "df.drop('income', axis=1, inplace=True)\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "wide_cols = ['education', 'relationship','workclass','occupation','native_country','gender']\n", + "crossed_cols = [('education', 'occupation'), ('native_country', 'occupation')]\n", + "cat_embed_cols = [('education',16), ('relationship',8), ('workclass',16), ('occupation',16),('native_country',16)]\n", + "continuous_cols = [\"age\",\"hours_per_week\"]\n", + "target_col = 'income_label'" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "# TARGET\n", + "target = df[target_col].values\n", + "\n", + "# wide\n", + "wide_preprocessor = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)\n", + "X_wide = wide_preprocessor.fit_transform(df)\n", + "\n", + "# deeptabular\n", + "tab_preprocessor = TabPreprocessor(embed_cols=cat_embed_cols, continuous_cols=continuous_cols)\n", + "X_tab = tab_preprocessor.fit_transform(df)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)\n", + "deeptabular = TabMlp(mlp_hidden_dims=[64,32], \n", + " column_idx=tab_preprocessor.column_idx,\n", + " embed_input=tab_preprocessor.embeddings_input,\n", + " continuous_cols=continuous_cols)\n", + "model = WideDeep(wide=wide, deeptabular=deeptabular)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Custom metric\n", + "\n", + "Let's say we want to use our own accuracy metric (again, this is already available in the package, but I will pass it here as a custom loss for illustration purposes). \n", + "\n", + "This could be done as:" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "from pytorch_widedeep.metrics import Metric" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "class Accuracy(Metric):\n", + " def __init__(self, top_k: int = 1):\n", + " super(Accuracy, self).__init__()\n", + "\n", + " self.top_k = top_k\n", + " self.correct_count = 0\n", + " self.total_count = 0\n", + "\n", + " # metric name needs to be defined\n", + " self._name = \"acc\"\n", + "\n", + " def reset(self):\n", + " self.correct_count = 0\n", + " self.total_count = 0\n", + "\n", + " def __call__(self, y_pred: Tensor, y_true: Tensor) -> float:\n", + " num_classes = y_pred.size(1)\n", + "\n", + " if num_classes == 1:\n", + " y_pred = y_pred.round()\n", + " y_true = y_true\n", + " elif num_classes > 1:\n", + " y_pred = y_pred.topk(self.top_k, 1)[1]\n", + " y_true = y_true.view(-1, 1).expand_as(y_pred)\n", + "\n", + " self.correct_count += y_pred.eq(y_true).sum().item()\n", + " self.total_count += len(y_pred)\n", + " accuracy = float(self.correct_count) / float(self.total_count)\n", + " return accuracy" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Custom Callback\n", + "\n", + "Let's code a callback that records the current epoch at the beginning and the end of each epoch (silly, but you know, this is just an example)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "# have a look to the class\n", + "from pytorch_widedeep.callbacks import Callback" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "class SillyCallback(Callback):\n", + " def on_train_begin(self, logs = None):\n", + " # recordings will be the trainer object attributes\n", + " self.trainer.silly_callback = {}\n", + "\n", + " self.trainer.silly_callback['beginning'] = []\n", + " self.trainer.silly_callback['end'] = []\n", + "\n", + " def on_epoch_begin(self, epoch, logs=None):\n", + " self.trainer.silly_callback['beginning'].append(epoch+1)\n", + "\n", + " def on_epoch_end(self, epoch, logs=None):\n", + " self.trainer.silly_callback['end'].append(epoch+1)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "and now, as usual:" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "trainer = Trainer(model, objective='binary', metrics=[Accuracy], callbacks=[SillyCallback])" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "epoch 1: 100%|██████████| 611/611 [00:06<00:00, 92.66it/s, loss=0.397, metrics={'acc': 0.8112}]\n", + "valid: 100%|██████████| 153/153 [00:00<00:00, 163.83it/s, loss=0.364, metrics={'acc': 0.8154}]\n", + "epoch 2: 100%|██████████| 611/611 [00:06<00:00, 93.55it/s, loss=0.363, metrics={'acc': 0.8289}]\n", + "valid: 100%|██████████| 153/153 [00:00<00:00, 167.03it/s, loss=0.356, metrics={'acc': 0.8304}]\n", + "epoch 3: 100%|██████████| 611/611 [00:06<00:00, 93.64it/s, loss=0.357, metrics={'acc': 0.8325}]\n", + "valid: 100%|██████████| 153/153 [00:00<00:00, 164.14it/s, loss=0.35, metrics={'acc': 0.834}] \n", + "epoch 4: 100%|██████████| 611/611 [00:06<00:00, 92.59it/s, loss=0.352, metrics={'acc': 0.8347}]\n", + "valid: 100%|██████████| 153/153 [00:00<00:00, 171.96it/s, loss=0.349, metrics={'acc': 0.8359}]\n", + "epoch 5: 100%|██████████| 611/611 [00:06<00:00, 93.63it/s, loss=0.348, metrics={'acc': 0.8361}]\n", + "valid: 100%|██████████| 153/153 [00:00<00:00, 162.69it/s, loss=0.347, metrics={'acc': 0.8372}]\n" + ] + } + ], + "source": [ + "trainer.fit(X_wide=X_wide, X_tab=X_tab, target=target, n_epochs=5, batch_size=64, val_split=0.2)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'beginning': [1, 2, 3, 4, 5], 'end': [1, 2, 3, 4, 5]}" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.silly_callback" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/pytorch_widedeep/callbacks.py b/pytorch_widedeep/callbacks.py index bc6b7c3359095607500b0fdace52e316f0069ff4..02901d7aecb9023545400e1a2dfa278da93375df 100644 --- a/pytorch_widedeep/callbacks.py +++ b/pytorch_widedeep/callbacks.py @@ -121,10 +121,8 @@ class Callback(object): class History(Callback): r"""Callback that records metrics to a ``history`` attribute. - This callback runs by default within :obj:`Trainer`. Callbacks are passed - as input parameters to the ``Trainer`` class See - :class:`pytorch_widedeep.trainer.Trainer`. Documentation is included here - for completion. + This callback runs by default within :obj:`Trainer`, therefore, should not + be passed to the ``Trainer``. Is included here just for completion. """ def on_train_begin(self, logs: Optional[Dict] = None): diff --git a/pytorch_widedeep/training/trainer.py b/pytorch_widedeep/training/trainer.py index c379b91fb3598ffb774372cbc89eb98f4c25b132..3fc928916a61560408f71a927f09db6aeb4bab37 100644 --- a/pytorch_widedeep/training/trainer.py +++ b/pytorch_widedeep/training/trainer.py @@ -93,7 +93,7 @@ class Trainer: - ``root_mean_squared_error``, aliases: ``rmse`` - ``root_mean_squared_log_error``, aliases: ``rmsle`` - custom_loss: ``nn.Module``, Optional, default = None + custom_loss_function: ``nn.Module``, Optional, default = None object of class ``nn.Module``. If none of the loss functions available suits the user, it is possible to pass a custom loss function. See for example @@ -101,6 +101,11 @@ class Trainer: structure of the object or the `Examples `_ folder in the repo. + + .. note:: If ``custom_loss_function`` is not None, ``objective`` must be + 'binary', 'multiclass' or 'regression', consistent with the loss + function + optimizers: ``Optimzer`` or Dict, Optional, default= ``AdamW`` - An instance of Pytorch's ``Optimizer`` object (e.g. :obj:`torch.optim.Adam()`) or - a dictionary where there keys are the model components (i.e. @@ -222,7 +227,7 @@ class Trainer: "regression", ]: raise ValueError( - "If 'custom_loss_function' is not None, 'objective' might be 'binary' " + "If 'custom_loss_function' is not None, 'objective' must be 'binary' " "'multiclass' or 'regression', consistent with the loss function" ) diff --git a/pytorch_widedeep/version.py b/pytorch_widedeep/version.py index 5becc17c04a9e3ad1c2a15f53252b7bb5a7517e7..a3a9bd54437bb4d3bd18395b411d25a11093f300 100644 --- a/pytorch_widedeep/version.py +++ b/pytorch_widedeep/version.py @@ -1 +1 @@ -__version__ = "1.0.0" +__version__ = "0.4.8"