提交 77443f64 编写于 作者: W wuzewu

Update task input

上级 a2575523
......@@ -27,27 +27,14 @@ def predict(args):
# define batch reader
data_reader = ObjectDetectionReader(dataset=dataset, model_type='rcnn')
input_dict, output_dict, program = module.context(trainable=True)
pred_input_dict, pred_output_dict, pred_program = module.context(
trainable=False, phase='predict')
feed_list = [
input_dict["image"].name, input_dict["im_info"].name,
input_dict['gt_bbox'].name, input_dict['gt_class'].name,
input_dict['is_crowd'].name
]
pred_feed_list = [
pred_input_dict['image'].name, pred_input_dict['im_info'].name,
pred_input_dict['im_shape'].name
]
feature = [
output_dict['head_feat'], output_dict['rpn_cls_loss'],
output_dict['rpn_reg_loss'], output_dict['generate_proposal_labels']
]
pred_feature = [pred_output_dict['head_feat'], pred_output_dict['rois']]
config = hub.RunConfig(
......@@ -62,8 +49,6 @@ def predict(args):
task = hub.FasterRCNNTask(
data_reader=data_reader,
num_classes=dataset.num_labels,
feed_list=feed_list,
feature=feature,
predict_feed_list=pred_feed_list,
predict_feature=pred_feature,
config=config)
......
......@@ -116,7 +116,7 @@ class DetectionTask(BaseTask):
if metrics_choices == "default":
metrics_choices = ["ap"]
main_program = feature[0].block.program
main_program = feature[0].block.program if feature else None
super(DetectionTask, self).__init__(
data_reader=data_reader,
main_program=main_program,
......
......@@ -28,12 +28,12 @@ from paddlehub.finetune.task.detection_task import DetectionTask
class FasterRCNNTask(DetectionTask):
def __init__(self,
data_reader,
num_classes,
feed_list,
feature,
predict_feed_list=None,
data_reader,
feature=None,
feed_list=None,
predict_feature=None,
predict_feed_list=None,
startup_program=None,
config=None,
metrics_choices="default"):
......
......@@ -24,10 +24,10 @@ from paddlehub.finetune.task.detection_task import DetectionTask
class SSDTask(DetectionTask):
def __init__(self,
data_reader,
feature,
num_classes,
feed_list,
feature,
data_reader,
multi_box_head_config,
startup_program=None,
config=None,
......
......@@ -26,10 +26,10 @@ from paddlehub.finetune.task.detection_task import DetectionTask
class YOLOTask(DetectionTask):
def __init__(self,
data_reader,
feature,
num_classes,
feed_list,
feature,
data_reader,
startup_program=None,
config=None,
metrics_choices="default"):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册