提交 c36b38c7 编写于 作者: J jrzaurin

lr scheduler step as a Callback and adjusted a few tests

上级 6def0a0b
......@@ -81,8 +81,8 @@ if __name__ == "__main__":
wide_opt = torch.optim.Adam(model.wide.parameters(), lr=0.01)
deep_opt = RAdam(model.deeptabular.parameters())
wide_sch = torch.optim.lr_scheduler.StepLR(wide_opt, step_size=3)
deep_sch = torch.optim.lr_scheduler.StepLR(deep_opt, step_size=5)
wide_sch = torch.optim.lr_scheduler.StepLR(wide_opt, step_size=2)
deep_sch = torch.optim.lr_scheduler.StepLR(deep_opt, step_size=3)
optimizers = {"wide": wide_opt, "deeptabular": deep_opt}
schedulers = {"wide": wide_sch, "deeptabular": deep_sch}
......@@ -108,7 +108,7 @@ if __name__ == "__main__":
X_wide=X_wide,
X_tab=X_tab,
target=target,
n_epochs=4,
n_epochs=10,
batch_size=64,
val_split=0.2,
)
......
......@@ -134,6 +134,53 @@ class History(Callback):
self.trainer.history.setdefault(k, []).append(v)
class LRShedulerCallback(Callback):
r"""Callback for the learning rate schedulers to take a step
This callback runs by default within :obj:`Trainer`, therefore, should not
be passed to the ``Trainer``. Is included here just for completion.
"""
def on_batch_end(self, batch: int, logs: Optional[Dict] = None):
if self.trainer.lr_scheduler is not None:
if self._multiple_scheduler():
for (
model_name,
scheduler,
) in self.trainer.lr_scheduler._schedulers.items():
if self._is_cyclic(model_name):
scheduler.step()
elif self.trainer.cyclic_lr:
self.trainer.lr_scheduler.step()
def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None):
if self.trainer.lr_scheduler is not None:
if self._multiple_scheduler():
for (
model_name,
scheduler,
) in self.trainer.lr_scheduler._schedulers.items():
if not self._is_cyclic(model_name):
scheduler.step()
elif not self.trainer.cyclic_lr:
self.trainer.lr_scheduler.step()
def _multiple_scheduler(self):
return self.trainer.lr_scheduler.__class__.__name__ == "MultipleLRScheduler"
def _is_cyclic(self, model_name: str):
return (
self._has_scheduler(model_name)
and "cycl"
in self.trainer.lr_scheduler._schedulers[
model_name
].__class__.__name__.lower()
)
def _has_scheduler(self, model_name: str):
return model_name in self.trainer.lr_scheduler._schedulers
class LRHistory(Callback):
def __init__(self, n_epochs: int):
r"""Saves the learning rates during training to a ``lr_history`` attribute.
......
......@@ -367,17 +367,17 @@ class EmbeddingsAndContinuous(nn.Module):
}
)
self.embedding_dropout = nn.Dropout(embed_dropout)
emb_inp_dim = np.sum([embed[2] for embed in self.embed_input])
emb_out_dim = np.sum([embed[2] for embed in self.embed_input])
# Continuous
if self.continuous_cols is not None:
cont_inp_dim = len(self.continuous_cols)
cont_out_dim = len(self.continuous_cols)
if self.batchnorm_cont:
self.norm = nn.BatchNorm1d(cont_inp_dim)
self.norm = nn.BatchNorm1d(cont_out_dim)
else:
cont_inp_dim = 0
cont_out_dim = 0
self.output_dim = emb_inp_dim + cont_inp_dim
self.output_dim = emb_out_dim + cont_out_dim
def forward(self, X):
embed = [
......
......@@ -164,10 +164,19 @@ class WideDeep(nn.Module):
if self.deeptabular is not None:
self.is_tabnet = deeptabular.__class__.__name__ == "TabNet"
else:
self.is_tabnet = False
if self.deephead is None:
if head_hidden_dims is not None:
self._build_deephead()
self._build_deephead(
head_hidden_dims,
head_activation,
head_dropout,
head_batchnorm,
head_batchnorm_last,
head_linear_first,
)
else:
self._add_pred_layer()
......@@ -178,7 +187,15 @@ class WideDeep(nn.Module):
else:
return self._forward_deep(X, wide_out)
def _build_deephead(self):
def _build_deephead(
self,
head_hidden_dims,
head_activation,
head_dropout,
head_batchnorm,
head_batchnorm_last,
head_linear_first,
):
deep_dim = 0
if self.deeptabular is not None:
deep_dim += self.deeptabular.output_dim
......
......@@ -12,7 +12,7 @@ from pytorch_widedeep.losses import MSLELoss, RMSELoss, FocalLoss, RMSLELoss
from pytorch_widedeep.models import WideDeep
from pytorch_widedeep.metrics import Metric, MetricCallback, MultipleMetrics
from pytorch_widedeep.wdtypes import * # noqa: F403
from pytorch_widedeep.callbacks import History, Callback, CallbackContainer
from pytorch_widedeep.callbacks import History, LRShedulerCallback, Callback, CallbackContainer
from pytorch_widedeep.initializers import Initializer, MultipleInitializer
from pytorch_widedeep.training._finetune import FineTune
from pytorch_widedeep.utils.general_utils import Alias
......@@ -554,8 +554,6 @@ class Trainer:
)
else:
t.set_postfix(loss=train_loss)
if self.lr_scheduler:
self._lr_scheduler_step(step_location="on_batch_end")
self.callback_container.on_batch_end(batch=batch_idx)
epoch_logs["train_loss"] = train_loss
if score is not None:
......@@ -582,8 +580,6 @@ class Trainer:
for k, v in score.items():
log_k = "_".join(["val", k])
epoch_logs[log_k] = v
if self.lr_scheduler:
self._lr_scheduler_step(step_location="on_epoch_end")
self.callback_container.on_epoch_end(epoch, epoch_logs)
if self.early_stop:
self.callback_container.on_train_end(epoch_logs)
......@@ -936,46 +932,6 @@ class Trainer:
self.model.deepimage, "deepimage", loader, n_epochs, max_lr
)
def _lr_scheduler_step(self, step_location: str): # noqa: C901
r"""
Function to execute the learning rate schedulers steps.
If the lr_scheduler is Cyclic (i.e. CyclicLR or OneCycleLR), the step
must happen after training each bach durig training. On the other
hand, if the scheduler is not Cyclic, is expected to be called after
validation. (Consider coding this function as callback)
Parameters
----------
step_location: Str
Indicates where to run the lr_scheduler step
"""
if (
self.lr_scheduler.__class__.__name__ == "MultipleLRScheduler"
and self.cyclic_lr
):
if step_location == "on_batch_end":
for model_name, scheduler in self.lr_scheduler._schedulers.items(): # type: ignore
if "cycl" in scheduler.__class__.__name__.lower():
scheduler.step() # type: ignore
elif step_location == "on_epoch_end":
for scheduler_name, scheduler in self.lr_scheduler._schedulers.items(): # type: ignore
if "cycl" not in scheduler.__class__.__name__.lower():
scheduler.step() # type: ignore
elif self.cyclic_lr:
if step_location == "on_batch_end":
self.lr_scheduler.step() # type: ignore
else:
pass
elif self.lr_scheduler.__class__.__name__ == "MultipleLRScheduler":
if step_location == "on_epoch_end":
self.lr_scheduler.step() # type: ignore
else:
pass
elif step_location == "on_epoch_end":
self.lr_scheduler.step() # type: ignore
else:
pass
def _training_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int):
self.model.train()
X = {k: v.cuda() for k, v in data.items()} if use_cuda else data
......@@ -1192,7 +1148,7 @@ class Trainer:
return None
def _set_callbacks_and_metrics(self, callbacks, metrics):
self.callbacks: List = [History()]
self.callbacks: List = [History(), LRShedulerCallback()]
if callbacks is not None:
for callback in callbacks:
if isinstance(callback, type):
......
......@@ -110,7 +110,7 @@ def test_non_instantiated_callbacks():
model = WideDeep(wide=wide, deeptabular=tabmlp)
callbacks = [EarlyStopping]
trainer = Trainer(model, objective="binary", callbacks=callbacks)
assert trainer.callbacks[1].__class__.__name__ == "EarlyStopping"
assert trainer.callbacks[2].__class__.__name__ == "EarlyStopping"
###############################################################################
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册