提交 a8c87185 编写于 作者: W wuzewu

update detection demo

上级 7e967e03
...@@ -12,21 +12,21 @@ from paddlehub.dataset.base_cv_dataset import ObjectDetectionDataset ...@@ -12,21 +12,21 @@ from paddlehub.dataset.base_cv_dataset import ObjectDetectionDataset
# yapf: disable # yapf: disable
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="Whether use GPU for fine-tuning.") parser.add_argument("--use_gpu", type=ast.literal_eval, default=True, help="Whether use GPU for fine-tuning.")
parser.add_argument("--checkpoint_dir", type=str, default="ssd_finetune_ckpt", help="Path to save log data.") parser.add_argument("--checkpoint_dir", type=str, default="yolo_finetune_ckpt", help="Path to save log data.")
parser.add_argument("--batch_size", type=int, default=8, help="Total examples' number in batch for training.") parser.add_argument("--batch_size", type=int, default=8, help="Total examples' number in batch for training.")
parser.add_argument("--module", type=str, default="ssd_vgg16_512_coco2017", help="Module used as feature extractor.") parser.add_argument("--module", type=str, default="yolov3_darknet53_coco2017", help="Module used as feature extractor.")
parser.add_argument("--dataset", type=str, default="coco_10", help="Dataset to finetune.") parser.add_argument("--dataset", type=str, default="coco_10", help="Dataset to finetune.")
# yapf: enable. # yapf: enable.
def predict(args): def predict(args):
module = hub.Module(name=args.module) module = hub.Module(name=args.module)
dataset = hub.dataset.Coco10('ssd') dataset = hub.dataset.Coco10('yolo')
print("dataset.num_labels:", dataset.num_labels) print("dataset.num_labels:", dataset.num_labels)
# define batch reader # define batch reader
data_reader = ObjectDetectionReader(dataset=dataset, model_type='ssd') data_reader = ObjectDetectionReader(dataset=dataset, model_type='yolo')
input_dict, output_dict, program = module.context(trainable=True) input_dict, output_dict, program = module.context(trainable=True)
feed_list = [input_dict["image"].name, input_dict["im_size"].name] feed_list = [input_dict["image"].name, input_dict["im_size"].name]
...@@ -41,12 +41,11 @@ def predict(args): ...@@ -41,12 +41,11 @@ def predict(args):
checkpoint_dir=args.checkpoint_dir, checkpoint_dir=args.checkpoint_dir,
strategy=hub.finetune.strategy.DefaultFinetuneStrategy()) strategy=hub.finetune.strategy.DefaultFinetuneStrategy())
task = hub.SSDTask( task = hub.YOLOTask(
data_reader=data_reader, data_reader=data_reader,
num_classes=dataset.num_labels, num_classes=dataset.num_labels,
feed_list=feed_list, feed_list=feed_list,
feature=feature, feature=feature,
multi_box_head_config=module.multi_box_head_config,
config=config) config=config)
data = [ data = [
......
...@@ -46,11 +46,11 @@ def finetune(args): ...@@ -46,11 +46,11 @@ def finetune(args):
] ]
feature = [ feature = [
output_dict['head_feat'], output_dict['rpn_cls_loss'], output_dict['head_features'], output_dict['rpn_cls_loss'],
output_dict['rpn_reg_loss'], output_dict['generate_proposal_labels'] output_dict['rpn_reg_loss'], output_dict['generate_proposal_labels']
] ]
pred_feature = [pred_output_dict['head_feat'], pred_output_dict['rois']] pred_feature = [pred_output_dict['head_features'], pred_output_dict['rois']]
config = hub.RunConfig( config = hub.RunConfig(
log_interval=10, log_interval=10,
......
...@@ -32,7 +32,7 @@ def finetune(args): ...@@ -32,7 +32,7 @@ def finetune(args):
input_dict, output_dict, program = module.context(trainable=True) input_dict, output_dict, program = module.context(trainable=True)
feed_list = [input_dict["image"].name, input_dict["im_size"].name] feed_list = [input_dict["image"].name, input_dict["im_size"].name]
feature = output_dict['head_features'] feature = output_dict['body_features']
config = hub.RunConfig( config = hub.RunConfig(
log_interval=10, log_interval=10,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册