未验证 提交 3c5f0743 编写于 作者: K Kaipeng Deng 提交者: GitHub

Merge pull request #31 from heavengate/fix_compile_prune

extract input variable from feed
...@@ -360,10 +360,27 @@ class StaticGraphAdapter(object): ...@@ -360,10 +360,27 @@ class StaticGraphAdapter(object):
metric_list, metric_splits = flatten_list(endpoints['metric']) metric_list, metric_splits = flatten_list(endpoints['metric'])
fetch_list = endpoints['loss'] + metric_list fetch_list = endpoints['loss'] + metric_list
num_loss = len(endpoints['loss']) num_loss = len(endpoints['loss'])
# if fetch Variable is same as input Variable, do not fetch
# from program, get it from input directly
pruned_fetch_list = []
pruned_fetch_idx_name_map = [""] * len(fetch_list)
for i, fetch_var in enumerate(fetch_list):
if fetch_var.name in feed.keys():
pruned_fetch_idx_name_map[i] = fetch_var.name
else:
pruned_fetch_list.append(fetch_var)
rets = self._executor.run(compiled_prog, rets = self._executor.run(compiled_prog,
feed=feed, feed=feed,
fetch_list=fetch_list, fetch_list=pruned_fetch_list,
return_numpy=False) return_numpy=False)
# restore pruned fetch_list Variable from feeds
for i, name in enumerate(pruned_fetch_idx_name_map):
if len(name) > 0:
rets.insert(i, feed[name])
# LoDTensor cannot be fetch as numpy directly # LoDTensor cannot be fetch as numpy directly
rets = [np.array(v) for v in rets] rets = [np.array(v) for v in rets]
if self.mode == 'test': if self.mode == 'test':
......
...@@ -138,7 +138,7 @@ class YOLOv3(Model): ...@@ -138,7 +138,7 @@ class YOLOv3(Model):
act='leaky_relu')) act='leaky_relu'))
self.route_blocks.append(route) self.route_blocks.append(route)
def forward(self, img_info, inputs): def forward(self, img_id, img_shape, inputs):
outputs = [] outputs = []
boxes = [] boxes = []
scores = [] scores = []
...@@ -163,8 +163,6 @@ class YOLOv3(Model): ...@@ -163,8 +163,6 @@ class YOLOv3(Model):
for m in anchor_mask: for m in anchor_mask:
mask_anchors.append(self.anchors[2 * m]) mask_anchors.append(self.anchors[2 * m])
mask_anchors.append(self.anchors[2 * m + 1]) mask_anchors.append(self.anchors[2 * m + 1])
img_shape = fluid.layers.slice(img_info, axes=[1], starts=[1], ends=[3])
img_id = fluid.layers.slice(img_info, axes=[1], starts=[0], ends=[1])
b, s = fluid.layers.yolo_box( b, s = fluid.layers.yolo_box(
x=block_out, x=block_out,
img_size=img_shape, img_size=img_shape,
...@@ -181,7 +179,7 @@ class YOLOv3(Model): ...@@ -181,7 +179,7 @@ class YOLOv3(Model):
if self.model_mode == 'train': if self.model_mode == 'train':
return outputs return outputs
preds = [img_id[0, :], preds = [img_id,
fluid.layers.multiclass_nms( fluid.layers.multiclass_nms(
bboxes=fluid.layers.concat(boxes, axis=1), bboxes=fluid.layers.concat(boxes, axis=1),
scores=fluid.layers.concat(scores, axis=2), scores=fluid.layers.concat(scores, axis=2),
......
...@@ -186,30 +186,31 @@ class COCODataset(Dataset): ...@@ -186,30 +186,31 @@ class COCODataset(Dataset):
data = np.frombuffer(f.read(), dtype='uint8') data = np.frombuffer(f.read(), dtype='uint8')
im = cv2.imdecode(data, 1) im = cv2.imdecode(data, 1)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im_info = np.array([roidb['im_id'][0], roidb['h'], roidb['w']], dtype='int32') im_id = roidb['im_id']
im_shape = np.array([roidb['h'], roidb['w']], dtype='int32')
gt_bbox = roidb['gt_bbox'] gt_bbox = roidb['gt_bbox']
gt_class = roidb['gt_class'] gt_class = roidb['gt_class']
gt_score = roidb['gt_score'] gt_score = roidb['gt_score']
return im_info, im, gt_bbox, gt_class, gt_score return im_id, im_shape, im, gt_bbox, gt_class, gt_score
def __getitem__(self, idx): def __getitem__(self, idx):
im_info, im, gt_bbox, gt_class, gt_score = self._getitem_by_index(idx) im_id, im_shape, im, gt_bbox, gt_class, gt_score = self._getitem_by_index(idx)
if self._mixup: if self._mixup:
mixup_idx = idx + np.random.randint(1, self.__len__()) mixup_idx = idx + np.random.randint(1, self.__len__())
mixup_idx %= self.__len__() mixup_idx %= self.__len__()
_, mixup_im, mixup_bbox, mixup_class, _ = \ _, _, mixup_im, mixup_bbox, mixup_class, _ = \
self._getitem_by_index(mixup_idx) self._getitem_by_index(mixup_idx)
im, gt_bbox, gt_class, gt_score = \ im_shape, im, gt_bbox, gt_class, gt_score = \
self._mixup_image(im, gt_bbox, gt_class, mixup_im, self._mixup_image(im, gt_bbox, gt_class, mixup_im,
mixup_bbox, mixup_class) mixup_bbox, mixup_class)
if self._transform: if self._transform:
im_info, im, gt_bbox, gt_class, gt_score = \ im_id, im_shape, im, gt_bbox, gt_class, gt_score = \
self._transform(im_info, im, gt_bbox, gt_class, gt_score) self._transform(im_id, im_shape, im, gt_bbox, gt_class, gt_score)
return [im_info, im, gt_bbox, gt_class, gt_score] return [im_id, im_shape, im, gt_bbox, gt_class, gt_score]
def _mixup_image(self, img1, bbox1, class1, img2, bbox2, class2): def _mixup_image(self, img1, bbox1, class1, img2, bbox2, class2):
factor = np.random.beta(self._alpha, self._beta) factor = np.random.beta(self._alpha, self._beta)
...@@ -234,7 +235,9 @@ class COCODataset(Dataset): ...@@ -234,7 +235,9 @@ class COCODataset(Dataset):
score2 = np.ones_like(class2, dtype="float32") * (1.0 - factor) score2 = np.ones_like(class2, dtype="float32") * (1.0 - factor)
gt_score = np.concatenate((score1, score2), axis=0) gt_score = np.concatenate((score1, score2), axis=0)
return img, gt_bbox, gt_class, gt_score im_shape = np.array([h, w], dtype='int32')
return im_shape, img, gt_bbox, gt_class, gt_score
@property @property
def mixup(self): def mixup(self):
......
...@@ -63,7 +63,8 @@ def main(): ...@@ -63,7 +63,8 @@ def main():
device = set_device(FLAGS.device) device = set_device(FLAGS.device)
fluid.enable_dygraph(device) if FLAGS.dynamic else None fluid.enable_dygraph(device) if FLAGS.dynamic else None
inputs = [Input([None, 3], 'int32', name='img_info'), inputs = [Input([None, 1], 'int64', name='img_id'),
Input([None, 2], 'int32', name='img_shape'),
Input([None, 3, None, None], 'float32', name='image')] Input([None, 3, None, None], 'float32', name='image')]
labels = [Input([None, NUM_MAX_BOXES, 4], 'float32', name='gt_bbox'), labels = [Input([None, NUM_MAX_BOXES, 4], 'float32', name='gt_bbox'),
Input([None, NUM_MAX_BOXES], 'int32', name='gt_label'), Input([None, NUM_MAX_BOXES], 'int32', name='gt_label'),
......
...@@ -145,7 +145,7 @@ class ColorDistort(object): ...@@ -145,7 +145,7 @@ class ColorDistort(object):
img += delta img += delta
return img return img
def __call__(self, im_info, im, gt_bbox, gt_class, gt_score): def __call__(self, im_id, im_shape, im, gt_bbox, gt_class, gt_score):
if self.random_apply: if self.random_apply:
distortions = np.random.permutation([ distortions = np.random.permutation([
self.apply_brightness, self.apply_contrast, self.apply_brightness, self.apply_contrast,
...@@ -153,7 +153,7 @@ class ColorDistort(object): ...@@ -153,7 +153,7 @@ class ColorDistort(object):
]) ])
for func in distortions: for func in distortions:
im = func(im) im = func(im)
return [im_info, im, gt_bbox, gt_class, gt_score] return [im_id, im_shape, im, gt_bbox, gt_class, gt_score]
im = self.apply_brightness(im) im = self.apply_brightness(im)
...@@ -165,7 +165,7 @@ class ColorDistort(object): ...@@ -165,7 +165,7 @@ class ColorDistort(object):
im = self.apply_saturation(im) im = self.apply_saturation(im)
im = self.apply_hue(im) im = self.apply_hue(im)
im = self.apply_contrast(im) im = self.apply_contrast(im)
return [im_info, im, gt_bbox, gt_class, gt_score] return [im_id, im_shape, im, gt_bbox, gt_class, gt_score]
class RandomExpand(object): class RandomExpand(object):
...@@ -183,16 +183,16 @@ class RandomExpand(object): ...@@ -183,16 +183,16 @@ class RandomExpand(object):
self.prob = prob self.prob = prob
self.fill_value = fill_value self.fill_value = fill_value
def __call__(self, im_info, im, gt_bbox, gt_class, gt_score): def __call__(self, im_id, im_shape, im, gt_bbox, gt_class, gt_score):
if np.random.uniform(0., 1.) < self.prob: if np.random.uniform(0., 1.) < self.prob:
return [im_info, im, gt_bbox, gt_class, gt_score] return [im_id, im_shape, im, gt_bbox, gt_class, gt_score]
height, width, _ = im.shape height, width, _ = im.shape
expand_ratio = np.random.uniform(1., self.ratio) expand_ratio = np.random.uniform(1., self.ratio)
h = int(height * expand_ratio) h = int(height * expand_ratio)
w = int(width * expand_ratio) w = int(width * expand_ratio)
if not h > height or not w > width: if not h > height or not w > width:
return [im_info, im, gt_bbox, gt_class, gt_score] return [im_id, im_shape, im, gt_bbox, gt_class, gt_score]
y = np.random.randint(0, h - height) y = np.random.randint(0, h - height)
x = np.random.randint(0, w - width) x = np.random.randint(0, w - width)
canvas = np.ones((h, w, 3), dtype=np.uint8) canvas = np.ones((h, w, 3), dtype=np.uint8)
...@@ -201,7 +201,7 @@ class RandomExpand(object): ...@@ -201,7 +201,7 @@ class RandomExpand(object):
gt_bbox += np.array([x, y, x, y], dtype=np.float32) gt_bbox += np.array([x, y, x, y], dtype=np.float32)
return [im_info, canvas, gt_bbox, gt_class, gt_score] return [im_id, im_shape, canvas, gt_bbox, gt_class, gt_score]
class RandomCrop(): class RandomCrop():
...@@ -232,9 +232,9 @@ class RandomCrop(): ...@@ -232,9 +232,9 @@ class RandomCrop():
self.allow_no_crop = allow_no_crop self.allow_no_crop = allow_no_crop
self.cover_all_box = cover_all_box self.cover_all_box = cover_all_box
def __call__(self, im_info, im, gt_bbox, gt_class, gt_score): def __call__(self, im_id, im_shape, im, gt_bbox, gt_class, gt_score):
if len(gt_bbox) == 0: if len(gt_bbox) == 0:
return [im_info, im, gt_bbox, gt_class, gt_score] return [im_id, im_shape, im, gt_bbox, gt_class, gt_score]
# NOTE Original method attempts to generate one candidate for each # NOTE Original method attempts to generate one candidate for each
# threshold then randomly sample one from the resulting list. # threshold then randomly sample one from the resulting list.
...@@ -251,7 +251,7 @@ class RandomCrop(): ...@@ -251,7 +251,7 @@ class RandomCrop():
for thresh in thresholds: for thresh in thresholds:
if thresh == 'no_crop': if thresh == 'no_crop':
return [im_info, im, gt_bbox, gt_class, gt_score] return [im_id, im_shape, im, gt_bbox, gt_class, gt_score]
h, w, _ = im.shape h, w, _ = im.shape
found = False found = False
...@@ -286,9 +286,9 @@ class RandomCrop(): ...@@ -286,9 +286,9 @@ class RandomCrop():
gt_bbox = np.take(cropped_box, valid_ids, axis=0) gt_bbox = np.take(cropped_box, valid_ids, axis=0)
gt_class = np.take(gt_class, valid_ids, axis=0) gt_class = np.take(gt_class, valid_ids, axis=0)
gt_score = np.take(gt_score, valid_ids, axis=0) gt_score = np.take(gt_score, valid_ids, axis=0)
return [im_info, im, gt_bbox, gt_class, gt_score] return [im_id, im_shape, im, gt_bbox, gt_class, gt_score]
return [im_info, im, gt_bbox, gt_class, gt_score] return [im_id, im_shape, im, gt_bbox, gt_class, gt_score]
def _iou_matrix(self, a, b): def _iou_matrix(self, a, b):
tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2]) tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2])
...@@ -334,7 +334,7 @@ class RandomFlip(): ...@@ -334,7 +334,7 @@ class RandomFlip():
isinstance(self.is_normalized, bool)): isinstance(self.is_normalized, bool)):
raise TypeError("{}: input type is invalid.".format(self)) raise TypeError("{}: input type is invalid.".format(self))
def __call__(self, im_info, im, gt_bbox, gt_class, gt_score): def __call__(self, im_id, im_shape, im, gt_bbox, gt_class, gt_score):
"""Filp the image and bounding box. """Filp the image and bounding box.
Operators: Operators:
1. Flip the image numpy. 1. Flip the image numpy.
...@@ -363,20 +363,20 @@ class RandomFlip(): ...@@ -363,20 +363,20 @@ class RandomFlip():
m = "{}: invalid box, x2 should be greater than x1".format( m = "{}: invalid box, x2 should be greater than x1".format(
self) self)
raise ValueError(m) raise ValueError(m)
return [im_info, im, gt_bbox, gt_class, gt_score] return [im_id, im_shape, im, gt_bbox, gt_class, gt_score]
class NormalizeBox(object): class NormalizeBox(object):
"""Transform the bounding box's coornidates to [0,1].""" """Transform the bounding box's coornidates to [0,1]."""
def __call__(self, im_info, im, gt_bbox, gt_class, gt_score): def __call__(self, im_id, im_shape, im, gt_bbox, gt_class, gt_score):
height, width, _ = im.shape height, width, _ = im.shape
for i in range(gt_bbox.shape[0]): for i in range(gt_bbox.shape[0]):
gt_bbox[i][0] = gt_bbox[i][0] / width gt_bbox[i][0] = gt_bbox[i][0] / width
gt_bbox[i][1] = gt_bbox[i][1] / height gt_bbox[i][1] = gt_bbox[i][1] / height
gt_bbox[i][2] = gt_bbox[i][2] / width gt_bbox[i][2] = gt_bbox[i][2] / width
gt_bbox[i][3] = gt_bbox[i][3] / height gt_bbox[i][3] = gt_bbox[i][3] / height
return [im_info, im, gt_bbox, gt_class, gt_score] return [im_id, im_shape, im, gt_bbox, gt_class, gt_score]
class PadBox(object): class PadBox(object):
...@@ -388,7 +388,7 @@ class PadBox(object): ...@@ -388,7 +388,7 @@ class PadBox(object):
""" """
self.num_max_boxes = num_max_boxes self.num_max_boxes = num_max_boxes
def __call__(self, im_info, im, gt_bbox, gt_class, gt_score): def __call__(self, im_id, im_shape, im, gt_bbox, gt_class, gt_score):
gt_num = min(self.num_max_boxes, len(gt_bbox)) gt_num = min(self.num_max_boxes, len(gt_bbox))
num_max = self.num_max_boxes num_max = self.num_max_boxes
...@@ -406,7 +406,7 @@ class PadBox(object): ...@@ -406,7 +406,7 @@ class PadBox(object):
if gt_num > 0: if gt_num > 0:
pad_score[:gt_num] = gt_score[:gt_num, 0] pad_score[:gt_num] = gt_score[:gt_num, 0]
gt_score = pad_score gt_score = pad_score
return [im_info, im, gt_bbox, gt_class, gt_score] return [im_id, im_shape, im, gt_bbox, gt_class, gt_score]
class BboxXYXY2XYWH(object): class BboxXYXY2XYWH(object):
...@@ -414,10 +414,10 @@ class BboxXYXY2XYWH(object): ...@@ -414,10 +414,10 @@ class BboxXYXY2XYWH(object):
Convert bbox XYXY format to XYWH format. Convert bbox XYXY format to XYWH format.
""" """
def __call__(self, im_info, im, gt_bbox, gt_class, gt_score): def __call__(self, im_id, im_shape, im, gt_bbox, gt_class, gt_score):
gt_bbox[:, 2:4] = gt_bbox[:, 2:4] - gt_bbox[:, :2] gt_bbox[:, 2:4] = gt_bbox[:, 2:4] - gt_bbox[:, :2]
gt_bbox[:, :2] = gt_bbox[:, :2] + gt_bbox[:, 2:4] / 2. gt_bbox[:, :2] = gt_bbox[:, :2] + gt_bbox[:, 2:4] / 2.
return [im_info, im, gt_bbox, gt_class, gt_score] return [im_id, im_shape, im, gt_bbox, gt_class, gt_score]
class RandomShape(object): class RandomShape(object):
...@@ -450,13 +450,13 @@ class RandomShape(object): ...@@ -450,13 +450,13 @@ class RandomShape(object):
method = np.random.choice(self.interps) if self.random_inter \ method = np.random.choice(self.interps) if self.random_inter \
else cv2.INTER_NEAREST else cv2.INTER_NEAREST
for i in range(len(samples)): for i in range(len(samples)):
im = samples[i][1] im = samples[i][2]
h, w = im.shape[:2] h, w = im.shape[:2]
scale_x = float(shape) / w scale_x = float(shape) / w
scale_y = float(shape) / h scale_y = float(shape) / h
im = cv2.resize( im = cv2.resize(
im, None, None, fx=scale_x, fy=scale_y, interpolation=method) im, None, None, fx=scale_x, fy=scale_y, interpolation=method)
samples[i][1] = im samples[i][2] = im
return samples return samples
...@@ -492,7 +492,7 @@ class NormalizeImage(object): ...@@ -492,7 +492,7 @@ class NormalizeImage(object):
3. (optional) permute channel 3. (optional) permute channel
""" """
for i in range(len(samples)): for i in range(len(samples)):
im = samples[i][1] im = samples[i][2]
im = im.astype(np.float32, copy=False) im = im.astype(np.float32, copy=False)
mean = np.array(self.mean)[np.newaxis, np.newaxis, :] mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
std = np.array(self.std)[np.newaxis, np.newaxis, :] std = np.array(self.std)[np.newaxis, np.newaxis, :]
...@@ -502,7 +502,7 @@ class NormalizeImage(object): ...@@ -502,7 +502,7 @@ class NormalizeImage(object):
im /= std im /= std
if self.channel_first: if self.channel_first:
im = im.transpose((2, 0, 1)) im = im.transpose((2, 0, 1))
samples[i][1] = im samples[i][2] = im
return samples return samples
...@@ -595,16 +595,15 @@ class ResizeImage(object): ...@@ -595,16 +595,15 @@ class ResizeImage(object):
format(type(target_size))) format(type(target_size)))
self.target_size = target_size self.target_size = target_size
def __call__(self, im_info, im, gt_bbox, gt_class, gt_score): def __call__(self, im_id, im_shape, im, gt_bbox, gt_class, gt_score):
""" Resize the image numpy. """ Resize the image numpy.
""" """
if not isinstance(im, np.ndarray): if not isinstance(im, np.ndarray):
raise TypeError("{}: image type is not numpy.".format(self)) raise TypeError("{}: image type is not numpy.".format(self))
if len(im.shape) != 3: if len(im.shape) != 3:
raise ImageError('{}: image is not 3-dimensional.'.format(self)) raise ImageError('{}: image is not 3-dimensional.'.format(self))
im_shape = im.shape im_scale_x = float(self.target_size) / float(im.shape[1])
im_scale_x = float(self.target_size) / float(im_shape[1]) im_scale_y = float(self.target_size) / float(im.shape[0])
im_scale_y = float(self.target_size) / float(im_shape[0])
resize_w = self.target_size resize_w = self.target_size
resize_h = self.target_size resize_h = self.target_size
...@@ -616,5 +615,5 @@ class ResizeImage(object): ...@@ -616,5 +615,5 @@ class ResizeImage(object):
fy=im_scale_y, fy=im_scale_y,
interpolation=self.interp) interpolation=self.interp)
return [im_info, im, gt_bbox, gt_class, gt_score] return [im_id, im_shape, im, gt_bbox, gt_class, gt_score]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册