提交 8727ce66 编写于 作者: J jrzaurin

ModelCheckpoint saves best epoch as well. Added dropout option for tabnet....

ModelCheckpoint saves best epoch as well. Added dropout option for tabnet. Adjusted RAdam for new signatures. Adjusted the training so it can take ReduceLROnPlateau. Also so that it automatically restores the best weights after training
上级 fd54eb71
......@@ -437,6 +437,7 @@ class ModelCheckpoint(Callback):
)
)
self.best = current
self.best_epoch = epoch
torch.save(self.model.state_dict(), filepath)
if self.max_save > 0:
if len(self.old_files) == self.max_save:
......@@ -590,9 +591,7 @@ class EarlyStopping(Callback):
self.trainer.early_stop = True
if self.restore_best_weights:
if self.verbose > 0:
print(
"Restoring model weights from the end of " "the best epoch"
)
print("Restoring model weights from the end of the best epoch")
self.model.load_state_dict(self.state_dict)
def on_train_end(self, logs: Optional[Dict] = None):
......
......@@ -24,7 +24,7 @@ def dense_layer(
bn: bool,
linear_first: bool,
):
# This is bascially the LinBnDrop class at the fastai library
# This is basically the LinBnDrop class at the fastai library
act_fn = _get_activation_fn(activation)
layers = [nn.BatchNorm1d(out if linear_first else inp)] if bn else []
if p != 0:
......
......@@ -55,6 +55,7 @@ class GLU_Layer(nn.Module):
self,
input_dim: int,
output_dim: int,
dropout: float,
fc: nn.Module = None,
ghost_bn: bool = True,
virtual_batch_size: int = 128,
......@@ -75,8 +76,10 @@ class GLU_Layer(nn.Module):
else:
self.bn = nn.BatchNorm1d(2 * output_dim, momentum=momentum)
self.dp = nn.Dropout(dropout)
def forward(self, X: Tensor) -> Tensor:
return F.glu(self.bn(self.fc(X)))
return self.dp(F.glu(self.bn(self.fc(X))))
class GLU_Block(nn.Module):
......@@ -84,6 +87,7 @@ class GLU_Block(nn.Module):
self,
input_dim: int,
output_dim: int,
dropout: float,
n_glu: int = 2,
first: bool = False,
shared_layers: nn.ModuleList = None,
......@@ -114,6 +118,7 @@ class GLU_Block(nn.Module):
GLU_Layer(
glu_dim[i],
glu_dim[i + 1],
dropout,
fc=fc,
ghost_bn=ghost_bn,
virtual_batch_size=virtual_batch_size,
......@@ -142,6 +147,7 @@ class FeatTransformer(nn.Module):
self,
input_dim: int,
output_dim: int,
dropout: float,
shared_layers: nn.ModuleList,
n_glu_step_dependent: int,
ghost_bn=True,
......@@ -159,6 +165,7 @@ class FeatTransformer(nn.Module):
self.shared = GLU_Block(
input_dim,
output_dim,
dropout,
n_glu=len(shared_layers),
first=True,
shared_layers=shared_layers,
......@@ -166,7 +173,12 @@ class FeatTransformer(nn.Module):
)
self.step_dependent = GLU_Block(
output_dim, output_dim, n_glu=n_glu_step_dependent, first=False, **params
output_dim,
output_dim,
dropout,
n_glu=n_glu_step_dependent,
first=False,
**params
)
def forward(self, X: Tensor) -> Tensor:
......@@ -216,6 +228,7 @@ class TabNetEncoder(nn.Module):
n_steps: int = 3,
step_dim: int = 8,
attn_dim: int = 8,
dropout: float = 0.0,
n_glu_step_dependent: int = 2,
n_glu_shared: int = 2,
ghost_bn: bool = True,
......@@ -258,6 +271,7 @@ class TabNetEncoder(nn.Module):
self.initial_splitter = FeatTransformer(
input_dim,
step_dim + attn_dim,
dropout,
shared_layers,
n_glu_step_dependent,
**params
......@@ -269,6 +283,7 @@ class TabNetEncoder(nn.Module):
feat_transformer = FeatTransformer(
input_dim,
step_dim + attn_dim,
dropout,
shared_layers,
n_glu_step_dependent,
**params
......@@ -406,6 +421,7 @@ class TabNet(nn.Module):
n_steps: int = 3,
step_dim: int = 8,
attn_dim: int = 8,
dropout: float = 0.0,
n_glu_step_dependent: int = 2,
n_glu_shared: int = 2,
ghost_bn: bool = True,
......@@ -517,9 +533,10 @@ class TabNet(nn.Module):
self.embed_and_cont_dim = self.embed_and_cont.output_dim
self.tabnet_encoder = TabNetEncoder(
self.embed_and_cont.output_dim,
n_steps,
step_dim,
attn_dim,
n_steps,
dropout,
n_glu_step_dependent,
n_glu_shared,
ghost_bn,
......
"""
Copied and pasted, with great gratitude from: https://github.com/LiyuanLucasLiu/RAdam/blob/master/radam/radam.py
Adapted to avoid warnings regarding the signatures of the inplace operations in new versions of Pytorch
"""
import math
......@@ -50,7 +52,7 @@ class RAdam(Optimizer):
def __setstate__(self, state):
super(RAdam, self).__setstate__(state)
def step(self, closure=None):
def step(self, closure=None): # noqa: C901
loss = None
if closure is not None:
......@@ -80,8 +82,8 @@ class RAdam(Optimizer):
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
state["step"] += 1
buffered = group["buffer"][int(state["step"] % 10)]
......@@ -115,17 +117,18 @@ class RAdam(Optimizer):
if N_sma >= 5:
if group["weight_decay"] != 0:
p_data_fp32.add_(
-group["weight_decay"] * group["lr"], p_data_fp32
p_data_fp32,
alpha=-group["weight_decay"] * group["lr"],
)
denom = exp_avg_sq.sqrt().add_(group["eps"])
p_data_fp32.addcdiv_(-step_size * group["lr"], exp_avg, denom)
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group["lr"])
p.data.copy_(p_data_fp32)
elif step_size > 0:
if group["weight_decay"] != 0:
p_data_fp32.add_(
-group["weight_decay"] * group["lr"], p_data_fp32
)
p_data_fp32.add_(-step_size * group["lr"], exp_avg)
p_data_fp32.add_(exp_avg, alpha=-step_size * group["lr"])
p.data.copy_(p_data_fp32)
return loss
......@@ -55,7 +55,7 @@ class Trainer:
custom_loss_function: Optional[Module] = None,
optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
lr_schedulers: Optional[Union[LRScheduler, Dict[str, LRScheduler]]] = None,
reduce_on: Optional[str] = "loss",
reducelronplateau_criterion: Optional[str] = "loss",
initializers: Optional[Union[Initializer, Dict[str, Initializer]]] = None,
transforms: Optional[List[Transforms]] = None,
callbacks: Optional[List[Callback]] = None,
......@@ -163,12 +163,8 @@ class Trainer:
- float indicating the weight of the minority class in binary classification
problems (e.g. 9.)
- a list or tuple with weights for the different classes in multiclass
classification problems (e.g. [1., 2., 3.]). The weights do
not neccesarily need to be normalised. If your loss function
uses reduction='mean', the loss will be normalized by the sum
of the corresponding weights for each element. If you are
using reduction='none', you would have to take care of the
normalization yourself. See `this discussion
classification problems (e.g. [1., 2., 3.]). The weights do not
need to be normalised. See `this discussion
<https://discuss.pytorch.org/t/passing-the-weights-to-crossentropyloss-correctly/14731/10>`_.
lambda_sparse: float. default=1e-3
Tabnet sparse regularization factor
......@@ -259,7 +255,7 @@ class Trainer:
)
self.reducelronplateau = False
self.reduce_on = reduce_on
self.reducelronplateau_criterion = reducelronplateau_criterion
if isinstance(lr_schedulers, Dict):
for _, scheduler in lr_schedulers.items():
if isinstance(scheduler, ReduceLROnPlateau):
......@@ -598,10 +594,12 @@ class Trainer:
epoch_logs = save_epoch_logs(epoch_logs, val_loss, val_score, "val")
if self.reducelronplateau:
if self.reduce_on == "loss":
if self.reducelronplateau_criterion == "loss":
on_epoch_end_metric = val_loss
else:
on_epoch_end_metric = val_score[self.reduce_on]
on_epoch_end_metric = val_score[
self.reducelronplateau_criterion
]
self.callback_container.on_epoch_end(epoch, epoch_logs, on_epoch_end_metric)
......@@ -612,6 +610,7 @@ class Trainer:
self.callback_container.on_train_end(epoch_logs)
if self.model.is_tabnet:
self._compute_feature_importance(train_loader)
self._restore_best_weights()
self.model.train()
def predict( # type: ignore[return]
......@@ -852,6 +851,7 @@ class Trainer:
filename where the feature importances will be stored
"""
# TO DO: ask advide on saving strategy
if not os.path.exists(path):
os.makedirs(path)
......@@ -866,6 +866,37 @@ class Trainer:
with open(feature_importance_fname, "w") as fi:
json.dump(self.feature_importance, fi)
def _restore_best_weights(self):
already_restored = any(
[
(
callback.__class__.__name__ == "EarlyStopping"
and callback.restore_best_weights
)
for callback in self.callback_container.callbacks
]
)
if already_restored:
pass
else:
for callback in self.callback_container.callbacks:
if callback.__class__.__name__ == "ModelCheckpoint":
if callback.save_best_only:
filepath = "{}_{}.p".format(
callback.filepath, callback.best_epoch + 1
)
if self.verbose:
print(
f"Model weights restored to best epoch: {callback.best_epoch + 1}"
)
self.model.load_state_dict(torch.load(filepath))
else:
if self.verbose:
print(
"Model weights after training corresponds to the those of the "
"final epoch which might not be the best performing weights"
)
def _finetune(
self,
loader: DataLoader,
......@@ -1065,7 +1096,7 @@ class Trainer:
def _set_loss_fn(self, objective, class_weight, custom_loss_function, alpha, gamma):
if class_weight is not None:
class_weight = torch.tensor(class_weight)
class_weight = torch.tensor(class_weight).to(device)
if custom_loss_function is not None:
return custom_loss_function
elif self.method != "regression" and "focal_loss" not in objective:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册