提交 6a680bcd 编写于 作者: Z zhangxuefei

add autofinetune-cv demo

上级 9a433ca4
# PaddleHub超参优化——图像分类
本示例展示如何利用PaddleHub超参优化Auto Finetune,得到一个效果较佳的超参数组合
使用PaddleHub Auto Fine-tune需要准备两个指定格式的文件:待优化的超参数信息yaml文件hparam.yaml和需要Fine-tune的python脚本train.py
以Fine-tune图像分类任务为例, 其中:
## hparam.yaml
hparam给出待搜索的超参名字、类型(int或者float,离散型和连续型的两种超参)、搜索范围等信息。
通过这些信息构建了一个超参空间,PaddleHub将在这个空间内进行超参数的搜索,将搜索到的超参传入train.py获得评估效果,根据评估效果自动调整超参搜索方向,直到满足搜索次数。
本示例中待优化超参数为learning_rate和batch_size。
## img_cls.py
以mobilenet为预训练模型,在flowers数据集上进行Finetune。
`NOTE`: 关于PaddleHub超参优化详情参考[教程](https://github.com/PaddlePaddle/PaddleHub/blob/release/v1.2/tutorial/autofinetune.md)
param_list:
- name : learning_rate
init_value : 0.001
type : float
lower_than : 0.05
greater_than : 0.00005
- name : batch_size
init_value : 12
type : int
lower_than : 20
greater_than : 10
# coding:utf-8
import argparse
import os
import ast
import shutil
import paddle.fluid as fluid
import paddlehub as hub
from paddlehub.common.logger import logger
parser = argparse.ArgumentParser(__doc__)
parser.add_argument(
"--epochs", type=int, default=5, help="Number of epoches for fine-tuning.")
parser.add_argument(
"--use_gpu",
type=ast.literal_eval,
default=True,
help="Whether use GPU for fine-tuning.")
parser.add_argument(
"--checkpoint_dir", type=str, default=None, help="Path to save log data.")
# the name of hyperparameters to be searched should keep with hparam.py
parser.add_argument(
"--batch_size",
type=int,
default=16,
help="Total examples' number in batch for training.")
parser.add_argument(
"--learning_rate", type=float, default=1e-4, help="learning_rate.")
# saved_params_dir and model_path are needed by auto finetune
parser.add_argument(
"--saved_params_dir",
type=str,
default="",
help="Directory for saving model")
parser.add_argument(
"--model_path", type=str, default="", help="load model path")
def is_path_valid(path):
if path == "":
return False
path = os.path.abspath(path)
dirname = os.path.dirname(path)
if not os.path.exists(dirname):
os.mkdir(dirname)
return True
def finetune(args):
# Load Paddlehub resnet50 pretrained model
module = hub.Module(name="mobilenet_v2_imagenet")
input_dict, output_dict, program = module.context(trainable=True)
# Download dataset and use ImageClassificationReader to read dataset
dataset = hub.dataset.Flowers()
data_reader = hub.reader.ImageClassificationReader(
image_width=module.get_expected_image_width(),
image_height=module.get_expected_image_height(),
images_mean=module.get_pretrained_images_mean(),
images_std=module.get_pretrained_images_std(),
dataset=dataset)
feature_map = output_dict["feature_map"]
img = input_dict["image"]
feed_list = [img.name]
# Select finetune strategy, setup config and finetune
strategy = hub.DefaultFinetuneStrategy(learning_rate=args.learning_rate)
config = hub.RunConfig(
use_cuda=True,
num_epoch=args.epochs,
batch_size=args.batch_size,
checkpoint_dir=args.checkpoint_dir,
strategy=strategy)
# Construct transfer learning network
task = hub.ImageClassifierTask(
data_reader=data_reader,
feed_list=feed_list,
feature=feature_map,
num_classes=dataset.num_labels,
config=config)
# Load model from the defined model path or not
if args.model_path != "":
with task.phase_guard(phase="train"):
task.init_if_necessary()
task.load_parameters(args.model_path)
logger.info("PaddleHub has loaded model from %s" % args.model_path)
task.finetune()
run_states = task.eval()
eval_avg_score, eval_avg_loss, eval_run_speed = task._calculate_metrics(
run_states)
# Move ckpt/best_model to the defined saved parameters directory
best_model_dir = os.path.join(config.checkpoint_dir, "best_model")
if is_path_valid(args.saved_params_dir) and os.path.exists(best_model_dir):
shutil.copytree(best_model_dir, args.saved_params_dir)
shutil.rmtree(config.checkpoint_dir)
# acc on dev will be used by auto finetune
print("AutoFinetuneEval" + "\t" + str(float(eval_avg_score["acc"])))
if __name__ == "__main__":
args = parser.parse_args()
finetune(args)
OUTPUT=result/
hub autofinetune img_cls.py \
--param_file=hparam.yaml \
--cuda=['6'] \
--popsize=5 \
--round=10 \
--output_dir=${OUTPUT} \
--evaluate_choice=fulltrail \
--tuning_strategy=pshe2
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册