提交 5386f3c5 编写于 作者: M Maxim Berman 提交者: Francisco Massa

Use dist.get_rank() instead of local_rank to detect master process (#40)

上级 b5de47b7
......@@ -18,6 +18,12 @@ def get_world_size():
return torch.distributed.deprecated.get_world_size()
def get_rank():
if not torch.distributed.deprecated.is_initialized():
return 0
return torch.distributed.deprecated.get_rank()
def is_main_process():
if not torch.distributed.deprecated.is_initialized():
return True
......
......@@ -4,11 +4,11 @@ import os
import sys
def setup_logger(name, save_dir, local_rank):
def setup_logger(name, save_dir, distributed_rank):
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
# don't log results for the non-master process
if local_rank > 0:
if distributed_rank > 0:
return logger
ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging.DEBUG)
......
......@@ -13,7 +13,7 @@ from maskrcnn_benchmark.engine.inference import inference
from maskrcnn_benchmark.modeling.detector import build_detection_model
from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
from maskrcnn_benchmark.utils.collect_env import collect_env_info
from maskrcnn_benchmark.utils.comm import synchronize
from maskrcnn_benchmark.utils.comm import synchronize, get_rank
from maskrcnn_benchmark.utils.logger import setup_logger
from maskrcnn_benchmark.utils.miscellaneous import mkdir
......@@ -50,7 +50,7 @@ def main():
cfg.freeze()
save_dir = ""
logger = setup_logger("maskrcnn_benchmark", save_dir, args.local_rank)
logger = setup_logger("maskrcnn_benchmark", save_dir, get_rank())
logger.info("Using {} GPUs".format(num_gpus))
logger.info(cfg)
......
......@@ -20,7 +20,7 @@ from maskrcnn_benchmark.engine.trainer import do_train
from maskrcnn_benchmark.modeling.detector import build_detection_model
from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
from maskrcnn_benchmark.utils.collect_env import collect_env_info
from maskrcnn_benchmark.utils.comm import synchronize
from maskrcnn_benchmark.utils.comm import synchronize, get_rank
from maskrcnn_benchmark.utils.imports import import_file
from maskrcnn_benchmark.utils.logger import setup_logger
from maskrcnn_benchmark.utils.miscellaneous import mkdir
......@@ -46,7 +46,7 @@ def train(cfg, local_rank, distributed):
output_dir = cfg.OUTPUT_DIR
save_to_disk = local_rank == 0
save_to_disk = get_rank() == 0
checkpointer = DetectronCheckpointer(
cfg, model, optimizer, scheduler, output_dir, save_to_disk
)
......@@ -147,7 +147,7 @@ def main():
if output_dir:
mkdir(output_dir)
logger = setup_logger("maskrcnn_benchmark", output_dir, args.local_rank)
logger = setup_logger("maskrcnn_benchmark", output_dir, get_rank())
logger.info("Using {} GPUs".format(num_gpus))
logger.info(args)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册