diff --git a/demo/object_detection/predict_faster_rcnn.py b/demo/object_detection/predict_faster_rcnn.py index 0758f0caa1b5632ab2510d8cea0b9ff1a6173fd6..60e6bffdf9b8b3d8edaedbeb70e7982e1d2a8f41 100644 --- a/demo/object_detection/predict_faster_rcnn.py +++ b/demo/object_detection/predict_faster_rcnn.py @@ -27,27 +27,14 @@ def predict(args): # define batch reader data_reader = ObjectDetectionReader(dataset=dataset, model_type='rcnn') - - input_dict, output_dict, program = module.context(trainable=True) pred_input_dict, pred_output_dict, pred_program = module.context( trainable=False, phase='predict') - feed_list = [ - input_dict["image"].name, input_dict["im_info"].name, - input_dict['gt_bbox'].name, input_dict['gt_class'].name, - input_dict['is_crowd'].name - ] - pred_feed_list = [ pred_input_dict['image'].name, pred_input_dict['im_info'].name, pred_input_dict['im_shape'].name ] - feature = [ - output_dict['head_feat'], output_dict['rpn_cls_loss'], - output_dict['rpn_reg_loss'], output_dict['generate_proposal_labels'] - ] - pred_feature = [pred_output_dict['head_feat'], pred_output_dict['rois']] config = hub.RunConfig( @@ -62,8 +49,6 @@ def predict(args): task = hub.FasterRCNNTask( data_reader=data_reader, num_classes=dataset.num_labels, - feed_list=feed_list, - feature=feature, predict_feed_list=pred_feed_list, predict_feature=pred_feature, config=config) diff --git a/paddlehub/finetune/task/detection_task.py b/paddlehub/finetune/task/detection_task.py index 11c2afe81ef4a2eedadd2c7abca36aa4680d2a91..82f368c903a0d4ea9cba5e1ff3ebd553daca1e8e 100644 --- a/paddlehub/finetune/task/detection_task.py +++ b/paddlehub/finetune/task/detection_task.py @@ -116,7 +116,7 @@ class DetectionTask(BaseTask): if metrics_choices == "default": metrics_choices = ["ap"] - main_program = feature[0].block.program + main_program = feature[0].block.program if feature else None super(DetectionTask, self).__init__( data_reader=data_reader, main_program=main_program, diff --git a/paddlehub/finetune/task/faster_rcnn_task.py b/paddlehub/finetune/task/faster_rcnn_task.py index 2d4916cfcadab127c3d980a690d9cdb9968a13ea..4cd079df72b9a0904c016b89c420a42c81d6525a 100644 --- a/paddlehub/finetune/task/faster_rcnn_task.py +++ b/paddlehub/finetune/task/faster_rcnn_task.py @@ -28,12 +28,12 @@ from paddlehub.finetune.task.detection_task import DetectionTask class FasterRCNNTask(DetectionTask): def __init__(self, - data_reader, num_classes, - feed_list, - feature, - predict_feed_list=None, + data_reader, + feature=None, + feed_list=None, predict_feature=None, + predict_feed_list=None, startup_program=None, config=None, metrics_choices="default"): diff --git a/paddlehub/finetune/task/ssd_task.py b/paddlehub/finetune/task/ssd_task.py index 4b99e7c14a590321e1fccc4c4a2531a298605f81..c571f34551b306569a52197c58b8be73e8b57ba8 100644 --- a/paddlehub/finetune/task/ssd_task.py +++ b/paddlehub/finetune/task/ssd_task.py @@ -24,10 +24,10 @@ from paddlehub.finetune.task.detection_task import DetectionTask class SSDTask(DetectionTask): def __init__(self, - data_reader, + feature, num_classes, feed_list, - feature, + data_reader, multi_box_head_config, startup_program=None, config=None, diff --git a/paddlehub/finetune/task/yolo_task.py b/paddlehub/finetune/task/yolo_task.py index 07b499ba8b67d84f3b13e2f857b755cffeeb163d..a13ba3e1e99e28d9c804a67827e2c32ef09a7f26 100644 --- a/paddlehub/finetune/task/yolo_task.py +++ b/paddlehub/finetune/task/yolo_task.py @@ -26,10 +26,10 @@ from paddlehub.finetune.task.detection_task import DetectionTask class YOLOTask(DetectionTask): def __init__(self, - data_reader, + feature, num_classes, feed_list, - feature, + data_reader, startup_program=None, config=None, metrics_choices="default"):