From 00f7a936bf461bf197d7fdcfe1ca10422ed6ddba Mon Sep 17 00:00:00 2001 From: gengdongjie Date: Wed, 29 Jul 2020 16:47:23 +0800 Subject: [PATCH] add resnet50 support multi node training --- .../cv/resnet/scripts/run_distribute_train.sh | 5 +++- model_zoo/official/cv/resnet/src/dataset.py | 25 ++++++++++++++----- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/model_zoo/official/cv/resnet/scripts/run_distribute_train.sh b/model_zoo/official/cv/resnet/scripts/run_distribute_train.sh index efcb620cd..4c27b5f6c 100755 --- a/model_zoo/official/cv/resnet/scripts/run_distribute_train.sh +++ b/model_zoo/official/cv/resnet/scripts/run_distribute_train.sh @@ -79,10 +79,13 @@ export RANK_SIZE=8 export MINDSPORE_HCCL_CONFIG_PATH=$PATH1 export RANK_TABLE_FILE=$PATH1 +export SERVER_ID=0 +rank_start=$((DEVICE_NUM * SERVER_ID)) + for((i=0; i<${DEVICE_NUM}; i++)) do export DEVICE_ID=$i - export RANK_ID=$i + export RANK_ID=$((rank_start + i)) rm -rf ./train_parallel$i mkdir ./train_parallel$i cp ../*.py ./train_parallel$i diff --git a/model_zoo/official/cv/resnet/src/dataset.py b/model_zoo/official/cv/resnet/src/dataset.py index 1cbe5e20f..79730fc46 100755 --- a/model_zoo/official/cv/resnet/src/dataset.py +++ b/model_zoo/official/cv/resnet/src/dataset.py @@ -37,8 +37,7 @@ def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target= dataset """ if target == "Ascend": - device_num = int(os.getenv("DEVICE_NUM")) - rank_id = int(os.getenv("RANK_ID")) + device_num, rank_id = _get_rank_info() else: init("nccl") rank_id = get_rank() @@ -93,8 +92,7 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target= dataset """ if target == "Ascend": - device_num = int(os.getenv("DEVICE_NUM")) - rank_id = int(os.getenv("RANK_ID")) + device_num, rank_id = _get_rank_info() else: init("nccl") rank_id = get_rank() @@ -153,8 +151,7 @@ def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32): Returns: dataset """ - device_num = int(os.getenv("RANK_SIZE")) - rank_id = int(os.getenv("RANK_ID")) + device_num, rank_id = _get_rank_info() if device_num == 1: ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) @@ -203,3 +200,19 @@ def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32): ds = ds.repeat(repeat_num) return ds + + +def _get_rank_info(): + """ + get rank size and rank id + """ + rank_size = int(os.environ.get("RANK_SIZE", 1)) + + if rank_size > 1: + rank_size = get_group_size() + rank_id = get_rank() + else: + rank_size = 1 + rank_id = 0 + + return rank_size, rank_id -- GitLab