提交 0fdbfdce 编写于 作者: J Javier

first step towards adding an example on how to reproduce a kaggle notebook...

first step towards adding an example on how to reproduce a kaggle notebook (details in the code) with this library
上级 4af577ae
......@@ -21,6 +21,7 @@ tmp_dir/
weights/
pretrained_weights/
model_weights/
prepared_data/
# Unit Tests/Coverage
.coverage
......
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from torch import nn, cat, mean
from scipy.sparse import coo_matrix
device = "cuda" if torch.cuda.is_available() else "cpu"
save_path = Path("prepared_data")
def get_coo_indexes(lil):
rows = []
cols = []
for i, el in enumerate(lil):
if type(el) != list:
el = [el]
for j in el:
rows.append(i)
cols.append(j)
return rows, cols
def get_sparse_features(series, shape):
coo_indexes = get_coo_indexes(series.tolist())
sparse_df = coo_matrix(
(np.ones(len(coo_indexes[0])), (coo_indexes[0], coo_indexes[1])), shape=shape
)
return sparse_df
def sparse_to_idx(data, pad_idx=-1):
indexes = data.nonzero()
indexes_df = pd.DataFrame()
indexes_df["rows"] = indexes[0]
indexes_df["cols"] = indexes[1]
mdf = indexes_df.groupby("rows").apply(lambda x: x["cols"].tolist())
max_len = mdf.apply(lambda x: len(x)).max()
return mdf.apply(lambda x: pd.Series(x + [pad_idx] * (max_len - len(x)))).values
def idx_to_sparse(idx, sparse_dim):
sparse = np.zeros(sparse_dim)
sparse[int(idx)] = 1
return pd.Series(sparse, dtype=int)
def process_cats_as_kaggle_notebook(df):
df["gender"] = (df["gender"] == "M").astype(int)
df = pd.concat(
[
df.drop("occupation", axis=1),
pd.get_dummies(df["occupation"]).astype(int),
],
axis=1,
)
df.drop("other", axis=1, inplace=True)
df.drop("zip_code", axis=1, inplace=True)
return df
id_cols = ["user_id", "movie_id"]
df_train = pd.read_pickle(save_path / "df_train.pkl")
df_valid = pd.read_pickle(save_path / "df_valid.pkl")
df_test = pd.read_pickle(save_path / "df_test.pkl")
df_test = pd.concat([df_valid, df_test], ignore_index=True)
df_train = process_cats_as_kaggle_notebook(df_train)
df_test = process_cats_as_kaggle_notebook(df_test)
# here is another caveat, using all dataset to build 'train_movies_watched'
# when in reality one should use only the training
max_movie_index = max(df_train.movie_id.max(), df_test.movie_id.max())
X_train = df_train.drop(id_cols + ["prev_movies", "target"], axis=1)
y_train = df_train.target.values
train_movies_watched = get_sparse_features(
df_train["prev_movies"], (len(df_train), max_movie_index + 1)
)
X_test = df_test.drop(id_cols + ["prev_movies", "target"], axis=1)
y_test = df_test.target.values
test_movies_watched = get_sparse_features(
df_test["prev_movies"], (len(df_test), max_movie_index + 1)
)
PAD_IDX = 0
X_train_tensor = torch.Tensor(X_train.fillna(0).values).to(device)
train_movies_watched_tensor = (
torch.sparse_coo_tensor(
indices=train_movies_watched.nonzero(),
values=[1] * len(train_movies_watched.nonzero()[0]),
size=train_movies_watched.shape,
)
.to_dense()
.to(device)
)
movies_train_sequences = (
torch.Tensor(
sparse_to_idx(train_movies_watched, pad_idx=PAD_IDX),
)
.long()
.to(device)
)
target_train = torch.Tensor(y_train).long().to(device)
X_test_tensor = torch.Tensor(X_test.fillna(0).values).to(device)
test_movies_watched_tensor = (
torch.sparse_coo_tensor(
indices=test_movies_watched.nonzero(),
values=[1] * len(test_movies_watched.nonzero()[0]),
size=test_movies_watched.shape,
)
.to_dense()
.to(device)
)
movies_test_sequences = (
torch.Tensor(
sparse_to_idx(test_movies_watched, pad_idx=PAD_IDX),
)
.long()
.to(device)
)
target_test = torch.Tensor(y_test).long().to(device)
class WideAndDeep(nn.Module):
def __init__(
self,
continious_feature_shape, # number of continious features
embed_size, # size of embedding for binary features
embed_dict_len, # number of unique binary features
pad_idx, # padding index
):
super(WideAndDeep, self).__init__()
self.embed = nn.Embedding(embed_dict_len, embed_size, padding_idx=pad_idx)
self.linear_relu_stack = nn.Sequential(
nn.Linear(embed_size + continious_feature_shape, 1024),
nn.ReLU(),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
)
self.head = nn.Sequential(
nn.Linear(embed_dict_len + 256, embed_dict_len),
)
def forward(self, continious, binary, binary_idx):
# get embeddings for sequence of indexes
binary_embed = self.embed(binary_idx)
binary_embed_mean = mean(binary_embed, dim=1)
# get logits for "deep" part: continious features + binary embeddings
deep_logits = self.linear_relu_stack(
cat((continious, binary_embed_mean), dim=1)
)
# get final softmax logits for "deep" part and raw binary features
total_logits = self.head(cat((deep_logits, binary), dim=1))
return total_logits
model = WideAndDeep(X_train.shape[1], 16, max_movie_index + 1, PAD_IDX).to(device)
print(model)
EPOCHS = 10
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for t in range(EPOCHS):
model.train()
pred_train = model(
X_train_tensor, train_movies_watched_tensor, movies_train_sequences
)
loss_train = loss_fn(pred_train, target_train)
# Backpropagation
optimizer.zero_grad()
loss_train.backward()
optimizer.step()
model.eval()
with torch.no_grad():
pred_test = model(
X_test_tensor, test_movies_watched_tensor, movies_test_sequences
)
loss_test = loss_fn(pred_test, target_test)
print(f"Epoch {t}")
print(f"Train loss: {loss_train:>7f}")
print(f"Test loss: {loss_test:>7f}")
from pathlib import Path
import pandas as pd
from sklearn.model_selection import train_test_split
raw_data_path = Path("~/ml_projects/wide_deep_learning_for_recsys/ml-100k")
save_path = Path("prepared_data")
if not save_path.exists():
save_path.mkdir(parents=True, exist_ok=True)
# Load the Ratings/Interaction (triplets (user, item, rating) plus timestamp)
data = pd.read_csv(raw_data_path / "u.data", sep="\t", header=None)
data.columns = ["user_id", "movie_id", "rating", "timestamp"]
# Load the User features
users = pd.read_csv(raw_data_path / "u.user", sep="|", encoding="latin-1", header=None)
users.columns = ["user_id", "age", "gender", "occupation", "zip_code"]
# Load the Item features
items = pd.read_csv(raw_data_path / "u.item", sep="|", encoding="latin-1", header=None)
items.columns = [
"movie_id",
"movie_title",
"release_date",
"video_release_date",
"IMDb_URL",
"unknown",
"Action",
"Adventure",
"Animation",
"Children's",
"Comedy",
"Crime",
"Documentary",
"Drama",
"Fantasy",
"Film-Noir",
"Horror",
"Musical",
"Mystery",
"Romance",
"Sci-Fi",
"Thriller",
"War",
"Western",
]
list_of_genres = pd.read_csv(
raw_data_path / "u.genre", sep="|", header=None, usecols=[0]
)[0].tolist()
list_of_genres
# adding a column with the number of movies watched per user
dataset = data.sort_values(["user_id", "timestamp"]).reset_index(drop=True)
dataset["one"] = 1
dataset["num_watched"] = dataset.groupby("user_id")["one"].cumsum()
dataset.drop("one", axis=1, inplace=True)
# adding a column with the mean rating at a point in time per user
dataset["mean_rate"] = (
dataset.groupby("user_id")["rating"].cumsum() / dataset["num_watched"]
)
# In this particular exercise the problem is formulating as predicting the
# next movie that will be watched (in consequence the last interactions will be discarded)
dataset["target"] = dataset.groupby("user_id")["movie_id"].shift(-1)
# Here the author builds the sequences
dataset["prev_movies"] = dataset["movie_id"].apply(lambda x: str(x))
dataset["prev_movies"] = (
dataset.groupby("user_id")["prev_movies"]
.apply(lambda x: (x + " ").cumsum().str.strip())
.reset_index(drop=True)
)
dataset["prev_movies"] = dataset["prev_movies"].apply(lambda x: x.split())
# Adding a genre_rate as the mean of all movies rated for a given genre per
# user
dataset = dataset.merge(items[["movie_id"] + list_of_genres], on="movie_id", how="left")
for genre in list_of_genres:
dataset[f"{genre}_rate"] = dataset[genre] * dataset["rating"]
dataset[genre] = dataset.groupby("user_id")[genre].cumsum()
dataset[f"{genre}_rate"] = (
dataset.groupby("user_id")[f"{genre}_rate"].cumsum() / dataset[genre]
)
dataset[list_of_genres] = dataset[list_of_genres].apply(
lambda x: x / dataset["num_watched"]
)
# Adding user features
dataset = dataset.merge(users, on="user_id", how="left")
# Again, we use the same settings as those in the Kaggle notebook,
# but 'COLD_START_TRESH' is pretty aggressive
COLD_START_TRESH = 5
filtred_data = dataset[
(dataset["num_watched"] >= COLD_START_TRESH) & ~(dataset["target"].isna())
].sort_values("timestamp")
train_data, _test_data = train_test_split(filtred_data, test_size=0.2, shuffle=False)
valid_data, test_data = train_test_split(_test_data, test_size=0.5, shuffle=False)
cols_to_drop = [
"rating",
"timestamp",
"num_watched",
]
df_train = train_data.drop(cols_to_drop, axis=1)
df_valid = valid_data.drop(cols_to_drop, axis=1)
df_test = test_data.drop(cols_to_drop, axis=1)
df_train.to_pickle(save_path / "df_train.pkl")
df_valid.to_pickle(save_path / "df_valid.pkl")
df_test.to_pickle(save_path / "df_test.pkl")
......@@ -3,7 +3,7 @@ import math
import torch
from torch import nn
from pytorch_widedeep.wdtypes import Tensor
from pytorch_widedeep.wdtypes import Union, Tensor
class Wide(nn.Module):
......@@ -38,17 +38,25 @@ class Wide(nn.Module):
>>> out = wide(X)
"""
def __init__(self, input_dim: int, pred_dim: int = 1):
def __init__(
self, input_dim: int, already_one_hot: bool = False, pred_dim: int = 1
):
super(Wide, self).__init__()
self.input_dim = input_dim
self.already_one_hot = already_one_hot
self.pred_dim = pred_dim
# Embeddings: val + 1 because 0 is reserved for padding/unseen cateogories.
self.wide_linear = nn.Embedding(input_dim + 1, pred_dim, padding_idx=0)
# (Sum(Embedding) + bias) is equivalent to (OneHotVector + Linear)
self.bias = nn.Parameter(torch.zeros(pred_dim))
self._reset_parameters()
if self.already_one_hot:
self.wide_linear: Union[nn.Linear, nn.Embedding] = nn.Linear(
input_dim, pred_dim
)
else:
# Embeddings: val + 1 because 0 is reserved for padding/unseen cateogories.
self.wide_linear = nn.Embedding(input_dim + 1, pred_dim, padding_idx=0)
# (Sum(Embedding) + bias) is equivalent to (OneHotVector + Linear)
self.bias = nn.Parameter(torch.zeros(pred_dim))
self._reset_parameters()
def _reset_parameters(self) -> None:
r"""initialize Embedding and bias like nn.Linear. See [original
......@@ -60,7 +68,10 @@ class Wide(nn.Module):
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, X: Tensor) -> Tensor:
r"""Forward pass. Simply connecting the Embedding layer with the ouput
r"""Forward pass. Simply connecting the Embedding/Linear layer with the ouput
neuron(s)"""
out = self.wide_linear(X.long()).sum(dim=1) + self.bias
if self.already_one_hot:
out = self.wide_linear(X)
else:
out = self.wide_linear(X.long()).sum(dim=1) + self.bias
return out
from pytorch_widedeep.models.text.basic_rnn import BasicRNN
from pytorch_widedeep.models.text.attentive_rnn import AttentiveRNN
from pytorch_widedeep.models.text.basic_transformer import Transformer
from pytorch_widedeep.models.text.stacked_attentive_rnn import (
StackedAttentiveRNN,
)
import math
import torch
from torch import nn
from pytorch_widedeep.wdtypes import Union, Tensor, Optional
from pytorch_widedeep.models.tabular.transformers._encoders import (
TransformerEncoder,
)
class Transformer(nn.Module):
def __init__(
self,
vocab_size: int,
embed_dim: int,
n_heads: int,
n_blocks: int,
attn_dropout: float = 0.1,
ff_dropout: float = 0.1,
activation: str = "gelu",
ff_dim_multiplier: float = 1.0,
*,
with_pos_encoding: bool = True,
pos_encoding_dropout: float = 0.1,
seq_length: Optional[int] = None,
pos_encoder: Optional[nn.Module] = None,
):
super().__init__()
self.embed_dim = embed_dim
self.n_heads = n_heads
self.n_blocks = n_blocks
self.attn_dropout = attn_dropout
self.ff_dropout = ff_dropout
self.activation = activation
self.ff_dim_multiplier = ff_dim_multiplier
self.with_pos_encoding = with_pos_encoding
self.pos_encoding_dropout = pos_encoding_dropout
self.seq_length = seq_length
self.embedding = nn.Embedding(vocab_size, embed_dim)
if with_pos_encoding:
if pos_encoder is not None:
self.pos_encoder: Union[
nn.Module, nn.Identity, PositionalEncoding
] = self.pos_encoder
else:
assert (
seq_length is not None
), "If positional encoding is used 'seq_length' must be passed to the model"
self.pos_encoder = PositionalEncoding(
embed_dim, pos_encoding_dropout, seq_length
)
else:
self.pos_encoder = nn.Identity()
self.encoder = nn.Sequential()
for i in range(n_blocks):
self.encoder.add_module(
"transformer_block" + str(i),
TransformerEncoder(
embed_dim,
n_heads,
False, # use_qkv_bias
attn_dropout,
ff_dropout,
activation,
),
)
def forward(self, X: Tensor) -> Tensor:
x = self.embedding(X)
x = self.pos_encoder(x)
out = self.encoder(x)
return out
class PositionalEncoding(nn.Module):
def __init__(self, embed_dim: int, dropout: float, seq_length: int):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(seq_length).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)
)
pe = torch.zeros(seq_length, 1, embed_dim)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe)
def forward(self, X: Tensor) -> Tensor:
return self.dropout(X + self.pe)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册