diff --git a/model_zoo/official/cv/resnet/scripts/run_distribute_train.sh b/model_zoo/official/cv/resnet/scripts/run_distribute_train.sh index efcb620cd868c3d89c9fa9abd454115aa73ec91d..4c27b5f6cfb8a791bdba12b74207358bdf23873a 100755 --- a/model_zoo/official/cv/resnet/scripts/run_distribute_train.sh +++ b/model_zoo/official/cv/resnet/scripts/run_distribute_train.sh @@ -79,10 +79,13 @@ export RANK_SIZE=8 export MINDSPORE_HCCL_CONFIG_PATH=$PATH1 export RANK_TABLE_FILE=$PATH1 +export SERVER_ID=0 +rank_start=$((DEVICE_NUM * SERVER_ID)) + for((i=0; i<${DEVICE_NUM}; i++)) do export DEVICE_ID=$i - export RANK_ID=$i + export RANK_ID=$((rank_start + i)) rm -rf ./train_parallel$i mkdir ./train_parallel$i cp ../*.py ./train_parallel$i diff --git a/model_zoo/official/cv/resnet/src/dataset.py b/model_zoo/official/cv/resnet/src/dataset.py index 1cbe5e20f3bcaebad8f6ba68999400d9264ddfc7..79730fc460d81a79b7f57bb7a65da21d715793ea 100755 --- a/model_zoo/official/cv/resnet/src/dataset.py +++ b/model_zoo/official/cv/resnet/src/dataset.py @@ -37,8 +37,7 @@ def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target= dataset """ if target == "Ascend": - device_num = int(os.getenv("DEVICE_NUM")) - rank_id = int(os.getenv("RANK_ID")) + device_num, rank_id = _get_rank_info() else: init("nccl") rank_id = get_rank() @@ -93,8 +92,7 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target= dataset """ if target == "Ascend": - device_num = int(os.getenv("DEVICE_NUM")) - rank_id = int(os.getenv("RANK_ID")) + device_num, rank_id = _get_rank_info() else: init("nccl") rank_id = get_rank() @@ -153,8 +151,7 @@ def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32): Returns: dataset """ - device_num = int(os.getenv("RANK_SIZE")) - rank_id = int(os.getenv("RANK_ID")) + device_num, rank_id = _get_rank_info() if device_num == 1: ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) @@ -203,3 +200,19 @@ def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32): ds = ds.repeat(repeat_num) return ds + + +def _get_rank_info(): + """ + get rank size and rank id + """ + rank_size = int(os.environ.get("RANK_SIZE", 1)) + + if rank_size > 1: + rank_size = get_group_size() + rank_id = get_rank() + else: + rank_size = 1 + rank_id = 0 + + return rank_size, rank_id