提交 aa20f1c4 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5558 thor optimizer training parameters for GPU

Merge pull request !5558 from wangmin0104/master
...@@ -32,13 +32,21 @@ then ...@@ -32,13 +32,21 @@ then
exit 1 exit 1
fi fi
BASE_PATH=$(cd "`dirname $0`" || exit; pwd) get_real_path(){
cd $BASE_PATH/../ || exit if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
ulimit -u unlimited ulimit -u unlimited
export DEVICE_NUM=$3 export DEVICE_NUM=$3
export RANK_SIZE=$3 export RANK_SIZE=$3
export RANK_TABLE_FILE=$1 export RANK_TABLE_FILE=$PATH1
for((i=0; i<${DEVICE_NUM}; i++)) for((i=0; i<${DEVICE_NUM}; i++))
do do
...@@ -46,12 +54,12 @@ do ...@@ -46,12 +54,12 @@ do
export RANK_ID=$i export RANK_ID=$i
rm -rf ./train_parallel$i rm -rf ./train_parallel$i
mkdir ./train_parallel$i mkdir ./train_parallel$i
cp *.py ./train_parallel$i cp ../*.py ./train_parallel$i
cp -r ./src ./train_parallel$i cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID" echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log env > env.log
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$2 > log 2>&1 & python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 > log 2>&1 &
cd .. cd ..
done done
...@@ -46,17 +46,17 @@ config_gpu = ed({ ...@@ -46,17 +46,17 @@ config_gpu = ed({
"loss_scale": 128, "loss_scale": 128,
"momentum": 0.9, "momentum": 0.9,
"weight_decay": 5e-4, "weight_decay": 5e-4,
"epoch_size": 45, "epoch_size": 40,
"save_checkpoint": True, "save_checkpoint": True,
"save_checkpoint_epochs": 1, "save_checkpoint_epochs": 1,
"keep_checkpoint_max": 15, "keep_checkpoint_max": 15,
"save_checkpoint_path": "./", "save_checkpoint_path": "./",
"use_label_smooth": True, "use_label_smooth": True,
"label_smooth_factor": 0.1, "label_smooth_factor": 0.1,
"lr_init": 0.04, "lr_init": 0.05672,
"lr_decay": 5, "lr_decay": 4.9687,
"lr_end_epoch": 58, "lr_end_epoch": 50,
"damping_init": 0.02, "damping_init": 0.02345,
"damping_decay": 0.87, "damping_decay": 0.5467,
"frequency": 834, "frequency": 834,
}) })
...@@ -109,6 +109,7 @@ if __name__ == '__main__': ...@@ -109,6 +109,7 @@ if __name__ == '__main__':
init() init()
context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True) mirror_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([107])
ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
# create dataset # create dataset
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册