diff --git a/ppdet/engine/callbacks.py b/ppdet/engine/callbacks.py index 1419661ca6e5213f0012ca04f60964ac5092998a..ca4ad613616056f41842f2dd2e9f1341bb556e4e 100644 --- a/ppdet/engine/callbacks.py +++ b/ppdet/engine/callbacks.py @@ -23,7 +23,7 @@ import six import numpy as np import paddle -from paddle.distributed import ParallelEnv +import paddle.distributed as dist from ppdet.utils.checkpoint import save_model from ppdet.optimizer import ModelEMA @@ -81,7 +81,7 @@ class LogPrinter(Callback): super(LogPrinter, self).__init__(model) def on_step_end(self, status): - if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0: + if dist.get_world_size() < 2 or dist.get_rank() == 0: mode = status['mode'] if mode == 'train': epoch_id = status['epoch_id'] @@ -129,7 +129,7 @@ class LogPrinter(Callback): logger.info("Eval iter: {}".format(step_id)) def on_epoch_end(self, status): - if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0: + if dist.get_world_size() < 2 or dist.get_rank() == 0: mode = status['mode'] if mode == 'eval': sample_num = status['sample_num'] @@ -160,7 +160,7 @@ class Checkpointer(Callback): epoch_id = status['epoch_id'] weight = None save_name = None - if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0: + if dist.get_world_size() < 2 or dist.get_rank() == 0: if mode == 'train': end_epoch = self.model.cfg.epoch if epoch_id % self.model.cfg.snapshot_epoch == 0 or epoch_id == end_epoch - 1: @@ -224,7 +224,7 @@ class VisualDLWriter(Callback): def on_step_end(self, status): mode = status['mode'] - if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0: + if dist.get_world_size() < 2 or dist.get_rank() == 0: if mode == 'train': training_staus = status['training_staus'] for loss_name, loss_value in training_staus.get().items(): @@ -248,7 +248,7 @@ class VisualDLWriter(Callback): def on_epoch_end(self, status): mode = status['mode'] - if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0: + if dist.get_world_size() < 2 or dist.get_rank() == 0: if mode == 'eval': for metric in self.model._metrics: for key, map_value in metric.get_results().items(): diff --git a/ppdet/engine/env.py b/ppdet/engine/env.py index ba0b7edd61bf39d5df9e647cefdded867f6ca86f..cfeea08c98c081083033120c9d3fbb5c02efdd35 100644 --- a/ppdet/engine/env.py +++ b/ppdet/engine/env.py @@ -21,7 +21,7 @@ import random import numpy as np import paddle -from paddle.distributed import ParallelEnv, fleet +from paddle.distributed import fleet __all__ = ['init_parallel_env', 'set_random_seed', 'init_fleet_env'] diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 37a244d61b1c709869b37e5c009c78c5ead5b7f8..4614abdc7bb4f8e0e2ed3865c421648e0ea91133 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -24,7 +24,8 @@ import numpy as np from PIL import Image import paddle -from paddle.distributed import ParallelEnv, fleet +import paddle.distributed as dist +from paddle.distributed import fleet from paddle import amp from paddle.static import InputSpec @@ -84,8 +85,8 @@ class Trainer(object): self.optimizer = create('OptimizerBuilder')(self.lr, self.model.parameters()) - self._nranks = ParallelEnv().nranks - self._local_rank = ParallelEnv().local_rank + self._nranks = dist.get_world_size() + self._local_rank = dist.get_rank() self.status = {} diff --git a/ppdet/utils/logger.py b/ppdet/utils/logger.py index 9f02313ed57b9ca19668ecd7ceecbddc8fa693e1..99b82f995e4ec1a8ec19b8253aa1c2b3948d1e2d 100644 --- a/ppdet/utils/logger.py +++ b/ppdet/utils/logger.py @@ -17,7 +17,7 @@ import logging import os import sys -from paddle.distributed import ParallelEnv +import paddle.distributed as dist __all__ = ['setup_logger'] @@ -47,7 +47,7 @@ def setup_logger(name="ppdet", output=None): "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S") # stdout logging: master only - local_rank = ParallelEnv().local_rank + local_rank = dist.get_rank() if local_rank == 0: ch = logging.StreamHandler(stream=sys.stdout) ch.setLevel(logging.DEBUG) diff --git a/tools/eval.py b/tools/eval.py index 8b0064762d4cbbad8bc0b6461ee289a60d1eec72..21ee29d160cfb1331489dcc09a597e302f6a09c8 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -27,7 +27,6 @@ import warnings warnings.filterwarnings('ignore') import paddle -from paddle.distributed import ParallelEnv from ppdet.core.workspace import load_config, merge_config from ppdet.utils.check import check_gpu, check_version, check_config @@ -115,8 +114,7 @@ def main(): check_gpu(cfg.use_gpu) check_version() - place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu' - place = paddle.set_device(place) + place = paddle.set_device('gpu' if cfg.use_gpu else 'cpu') run(FLAGS, cfg) diff --git a/tools/infer.py b/tools/infer.py index 9226e1eea84c0ca15d14b917623b193b5edcb779..a2507680b812b34e04007c4bff8e366330bf4760 100755 --- a/tools/infer.py +++ b/tools/infer.py @@ -27,7 +27,6 @@ warnings.filterwarnings('ignore') import glob import paddle -from paddle.distributed import ParallelEnv from ppdet.core.workspace import load_config, merge_config from ppdet.engine import Trainer from ppdet.utils.check import check_gpu, check_version, check_config @@ -140,8 +139,7 @@ def main(): check_gpu(cfg.use_gpu) check_version() - place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu' - place = paddle.set_device(place) + place = paddle.set_device('gpu' if cfg.use_gpu else 'cpu') run(FLAGS, cfg) diff --git a/tools/train.py b/tools/train.py index e7efcd07a30a729435437ae5db1a780f5fc6d7b2..cdbe87ab2d1603527ddf60397a66ad53670a04f3 100755 --- a/tools/train.py +++ b/tools/train.py @@ -29,7 +29,6 @@ import random import numpy as np import paddle -from paddle.distributed import ParallelEnv from ppdet.core.workspace import load_config, merge_config, create from ppdet.utils.checkpoint import load_weight, load_pretrain_weight @@ -122,8 +121,7 @@ def main(): check.check_gpu(cfg.use_gpu) check.check_version() - place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu' - place = paddle.set_device(place) + place = paddle.set_device('gpu' if cfg.use_gpu else 'cpu') run(FLAGS, cfg)