未验证 提交 3787cffa 编写于 作者: K Kaipeng Deng 提交者: GitHub

update DataLoader & speed up Mask RCNN (#2435)

* update DataLoader & speed up Mask RCNN
上级 0e6468c7
......@@ -12,6 +12,7 @@ TrainReader:
shuffle: true
drop_last: true
collate_batch: false
use_shared_memory: true
EvalReader:
sample_transforms:
......
......@@ -12,6 +12,7 @@ TrainReader:
shuffle: true
drop_last: true
collate_batch: false
use_shared_memory: true
EvalReader:
......
......@@ -24,8 +24,8 @@ else:
import Queue
import numpy as np
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from paddle.io import DataLoader, DistributedBatchSampler
from paddle.fluid.dataloader.collate import default_collate_fn
from ppdet.core.workspace import register, serializable, create
from . import transform
......@@ -44,11 +44,9 @@ class Compose(object):
for t in self.transforms:
for k, v in t.items():
op_cls = getattr(transform, k)
f = op_cls(**v)
if hasattr(f, 'num_classes'):
f.num_classes = num_classes
self.transforms_cls.append(f)
self.transforms_cls.append(op_cls(**v))
if hasattr(op_cls, 'num_classes'):
op_cls.num_classes = num_classes
def __call__(self, data):
for f in self.transforms_cls:
......@@ -56,8 +54,9 @@ class Compose(object):
data = f(data)
except Exception as e:
stack_info = traceback.format_exc()
logger.warn("fail to map op [{}] with error: {} and stack:\n{}".
format(f, e, str(stack_info)))
logger.warn("fail to map sample transform [{}] "
"with error: {} and stack:\n{}".format(
f, e, str(stack_info)))
raise e
return data
......@@ -66,8 +65,6 @@ class Compose(object):
class BatchCompose(Compose):
def __init__(self, transforms, num_classes=80, collate_batch=True):
super(BatchCompose, self).__init__(transforms, num_classes)
self.output_fields = mp.Manager().list([])
self.lock = mp.Lock()
self.collate_batch = collate_batch
def __call__(self, data):
......@@ -76,54 +73,31 @@ class BatchCompose(Compose):
data = f(data)
except Exception as e:
stack_info = traceback.format_exc()
logger.warn("fail to map op [{}] with error: {} and stack:\n{}".
format(f, e, str(stack_info)))
logger.warn("fail to map batch transform [{}] "
"with error: {} and stack:\n{}".format(
f, e, str(stack_info)))
raise e
# accessing ListProxy in main process (no worker subprocess)
# may incur errors in some enviroments, ListProxy back to
# list if no worker process start, while this `__call__`
# will be called in main process
global MAIN_PID
if os.getpid() == MAIN_PID and \
isinstance(self.output_fields, mp.managers.ListProxy):
self.output_fields = []
# parse output fields by first sample
# **this shoule be fixed if paddle.io.DataLoader support**
# For paddle.io.DataLoader not support dict currently,
# we need to parse the key from the first sample,
# BatchCompose.__call__ will be called in each worker
# process, so lock is need here.
if len(self.output_fields) == 0:
self.lock.acquire()
if len(self.output_fields) == 0:
for k, v in data[0].items():
# FIXME(dkp): for more elegent coding
if k not in ['flipped', 'h', 'w']:
self.output_fields.append(k)
self.lock.release()
batch_data = []
# If set collate_batch=True, all data will collate a batch
# and it will transfor to paddle.tensor.
# If set collate_batch=False, `image`, `im_shape` and
# `scale_factor` will collate a batch, but `gt` data(such as:
# gt_bbox, gt_class, gt_poly.etc.) will not collate a batch
# and it will transfor to list[Tensor] or list[list].
# remove keys which is not needed by model
extra_key = ['h', 'w', 'flipped']
for k in extra_key:
for sample in data:
if k in sample:
sample.pop(k)
# batch data, if user-define batch function needed
# use user-defined here
if self.collate_batch:
data = [[data[i][k] for k in self.output_fields]
for i in range(len(data))]
data = list(zip(*data))
batch_data = [np.stack(d, axis=0) for d in data]
batch_data = default_collate_fn(data)
else:
for k in self.output_fields:
batch_data = {}
for k in data[0].keys():
tmp_data = []
for i in range(len(data)):
tmp_data.append(data[i][k])
if not 'gt_' in k and not 'is_crowd' in k:
tmp_data = np.stack(tmp_data, axis=0)
batch_data.append(tmp_data)
batch_data[k] = tmp_data
return batch_data
......@@ -227,15 +201,8 @@ class BaseDataLoader(object):
return self
def __next__(self):
# pack {filed_name: field_data} here
# looking forward to support dictionary
# data structure in paddle.io.DataLoader
try:
data = next(self.loader)
return {
k: v
for k, v in zip(self._batch_transforms.output_fields, data)
}
return next(self.loader)
except StopIteration:
self.loader = iter(self.dataloader)
six.reraise(*sys.exc_info())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册