From 5bd9a50a0f18808fa77c63e6b6315d96e59f8058 Mon Sep 17 00:00:00 2001 From: wuzewu Date: Fri, 19 Apr 2019 16:45:13 +0800 Subject: [PATCH] update image classification demo options --- demo/image-classification/README.md | 4 +-- demo/image-classification/img_classifier.py | 14 ++++++++- demo/image-classification/predict.py | 14 +++++++-- demo/image-classification/run_classifier.sh | 32 +-------------------- demo/image-classification/run_predict.sh | 26 +---------------- 5 files changed, 28 insertions(+), 62 deletions(-) diff --git a/demo/image-classification/README.md b/demo/image-classification/README.md index cbde58fd..81c1d1e1 100644 --- a/demo/image-classification/README.md +++ b/demo/image-classification/README.md @@ -34,8 +34,8 @@ $ pip install --upgrade paddlepaddle --num_epoch: finetune迭代的轮数。默认为1 --module: 使用哪个Module作为finetune的特征提取器,脚本支持{resnet50/resnet101/resnet152/mobilenet/nasnet/pnasnet}等模型。默认为resnet50 --checkpoint_dir: 模型保存路径,PaddleHub会自动保存验证集上表现最好的模型。默认为paddlehub_finetune_ckpt ---dataset: 使用什么数据集进行finetune, 脚本支持分别是{flowers/dogcat}。默认为flowers ---use_gpu: 使用使用GPU进行训练,如果机器支持GPU且安装了GPU版本的PaddlePaddle,我们建议您打开这个开关。默认关闭 +--dataset: 使用什么数据集进行finetune, 脚本支持分别是{flowers/dogcat/stanforddogs/indoor67/food101}。默认为flowers +--use_gpu: 是否使用GPU进行训练,如果机器支持GPU且安装了GPU版本的PaddlePaddle,我们建议您打开这个开关。默认关闭 ``` ## 进行预测 diff --git a/demo/image-classification/img_classifier.py b/demo/image-classification/img_classifier.py index b3cc57c4..acef9d22 100644 --- a/demo/image-classification/img_classifier.py +++ b/demo/image-classification/img_classifier.py @@ -29,7 +29,19 @@ def finetune(args): module = hub.Module(name=args.module) input_dict, output_dict, program = module.context(trainable=True) - dataset = hub.dataset.Flowers() + if args.dataset.lower() == "flowers": + dataset = hub.dataset.Flowers() + elif args.dataset.lower() == "dogcat": + dataset = hub.dataset.DogCat() + elif args.dataset.lower() == "indoor67": + dataset = hub.dataset.Indoor67() + elif args.dataset.lower() == "food101": + dataset = hub.dataset.Food101() + elif args.dataset.lower() == "stanforddogs": + dataset = hub.dataset.StanfordDogs() + else: + raise ValueError("%s dataset is not defined" % args.dataset) + data_reader = hub.reader.ImageClassificationReader( image_width=module.get_expected_image_width(), image_height=module.get_expected_image_height(), diff --git a/demo/image-classification/predict.py b/demo/image-classification/predict.py index cb13d80c..32cdc1f1 100644 --- a/demo/image-classification/predict.py +++ b/demo/image-classification/predict.py @@ -25,10 +25,18 @@ module_map = { def predict(args): - if args.dataset == "dogcat": - dataset = hub.dataset.DogCat() - elif args.dataset == "flowers": + if args.dataset.lower() == "flowers": dataset = hub.dataset.Flowers() + elif args.dataset.lower() == "dogcat": + dataset = hub.dataset.DogCat() + elif args.dataset.lower() == "indoor67": + dataset = hub.dataset.Indoor67() + elif args.dataset.lower() == "food101": + dataset = hub.dataset.Food101() + elif args.dataset.lower() == "stanforddogs": + dataset = hub.dataset.StanfordDogs() + else: + raise ValueError("%s dataset is not defined" % args.dataset) label_map = dataset.label_dict() num_labels = len(label_map) diff --git a/demo/image-classification/run_classifier.sh b/demo/image-classification/run_classifier.sh index 779722fb..30028daf 100644 --- a/demo/image-classification/run_classifier.sh +++ b/demo/image-classification/run_classifier.sh @@ -1,31 +1 @@ -cuda_visible_devices=0 -module=resnet50 -num_epoch=1 -batch_size=16 -use_gpu=False -checkpoint_dir=paddlehub_finetune_ckpt - -while getopts "gm:n:b:c:d:" options -do - case "$options" in - m) - module=$OPTARG;; - n) - num_epoch=$OPTARG;; - b) - batch_size=$OPTARG;; - c) - checkpoint_dir=$OPTARG;; - d) - cuda_visible_devices=$OPTARG;; - g) - use_gpu=True;; - ?) - echo "unknown options" - exit 1;; - esac -done - -export CUDA_VISIBLE_DEVICES=${cuda_visible_devices} - -python -u img_classifier.py --use_gpu ${use_gpu} --batch_size ${batch_size} --checkpoint_dir ${checkpoint_dir} --num_epoch ${num_epoch} --module ${module} +python -u img_classifier.py $@ diff --git a/demo/image-classification/run_predict.sh b/demo/image-classification/run_predict.sh index f55e2392..68eed34c 100644 --- a/demo/image-classification/run_predict.sh +++ b/demo/image-classification/run_predict.sh @@ -1,25 +1 @@ -cuda_visible_devices=0 -module=resnet50 -use_gpu=False -checkpoint_dir=paddlehub_finetune_ckpt - -while getopts "gm:c:d:" options -do - case "$options" in - m) - module=$OPTARG;; - c) - checkpoint_dir=$OPTARG;; - d) - cuda_visible_devices=$OPTARG;; - g) - use_gpu=True;; - ?) - echo "unknown options" - exit 1;; - esac -done - -export CUDA_VISIBLE_DEVICES=${cuda_visible_devices} - -python -u predict.py --use_gpu ${use_gpu} --checkpoint_dir ${checkpoint_dir} --module ${module} +python -u predict.py $@ -- GitLab