提交 fd54eb71 编写于 作者: J jrzaurin

Modified the callbacks to accomodate the neccessity of the ReduceLROnPlateau...

Modified the callbacks to accomodate the neccessity of the ReduceLROnPlateau scheduler. Replace weight by pos_weight in BCEWithLogitsLoss. Added on_eval_begin method to reset metrics
上级 d3a4453c
......@@ -9,6 +9,7 @@ import warnings
import numpy as np
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pytorch_widedeep.wdtypes import * # noqa: F403
......@@ -17,6 +18,14 @@ def _get_current_time():
return datetime.datetime.now().strftime("%B %d, %Y - %I:%M%p")
def _is_metric(monitor: str):
# We assume no one will use f3 or more
if any([s in monitor for s in ["acc", "prec", "rec", "fscore", "f1", "f2"]]):
return True
else:
return False
class CallbackContainer(object):
"""
Container holding a list of callbacks.
......@@ -52,10 +61,12 @@ class CallbackContainer(object):
for callback in self.callbacks:
callback.on_epoch_begin(epoch, logs)
def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None):
def on_epoch_end(
self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None
):
logs = logs or {}
for callback in self.callbacks:
callback.on_epoch_end(epoch, logs)
callback.on_epoch_end(epoch, logs, metric)
def on_batch_begin(self, batch: int, logs: Optional[Dict] = None):
logs = logs or {}
......@@ -73,6 +84,12 @@ class CallbackContainer(object):
for callback in self.callbacks:
callback.on_train_begin(logs)
def on_eval_begin(self, logs: Optional[Dict] = None):
# at the moment only used to reset metrics before eval
logs = logs or {}
for callback in self.callbacks:
callback.on_eval_begin(logs)
def on_train_end(self, logs: Optional[Dict] = None):
logs = logs or {}
# logs['final_loss'] = self.model.history.epoch_losses[-1],
......@@ -102,7 +119,9 @@ class Callback(object):
def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None):
pass
def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None):
def on_epoch_end(
self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None
):
pass
def on_batch_begin(self, batch: int, logs: Optional[Dict] = None):
......@@ -114,6 +133,10 @@ class Callback(object):
def on_train_begin(self, logs: Optional[Dict] = None):
pass
def on_eval_begin(self, logs: Optional[Dict] = None):
# at the moment only used to reset metrics before eval
pass
def on_train_end(self, logs: Optional[Dict] = None):
pass
......@@ -128,7 +151,9 @@ class History(Callback):
def on_train_begin(self, logs: Optional[Dict] = None):
self.trainer.history = {}
def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None):
def on_epoch_end(
self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None
):
logs = logs or {}
for k, v in logs.items():
self.trainer.history.setdefault(k, []).append(v)
......@@ -153,7 +178,9 @@ class LRShedulerCallback(Callback):
elif self.trainer.cyclic_lr:
self.trainer.lr_scheduler.step()
def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None):
def on_epoch_end(
self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None
):
if self.trainer.lr_scheduler is not None:
if self._multiple_scheduler():
for (
......@@ -161,9 +188,15 @@ class LRShedulerCallback(Callback):
scheduler,
) in self.trainer.lr_scheduler._schedulers.items():
if not self._is_cyclic(model_name):
scheduler.step()
if isinstance(scheduler, ReduceLROnPlateau):
scheduler.step(metric)
else:
scheduler.step()
elif not self.trainer.cyclic_lr:
self.trainer.lr_scheduler.step()
if isinstance(self.trainer.lr_scheduler, ReduceLROnPlateau):
self.trainer.lr_scheduler.step(metric)
else:
self.trainer.lr_scheduler.step()
def _multiple_scheduler(self):
return self.trainer.lr_scheduler.__class__.__name__ == "MultipleLRScheduler"
......@@ -224,7 +257,9 @@ class LRHistory(Callback):
elif self.trainer.cyclic_lr:
self._save_group_lr(self.trainer.optimizer)
def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None):
def on_epoch_end(
self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None
):
if epoch != (self.n_epochs - 1) and self.trainer.lr_scheduler is not None:
if self._multiple_scheduler():
self._save_group_lr_mulitple_scheduler(step_location="on_epoch_end")
......@@ -293,10 +328,8 @@ class ModelCheckpoint(Callback):
be added. e.g. ``filepath="path/to/output_weights/weights_out"`` And
the saved files in that directory will be named: ``weights_out_1.pt,
weights_out_2.pt, ...``
monitor: str, default="val_loss"
quantity to monitor. :obj:`ModelCheckpoint` will infer if this is a
loss (i.e. contains the str `'loss'`) or a metric (i.e. contains the
str `'acc'` or starts with `'fmeasure'`).
monitor: str, default="loss"
quantity to monitor. Typically 'val_loss' or metric name (e.g. 'val_acc')
verbose:int, default=0,
verbosity mode
save_best_only: bool, default=False,
......@@ -305,8 +338,8 @@ class ModelCheckpoint(Callback):
mode: str, default="auto",
If ``save_best_only=True``, the decision to overwrite the current save
file is made based on either the maximization or the minimization of
the monitored quantity. For `'val_acc'`, this should be `'max'`, for
`'val_loss'` this should be `'min'`, etc. In `'auto'` mode, the
the monitored quantity. For `'acc'`, this should be `'max'`, for
`'loss'` this should be `'min'`, etc. In `'auto'` mode, the
direction is automatically inferred from the name of the monitored
quantity.
period: int, default=1,
......@@ -366,14 +399,16 @@ class ModelCheckpoint(Callback):
self.monitor_op = np.greater
self.best = -np.Inf
else:
if "acc" in self.monitor or self.monitor.startswith("fmeasure"):
if _is_metric(self.monitor):
self.monitor_op = np.greater
self.best = -np.Inf
else:
self.monitor_op = np.less
self.best = np.Inf
def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None): # noqa: C901
def on_epoch_end( # noqa: C901
self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None
):
logs = logs or {}
self.epochs_since_last_save += 1
if self.epochs_since_last_save >= self.period:
......@@ -453,7 +488,7 @@ class EarlyStopping(Callback):
Parameters
-----------
monitor: str, default='val_loss'.
Quantity to be monitored.
Quantity to monitor. Typically 'val_loss' or metric name (e.g. 'val_acc')
min_delta: float, default=0.
minimum change in the monitored quantity to qualify as an
improvement, i.e. an absolute change of less than min_delta, will
......@@ -517,7 +552,7 @@ class EarlyStopping(Callback):
elif self.mode == "max":
self.monitor_op = np.greater
else:
if "acc" in self.monitor:
if _is_metric(self.monitor):
self.monitor_op = np.greater
else:
self.monitor_op = np.less
......@@ -536,7 +571,9 @@ class EarlyStopping(Callback):
else:
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None):
def on_epoch_end(
self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None
):
current = self.get_monitor_value(logs)
if current is None:
return
......
......@@ -46,6 +46,9 @@ class MetricCallback(Callback):
def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None):
self.container.reset()
def on_eval_begin(self, logs: Optional[Dict] = None):
self.container.reset()
class Accuracy(Metric):
def __init__(self, top_k: int = 1):
......
......@@ -8,6 +8,7 @@ import torch.nn.functional as F
from tqdm import trange
from scipy.sparse import csc_matrix
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pytorch_widedeep.metrics import Metric, MetricCallback, MultipleMetrics
from pytorch_widedeep.wdtypes import * # noqa: F403
......@@ -54,6 +55,7 @@ class Trainer:
custom_loss_function: Optional[Module] = None,
optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
lr_schedulers: Optional[Union[LRScheduler, Dict[str, LRScheduler]]] = None,
reduce_on: Optional[str] = "loss",
initializers: Optional[Union[Initializer, Dict[str, Initializer]]] = None,
transforms: Optional[List[Transforms]] = None,
callbacks: Optional[List[Callback]] = None,
......@@ -256,6 +258,15 @@ class Trainer:
"'multiclass' or 'regression', consistent with the loss function"
)
self.reducelronplateau = False
self.reduce_on = reduce_on
if isinstance(lr_schedulers, Dict):
for _, scheduler in lr_schedulers.items():
if isinstance(scheduler, ReduceLROnPlateau):
self.reducelronplateau = True
elif isinstance(lr_schedulers, ReduceLROnPlateau):
self.reducelronplateau = True
if isinstance(model, str):
self.model = torch.load(model)
else:
......@@ -309,7 +320,6 @@ class Trainer:
n_epochs: int = 1,
validation_freq: int = 1,
batch_size: int = 32,
patience: int = 10,
finetune: bool = False,
finetune_epochs: int = 5,
finetune_max_lr: float = 0.01,
......@@ -362,9 +372,7 @@ class Trainer:
validation_freq: int, default=1
epochs validation frequency
batch_size: int, default=32
patience: int, default=10
Number of epochs without improving the target metric or loss
before the fit process stops
batch size
finetune: bool, default=False
param alias: ``warmup``
......@@ -571,23 +579,32 @@ class Trainer:
with trange(train_steps, disable=self.verbose != 1) as t:
for batch_idx, (data, targett) in zip(t, train_loader):
t.set_description("epoch %i" % (epoch + 1))
score, train_loss = self._train_step(data, targett, batch_idx)
print_loss_and_metric(t, train_loss, score)
train_score, train_loss = self._train_step(data, targett, batch_idx)
print_loss_and_metric(t, train_loss, train_score)
self.callback_container.on_batch_end(batch=batch_idx)
epoch_logs = save_epoch_logs(epoch_logs, train_loss, score, "train")
epoch_logs = save_epoch_logs(epoch_logs, train_loss, train_score, "train")
on_epoch_end_metric = None
if eval_set is not None and epoch % validation_freq == (
validation_freq - 1
):
self.callback_container.on_eval_begin()
self.valid_running_loss = 0.0
with trange(eval_steps, disable=self.verbose != 1) as v:
for i, (data, targett) in zip(v, eval_loader):
v.set_description("valid")
score, val_loss = self._eval_step(data, targett, i)
print_loss_and_metric(v, val_loss, score)
epoch_logs = save_epoch_logs(epoch_logs, val_loss, score, "val")
val_score, val_loss = self._eval_step(data, targett, i)
print_loss_and_metric(v, val_loss, val_score)
epoch_logs = save_epoch_logs(epoch_logs, val_loss, val_score, "val")
if self.reducelronplateau:
if self.reduce_on == "loss":
on_epoch_end_metric = val_loss
else:
on_epoch_end_metric = val_score[self.reduce_on]
self.callback_container.on_epoch_end(epoch, epoch_logs, on_epoch_end_metric)
self.callback_container.on_epoch_end(epoch, epoch_logs)
if self.early_stop:
self.callback_container.on_train_end(epoch_logs)
break
......@@ -1047,12 +1064,8 @@ class Trainer:
return preds_l
def _set_loss_fn(self, objective, class_weight, custom_loss_function, alpha, gamma):
if isinstance(class_weight, float):
class_weight = torch.tensor([1.0 - class_weight, class_weight])
elif isinstance(class_weight, (tuple, list)):
if class_weight is not None:
class_weight = torch.tensor(class_weight)
else:
class_weight = None
if custom_loss_function is not None:
return custom_loss_function
elif self.method != "regression" and "focal_loss" not in objective:
......@@ -1092,11 +1105,12 @@ class Trainer:
def _set_lr_scheduler(self, lr_schedulers):
if lr_schedulers is not None:
if isinstance(lr_schedulers, LRScheduler):
lr_scheduler: Union[
LRScheduler,
MultipleLRScheduler,
] = lr_schedulers
# ReduceLROnPlateau is special, only scheduler that is 'just' an
# object rather than a LRScheduler
if isinstance(lr_schedulers, LRScheduler) or isinstance(
lr_schedulers, ReduceLROnPlateau
):
lr_scheduler = lr_schedulers
cyclic_lr = "cycl" in lr_scheduler.__class__.__name__.lower()
else:
lr_scheduler = MultipleLRScheduler(lr_schedulers)
......
......@@ -214,7 +214,7 @@ def alias_to_loss(loss_fn: str, **kwargs):
"or loss functions: {}".format(", ".join(_ObjectiveToMethod.keys()))
)
if loss_fn in _LossAliases.get("binary"):
return nn.BCEWithLogitsLoss(weight=kwargs["weight"])
return nn.BCEWithLogitsLoss(pos_weight=kwargs["weight"])
if loss_fn in _LossAliases.get("multiclass"):
return nn.CrossEntropyLoss(weight=kwargs["weight"])
if loss_fn in _LossAliases.get("regression"):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册