提交 695f66e5 编写于 作者: J Javier Rodriguez Zaurin

First commit towards a self supervised trainer

上级 04544924
......@@ -818,3 +818,76 @@ class HuberLoss(nn.Module):
if lds_weight is not None:
loss *= lds_weight
return torch.mean(loss)
class InfoNCELoss(nn.Module):
def __init__(self, temperature: float = 0.1, reduction: str = "mean"):
super(InfoNCELoss, self).__init__()
self.temperature = temperature
self.reduction = reduction
def forward(self, z: Tensor, z_: Tensor) -> Tensor:
norm_z = F.normalize(z, dim=-1)
norm_z_ = F.normalize(z_, dim=-1)
logits = (norm_z @ norm_z_.t()) / self.temperature
logits_ = (norm_z_ @ norm_z.t()) / self.temperature
# the target/labels are the entries on the diagonal
target = torch.arange(len(norm_z), device=norm_z.device)
loss = F.cross_entropy(logits, target, reduction=self.reduction)
loss_ = F.cross_entropy(logits_, target, reduction=self.reduction)
return (loss + loss_) / 2.0
class ContrastiveLoss(nn.Module):
def __init__(self, temperature: float = 0.1, reduction: str = "mean"):
super(ContrastiveLoss, self).__init__()
self.temperature = temperature
self.reduction = reduction
def forward(self, z: Tensor, z_: Tensor) -> Tensor:
norm_z = F.normalize(z, dim=-1)
norm_z_ = F.normalize(z_, dim=-1)
logits = (norm_z @ norm_z_.t()) / self.temperature
return torch.diagonal(-1 * logits).add_(1).pow_(2).sum()
class DenoisingLoss(nn.Module):
def __init__(self, lambda_cont: float, lambda_cat: float, reduction: str = "mean"):
super(DenoisingLoss, self).__init__()
self.lambda_cont = lambda_cont
self.lambda_cat = lambda_cat
self.reduction = reduction
def forward(
self,
x_cont: Optional[Tensor],
x_cat: Optional[Tensor],
x_cont_: Optional[Tensor],
x_cat_: Optional[Tensor],
) -> Tensor:
loss_cont = (
F.MSELoss(x_cont, x_cont_, reduction=self.reduction)
if x_cont is not None
else 0
)
loss_cat = (
F.cross_entropy(x_cat, x_cat_, reduction=self.reduction)
if x_cat is not None
else 0
)
return self.lambda_cont * loss_cont + self.lambda_cat * loss_cat
......@@ -174,20 +174,6 @@ class SAINT(BaseTabularModelWithAttention):
input_dim=input_dim,
)
self.column_idx = column_idx
self.cat_embed_input = cat_embed_input
self.cat_embed_dropout = cat_embed_dropout
self.full_embed_dropout = full_embed_dropout
self.shared_embed = shared_embed
self.add_shared_embed = add_shared_embed
self.frac_shared_embed = frac_shared_embed
self.continuous_cols = continuous_cols
self.cont_embed_activation = cont_embed_activation
self.cont_embed_dropout = cont_embed_dropout
self.cont_norm_layer = cont_norm_layer
self.input_dim = input_dim
self.use_qkv_bias = use_qkv_bias
self.n_heads = n_heads
self.n_blocks = n_blocks
......
import numpy as np
import torch
from pytorch_widedeep.wdtypes import * # noqa: F403
def mix_up(p: Tensor, lam: float = 0.8) -> Tensor:
batch_size = p.size()[0]
rand_idx = torch.randperm(batch_size).to(p.device)
p_ = lam * p + (1 - lam) * p[rand_idx, :]
return p_
def cut_mix(x: Tensor, lam: float = 0.8) -> Tensor:
batch_size = x.size()[0]
mask = torch.from_numpy(np.random.choice(2, (x.shape), p=[lam, 1 - lam])).to(
x.device
)
rand_idx = torch.randperm(batch_size).to(x.device)
x_ = x[rand_idx].clone()
x_[mask == 0] = x[mask == 0]
return x_
import os
import sys
from abc import ABC, abstractmethod
import torch
from pytorch_widedeep.wdtypes import * # noqa: F403
from pytorch_widedeep.callbacks import (
History,
Callback,
CallbackContainer,
LRShedulerCallback,
)
class BaseSelfSupervisedTrainer(ABC):
def __init__(
self,
model,
optimizer: Optimizer,
lr_scheduler: LRScheduler,
callbacks: Optional[List[Callback]],
verbose: int,
seed: int,
**kwargs,
):
self.device, self.num_workers = self._set_device_and_num_workers(**kwargs)
self.model = model
self.early_stop = False
self.verbose = verbose
self.seed = seed
self.optimizer = (
optimizer
if optimizer is not None
else torch.optim.AdamW(self.model.parameters())
)
self.lr_scheduler = lr_scheduler
self._set_lr_scheduler_running_params(lr_scheduler, **kwargs)
self._set_callbacks(callbacks)
self.model.to(self.device)
@abstractmethod
def fit(self):
pass
def _set_lr_scheduler_running_params(self, lr_scheduler, **kwargs):
if lr_scheduler is not None:
self.cyclic_lr = "cycl" in lr_scheduler.__class__.__name__.lower()
else:
self.cyclic_lr = False
def _set_callbacks(self, callbacks):
self.callbacks: List = [History(), LRShedulerCallback()]
if callbacks is not None:
for callback in callbacks:
if isinstance(callback, type):
callback = callback()
self.callbacks.append(callback)
self.callback_container = CallbackContainer(self.callbacks)
self.callback_container.set_model(self.model)
self.callback_container.set_trainer(self)
@staticmethod
def _set_device_and_num_workers(**kwargs):
default_num_workers = (
0
if sys.platform == "darwin" and sys.version_info.minor > 7
else os.cpu_count()
)
default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = kwargs.get("device", default_device)
num_workers = kwargs.get("num_workers", default_num_workers)
return device, num_workers
import numpy as np
import torch
from tqdm import trange
from torch.utils.data import DataLoader, TensorDataset
from pytorch_widedeep.losses import InfoNCELoss, DenoisingLoss, ContrastiveLoss
from pytorch_widedeep.wdtypes import * # noqa: F403
from pytorch_widedeep.callbacks import Callback
from pytorch_widedeep.training._trainer_utils import save_epoch_logs
from pytorch_widedeep.self_supervised_training._augmentations import (
mix_up,
cut_mix,
)
from pytorch_widedeep.self_supervised_training._base_self_supervised_trainer import (
BaseSelfSupervisedTrainer,
)
class SelfSupervisedTrainer(BaseSelfSupervisedTrainer):
def __init__(
self,
model,
optimizer: Optimizer,
lr_scheduler: LRScheduler,
callbacks: Optional[List[Callback]],
verbose: int,
seed: int,
**kwargs,
):
super().__init__(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
callbacks=callbacks,
verbose=verbose,
seed=seed,
**kwargs,
)
def pretrain(
self,
X_tab: np.ndarray,
n_epochs: int = 1,
batch_size: int = 32,
):
self.batch_size = batch_size
pretrain_loader = DataLoader(
dataset=TensorDataset(torch.from_numpy(X_tab)),
batch_size=batch_size,
num_workers=self.num_workers,
)
train_steps = len(pretrain_loader)
self.callback_container.on_train_begin(
{
"batch_size": batch_size,
"train_steps": train_steps,
"n_epochs": n_epochs,
}
)
for epoch in range(n_epochs):
epoch_logs: Dict[str, float] = {}
self.callback_container.on_epoch_begin(epoch, logs=epoch_logs)
self.train_running_loss = 0.0
with trange(train_steps, disable=self.verbose != 1) as t:
for batch_idx, (X, y) in zip(t, pretrain_loader):
t.set_description("epoch %i" % (epoch + 1))
train_loss = self._pretrain_step(X, batch_idx)
self.callback_container.on_batch_end(batch=batch_idx)
epoch_logs = save_epoch_logs(epoch_logs, train_loss, None, "train")
self.callback_container.on_epoch_end(epoch, epoch_logs)
if self.early_stop:
self.callback_container.on_train_end(epoch_logs)
break
self.callback_container.on_train_end(epoch_logs)
self._restore_best_weights()
self.model.train()
def _pretrain_step(self, X_tab: np.ndarray, batch_idx: int):
X = X_tab.to(self.device)
self.optimizer.zero_grad()
encoded = self.model(X)
cut_mixed_x = cut_mix(X)
cut_mixed_x_embed = self.model.cat_and_cont_embed(cut_mixed_x)
cut_mixed_x_mixed_up_embed = mix_up(cut_mixed_x_embed)
encoded_ = self.model.transformer_blks(cut_mixed_x_mixed_up_embed)
proj_encoded = self.projection_head1(encoded)
proj_encoded_ = (
self.projection_head2(encoded_)
if self.projection_head2 is not None
else self.projection_head1(encoded_)
)
loss = self.loss(proj_encoded, proj_encoded_)
loss.backward()
self.optimizer.step()
self.train_running_loss += loss.item()
avg_loss = self.train_running_loss / (batch_idx + 1)
return avg_loss
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册