未验证 提交 2f6b2ec2 编写于 作者: W wangguanzhong 提交者: GitHub

support fcos on voc (#902)

上级 f1b91931
...@@ -75,7 +75,7 @@ OptimizerBuilder: ...@@ -75,7 +75,7 @@ OptimizerBuilder:
TrainReader: TrainReader:
inputs_def: inputs_def:
fields: ['image', 'gt_bbox', 'gt_class', 'gt_score', 'im_info'] fields: ['image', 'im_info', 'fcos_target']
dataset: dataset:
!COCODataSet !COCODataSet
image_dir: train2017 image_dir: train2017
......
...@@ -74,7 +74,7 @@ OptimizerBuilder: ...@@ -74,7 +74,7 @@ OptimizerBuilder:
TrainReader: TrainReader:
inputs_def: inputs_def:
fields: ['image', 'gt_bbox', 'gt_class', 'gt_score', 'im_info'] fields: ['image', 'im_info', 'fcos_target']
dataset: dataset:
!COCODataSet !COCODataSet
image_dir: train2017 image_dir: train2017
......
...@@ -74,7 +74,7 @@ OptimizerBuilder: ...@@ -74,7 +74,7 @@ OptimizerBuilder:
TrainReader: TrainReader:
inputs_def: inputs_def:
fields: ['image', 'gt_bbox', 'gt_class', 'gt_score', 'im_info'] fields: ['image', 'im_info', 'fcos_target']
dataset: dataset:
!COCODataSet !COCODataSet
image_dir: train2017 image_dir: train2017
......
...@@ -107,7 +107,7 @@ class FCOS(object): ...@@ -107,7 +107,7 @@ class FCOS(object):
'is_difficult': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1} 'is_difficult': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}
} }
# yapf: disable # yapf: disable
if 'gt_bbox' in fields: if 'fcos_target' in fields:
targets_def = { targets_def = {
'labels0': {'shape': [None, None, None, 1], 'dtype': 'int32', 'lod_level': 0}, 'labels0': {'shape': [None, None, None, 1], 'dtype': 'int32', 'lod_level': 0},
'reg_target0': {'shape': [None, None, None, 4], 'dtype': 'float32', 'lod_level': 0}, 'reg_target0': {'shape': [None, None, None, 4], 'dtype': 'float32', 'lod_level': 0},
...@@ -152,16 +152,15 @@ class FCOS(object): ...@@ -152,16 +152,15 @@ class FCOS(object):
def build_inputs( def build_inputs(
self, self,
image_shape=[3, None, None], image_shape=[3, None, None],
fields=[ fields=['image', 'im_info', 'fcos_target'], # for-train
'image', 'im_shape', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd'
], # for-train
use_dataloader=True, use_dataloader=True,
iterable=False): iterable=False):
inputs_def = self._inputs_def(image_shape, fields) inputs_def = self._inputs_def(image_shape, fields)
if "gt_bbox" in fields: if "fcos_target" in fields:
for i in range(len(self.fcos_head.fpn_stride)): for i in range(len(self.fcos_head.fpn_stride)):
fields.extend( fields.extend(
['labels%d' % i, 'reg_target%d' % i, 'centerness%d' % i]) ['labels%d' % i, 'reg_target%d' % i, 'centerness%d' % i])
fields.remove('fcos_target')
feed_vars = OrderedDict([(key, fluid.data( feed_vars = OrderedDict([(key, fluid.data(
name=key, name=key,
shape=inputs_def[key]['shape'], shape=inputs_def[key]['shape'],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册