airbnb_script_multiclass.py 2.0 KB
Newer Older
J
jrzaurin 已提交
1 2 3 4 5 6
import numpy as np
import pandas as pd
import torch

from pytorch_widedeep.preprocessing import WidePreprocessor, DeepPreprocessor
from pytorch_widedeep.models import Wide, DeepDense, WideDeep
J
jrzaurin 已提交
7
from pytorch_widedeep.metrics import CategoricalAccuracy
J
jrzaurin 已提交
8 9 10

use_cuda = torch.cuda.is_available()

J
jrzaurin 已提交
11
if __name__ == "__main__":
J
jrzaurin 已提交
12

J
jrzaurin 已提交
13
    df = pd.read_csv("data/airbnb/airbnb_sample.csv")
J
jrzaurin 已提交
14

J
jrzaurin 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
    crossed_cols = (["property_type", "room_type"],)
    already_dummies = [c for c in df.columns if "amenity" in c] + ["has_house_rules"]
    wide_cols = [
        "is_location_exact",
        "property_type",
        "room_type",
        "host_gender",
        "instant_bookable",
    ] + already_dummies
    cat_embed_cols = [(c, 16) for c in df.columns if "catg" in c] + [
        ("neighbourhood_cleansed", 64),
        ("cancellation_policy", 16),
    ]
    continuous_cols = ["latitude", "longitude", "security_deposit", "extra_people"]
    already_standard = ["latitude", "longitude"]
    df["yield_cat"] = pd.cut(df["yield"], bins=[0.2, 65, 163, 600], labels=[0, 1, 2])
    df.drop("yield", axis=1, inplace=True)
    target = "yield_cat"
J
jrzaurin 已提交
33 34 35 36

    target = np.array(df[target].values)
    prepare_wide = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)
    X_wide = prepare_wide.fit_transform(df)
J
jrzaurin 已提交
37 38 39
    prepare_deep = DeepPreprocessor(
        embed_cols=cat_embed_cols, continuous_cols=continuous_cols
    )
J
jrzaurin 已提交
40
    X_deep = prepare_deep.fit_transform(df)
J
jrzaurin 已提交
41
    wide = Wide(wide_dim=X_wide.shape[1], output_dim=3)
J
jrzaurin 已提交
42
    deepdense = DeepDense(
J
jrzaurin 已提交
43 44
        hidden_layers=[64, 32],
        dropout=[0.2, 0.2],
J
jrzaurin 已提交
45 46
        deep_column_idx=prepare_deep.deep_column_idx,
        embed_input=prepare_deep.embeddings_input,
J
jrzaurin 已提交
47 48
        continuous_cols=continuous_cols,
    )
J
jrzaurin 已提交
49
    model = WideDeep(wide=wide, deepdense=deepdense, output_dim=3)
J
jrzaurin 已提交
50
    model.compile(method="multiclass", metrics=[CategoricalAccuracy])
J
jrzaurin 已提交
51 52 53 54 55 56
    model.fit(
        X_wide=X_wide,
        X_deep=X_deep,
        target=target,
        n_epochs=1,
        batch_size=32,
J
jrzaurin 已提交
57 58
        val_split=0.2,
    )