diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 38d5b710c75de6cb63a6ed4effd26c5f4b762adc..ac7534dd506381a37fe03d6dcfad6607b9f81f06 100755 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -369,6 +369,9 @@ class Engine(object): ema_module) if metric_info is not None: best_metric.update(metric_info) + if hasattr(self.train_dataloader.batch_sampler, "set_epoch"): + self.train_dataloader.batch_sampler.set_epoch(best_metric[ + "epoch"]) for epoch_id in range(best_metric["epoch"] + 1, self.config["Global"]["epochs"] + 1):