提交 cb8b44b2 编写于 作者: W wuzewu

Update parallel rank calling method

上级 ee2be095
......@@ -20,7 +20,6 @@ from collections import defaultdict
from typing import Any, Callable, Generic, List
import paddle
from paddle.distributed import ParallelEnv
from visualdl import LogWriter
from paddlehub.utils.log import logger, processing
......@@ -56,8 +55,8 @@ class Trainer(object):
use_vdl: bool = True,
checkpoint_dir: str = None,
compare_metrics: Callable = None):
self.nranks = ParallelEnv().nranks
self.local_rank = ParallelEnv().local_rank
self.nranks = paddle.distributed.get_rank()
self.local_rank = paddle.distributed.get_world_size()
self.model = model
self.optimizer = strategy
self.checkpoint_dir = checkpoint_dir if checkpoint_dir else 'ckpt_{}'.format(time.time())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册