提交 92d93ebc 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4041 modify mobilenetv2 quant scripts and fix bug

Merge pull request !4041 from chengxb7532/master
......@@ -75,15 +75,15 @@ run_gpu()
python ${BASEPATH}/../train.py \
--dataset_path=$4 \
--device_target=$1 \
--quantization_aware=True \
&> ../train.log & # dataset train folder
--pre_trained=$5 \
--quantization_aware=True &> ../train.log & # dataset train folder
}
if [ $# -gt 6 ] || [ $# -lt 4 ]
if [ $# -gt 6 ] || [ $# -lt 5 ]
then
echo "Usage:\n \
Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \
GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \
Ascend: sh run_train_quant.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \
GPU: sh run_train_quant.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \
"
exit 1
fi
......
......@@ -22,7 +22,6 @@ import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
import mindspore.dataset.transforms.vision.py_transforms as P
from src.config import config_ascend
def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, batch_size=32):
......@@ -42,7 +41,7 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1,
rank_size = int(os.getenv("RANK_SIZE"))
rank_id = int(os.getenv("RANK_ID"))
columns_list = ['image', 'label']
if config_ascend.data_load_mode == "mindrecord":
if config.data_load_mode == "mindrecord":
load_func = partial(de.MindDataset, dataset_path, columns_list)
else:
load_func = partial(de.ImageFolderDatasetV2, dataset_path)
......@@ -54,6 +53,13 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1,
num_shards=rank_size, shard_id=rank_id)
else:
ds = load_func(num_parallel_workers=8, shuffle=False)
elif device_target == "GPU":
if do_train:
from mindspore.communication.management import get_rank, get_group_size
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=get_group_size(), shard_id=get_rank())
else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
else:
raise ValueError("Unsupport device_target.")
......
......@@ -56,7 +56,7 @@ if args_opt.device_target == "Ascend":
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
device_id=device_id, save_graphs=False)
elif args_opt.platform == "GPU":
elif args_opt.device_target == "GPU":
init("nccl")
context.set_auto_parallel_context(device_num=get_group_size(),
parallel_mode=ParallelMode.DATA_PARALLEL,
......@@ -205,5 +205,5 @@ def train_on_gpu():
if __name__ == '__main__':
if args_opt.device_target == "Ascend":
train_on_ascend()
elif args_opt.platform == "GPU":
elif args_opt.device_target == "GPU":
train_on_gpu()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册