提交 5bd9a50a 编写于 作者: W wuzewu

update image classification demo options

上级 5a829a9e
......@@ -34,8 +34,8 @@ $ pip install --upgrade paddlepaddle
--num_epoch: finetune迭代的轮数。默认为1
--module: 使用哪个Module作为finetune的特征提取器,脚本支持{resnet50/resnet101/resnet152/mobilenet/nasnet/pnasnet}等模型。默认为resnet50
--checkpoint_dir: 模型保存路径,PaddleHub会自动保存验证集上表现最好的模型。默认为paddlehub_finetune_ckpt
--dataset: 使用什么数据集进行finetune, 脚本支持分别是{flowers/dogcat}。默认为flowers
--use_gpu: 使用使用GPU进行训练,如果机器支持GPU且安装了GPU版本的PaddlePaddle,我们建议您打开这个开关。默认关闭
--dataset: 使用什么数据集进行finetune, 脚本支持分别是{flowers/dogcat/stanforddogs/indoor67/food101}。默认为flowers
--use_gpu: 是否使用GPU进行训练,如果机器支持GPU且安装了GPU版本的PaddlePaddle,我们建议您打开这个开关。默认关闭
```
## 进行预测
......
......@@ -29,7 +29,19 @@ def finetune(args):
module = hub.Module(name=args.module)
input_dict, output_dict, program = module.context(trainable=True)
if args.dataset.lower() == "flowers":
dataset = hub.dataset.Flowers()
elif args.dataset.lower() == "dogcat":
dataset = hub.dataset.DogCat()
elif args.dataset.lower() == "indoor67":
dataset = hub.dataset.Indoor67()
elif args.dataset.lower() == "food101":
dataset = hub.dataset.Food101()
elif args.dataset.lower() == "stanforddogs":
dataset = hub.dataset.StanfordDogs()
else:
raise ValueError("%s dataset is not defined" % args.dataset)
data_reader = hub.reader.ImageClassificationReader(
image_width=module.get_expected_image_width(),
image_height=module.get_expected_image_height(),
......
......@@ -25,10 +25,18 @@ module_map = {
def predict(args):
if args.dataset == "dogcat":
dataset = hub.dataset.DogCat()
elif args.dataset == "flowers":
if args.dataset.lower() == "flowers":
dataset = hub.dataset.Flowers()
elif args.dataset.lower() == "dogcat":
dataset = hub.dataset.DogCat()
elif args.dataset.lower() == "indoor67":
dataset = hub.dataset.Indoor67()
elif args.dataset.lower() == "food101":
dataset = hub.dataset.Food101()
elif args.dataset.lower() == "stanforddogs":
dataset = hub.dataset.StanfordDogs()
else:
raise ValueError("%s dataset is not defined" % args.dataset)
label_map = dataset.label_dict()
num_labels = len(label_map)
......
cuda_visible_devices=0
module=resnet50
num_epoch=1
batch_size=16
use_gpu=False
checkpoint_dir=paddlehub_finetune_ckpt
while getopts "gm:n:b:c:d:" options
do
case "$options" in
m)
module=$OPTARG;;
n)
num_epoch=$OPTARG;;
b)
batch_size=$OPTARG;;
c)
checkpoint_dir=$OPTARG;;
d)
cuda_visible_devices=$OPTARG;;
g)
use_gpu=True;;
?)
echo "unknown options"
exit 1;;
esac
done
export CUDA_VISIBLE_DEVICES=${cuda_visible_devices}
python -u img_classifier.py --use_gpu ${use_gpu} --batch_size ${batch_size} --checkpoint_dir ${checkpoint_dir} --num_epoch ${num_epoch} --module ${module}
python -u img_classifier.py $@
cuda_visible_devices=0
module=resnet50
use_gpu=False
checkpoint_dir=paddlehub_finetune_ckpt
while getopts "gm:c:d:" options
do
case "$options" in
m)
module=$OPTARG;;
c)
checkpoint_dir=$OPTARG;;
d)
cuda_visible_devices=$OPTARG;;
g)
use_gpu=True;;
?)
echo "unknown options"
exit 1;;
esac
done
export CUDA_VISIBLE_DEVICES=${cuda_visible_devices}
python -u predict.py --use_gpu ${use_gpu} --checkpoint_dir ${checkpoint_dir} --module ${module}
python -u predict.py $@
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册