未验证 提交 15747a06 编写于 作者: W wangguanzhong 提交者: GitHub

update dist api (#2444)

上级 f49036bd
......@@ -23,7 +23,7 @@ import six
import numpy as np
import paddle
from paddle.distributed import ParallelEnv
import paddle.distributed as dist
from ppdet.utils.checkpoint import save_model
from ppdet.optimizer import ModelEMA
......@@ -81,7 +81,7 @@ class LogPrinter(Callback):
super(LogPrinter, self).__init__(model)
def on_step_end(self, status):
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
if dist.get_world_size() < 2 or dist.get_rank() == 0:
mode = status['mode']
if mode == 'train':
epoch_id = status['epoch_id']
......@@ -129,7 +129,7 @@ class LogPrinter(Callback):
logger.info("Eval iter: {}".format(step_id))
def on_epoch_end(self, status):
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
if dist.get_world_size() < 2 or dist.get_rank() == 0:
mode = status['mode']
if mode == 'eval':
sample_num = status['sample_num']
......@@ -160,7 +160,7 @@ class Checkpointer(Callback):
epoch_id = status['epoch_id']
weight = None
save_name = None
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
if dist.get_world_size() < 2 or dist.get_rank() == 0:
if mode == 'train':
end_epoch = self.model.cfg.epoch
if epoch_id % self.model.cfg.snapshot_epoch == 0 or epoch_id == end_epoch - 1:
......@@ -224,7 +224,7 @@ class VisualDLWriter(Callback):
def on_step_end(self, status):
mode = status['mode']
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
if dist.get_world_size() < 2 or dist.get_rank() == 0:
if mode == 'train':
training_staus = status['training_staus']
for loss_name, loss_value in training_staus.get().items():
......@@ -248,7 +248,7 @@ class VisualDLWriter(Callback):
def on_epoch_end(self, status):
mode = status['mode']
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
if dist.get_world_size() < 2 or dist.get_rank() == 0:
if mode == 'eval':
for metric in self.model._metrics:
for key, map_value in metric.get_results().items():
......
......@@ -21,7 +21,7 @@ import random
import numpy as np
import paddle
from paddle.distributed import ParallelEnv, fleet
from paddle.distributed import fleet
__all__ = ['init_parallel_env', 'set_random_seed', 'init_fleet_env']
......
......@@ -24,7 +24,8 @@ import numpy as np
from PIL import Image
import paddle
from paddle.distributed import ParallelEnv, fleet
import paddle.distributed as dist
from paddle.distributed import fleet
from paddle import amp
from paddle.static import InputSpec
......@@ -83,8 +84,8 @@ class Trainer(object):
self.optimizer = create('OptimizerBuilder')(self.lr,
self.model.parameters())
self._nranks = ParallelEnv().nranks
self._local_rank = ParallelEnv().local_rank
self._nranks = dist.get_world_size()
self._local_rank = dist.get_rank()
self.status = {}
......
......@@ -17,7 +17,7 @@ import logging
import os
import sys
from paddle.distributed import ParallelEnv
import paddle.distributed as dist
__all__ = ['setup_logger']
......@@ -47,7 +47,7 @@ def setup_logger(name="ppdet", output=None):
"[%(asctime)s] %(name)s %(levelname)s: %(message)s",
datefmt="%m/%d %H:%M:%S")
# stdout logging: master only
local_rank = ParallelEnv().local_rank
local_rank = dist.get_rank()
if local_rank == 0:
ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging.DEBUG)
......
......@@ -27,7 +27,6 @@ import warnings
warnings.filterwarnings('ignore')
import paddle
from paddle.distributed import ParallelEnv
from ppdet.core.workspace import load_config, merge_config
from ppdet.utils.check import check_gpu, check_version, check_config
......@@ -115,8 +114,7 @@ def main():
check_gpu(cfg.use_gpu)
check_version()
place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu'
place = paddle.set_device(place)
place = paddle.set_device('gpu' if cfg.use_gpu else 'cpu')
run(FLAGS, cfg)
......
......@@ -27,7 +27,6 @@ warnings.filterwarnings('ignore')
import glob
import paddle
from paddle.distributed import ParallelEnv
from ppdet.core.workspace import load_config, merge_config
from ppdet.engine import Trainer
from ppdet.utils.check import check_gpu, check_version, check_config
......@@ -140,8 +139,7 @@ def main():
check_gpu(cfg.use_gpu)
check_version()
place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu'
place = paddle.set_device(place)
place = paddle.set_device('gpu' if cfg.use_gpu else 'cpu')
run(FLAGS, cfg)
......
......@@ -29,7 +29,6 @@ import random
import numpy as np
import paddle
from paddle.distributed import ParallelEnv
from ppdet.core.workspace import load_config, merge_config, create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
......@@ -122,8 +121,7 @@ def main():
check.check_gpu(cfg.use_gpu)
check.check_version()
place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu'
place = paddle.set_device(place)
place = paddle.set_device('gpu' if cfg.use_gpu else 'cpu')
run(FLAGS, cfg)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册