未验证 提交 de98283b 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #788 from PaddlePaddle/train

fix trainer when dataloader not has batch_sampler
......@@ -181,7 +181,7 @@ class Trainer():
"""Reset the train loader seed and increment `epoch`.
"""
self.epoch += 1
if self.parallel:
if self.parallel and hasattr(self.train_loader, "batch_sampler"):
self.train_loader.batch_sampler.set_epoch(self.epoch)
def train(self):
......@@ -191,7 +191,7 @@ class Trainer():
# save init model, i.e. 0 epoch
self.save(tag='init', infos=None)
self.lr_scheduler.step(self.epoch)
if self.parallel:
if self.parallel and hasattr(self.train_loader, "batch_sampler"):
self.train_loader.batch_sampler.set_epoch(self.epoch)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册