diff --git a/examples/adult_census.py b/examples/adult_census.py index 7b2d9159ff184b8e61e5d2403ab58e7d1f665ff3..d1d3d7733c4235f267f75a337fc480a372422772 100644 --- a/examples/adult_census.py +++ b/examples/adult_census.py @@ -81,8 +81,8 @@ if __name__ == "__main__": wide_opt = torch.optim.Adam(model.wide.parameters(), lr=0.01) deep_opt = RAdam(model.deeptabular.parameters()) - wide_sch = torch.optim.lr_scheduler.StepLR(wide_opt, step_size=3) - deep_sch = torch.optim.lr_scheduler.StepLR(deep_opt, step_size=5) + wide_sch = torch.optim.lr_scheduler.StepLR(wide_opt, step_size=2) + deep_sch = torch.optim.lr_scheduler.StepLR(deep_opt, step_size=3) optimizers = {"wide": wide_opt, "deeptabular": deep_opt} schedulers = {"wide": wide_sch, "deeptabular": deep_sch} @@ -108,7 +108,7 @@ if __name__ == "__main__": X_wide=X_wide, X_tab=X_tab, target=target, - n_epochs=4, + n_epochs=10, batch_size=64, val_split=0.2, ) diff --git a/pytorch_widedeep/callbacks.py b/pytorch_widedeep/callbacks.py index 02901d7aecb9023545400e1a2dfa278da93375df..ee5b81ddd615ffa23f75657575eea09b0bc8d1ac 100644 --- a/pytorch_widedeep/callbacks.py +++ b/pytorch_widedeep/callbacks.py @@ -134,6 +134,53 @@ class History(Callback): self.trainer.history.setdefault(k, []).append(v) +class LRShedulerCallback(Callback): + r"""Callback for the learning rate schedulers to take a step + + This callback runs by default within :obj:`Trainer`, therefore, should not + be passed to the ``Trainer``. Is included here just for completion. + """ + + def on_batch_end(self, batch: int, logs: Optional[Dict] = None): + if self.trainer.lr_scheduler is not None: + if self._multiple_scheduler(): + for ( + model_name, + scheduler, + ) in self.trainer.lr_scheduler._schedulers.items(): + if self._is_cyclic(model_name): + scheduler.step() + elif self.trainer.cyclic_lr: + self.trainer.lr_scheduler.step() + + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None): + if self.trainer.lr_scheduler is not None: + if self._multiple_scheduler(): + for ( + model_name, + scheduler, + ) in self.trainer.lr_scheduler._schedulers.items(): + if not self._is_cyclic(model_name): + scheduler.step() + elif not self.trainer.cyclic_lr: + self.trainer.lr_scheduler.step() + + def _multiple_scheduler(self): + return self.trainer.lr_scheduler.__class__.__name__ == "MultipleLRScheduler" + + def _is_cyclic(self, model_name: str): + return ( + self._has_scheduler(model_name) + and "cycl" + in self.trainer.lr_scheduler._schedulers[ + model_name + ].__class__.__name__.lower() + ) + + def _has_scheduler(self, model_name: str): + return model_name in self.trainer.lr_scheduler._schedulers + + class LRHistory(Callback): def __init__(self, n_epochs: int): r"""Saves the learning rates during training to a ``lr_history`` attribute. diff --git a/pytorch_widedeep/models/tabnet/tab_net.py b/pytorch_widedeep/models/tabnet/tab_net.py index 79b9b5b3769e48c7881d68e0b30a6ca85d8ed2aa..7ad23995acffd83018a27740c92c474bba4b8263 100644 --- a/pytorch_widedeep/models/tabnet/tab_net.py +++ b/pytorch_widedeep/models/tabnet/tab_net.py @@ -367,17 +367,17 @@ class EmbeddingsAndContinuous(nn.Module): } ) self.embedding_dropout = nn.Dropout(embed_dropout) - emb_inp_dim = np.sum([embed[2] for embed in self.embed_input]) + emb_out_dim = np.sum([embed[2] for embed in self.embed_input]) # Continuous if self.continuous_cols is not None: - cont_inp_dim = len(self.continuous_cols) + cont_out_dim = len(self.continuous_cols) if self.batchnorm_cont: - self.norm = nn.BatchNorm1d(cont_inp_dim) + self.norm = nn.BatchNorm1d(cont_out_dim) else: - cont_inp_dim = 0 + cont_out_dim = 0 - self.output_dim = emb_inp_dim + cont_inp_dim + self.output_dim = emb_out_dim + cont_out_dim def forward(self, X): embed = [ diff --git a/pytorch_widedeep/models/wide_deep.py b/pytorch_widedeep/models/wide_deep.py index fb10762868da724765f1cd6865069680316705f9..8015a14593f06bc826de0f224c3df3163fb0241b 100644 --- a/pytorch_widedeep/models/wide_deep.py +++ b/pytorch_widedeep/models/wide_deep.py @@ -164,10 +164,19 @@ class WideDeep(nn.Module): if self.deeptabular is not None: self.is_tabnet = deeptabular.__class__.__name__ == "TabNet" + else: + self.is_tabnet = False if self.deephead is None: if head_hidden_dims is not None: - self._build_deephead() + self._build_deephead( + head_hidden_dims, + head_activation, + head_dropout, + head_batchnorm, + head_batchnorm_last, + head_linear_first, + ) else: self._add_pred_layer() @@ -178,7 +187,15 @@ class WideDeep(nn.Module): else: return self._forward_deep(X, wide_out) - def _build_deephead(self): + def _build_deephead( + self, + head_hidden_dims, + head_activation, + head_dropout, + head_batchnorm, + head_batchnorm_last, + head_linear_first, + ): deep_dim = 0 if self.deeptabular is not None: deep_dim += self.deeptabular.output_dim diff --git a/pytorch_widedeep/training/trainer.py b/pytorch_widedeep/training/trainer.py index 97b35c35a5bddada70979bea97942f3ef52523a0..d8fbf867878aa75e4ad6a10a5a496a527e524c77 100644 --- a/pytorch_widedeep/training/trainer.py +++ b/pytorch_widedeep/training/trainer.py @@ -12,7 +12,7 @@ from pytorch_widedeep.losses import MSLELoss, RMSELoss, FocalLoss, RMSLELoss from pytorch_widedeep.models import WideDeep from pytorch_widedeep.metrics import Metric, MetricCallback, MultipleMetrics from pytorch_widedeep.wdtypes import * # noqa: F403 -from pytorch_widedeep.callbacks import History, Callback, CallbackContainer +from pytorch_widedeep.callbacks import History, LRShedulerCallback, Callback, CallbackContainer from pytorch_widedeep.initializers import Initializer, MultipleInitializer from pytorch_widedeep.training._finetune import FineTune from pytorch_widedeep.utils.general_utils import Alias @@ -554,8 +554,6 @@ class Trainer: ) else: t.set_postfix(loss=train_loss) - if self.lr_scheduler: - self._lr_scheduler_step(step_location="on_batch_end") self.callback_container.on_batch_end(batch=batch_idx) epoch_logs["train_loss"] = train_loss if score is not None: @@ -582,8 +580,6 @@ class Trainer: for k, v in score.items(): log_k = "_".join(["val", k]) epoch_logs[log_k] = v - if self.lr_scheduler: - self._lr_scheduler_step(step_location="on_epoch_end") self.callback_container.on_epoch_end(epoch, epoch_logs) if self.early_stop: self.callback_container.on_train_end(epoch_logs) @@ -936,46 +932,6 @@ class Trainer: self.model.deepimage, "deepimage", loader, n_epochs, max_lr ) - def _lr_scheduler_step(self, step_location: str): # noqa: C901 - r""" - Function to execute the learning rate schedulers steps. - If the lr_scheduler is Cyclic (i.e. CyclicLR or OneCycleLR), the step - must happen after training each bach durig training. On the other - hand, if the scheduler is not Cyclic, is expected to be called after - validation. (Consider coding this function as callback) - - Parameters - ---------- - step_location: Str - Indicates where to run the lr_scheduler step - """ - if ( - self.lr_scheduler.__class__.__name__ == "MultipleLRScheduler" - and self.cyclic_lr - ): - if step_location == "on_batch_end": - for model_name, scheduler in self.lr_scheduler._schedulers.items(): # type: ignore - if "cycl" in scheduler.__class__.__name__.lower(): - scheduler.step() # type: ignore - elif step_location == "on_epoch_end": - for scheduler_name, scheduler in self.lr_scheduler._schedulers.items(): # type: ignore - if "cycl" not in scheduler.__class__.__name__.lower(): - scheduler.step() # type: ignore - elif self.cyclic_lr: - if step_location == "on_batch_end": - self.lr_scheduler.step() # type: ignore - else: - pass - elif self.lr_scheduler.__class__.__name__ == "MultipleLRScheduler": - if step_location == "on_epoch_end": - self.lr_scheduler.step() # type: ignore - else: - pass - elif step_location == "on_epoch_end": - self.lr_scheduler.step() # type: ignore - else: - pass - def _training_step(self, data: Dict[str, Tensor], target: Tensor, batch_idx: int): self.model.train() X = {k: v.cuda() for k, v in data.items()} if use_cuda else data @@ -1192,7 +1148,7 @@ class Trainer: return None def _set_callbacks_and_metrics(self, callbacks, metrics): - self.callbacks: List = [History()] + self.callbacks: List = [History(), LRShedulerCallback()] if callbacks is not None: for callback in callbacks: if isinstance(callback, type): diff --git a/tests/test_model_functioning/test_miscellaneous.py b/tests/test_model_functioning/test_miscellaneous.py index 5c3ecf302c03b177de50d05af1c2f478b68e2fd1..94e972dbe835f839059b4a54f641f158ea853529 100644 --- a/tests/test_model_functioning/test_miscellaneous.py +++ b/tests/test_model_functioning/test_miscellaneous.py @@ -110,7 +110,7 @@ def test_non_instantiated_callbacks(): model = WideDeep(wide=wide, deeptabular=tabmlp) callbacks = [EarlyStopping] trainer = Trainer(model, objective="binary", callbacks=callbacks) - assert trainer.callbacks[1].__class__.__name__ == "EarlyStopping" + assert trainer.callbacks[2].__class__.__name__ == "EarlyStopping" ###############################################################################