quick_start.rst 3.3 KB
Newer Older
1 2 3
Quick Start
***********

4
This is an example of a binary classification with the `adult census
5
<https://www.kaggle.com/wenruliu/adult-income-dataset?select=adult.csv>`__
6 7
dataset using a combination of a ``Wide`` and ``DeepDense`` model with
defaults settings.
8 9 10 11 12


Read and split the dataset
--------------------------

13 14
The following code snippet is not directly related to ``pytorch-widedeep``.

15 16 17 18 19 20 21 22 23 24 25
.. code-block:: python

    import pandas as pd
    from sklearn.model_selection import train_test_split

    df = pd.read_csv("data/adult/adult.csv.zip")
    df["income_label"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
    df.drop("income", axis=1, inplace=True)
    df_train, df_test = train_test_split(df, test_size=0.2, stratify=df.income_label)


26 27
Prepare the wide and deep columns
---------------------------------
28 29 30

.. code-block:: python

31 32
    from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor
    from pytorch_widedeep.models import Wide, DeepDense, WideDeep
33
    from pytorch_widedeep.metrics import Accuracy
34

35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
    # prepare wide, crossed, embedding and continuous columns
    wide_cols = [
        "education",
        "relationship",
        "workclass",
        "occupation",
        "native-country",
        "gender",
    ]
    cross_cols = [("education", "occupation"), ("native-country", "occupation")]
    embed_cols = [
        ("education", 16),
        ("workclass", 16),
        ("occupation", 16),
        ("native-country", 32),
    ]
    cont_cols = ["age", "hours-per-week"]
    target_col = "income_label"

    # target
    target = df_train[target_col].values

57 58 59 60 61 62

Preprocessing and model components definition
---------------------------------------------

.. code-block:: python

63 64 65
    # wide
    preprocess_wide = WidePreprocessor(wide_cols=wide_cols, crossed_cols=cross_cols)
    X_wide = preprocess_wide.fit_transform(df_train)
66
    wide = Wide(wide_dim=X_wide.shape[1], pred_dim=1)
67 68

    # deepdense
69
    preprocess_deep = DensePreprocessor(embed_cols=embed_cols, continuous_cols=cont_cols)
70 71 72 73 74 75 76 77
    X_deep = preprocess_deep.fit_transform(df_train)
    deepdense = DeepDense(
        hidden_layers=[64, 32],
        deep_column_idx=preprocess_deep.deep_column_idx,
        embed_input=preprocess_deep.embeddings_input,
        continuous_cols=cont_cols,
    )

78 79 80 81 82 83

Build, compile, fit and predict
-------------------------------

.. code-block:: python

84 85
    # build, compile and fit
    model = WideDeep(wide=wide, deepdense=deepdense)
86
    model.compile(method="binary", metrics=[Accuracy])
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
    model.fit(
        X_wide=X_wide,
        X_deep=X_deep,
        target=target,
        n_epochs=5,
        batch_size=256,
        val_split=0.1,
    )

    # predict
    X_wide_te = preprocess_wide.transform(df_test)
    X_deep_te = preprocess_deep.transform(df_test)
    preds = model.predict(X_wide=X_wide_te, X_deep=X_deep_te)

Of course, one can do much more, such as using different initializations,
optimizers or learning rate schedulers for each component of the overall
model. Adding FC-Heads to the Text and Image components. Using the Focal Loss,
warming up individual components before joined training, etc. See the
`examples
<https://github.com/jrzaurin/pytorch-widedeep/tree/build_docs/examples>`__
directory for a better understanding of the content of the package and its
functionalities.