未验证 提交 26e7c73e 编写于 作者: K Kaipeng Deng 提交者: GitHub

remove fields config in input_def (#1921)

* remove fields config in input_def
上级 7e710c1c
...@@ -6,6 +6,7 @@ TrainDataset: ...@@ -6,6 +6,7 @@ TrainDataset:
image_dir: train2017 image_dir: train2017
anno_path: annotations/instances_train2017.json anno_path: annotations/instances_train2017.json
dataset_dir: dataset/coco dataset_dir: dataset/coco
data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
EvalDataset: EvalDataset:
!COCODataSet !COCODataSet
......
metric: COCO
num_classes: 80
TrainDataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: dataset/coco
data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_poly', 'is_crowd']
EvalDataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco
TestDataset:
!ImageFolder
anno_path: annotations/instances_val2017.json
...@@ -6,12 +6,14 @@ TrainDataset: ...@@ -6,12 +6,14 @@ TrainDataset:
dataset_dir: dataset/voc dataset_dir: dataset/voc
anno_path: trainval.txt anno_path: trainval.txt
label_list: label_list.txt label_list: label_list.txt
data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']
EvalDataset: EvalDataset:
!VOCDataSet !VOCDataSet
dataset_dir: dataset/voc dataset_dir: dataset/voc
anno_path: test.txt anno_path: test.txt
label_list: label_list.txt label_list: label_list.txt
data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']
TestDataset: TestDataset:
!ImageFolder !ImageFolder
......
epoch: 240
LearningRate:
base_lr: 0.001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 160
- 200
- !LinearWarmup
start_factor: 0.3333333333333333
steps: 500
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
worker_num: 2 worker_num: 2
TrainReader: TrainReader:
inputs_def:
fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd']
sample_transforms: sample_transforms:
- DecodeImage: {to_rgb: true} - DecodeOp: { }
- RandomFlipImage: {prob: 0.5} - RandomFlipImage: {prob: 0.5}
- NormalizeImage: {is_channel_first: false, is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - NormalizeImage: {is_channel_first: false, is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- ResizeImage: {target_size: 800, max_size: 1333, interp: 1, use_cv2: true} - ResizeImage: {target_size: 800, max_size: 1333, interp: 1, use_cv2: true}
...@@ -16,8 +14,6 @@ TrainReader: ...@@ -16,8 +14,6 @@ TrainReader:
EvalReader: EvalReader:
inputs_def:
fields: ['image', 'im_shape', 'scale_factor', 'im_id']
sample_transforms: sample_transforms:
- DecodeOp: { } - DecodeOp: { }
- NormalizeImageOp: { is_scale: true, mean: [ 0.485,0.456,0.406 ], std: [ 0.229, 0.224,0.225 ] } - NormalizeImageOp: { is_scale: true, mean: [ 0.485,0.456,0.406 ], std: [ 0.229, 0.224,0.225 ] }
...@@ -32,8 +28,6 @@ EvalReader: ...@@ -32,8 +28,6 @@ EvalReader:
TestReader: TestReader:
inputs_def:
fields: ['image', 'im_shape', 'scale_factor', 'im_id']
sample_transforms: sample_transforms:
- DecodeOp: { } - DecodeOp: { }
- NormalizeImageOp: { is_scale: true, mean: [ 0.485,0.456,0.406 ], std: [ 0.229, 0.224,0.225 ] } - NormalizeImageOp: { is_scale: true, mean: [ 0.485,0.456,0.406 ], std: [ 0.229, 0.224,0.225 ] }
......
worker_num: 2 worker_num: 2
TrainReader: TrainReader:
inputs_def:
fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd']
sample_transforms: sample_transforms:
- DecodeImage: {to_rgb: true} - DecodeOp: { }
- RandomFlipImage: {prob: 0.5} - RandomFlipImage: {prob: 0.5}
- NormalizeImage: {is_channel_first: false, is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - NormalizeImage: {is_channel_first: false, is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- ResizeImage: {target_size: 800, max_size: 1333, interp: 1, use_cv2: true} - ResizeImage: {target_size: 800, max_size: 1333, interp: 1, use_cv2: true}
...@@ -16,8 +14,6 @@ TrainReader: ...@@ -16,8 +14,6 @@ TrainReader:
EvalReader: EvalReader:
inputs_def:
fields: ['image', 'im_shape', 'scale_factor', 'im_id']
sample_transforms: sample_transforms:
- DecodeOp: { } - DecodeOp: { }
- NormalizeImageOp: { is_scale: true, mean: [ 0.485,0.456,0.406 ], std: [ 0.229, 0.224,0.225 ] } - NormalizeImageOp: { is_scale: true, mean: [ 0.485,0.456,0.406 ], std: [ 0.229, 0.224,0.225 ] }
...@@ -32,8 +28,6 @@ EvalReader: ...@@ -32,8 +28,6 @@ EvalReader:
TestReader: TestReader:
inputs_def:
fields: ['image', 'im_shape', 'scale_factor', 'im_id']
sample_transforms: sample_transforms:
- DecodeOp: { } - DecodeOp: { }
- NormalizeImageOp: { is_scale: true, mean: [ 0.485,0.456,0.406 ], std: [ 0.229, 0.224,0.225 ] } - NormalizeImageOp: { is_scale: true, mean: [ 0.485,0.456,0.406 ], std: [ 0.229, 0.224,0.225 ] }
......
worker_num: 2 worker_num: 2
TrainReader: TrainReader:
inputs_def:
fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd', 'gt_poly']
sample_transforms: sample_transforms:
- DecodeImage: {to_rgb: true} - DecodeOp: {}
- RandomFlipImage: {prob: 0.5, is_mask_flip: true} - RandomFlipImage: {prob: 0.5, is_mask_flip: true}
- NormalizeImage: {is_channel_first: false, is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - NormalizeImage: {is_channel_first: false, is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- ResizeImage: {target_size: 800, max_size: 1333, interp: 1, use_cv2: true} - ResizeImage: {target_size: 800, max_size: 1333, interp: 1, use_cv2: true}
...@@ -16,8 +14,6 @@ TrainReader: ...@@ -16,8 +14,6 @@ TrainReader:
EvalReader: EvalReader:
inputs_def:
fields: ['image', 'im_shape', 'scale_factor', 'im_id']
sample_transforms: sample_transforms:
- DecodeOp: {} - DecodeOp: {}
- NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
...@@ -32,8 +28,6 @@ EvalReader: ...@@ -32,8 +28,6 @@ EvalReader:
TestReader: TestReader:
inputs_def:
fields: ['image', 'im_shape', 'scale_factor', 'im_id']
sample_transforms: sample_transforms:
- DecodeOp: {} - DecodeOp: {}
- NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
......
worker_num: 2 worker_num: 2
TrainReader: TrainReader:
inputs_def:
fields: ['image', 'im_info', 'im_id', 'gt_bbox', 'gt_class', 'is_crowd', 'gt_poly']
sample_transforms: sample_transforms:
- DecodeImage: {to_rgb: true} - DecodeOp: {}
- RandomFlipImage: {prob: 0.5, is_mask_flip: true} - RandomFlipImage: {prob: 0.5, is_mask_flip: true}
- NormalizeImage: {is_channel_first: false, is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - NormalizeImage: {is_channel_first: false, is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- ResizeImage: {target_size: 800, max_size: 1333, interp: 1, use_cv2: true} - ResizeImage: {target_size: 800, max_size: 1333, interp: 1, use_cv2: true}
...@@ -16,8 +14,6 @@ TrainReader: ...@@ -16,8 +14,6 @@ TrainReader:
EvalReader: EvalReader:
inputs_def:
fields: ['image', 'im_shape', 'scale_factor', 'im_id']
sample_transforms: sample_transforms:
- DecodeOp: {} - DecodeOp: {}
- NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
...@@ -32,8 +28,6 @@ EvalReader: ...@@ -32,8 +28,6 @@ EvalReader:
TestReader: TestReader:
inputs_def:
fields: ['image', 'im_shape', 'scale_factor', 'im_id']
sample_transforms: sample_transforms:
- DecodeOp: {} - DecodeOp: {}
- NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]} - NormalizeImageOp: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
......
worker_num: 2 worker_num: 2
TrainReader: TrainReader:
inputs_def: inputs_def:
fields: ['image', 'gt_bbox', 'gt_class']
num_max_boxes: 90 num_max_boxes: 90
sample_transforms: sample_transforms:
...@@ -24,8 +23,6 @@ TrainReader: ...@@ -24,8 +23,6 @@ TrainReader:
EvalReader: EvalReader:
inputs_def:
fields: ['image', 'im_shape', 'scale_factor', 'im_id', 'gt_bbox', 'gt_class', 'difficult']
sample_transforms: sample_transforms:
- DecodeOp: {} - DecodeOp: {}
- ResizeOp: {target_size: [300, 300], keep_ratio: False, interp: 1} - ResizeOp: {target_size: [300, 300], keep_ratio: False, interp: 1}
...@@ -37,7 +34,6 @@ EvalReader: ...@@ -37,7 +34,6 @@ EvalReader:
TestReader: TestReader:
inputs_def: inputs_def:
image_shape: [3, 300, 300] image_shape: [3, 300, 300]
fields: ['image', 'im_shape', 'scale_factor', 'im_id']
sample_transforms: sample_transforms:
- DecodeOp: {} - DecodeOp: {}
- ResizeOp: {target_size: [300, 300], keep_ratio: False, interp: 1} - ResizeOp: {target_size: [300, 300], keep_ratio: False, interp: 1}
......
worker_num: 2 worker_num: 2
TrainReader: TrainReader:
inputs_def: inputs_def:
fields: ['image', 'gt_bbox', 'gt_class', 'gt_score', 'im_shape', 'scale_factor']
num_max_boxes: 50 num_max_boxes: 50
sample_transforms: sample_transforms:
- DecodeOp: {} - DecodeOp: {}
...@@ -26,13 +25,11 @@ TrainReader: ...@@ -26,13 +25,11 @@ TrainReader:
EvalReader: EvalReader:
inputs_def: inputs_def:
fields: ['image', 'im_shape', 'scale_factor', 'im_id']
num_max_boxes: 50 num_max_boxes: 50
sample_transforms: sample_transforms:
- DecodeOp: {} - DecodeOp: {}
- ResizeOp: {target_size: [608, 608], keep_ratio: False, interp: 2} - ResizeOp: {target_size: [608, 608], keep_ratio: False, interp: 2}
- NormalizeImageOp: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} - NormalizeImageOp: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- PadBoxOp: {num_max_boxes: 50}
- PermuteOp: {} - PermuteOp: {}
batch_size: 1 batch_size: 1
drop_empty: false drop_empty: false
...@@ -40,7 +37,6 @@ EvalReader: ...@@ -40,7 +37,6 @@ EvalReader:
TestReader: TestReader:
inputs_def: inputs_def:
image_shape: [3, 608, 608] image_shape: [3, 608, 608]
fields: ['image', 'im_shape', 'scale_factor', 'im_id']
sample_transforms: sample_transforms:
- DecodeOp: {} - DecodeOp: {}
- ResizeOp: {target_size: [608, 608], keep_ratio: False, interp: 2} - ResizeOp: {target_size: [608, 608], keep_ratio: False, interp: 2}
......
_BASE_: [ _BASE_: [
'./_base_/models/cascade_mask_rcnn_r50_fpn.yml', './_base_/models/cascade_mask_rcnn_r50_fpn.yml',
'./_base_/optimizers/rcnn_1x.yml', './_base_/optimizers/rcnn_1x.yml',
'./_base_/datasets/coco.yml', './_base_/datasets/coco_instance.yml',
'./_base_/readers/mask_fpn_reader.yml', './_base_/readers/mask_fpn_reader.yml',
'./_base_/runtime.yml', './_base_/runtime.yml',
] ]
_BASE_: [ _BASE_: [
'./_base_/models/cascade_rcnn_r50_fpn.yml', './_base_/models/cascade_rcnn_r50_fpn.yml',
'./_base_/optimizers/rcnn_1x.yml', './_base_/optimizers/rcnn_1x.yml',
'./_base_/datasets/coco.yml', './_base_/datasets/coco_detection.yml',
'./_base_/readers/faster_fpn_reader.yml', './_base_/readers/faster_fpn_reader.yml',
'./_base_/runtime.yml', './_base_/runtime.yml',
] ]
_BASE_: [ _BASE_: [
'./_base_/models/faster_rcnn_r50.yml', './_base_/models/faster_rcnn_r50.yml',
'./_base_/optimizers/rcnn_1x.yml', './_base_/optimizers/rcnn_1x.yml',
'./_base_/datasets/coco.yml', './_base_/datasets/coco_detection.yml',
'./_base_/readers/faster_reader.yml', './_base_/readers/faster_reader.yml',
'./_base_/runtime.yml', './_base_/runtime.yml',
] ]
_BASE_: [ _BASE_: [
'./_base_/models/faster_rcnn_r50_fpn.yml', './_base_/models/faster_rcnn_r50_fpn.yml',
'./_base_/optimizers/rcnn_1x.yml', './_base_/optimizers/rcnn_1x.yml',
'./_base_/datasets/coco.yml', './_base_/datasets/coco_detection.yml',
'./_base_/readers/faster_fpn_reader.yml', './_base_/readers/faster_fpn_reader.yml',
'./_base_/runtime.yml', './_base_/runtime.yml',
] ]
_BASE_: [ _BASE_: [
'./_base_/models/mask_rcnn_r50.yml', './_base_/models/mask_rcnn_r50.yml',
'./_base_/optimizers/rcnn_1x.yml', './_base_/optimizers/rcnn_1x.yml',
'./_base_/datasets/coco.yml', './_base_/datasets/coco_instance.yml',
'./_base_/readers/mask_reader.yml', './_base_/readers/mask_reader.yml',
'./_base_/runtime.yml', './_base_/runtime.yml',
] ]
_BASE_: [ _BASE_: [
'./_base_/models/mask_rcnn_r50_fpn.yml', './_base_/models/mask_rcnn_r50_fpn.yml',
'./_base_/optimizers/rcnn_1x.yml', './_base_/optimizers/rcnn_1x.yml',
'./_base_/datasets/coco.yml', './_base_/datasets/coco_instance.yml',
'./_base_/readers/mask_fpn_reader.yml', './_base_/readers/mask_fpn_reader.yml',
'./_base_/runtime.yml', './_base_/runtime.yml',
] ]
_BASE_: [
'./_base_/models/ssd_vgg16_300.yml',
'./_base_/optimizers/ssd_120e.yml',
'./_base_/datasets/coco.yml',
'./_base_/readers/ssd_reader.yml',
'./_base_/runtime.yml',
]
_BASE_: [ _BASE_: [
'./_base_/models/yolov3_darknet53.yml', './_base_/models/yolov3_darknet53.yml',
'./_base_/optimizers/yolov3_270e.yml', './_base_/optimizers/yolov3_270e.yml',
'./_base_/datasets/coco.yml', './_base_/datasets/coco_detection.yml',
'./_base_/readers/yolov3_reader.yml', './_base_/readers/yolov3_reader.yml',
'./_base_/runtime.yml', './_base_/runtime.yml',
] ]
...@@ -16,6 +16,7 @@ import copy ...@@ -16,6 +16,7 @@ import copy
import traceback import traceback
import six import six
import sys import sys
import multiprocessing as mp
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
import queue as Queue import queue as Queue
else: else:
...@@ -27,45 +28,42 @@ from paddle.io import DistributedBatchSampler ...@@ -27,45 +28,42 @@ from paddle.io import DistributedBatchSampler
from ppdet.core.workspace import register, serializable, create from ppdet.core.workspace import register, serializable, create
from . import transform from . import transform
from .transform import operator, batch_operator
from ppdet.utils.logger import setup_logger from ppdet.utils.logger import setup_logger
logger = setup_logger('reader') logger = setup_logger('reader')
class Compose(object): class Compose(object):
def __init__(self, transforms, fields=None, from_=transform, def __init__(self, transforms, num_classes=81):
num_classes=81):
self.transforms = transforms self.transforms = transforms
self.transforms_cls = [] self.transforms_cls = []
output_fields = None
for t in self.transforms: for t in self.transforms:
for k, v in t.items(): for k, v in t.items():
op_cls = getattr(from_, k) op_cls = getattr(transform, k)
self.transforms_cls.append(op_cls(**v)) self.transforms_cls.append(op_cls(**v))
if hasattr(op_cls, 'num_classes'): if hasattr(op_cls, 'num_classes'):
op_cls.num_classes = num_classes op_cls.num_classes = num_classes
# TODO: should be refined in the future def __call__(self, data):
if op_cls in [ for f in self.transforms_cls:
transform.Gt2YoloTargetOp, transform.Gt2YoloTarget try:
]: data = f(data)
output_fields = ['image', 'gt_bbox'] except Exception as e:
output_fields.extend([ stack_info = traceback.format_exc()
'target{}'.format(i) logger.warn("fail to map op [{}] with error: {} and stack:\n{}".
for i in range(len(v['anchor_masks'])) format(f, e, str(stack_info)))
]) raise e
return data
self.fields = fields
self.output_fields = output_fields if output_fields else fields
def __call__(self, data): class BatchCompose(Compose):
if self.fields is not None: def __init__(self, transforms, num_classes=81):
data_new = [] super(BatchCompose, self).__init__(transforms, num_classes)
for item in data: self.output_fields = mp.Manager().list([])
data_new.append(dict(zip(self.fields, item))) self.lock = mp.Lock()
data = data_new
def __call__(self, data):
for f in self.transforms_cls: for f in self.transforms_cls:
try: try:
data = f(data) data = f(data)
...@@ -75,23 +73,27 @@ class Compose(object): ...@@ -75,23 +73,27 @@ class Compose(object):
format(f, e, str(stack_info))) format(f, e, str(stack_info)))
raise e raise e
if self.output_fields is not None: # parse output fields by first sample
data_new = [] # **this shoule be fixed if paddle.io.DataLoader support**
for item in data: # For paddle.io.DataLoader not support dict currently,
batch = [] # we need to parse the key from the first sample,
for k in self.output_fields: # BatchCompose.__call__ will be called in each worker
batch.append(item[k]) # process, so lock is need here.
data_new.append(batch) if len(self.output_fields) == 0:
batch_size = len(data_new) self.lock.acquire()
data_new = list(zip(*data_new)) if len(self.output_fields) == 0:
if batch_size > 1: for k, v in data[0].items():
data = [ # FIXME(dkp): for more elegent coding
np.array(item).astype(item[0].dtype) for item in data_new if k not in ['flipped', 'h', 'w']:
] self.output_fields.append(k)
else: self.lock.release()
data = data_new
data = [[data[i][k] for k in self.output_fields]
return data for i in range(len(data))]
data = list(zip(*data))
batch_data = [np.stack(d, axis=0) for d in data]
return batch_data
class BaseDataLoader(object): class BaseDataLoader(object):
...@@ -99,8 +101,8 @@ class BaseDataLoader(object): ...@@ -99,8 +101,8 @@ class BaseDataLoader(object):
def __init__(self, def __init__(self,
inputs_def=None, inputs_def=None,
sample_transforms=None, sample_transforms=[],
batch_transforms=None, batch_transforms=[],
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
...@@ -108,21 +110,12 @@ class BaseDataLoader(object): ...@@ -108,21 +110,12 @@ class BaseDataLoader(object):
num_classes=81, num_classes=81,
with_background=True, with_background=True,
**kwargs): **kwargs):
# out fields
self._fields = inputs_def['fields'] if inputs_def else None
# sample transform # sample transform
self._sample_transforms = Compose( self._sample_transforms = Compose(
sample_transforms, num_classes=num_classes) sample_transforms, num_classes=num_classes)
# batch transfrom # batch transfrom
self._batch_transforms = None self._batch_transforms = BatchCompose(batch_transforms, num_classes)
if batch_transforms:
self._batch_transforms = Compose(batch_transforms,
copy.deepcopy(self._fields),
transform, num_classes)
self.output_fields = self._batch_transforms.output_fields
else:
self.output_fields = self._fields
self.batch_size = batch_size self.batch_size = batch_size
self.shuffle = shuffle self.shuffle = shuffle
...@@ -139,8 +132,7 @@ class BaseDataLoader(object): ...@@ -139,8 +132,7 @@ class BaseDataLoader(object):
self.dataset = dataset self.dataset = dataset
self.dataset.parse_dataset(self.with_background) self.dataset.parse_dataset(self.with_background)
# get data # get data
self.dataset.set_out(self._sample_transforms, self.dataset.set_transform(self._sample_transforms)
copy.deepcopy(self._fields))
# set kwargs # set kwargs
self.dataset.set_kwargs(**self.kwargs) self.dataset.set_kwargs(**self.kwargs)
# batch sampler # batch sampler
...@@ -177,7 +169,10 @@ class BaseDataLoader(object): ...@@ -177,7 +169,10 @@ class BaseDataLoader(object):
# data structure in paddle.io.DataLoader # data structure in paddle.io.DataLoader
try: try:
data = next(self.loader) data = next(self.loader)
return {k: v for k, v in zip(self.output_fields, data)} return {
k: v
for k, v in zip(self._batch_transforms.output_fields, data)
}
except StopIteration: except StopIteration:
self.loader = iter(self.dataloader) self.loader = iter(self.dataloader)
six.reraise(*sys.exc_info()) six.reraise(*sys.exc_info())
...@@ -191,8 +186,8 @@ class BaseDataLoader(object): ...@@ -191,8 +186,8 @@ class BaseDataLoader(object):
class TrainReader(BaseDataLoader): class TrainReader(BaseDataLoader):
def __init__(self, def __init__(self,
inputs_def=None, inputs_def=None,
sample_transforms=None, sample_transforms=[],
batch_transforms=None, batch_transforms=[],
batch_size=1, batch_size=1,
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
...@@ -210,8 +205,8 @@ class TrainReader(BaseDataLoader): ...@@ -210,8 +205,8 @@ class TrainReader(BaseDataLoader):
class EvalReader(BaseDataLoader): class EvalReader(BaseDataLoader):
def __init__(self, def __init__(self,
inputs_def=None, inputs_def=None,
sample_transforms=None, sample_transforms=[],
batch_transforms=None, batch_transforms=[],
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
drop_last=True, drop_last=True,
...@@ -229,8 +224,8 @@ class EvalReader(BaseDataLoader): ...@@ -229,8 +224,8 @@ class EvalReader(BaseDataLoader):
class TestReader(BaseDataLoader): class TestReader(BaseDataLoader):
def __init__(self, def __init__(self,
inputs_def=None, inputs_def=None,
sample_transforms=None, sample_transforms=[],
batch_transforms=None, batch_transforms=[],
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
......
...@@ -28,9 +28,10 @@ class COCODataSet(DetDataset): ...@@ -28,9 +28,10 @@ class COCODataSet(DetDataset):
dataset_dir=None, dataset_dir=None,
image_dir=None, image_dir=None,
anno_path=None, anno_path=None,
data_fields=['image'],
sample_num=-1): sample_num=-1):
super(COCODataSet, self).__init__(dataset_dir, image_dir, anno_path, super(COCODataSet, self).__init__(dataset_dir, image_dir, anno_path,
sample_num) data_fields, sample_num)
self.load_image_only = False self.load_image_only = False
self.load_semantic = False self.load_semantic = False
...@@ -82,13 +83,6 @@ class COCODataSet(DetDataset): ...@@ -82,13 +83,6 @@ class COCODataSet(DetDataset):
im_w, im_h, img_id)) im_w, im_h, img_id))
continue continue
coco_rec = {
'im_file': im_path,
'im_id': np.array([img_id]),
'h': im_h,
'w': im_w,
}
if not self.load_image_only: if not self.load_image_only:
ins_anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=False) ins_anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=False)
instances = coco.loadAnns(ins_anno_ids) instances = coco.loadAnns(ins_anno_ids)
...@@ -121,7 +115,6 @@ class COCODataSet(DetDataset): ...@@ -121,7 +115,6 @@ class COCODataSet(DetDataset):
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32) gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
gt_class = np.zeros((num_bbox, 1), dtype=np.int32) gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
gt_score = np.ones((num_bbox, 1), dtype=np.float32)
is_crowd = np.zeros((num_bbox, 1), dtype=np.int32) is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
difficult = np.zeros((num_bbox, 1), dtype=np.int32) difficult = np.zeros((num_bbox, 1), dtype=np.int32)
gt_poly = [None] * num_bbox gt_poly = [None] * num_bbox
...@@ -142,15 +135,25 @@ class COCODataSet(DetDataset): ...@@ -142,15 +135,25 @@ class COCODataSet(DetDataset):
if has_segmentation and not any(gt_poly): if has_segmentation and not any(gt_poly):
continue continue
coco_rec.update({ coco_rec = {
'im_file': im_path,
'im_id': np.array([img_id]),
'h': im_h,
'w': im_w,
} if 'image' in self.data_fields else {}
gt_rec = {
'is_crowd': is_crowd, 'is_crowd': is_crowd,
'gt_class': gt_class, 'gt_class': gt_class,
'gt_bbox': gt_bbox, 'gt_bbox': gt_bbox,
'gt_score': gt_score,
'gt_poly': gt_poly, 'gt_poly': gt_poly,
}) }
for k, v in gt_rec.items():
if k in self.data_fields:
coco_rec[k] = v
# TODO: remove load_semantic # TODO: remove load_semantic
if self.load_semantic: if self.load_semantic and 'semantic' in self.data_fields:
seg_path = os.path.join(self.dataset_dir, 'stuffthingmaps', seg_path = os.path.join(self.dataset_dir, 'stuffthingmaps',
'train2017', im_fname[:-3] + 'png') 'train2017', im_fname[:-3] + 'png')
coco_rec.update({'semantic': seg_path}) coco_rec.update({'semantic': seg_path})
......
...@@ -31,6 +31,7 @@ class DetDataset(Dataset): ...@@ -31,6 +31,7 @@ class DetDataset(Dataset):
dataset_dir=None, dataset_dir=None,
image_dir=None, image_dir=None,
anno_path=None, anno_path=None,
data_fields=['image'],
sample_num=-1, sample_num=-1,
use_default_label=None, use_default_label=None,
**kwargs): **kwargs):
...@@ -38,6 +39,7 @@ class DetDataset(Dataset): ...@@ -38,6 +39,7 @@ class DetDataset(Dataset):
self.dataset_dir = dataset_dir if dataset_dir is not None else '' self.dataset_dir = dataset_dir if dataset_dir is not None else ''
self.anno_path = anno_path self.anno_path = anno_path
self.image_dir = image_dir if image_dir is not None else '' self.image_dir = image_dir if image_dir is not None else ''
self.data_fields = data_fields
self.sample_num = sample_num self.sample_num = sample_num
self.use_default_label = use_default_label self.use_default_label = use_default_label
self._epoch = 0 self._epoch = 0
...@@ -63,26 +65,19 @@ class DetDataset(Dataset): ...@@ -63,26 +65,19 @@ class DetDataset(Dataset):
for _ in range(3) for _ in range(3)
] ]
# data augment return self.transform(roidb)
roidb = self.transform(roidb)
# data item
out = OrderedDict()
for k in self.fields:
out[k] = roidb[k]
return out.values()
def set_kwargs(self, **kwargs): def set_kwargs(self, **kwargs):
self.mixup_epoch = kwargs.get('mixup_epoch', -1) self.mixup_epoch = kwargs.get('mixup_epoch', -1)
self.cutmix_epoch = kwargs.get('cutmix_epoch', -1) self.cutmix_epoch = kwargs.get('cutmix_epoch', -1)
self.mosaic_epoch = kwargs.get('mosaic_epoch', -1) self.mosaic_epoch = kwargs.get('mosaic_epoch', -1)
def set_transform(self, transform):
self.transform = transform
def set_epoch(self, epoch_id): def set_epoch(self, epoch_id):
self._epoch = epoch_id self._epoch = epoch_id
def set_out(self, sample_transform, fields):
self.transform = sample_transform
self.fields = fields
def parse_dataset(self, with_background=True): def parse_dataset(self, with_background=True):
raise NotImplemented( raise NotImplemented(
"Need to implement parse_dataset method of Dataset") "Need to implement parse_dataset method of Dataset")
......
...@@ -46,12 +46,14 @@ class VOCDataSet(DetDataset): ...@@ -46,12 +46,14 @@ class VOCDataSet(DetDataset):
dataset_dir=None, dataset_dir=None,
image_dir=None, image_dir=None,
anno_path=None, anno_path=None,
data_fields=['image'],
sample_num=-1, sample_num=-1,
label_list=None): label_list=None):
super(VOCDataSet, self).__init__( super(VOCDataSet, self).__init__(
dataset_dir=dataset_dir, dataset_dir=dataset_dir,
image_dir=image_dir, image_dir=image_dir,
anno_path=anno_path, anno_path=anno_path,
data_fields=data_fields,
sample_num=sample_num) sample_num=sample_num)
self.label_list = label_list self.label_list = label_list
...@@ -113,7 +115,6 @@ class VOCDataSet(DetDataset): ...@@ -113,7 +115,6 @@ class VOCDataSet(DetDataset):
gt_bbox = [] gt_bbox = []
gt_class = [] gt_class = []
gt_score = [] gt_score = []
is_crowd = []
difficult = [] difficult = []
for i, obj in enumerate(objs): for i, obj in enumerate(objs):
cname = obj.find('name').text cname = obj.find('name').text
...@@ -130,7 +131,6 @@ class VOCDataSet(DetDataset): ...@@ -130,7 +131,6 @@ class VOCDataSet(DetDataset):
gt_bbox.append([x1, y1, x2, y2]) gt_bbox.append([x1, y1, x2, y2])
gt_class.append([cname2cid[cname]]) gt_class.append([cname2cid[cname]])
gt_score.append([1.]) gt_score.append([1.])
is_crowd.append([0])
difficult.append([_difficult]) difficult.append([_difficult])
else: else:
logger.warn( logger.warn(
...@@ -140,19 +140,25 @@ class VOCDataSet(DetDataset): ...@@ -140,19 +140,25 @@ class VOCDataSet(DetDataset):
gt_bbox = np.array(gt_bbox).astype('float32') gt_bbox = np.array(gt_bbox).astype('float32')
gt_class = np.array(gt_class).astype('int32') gt_class = np.array(gt_class).astype('int32')
gt_score = np.array(gt_score).astype('float32') gt_score = np.array(gt_score).astype('float32')
is_crowd = np.array(is_crowd).astype('int32')
difficult = np.array(difficult).astype('int32') difficult = np.array(difficult).astype('int32')
voc_rec = { voc_rec = {
'im_file': img_file, 'im_file': img_file,
'im_id': im_id, 'im_id': im_id,
'h': im_h, 'h': im_h,
'w': im_w, 'w': im_w
'is_crowd': is_crowd, } if 'image' in self.data_fields else {}
gt_rec = {
'gt_class': gt_class, 'gt_class': gt_class,
'gt_score': gt_score, 'gt_score': gt_score,
'gt_bbox': gt_bbox, 'gt_bbox': gt_bbox,
'difficult': difficult 'difficult': difficult
} }
for k, v in gt_rec.items():
if k in self.data_fields:
voc_rec[k] = v
if len(objs) != 0: if len(objs) != 0:
records.append(voc_rec) records.append(voc_rec)
......
...@@ -303,6 +303,11 @@ class Gt2YoloTargetOp(BaseOperator): ...@@ -303,6 +303,11 @@ class Gt2YoloTargetOp(BaseOperator):
# classification # classification
target[idx, 6 + cls, gj, gi] = 1. target[idx, 6 + cls, gj, gi] = 1.
sample['target{}'.format(i)] = target sample['target{}'.format(i)] = target
# remove useless gt_class and gt_score after target calculated
sample.pop('gt_class')
sample.pop('gt_score')
return samples return samples
......
...@@ -116,6 +116,7 @@ class DecodeOp(BaseOperator): ...@@ -116,6 +116,7 @@ class DecodeOp(BaseOperator):
if 'image' not in sample: if 'image' not in sample:
with open(sample['im_file'], 'rb') as f: with open(sample['im_file'], 'rb') as f:
sample['image'] = f.read() sample['image'] = f.read()
sample.pop('im_file')
im = sample['image'] im = sample['image']
data = np.frombuffer(im, dtype='uint8') data = np.frombuffer(im, dtype='uint8')
...@@ -1570,9 +1571,9 @@ class MixupOp(BaseOperator): ...@@ -1570,9 +1571,9 @@ class MixupOp(BaseOperator):
gt_class2 = sample[1]['gt_class'] gt_class2 = sample[1]['gt_class']
gt_class = np.concatenate((gt_class1, gt_class2), axis=0) gt_class = np.concatenate((gt_class1, gt_class2), axis=0)
result['gt_class'] = gt_class result['gt_class'] = gt_class
if 'gt_score' in sample[0]:
gt_score1 = sample[0]['gt_score'] gt_score1 = np.ones_like(sample[0]['gt_class'])
gt_score2 = sample[1]['gt_score'] gt_score2 = np.ones_like(sample[1]['gt_class'])
gt_score = np.concatenate( gt_score = np.concatenate(
(gt_score1 * factor, gt_score2 * (1. - factor)), axis=0) (gt_score1 * factor, gt_score2 * (1. - factor)), axis=0)
result['gt_score'] = gt_score result['gt_score'] = gt_score
...@@ -1673,6 +1674,11 @@ class PadBoxOp(BaseOperator): ...@@ -1673,6 +1674,11 @@ class PadBoxOp(BaseOperator):
if gt_num > 0: if gt_num > 0:
pad_diff[:gt_num] = sample['difficult'][:gt_num, 0] pad_diff[:gt_num] = sample['difficult'][:gt_num, 0]
sample['difficult'] = pad_diff sample['difficult'] = pad_diff
if 'is_crowd' in sample:
pad_crowd = np.zeros((num_max, ), dtype=np.int32)
if gt_num > 0:
pad_crowd[:gt_num] = sample['is_crowd'][:gt_num, 0]
sample['is_crowd'] = pad_crowd
return sample return sample
......
...@@ -131,16 +131,18 @@ def run(FLAGS, cfg, place): ...@@ -131,16 +131,18 @@ def run(FLAGS, cfg, place):
dataset.set_images(test_images) dataset.set_images(test_images)
test_loader = create('TestReader')(dataset, cfg['worker_num']) test_loader = create('TestReader')(dataset, cfg['worker_num'])
extra_key = ['im_shape', 'scale_factor', 'im_id'] extra_key = ['im_shape', 'scale_factor', 'im_id']
if cfg.metric == 'VOC':
extra_key += ['gt_bbox', 'gt_class', 'difficult']
# TODO: support other metrics # TODO: support other metrics
imid2path = dataset.get_imid2path() imid2path = dataset.get_imid2path()
from ppdet.utils.coco_eval import get_category_info
anno_file = dataset.get_anno() anno_file = dataset.get_anno()
with_background = cfg.with_background with_background = cfg.with_background
use_default_label = dataset.use_default_label use_default_label = dataset.use_default_label
if cfg.metric == 'COCO':
from ppdet.utils.coco_eval import get_category_info
if cfg.metric == 'VOC':
from ppdet.utils.voc_eval import get_category_info
clsid2catid, catid2name = get_category_info(anno_file, with_background, clsid2catid, catid2name = get_category_info(anno_file, with_background,
use_default_label) use_default_label)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册