提交 9ca66900 编写于 作者: W wuzewu

Fix trainer bug

上级 e93c01e4
......@@ -55,8 +55,8 @@ class Trainer(object):
use_vdl: bool = True,
checkpoint_dir: str = None,
compare_metrics: Callable = None):
self.nranks = paddle.distributed.get_rank()
self.local_rank = paddle.distributed.get_world_size()
self.nranks = paddle.distributed.get_world_size()
self.local_rank = paddle.distributed.get_rank()
self.model = model
self.optimizer = strategy
self.checkpoint_dir = checkpoint_dir if checkpoint_dir else 'ckpt_{}'.format(time.time())
......@@ -77,7 +77,6 @@ class Trainer(object):
strategy = paddle.distributed.prepare_context()
self.model = paddle.DataParallel(self.model, strategy)
self.compare_metrics = self._compare_metrics if not compare_metrics else compare_metrics
self._load_checkpoint()
def _load_checkpoint(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册