airbnb_script_multiclass.py 2.1 KB
Newer Older
J
jrzaurin 已提交
1 2
import numpy as np
import torch
3
import pandas as pd
J
jrzaurin 已提交
4

5
from pytorch_widedeep.models import Wide, WideDeep, TabMlp
6
from pytorch_widedeep.metrics import F1Score, Accuracy
7
from pytorch_widedeep.preprocessing import TabPreprocessor, WidePreprocessor
J
jrzaurin 已提交
8 9 10

use_cuda = torch.cuda.is_available()

J
jrzaurin 已提交
11
if __name__ == "__main__":
J
jrzaurin 已提交
12

J
jrzaurin 已提交
13
    df = pd.read_csv("data/airbnb/airbnb_sample.csv")
J
jrzaurin 已提交
14

15
    crossed_cols = [("property_type", "room_type")]
J
jrzaurin 已提交
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
    already_dummies = [c for c in df.columns if "amenity" in c] + ["has_house_rules"]
    wide_cols = [
        "is_location_exact",
        "property_type",
        "room_type",
        "host_gender",
        "instant_bookable",
    ] + already_dummies
    cat_embed_cols = [(c, 16) for c in df.columns if "catg" in c] + [
        ("neighbourhood_cleansed", 64),
        ("cancellation_policy", 16),
    ]
    continuous_cols = ["latitude", "longitude", "security_deposit", "extra_people"]
    already_standard = ["latitude", "longitude"]
    df["yield_cat"] = pd.cut(df["yield"], bins=[0.2, 65, 163, 600], labels=[0, 1, 2])
    df.drop("yield", axis=1, inplace=True)
    target = "yield_cat"
J
jrzaurin 已提交
33 34 35 36

    target = np.array(df[target].values)
    prepare_wide = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)
    X_wide = prepare_wide.fit_transform(df)
J
Javier 已提交
37

38 39
    prepare_deep = TabPreprocessor(
        embed_cols=cat_embed_cols, continuous_cols=continuous_cols  # type: ignore[arg-type]
J
jrzaurin 已提交
40
    )
J
jrzaurin 已提交
41
    X_deep = prepare_deep.fit_transform(df)
42 43

    wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=3)
44 45
    deepdense = TabMlp(
        mlp_hidden_dims=[64, 32],
J
jrzaurin 已提交
46
        dropout=[0.2, 0.2],
J
jrzaurin 已提交
47 48
        deep_column_idx=prepare_deep.deep_column_idx,
        embed_input=prepare_deep.embeddings_input,
J
jrzaurin 已提交
49 50
        continuous_cols=continuous_cols,
    )
51
    model = WideDeep(wide=wide, deeptabular=deepdense, pred_dim=3)
52 53 54 55
    optimizer = torch.optim.Adam(model.parameters(), lr=0.03)
    model.compile(
        method="multiclass", metrics=[Accuracy, F1Score], optimizers=optimizer
    )
J
Javier 已提交
56

J
jrzaurin 已提交
57 58
    model.fit(
        X_wide=X_wide,
59
        X_tab=X_deep,
J
jrzaurin 已提交
60 61 62
        target=target,
        n_epochs=1,
        batch_size=32,
J
jrzaurin 已提交
63
        val_split=0.2,
J
jrzaurin 已提交
64
    )