提交 d9ecfb18 编写于 作者: C caojian05

support multi server muli process

上级 bf699955
...@@ -33,10 +33,12 @@ MINDSPORE_HCCL_CONFIG_PATH=$(realpath $1) ...@@ -33,10 +33,12 @@ MINDSPORE_HCCL_CONFIG_PATH=$(realpath $1)
export MINDSPORE_HCCL_CONFIG_PATH export MINDSPORE_HCCL_CONFIG_PATH
echo "MINDSPORE_HCCL_CONFIG_PATH=${MINDSPORE_HCCL_CONFIG_PATH}" echo "MINDSPORE_HCCL_CONFIG_PATH=${MINDSPORE_HCCL_CONFIG_PATH}"
export SERVER_ID=0
rank_start=$((DEVICE_NUM * SERVER_ID))
for((i=0; i<${DEVICE_NUM}; i++)) for((i=0; i<${DEVICE_NUM}; i++))
do do
export DEVICE_ID=$i export DEVICE_ID=$i
export RANK_ID=$i export RANK_ID=$((rank_start + i))
rm -rf ./train_parallel$i rm -rf ./train_parallel$i
mkdir ./train_parallel$i mkdir ./train_parallel$i
cp -r ./src ./train_parallel$i cp -r ./src ./train_parallel$i
......
...@@ -31,8 +31,7 @@ def create_dataset(data_home, repeat_num=1, training=True): ...@@ -31,8 +31,7 @@ def create_dataset(data_home, repeat_num=1, training=True):
if not training: if not training:
data_dir = os.path.join(data_home, "cifar-10-verify-bin") data_dir = os.path.join(data_home, "cifar-10-verify-bin")
rank_size = int(os.environ.get("RANK_SIZE")) if os.environ.get("RANK_SIZE") else None rank_size, rank_id = _get_rank_info()
rank_id = int(os.environ.get("RANK_ID")) if os.environ.get("RANK_ID") else None
data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id) data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id)
resize_height = cfg.image_height resize_height = cfg.image_height
...@@ -65,3 +64,19 @@ def create_dataset(data_home, repeat_num=1, training=True): ...@@ -65,3 +64,19 @@ def create_dataset(data_home, repeat_num=1, training=True):
data_set = data_set.batch(batch_size=cfg.batch_size, drop_remainder=True) data_set = data_set.batch(batch_size=cfg.batch_size, drop_remainder=True)
return data_set return data_set
def _get_rank_info():
"""
get rank size and rank id
"""
rank_size = int(os.environ.get("RANK_SIZE", 1))
if rank_size > 1:
from mindspore.communication.management import get_rank, get_group_size
rank_size = get_group_size()
rank_id = get_rank()
else:
rank_size = rank_id = None
return rank_size, rank_id
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册