提交 600c2652 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3941 update gpu resnet101 scripts

Merge pull request !3941 from panfengfeng/fix_gpu_scripts_resnet101
......@@ -52,7 +52,7 @@ class DatasetHelper:
sink_size (int): Control the amount of data each sink.
If sink_size=-1, sink the complete dataset each epoch.
If sink_size>0, sink sink_size data each epoch. Default: -1.
epoch_num (int): Control the number of epoch data to send.
epoch_num (int): Control the number of epoch data to send. Default: 1.
Examples:
>>> dataset_helper = DatasetHelper(dataset)
......
......@@ -44,6 +44,9 @@ ImageNet2012
├── run_distribute_train.sh # launch distributed training(8 pcs)
├── run_eval.sh # launch evaluation
└── run_standalone_train.sh # launch standalone training(1 pcs)
├── run_distribute_train_gpu.sh # launch gpu distributed training(8 pcs)
├── run_eval_gpu.sh # launch gpu evaluation
└── run_standalone_train_gpu.sh # launch gpu standalone training(1 pcs)
├── src
├── config.py # parameter configuration
├── dataset.py # data preprocessing
......@@ -241,11 +244,11 @@ result: {'top_5_accuracy': 0.9429417413572343, 'top_1_accuracy': 0.7853513124199
### Running on GPU
```
# distributed training example
mpirun -n 8 python train.py --net=resnet50 --dataset=cifar10 --dataset_path=~/cifar-10-batches-bin --device_target="GPU" --run_distribute=True
sh run_distribute_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
# standalone training example
python train.py --net=resnet50 --dataset=cifar10 --dataset_path=~/cifar-10-batches-bin --device_target="GPU"
sh run_standalone_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)
# infer example
python eval.py --net=resnet50 --dataset=cifar10 --dataset_path=~/cifar10-10-verify-bin --device_target="GPU" --checkpoint_path=resnet-90_195.ckpt
sh run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]
```
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 4 ]
then
echo "Usage: sh run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
if [ $1 != "resnet50" ] && [ $1 != "resnet101" ]
then
echo "error: the selected net is neither resnet50 nor resnet101"
exit 1
fi
if [ $2 != "cifar10" ] && [ $2 != "imagenet2012" ]
then
echo "error: the selected dataset is neither cifar10 nor imagenet2012"
exit 1
fi
if [ $1 == "resnet101" ] && [ $2 == "cifar10" ]
then
echo "error: evaluating resnet101 with cifar10 dataset is unsupported now!"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $3)
PATH2=$(get_real_path $4)
if [ ! -d $PATH1 ]
then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
if [ ! -f $PATH2 ]
then
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
env > env.log
echo "start evaluation for device $DEVICE_ID"
python eval.py --net=$1 --dataset=$2 --dataset_path=$PATH1 --checkpoint_path=$PATH2 --device_target="GPU" &> log &
cd ..
......@@ -16,7 +16,7 @@
if [ $# != 3 ] && [ $# != 4 ]
then
echo "Usage: sh run_standalone_train.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
echo "Usage: sh run_standalone_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
exit 1
fi
......
......@@ -157,13 +157,18 @@ if __name__ == '__main__':
else:
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", is_grad=False,
num_classes=config.class_num)
## fp32 training
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
# # Mixed precision
# loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
# opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay, config.loss_scale)
# model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, amp_level="O2")
if args_opt.net == "resnet101":
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay,
config.loss_scale)
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
# Mixed precision
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=True)
else:
## fp32 training
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
# define callbacks
time_cb = TimeMonitor(data_size=step_size)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册