From 77443f64cd807f5fe5007ee02cc06b6d00648347 Mon Sep 17 00:00:00 2001 From: wuzewu Date: Tue, 28 Apr 2020 16:59:25 +0800 Subject: [PATCH] Update task input --- demo/object_detection/predict_faster_rcnn.py | 15 --------------- paddlehub/finetune/task/detection_task.py | 2 +- paddlehub/finetune/task/faster_rcnn_task.py | 8 ++++---- paddlehub/finetune/task/ssd_task.py | 4 ++-- paddlehub/finetune/task/yolo_task.py | 4 ++-- 5 files changed, 9 insertions(+), 24 deletions(-) diff --git a/demo/object_detection/predict_faster_rcnn.py b/demo/object_detection/predict_faster_rcnn.py index 0758f0ca..60e6bffd 100644 --- a/demo/object_detection/predict_faster_rcnn.py +++ b/demo/object_detection/predict_faster_rcnn.py @@ -27,27 +27,14 @@ def predict(args): # define batch reader data_reader = ObjectDetectionReader(dataset=dataset, model_type='rcnn') - - input_dict, output_dict, program = module.context(trainable=True) pred_input_dict, pred_output_dict, pred_program = module.context( trainable=False, phase='predict') - feed_list = [ - input_dict["image"].name, input_dict["im_info"].name, - input_dict['gt_bbox'].name, input_dict['gt_class'].name, - input_dict['is_crowd'].name - ] - pred_feed_list = [ pred_input_dict['image'].name, pred_input_dict['im_info'].name, pred_input_dict['im_shape'].name ] - feature = [ - output_dict['head_feat'], output_dict['rpn_cls_loss'], - output_dict['rpn_reg_loss'], output_dict['generate_proposal_labels'] - ] - pred_feature = [pred_output_dict['head_feat'], pred_output_dict['rois']] config = hub.RunConfig( @@ -62,8 +49,6 @@ def predict(args): task = hub.FasterRCNNTask( data_reader=data_reader, num_classes=dataset.num_labels, - feed_list=feed_list, - feature=feature, predict_feed_list=pred_feed_list, predict_feature=pred_feature, config=config) diff --git a/paddlehub/finetune/task/detection_task.py b/paddlehub/finetune/task/detection_task.py index 11c2afe8..82f368c9 100644 --- a/paddlehub/finetune/task/detection_task.py +++ b/paddlehub/finetune/task/detection_task.py @@ -116,7 +116,7 @@ class DetectionTask(BaseTask): if metrics_choices == "default": metrics_choices = ["ap"] - main_program = feature[0].block.program + main_program = feature[0].block.program if feature else None super(DetectionTask, self).__init__( data_reader=data_reader, main_program=main_program, diff --git a/paddlehub/finetune/task/faster_rcnn_task.py b/paddlehub/finetune/task/faster_rcnn_task.py index 2d4916cf..4cd079df 100644 --- a/paddlehub/finetune/task/faster_rcnn_task.py +++ b/paddlehub/finetune/task/faster_rcnn_task.py @@ -28,12 +28,12 @@ from paddlehub.finetune.task.detection_task import DetectionTask class FasterRCNNTask(DetectionTask): def __init__(self, - data_reader, num_classes, - feed_list, - feature, - predict_feed_list=None, + data_reader, + feature=None, + feed_list=None, predict_feature=None, + predict_feed_list=None, startup_program=None, config=None, metrics_choices="default"): diff --git a/paddlehub/finetune/task/ssd_task.py b/paddlehub/finetune/task/ssd_task.py index 4b99e7c1..c571f345 100644 --- a/paddlehub/finetune/task/ssd_task.py +++ b/paddlehub/finetune/task/ssd_task.py @@ -24,10 +24,10 @@ from paddlehub.finetune.task.detection_task import DetectionTask class SSDTask(DetectionTask): def __init__(self, - data_reader, + feature, num_classes, feed_list, - feature, + data_reader, multi_box_head_config, startup_program=None, config=None, diff --git a/paddlehub/finetune/task/yolo_task.py b/paddlehub/finetune/task/yolo_task.py index 07b499ba..a13ba3e1 100644 --- a/paddlehub/finetune/task/yolo_task.py +++ b/paddlehub/finetune/task/yolo_task.py @@ -26,10 +26,10 @@ from paddlehub.finetune.task.detection_task import DetectionTask class YOLOTask(DetectionTask): def __init__(self, - data_reader, + feature, num_classes, feed_list, - feature, + data_reader, startup_program=None, config=None, metrics_choices="default"): -- GitLab