提交 41e6ceaa 编写于 作者: C CaoJian

vgg16 support imagenet dataset on Ascend

上级 75045e3e
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
if [ $# != 2 ] if [ $# != 2 ] && [ $# != 3 ]
then then
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATA_PATH]" echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATA_PATH] [cifar10|imagenet2012]"
exit 1 exit 1
fi fi
...@@ -32,6 +32,19 @@ then ...@@ -32,6 +32,19 @@ then
exit 1 exit 1
fi fi
dataset_type='cifar10'
if [ $# == 3 ]
then
if [ $3 != "cifar10" ] && [ $3 != "imagenet2012" ]
then
echo "error: the selected dataset is neither cifar10 nor imagenet2012"
exit 1
fi
dataset_type=$3
fi
export DEVICE_NUM=8 export DEVICE_NUM=8
export RANK_SIZE=8 export RANK_SIZE=8
export RANK_TABLE_FILE=$1 export RANK_TABLE_FILE=$1
...@@ -45,8 +58,8 @@ do ...@@ -45,8 +58,8 @@ do
cp *.py ./train_parallel$i cp *.py ./train_parallel$i
cp -r src ./train_parallel$i cp -r src ./train_parallel$i
cd ./train_parallel$i || exit cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID" echo "start training for rank $RANK_ID, device $DEVICE_ID, $dataset_type"
env > env.log env > env.log
python train.py --data_path=$2 --device_target="Ascend" --device_id=$i --is_distributed=1 &> log & python train.py --data_path=$2 --device_target="Ascend" --device_id=$i --is_distributed=1 --dataset=$dataset_type &> log &
cd .. cd ..
done done
...@@ -139,5 +139,8 @@ def vgg16(num_classes=1000, args=None, phase="train"): ...@@ -139,5 +139,8 @@ def vgg16(num_classes=1000, args=None, phase="train"):
>>> vgg16(num_classes=1000, args=args) >>> vgg16(num_classes=1000, args=args)
""" """
if args is None:
from .config import cifar_cfg
args = cifar_cfg
net = Vgg(cfg['16'], num_classes=num_classes, args=args, batch_norm=args.batch_norm, phase=phase) net = Vgg(cfg['16'], num_classes=num_classes, args=args, batch_norm=args.batch_norm, phase=phase)
return net return net
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册