未验证 提交 36d3825b 编写于 作者: Q qingqing01 提交者: GitHub

Support variable image shape. (#1264)

* Support variable image shape.
* Update padding mode.
上级 56192f10
......@@ -39,7 +39,6 @@ def get_image_blob(roidb, settings):
'Failed to read image \'{}\''.format(roidb['image'])
if roidb['flipped']:
im = im[:, ::-1, :]
#print(im[10:, 10:, :])
target_size = settings.scales[scale_ind]
im, im_scale = prep_im_for_blob(im, settings.mean_value, target_size,
settings.max_size)
......@@ -57,7 +56,6 @@ def prep_im_for_blob(im, pixel_means, target_size, max_size):
"""
im = im.astype(np.float32, copy=False)
im -= pixel_means
#print(im[10:, 10:, :])
im_shape = im.shape
im_size_min = np.min(im_shape[0:2])
......@@ -74,10 +72,6 @@ def prep_im_for_blob(im, pixel_means, target_size, max_size):
fy=im_scale,
interpolation=cv2.INTER_LINEAR)
im_height, im_width, channel = im.shape
padding_im = im
#print(padding_im[10:, 10:, :])
channel_swap = (2, 0, 1) #(batch, channel, height, width)
#im = im.transpose(channel_swap)
padding_im = padding_im.transpose(channel_swap)
#print(padding_im[10:, 10:, :])
return padding_im, im_scale
im = im.transpose(channel_swap)
return im, im_scale
......@@ -52,7 +52,14 @@ class Settings(object):
np.newaxis, np.newaxis, :].astype('float32')
def coco(settings, mode, batch_size=None, shuffle=False):
def coco(settings,
mode,
batch_size=None,
total_batch_size=None,
padding_total=False,
shuffle=False):
total_batch_size = total_batch_size if total_batch_size else batch_size
assert total_batch_size % batch_size == 0
if mode == 'train':
settings.train_file_list = os.path.join(settings.data_dir,
settings.train_file_list)
......@@ -79,6 +86,21 @@ def coco(settings, mode, batch_size=None, shuffle=False):
is_crowd = roidb['is_crowd'].astype('int32')
return im, gt_boxes, gt_classes, is_crowd, im_info, im_id
def padding_minibatch(batch_data):
if len(batch_data) == 1:
return batch_data
max_shape = np.array([data[0].shape for data in batch_data]).max(axis=0)
padding_batch = []
for data in batch_data:
im_c, im_h, im_w = data[0].shape[:]
padding_im = np.zeros(
(im_c, max_shape[1], max_shape[2]), dtype=np.float32)
padding_im[:, :im_h, :im_w] = data[0]
padding_batch.append((padding_im, ) + data[1:])
return padding_batch
def reader():
if mode == "train":
roidb_perm = deque(np.random.permutation(roidbs))
......@@ -96,9 +118,21 @@ def coco(settings, mode, batch_size=None, shuffle=False):
continue
batch_out.append(
(im, gt_boxes, gt_classes, is_crowd, im_info, im_id))
if len(batch_out) == batch_size:
yield batch_out
batch_out = []
if not padding_total:
if len(batch_out) == batch_size:
yield padding_minibatch(batch_out)
batch_out = []
else:
if len(batch_out) == total_batch_size:
batch_out = padding_minibatch(batch_out)
for i in range(total_batch_size / batch_size):
sub_batch_out = []
for j in range(batch_size):
sub_batch_out.append(batch_out[i * batch_size +
j])
yield sub_batch_out
sub_batch_out = []
batch_out = []
else:
batch_out = []
for roidb in roidbs:
......@@ -113,9 +147,19 @@ def coco(settings, mode, batch_size=None, shuffle=False):
return reader
def train(settings, batch_size, shuffle=True):
return coco(settings, 'train', batch_size, shuffle)
def train(settings,
batch_size,
total_batch_size=None,
padding_total=False,
shuffle=True):
return coco(
settings,
'train',
batch_size,
total_batch_size,
padding_total,
shuffle=shuffle)
def test(settings, batch_size):
return coco(settings, 'test', batch_size, shuffle=False)
def test(settings, batch_size, total_batch_size=None, padding_total=False):
return coco(settings, 'test', batch_size, total_batch_size, shuffle=False)
......@@ -26,6 +26,9 @@ add_arg('dataset', str, 'coco2017', "coco2014, coco2017, and pascalv
add_arg('data_dir', str, 'data/COCO17', "data directory")
add_arg('class_num', int, 81, "Class number.")
add_arg('use_pyreader', bool, True, "Use pyreader.")
add_arg('padding_minibatch',bool, False,
"If False, only resize image and not pad, image shape is different between"
" GPUs in one mini-batch. If True, image shape is the same in one mini-batch.")
# SOLVER
add_arg('learning_rate', float, 0.01, "Learning rate.")
add_arg('max_iter', int, 180000, "Iter number.")
......@@ -38,7 +41,7 @@ add_arg('variance', float, [1.,1.,1.,1.], "The variance of anchors."
add_arg('rpn_stride', float, 16., "Stride of the feature map that RPN is attached.")
# FAST RCNN
# TRAIN TEST
add_arg('batch_size', int, 1, "Minibatch size.")
add_arg('batch_size', int, 8, "Minibatch size of all devices.")
add_arg('max_size', int, 1333, "The max resized image size.")
add_arg('scales', int, [800], "The resized image height.")
add_arg('batch_size_per_im',int, 512, "fast rcnn head batch size")
......@@ -103,13 +106,17 @@ def train(cfg):
train_exe = fluid.ParallelExecutor(
use_cuda=bool(cfg.use_gpu), loss_name=loss.name)
assert cfg.batch_size % devices_num == 0
batch_size_per_dev = cfg.batch_size / devices_num
if cfg.use_pyreader:
train_reader = reader.train(cfg, batch_size=1, shuffle=not cfg.debug)
train_reader = reader.train(cfg, batch_size=batch_size_per_dev,
total_batch_size=cfg.batch_size,
padding_total=cfg.padding_minibatch,
shuffle=True)
py_reader = model.py_reader
py_reader.decorate_paddle_reader(train_reader)
else:
train_reader = reader.train(cfg, batch_size=cfg.batch_size, shuffle=not cfg.debug)
train_reader = reader.train(cfg, batch_size=cfg.batch_size, shuffle=True)
feeder = fluid.DataFeeder(place=place, feed_list=model.feeds())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册