未验证 提交 f5aba69c 编写于 作者: W wangguanzhong 提交者: GitHub

update dist api (#2443)

上级 0d14d704
...@@ -23,7 +23,7 @@ import six ...@@ -23,7 +23,7 @@ import six
import numpy as np import numpy as np
import paddle import paddle
from paddle.distributed import ParallelEnv import paddle.distributed as dist
from ppdet.utils.checkpoint import save_model from ppdet.utils.checkpoint import save_model
from ppdet.optimizer import ModelEMA from ppdet.optimizer import ModelEMA
...@@ -81,7 +81,7 @@ class LogPrinter(Callback): ...@@ -81,7 +81,7 @@ class LogPrinter(Callback):
super(LogPrinter, self).__init__(model) super(LogPrinter, self).__init__(model)
def on_step_end(self, status): def on_step_end(self, status):
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0: if dist.get_world_size() < 2 or dist.get_rank() == 0:
mode = status['mode'] mode = status['mode']
if mode == 'train': if mode == 'train':
epoch_id = status['epoch_id'] epoch_id = status['epoch_id']
...@@ -129,7 +129,7 @@ class LogPrinter(Callback): ...@@ -129,7 +129,7 @@ class LogPrinter(Callback):
logger.info("Eval iter: {}".format(step_id)) logger.info("Eval iter: {}".format(step_id))
def on_epoch_end(self, status): def on_epoch_end(self, status):
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0: if dist.get_world_size() < 2 or dist.get_rank() == 0:
mode = status['mode'] mode = status['mode']
if mode == 'eval': if mode == 'eval':
sample_num = status['sample_num'] sample_num = status['sample_num']
...@@ -160,7 +160,7 @@ class Checkpointer(Callback): ...@@ -160,7 +160,7 @@ class Checkpointer(Callback):
epoch_id = status['epoch_id'] epoch_id = status['epoch_id']
weight = None weight = None
save_name = None save_name = None
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0: if dist.get_world_size() < 2 or dist.get_rank() == 0:
if mode == 'train': if mode == 'train':
end_epoch = self.model.cfg.epoch end_epoch = self.model.cfg.epoch
if epoch_id % self.model.cfg.snapshot_epoch == 0 or epoch_id == end_epoch - 1: if epoch_id % self.model.cfg.snapshot_epoch == 0 or epoch_id == end_epoch - 1:
...@@ -224,7 +224,7 @@ class VisualDLWriter(Callback): ...@@ -224,7 +224,7 @@ class VisualDLWriter(Callback):
def on_step_end(self, status): def on_step_end(self, status):
mode = status['mode'] mode = status['mode']
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0: if dist.get_world_size() < 2 or dist.get_rank() == 0:
if mode == 'train': if mode == 'train':
training_staus = status['training_staus'] training_staus = status['training_staus']
for loss_name, loss_value in training_staus.get().items(): for loss_name, loss_value in training_staus.get().items():
...@@ -248,7 +248,7 @@ class VisualDLWriter(Callback): ...@@ -248,7 +248,7 @@ class VisualDLWriter(Callback):
def on_epoch_end(self, status): def on_epoch_end(self, status):
mode = status['mode'] mode = status['mode']
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0: if dist.get_world_size() < 2 or dist.get_rank() == 0:
if mode == 'eval': if mode == 'eval':
for metric in self.model._metrics: for metric in self.model._metrics:
for key, map_value in metric.get_results().items(): for key, map_value in metric.get_results().items():
......
...@@ -21,7 +21,7 @@ import random ...@@ -21,7 +21,7 @@ import random
import numpy as np import numpy as np
import paddle import paddle
from paddle.distributed import ParallelEnv, fleet from paddle.distributed import fleet
__all__ = ['init_parallel_env', 'set_random_seed', 'init_fleet_env'] __all__ = ['init_parallel_env', 'set_random_seed', 'init_fleet_env']
......
...@@ -24,7 +24,8 @@ import numpy as np ...@@ -24,7 +24,8 @@ import numpy as np
from PIL import Image from PIL import Image
import paddle import paddle
from paddle.distributed import ParallelEnv, fleet import paddle.distributed as dist
from paddle.distributed import fleet
from paddle import amp from paddle import amp
from paddle.static import InputSpec from paddle.static import InputSpec
...@@ -84,8 +85,8 @@ class Trainer(object): ...@@ -84,8 +85,8 @@ class Trainer(object):
self.optimizer = create('OptimizerBuilder')(self.lr, self.optimizer = create('OptimizerBuilder')(self.lr,
self.model.parameters()) self.model.parameters())
self._nranks = ParallelEnv().nranks self._nranks = dist.get_world_size()
self._local_rank = ParallelEnv().local_rank self._local_rank = dist.get_rank()
self.status = {} self.status = {}
......
...@@ -17,7 +17,7 @@ import logging ...@@ -17,7 +17,7 @@ import logging
import os import os
import sys import sys
from paddle.distributed import ParallelEnv import paddle.distributed as dist
__all__ = ['setup_logger'] __all__ = ['setup_logger']
...@@ -47,7 +47,7 @@ def setup_logger(name="ppdet", output=None): ...@@ -47,7 +47,7 @@ def setup_logger(name="ppdet", output=None):
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", "[%(asctime)s] %(name)s %(levelname)s: %(message)s",
datefmt="%m/%d %H:%M:%S") datefmt="%m/%d %H:%M:%S")
# stdout logging: master only # stdout logging: master only
local_rank = ParallelEnv().local_rank local_rank = dist.get_rank()
if local_rank == 0: if local_rank == 0:
ch = logging.StreamHandler(stream=sys.stdout) ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging.DEBUG) ch.setLevel(logging.DEBUG)
......
...@@ -27,7 +27,6 @@ import warnings ...@@ -27,7 +27,6 @@ import warnings
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
import paddle import paddle
from paddle.distributed import ParallelEnv
from ppdet.core.workspace import load_config, merge_config from ppdet.core.workspace import load_config, merge_config
from ppdet.utils.check import check_gpu, check_version, check_config from ppdet.utils.check import check_gpu, check_version, check_config
...@@ -115,8 +114,7 @@ def main(): ...@@ -115,8 +114,7 @@ def main():
check_gpu(cfg.use_gpu) check_gpu(cfg.use_gpu)
check_version() check_version()
place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu' place = paddle.set_device('gpu' if cfg.use_gpu else 'cpu')
place = paddle.set_device(place)
run(FLAGS, cfg) run(FLAGS, cfg)
......
...@@ -27,7 +27,6 @@ warnings.filterwarnings('ignore') ...@@ -27,7 +27,6 @@ warnings.filterwarnings('ignore')
import glob import glob
import paddle import paddle
from paddle.distributed import ParallelEnv
from ppdet.core.workspace import load_config, merge_config from ppdet.core.workspace import load_config, merge_config
from ppdet.engine import Trainer from ppdet.engine import Trainer
from ppdet.utils.check import check_gpu, check_version, check_config from ppdet.utils.check import check_gpu, check_version, check_config
...@@ -140,8 +139,7 @@ def main(): ...@@ -140,8 +139,7 @@ def main():
check_gpu(cfg.use_gpu) check_gpu(cfg.use_gpu)
check_version() check_version()
place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu' place = paddle.set_device('gpu' if cfg.use_gpu else 'cpu')
place = paddle.set_device(place)
run(FLAGS, cfg) run(FLAGS, cfg)
......
...@@ -29,7 +29,6 @@ import random ...@@ -29,7 +29,6 @@ import random
import numpy as np import numpy as np
import paddle import paddle
from paddle.distributed import ParallelEnv
from ppdet.core.workspace import load_config, merge_config, create from ppdet.core.workspace import load_config, merge_config, create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
...@@ -122,8 +121,7 @@ def main(): ...@@ -122,8 +121,7 @@ def main():
check.check_gpu(cfg.use_gpu) check.check_gpu(cfg.use_gpu)
check.check_version() check.check_version()
place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu' place = paddle.set_device('gpu' if cfg.use_gpu else 'cpu')
place = paddle.set_device(place)
run(FLAGS, cfg) run(FLAGS, cfg)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册