diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 2ab7eac03a14f643e730341cc51eec183b9c524e..866be552da1537ac5a32720b1fdb3cbd35e7509b 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -181,7 +181,7 @@ class Trainer(): """Reset the train loader seed and increment `epoch`. """ self.epoch += 1 - if self.parallel: + if self.parallel and hasattr(self.train_loader, "batch_sampler"): self.train_loader.batch_sampler.set_epoch(self.epoch) def train(self): @@ -191,7 +191,7 @@ class Trainer(): # save init model, i.e. 0 epoch self.save(tag='init', infos=None) self.lr_scheduler.step(self.epoch) - if self.parallel: + if self.parallel and hasattr(self.train_loader, "batch_sampler"): self.train_loader.batch_sampler.set_epoch(self.epoch) logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")