提交 cb8b44b2 编写于 作者: W wuzewu

Update parallel rank calling method

上级 ee2be095
...@@ -20,7 +20,6 @@ from collections import defaultdict ...@@ -20,7 +20,6 @@ from collections import defaultdict
from typing import Any, Callable, Generic, List from typing import Any, Callable, Generic, List
import paddle import paddle
from paddle.distributed import ParallelEnv
from visualdl import LogWriter from visualdl import LogWriter
from paddlehub.utils.log import logger, processing from paddlehub.utils.log import logger, processing
...@@ -56,8 +55,8 @@ class Trainer(object): ...@@ -56,8 +55,8 @@ class Trainer(object):
use_vdl: bool = True, use_vdl: bool = True,
checkpoint_dir: str = None, checkpoint_dir: str = None,
compare_metrics: Callable = None): compare_metrics: Callable = None):
self.nranks = ParallelEnv().nranks self.nranks = paddle.distributed.get_rank()
self.local_rank = ParallelEnv().local_rank self.local_rank = paddle.distributed.get_world_size()
self.model = model self.model = model
self.optimizer = strategy self.optimizer = strategy
self.checkpoint_dir = checkpoint_dir if checkpoint_dir else 'ckpt_{}'.format(time.time()) self.checkpoint_dir = checkpoint_dir if checkpoint_dir else 'ckpt_{}'.format(time.time())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册