diff --git a/model.py b/model.py index ba80bea0c137158bebbee2537bae3788d0229800..9c7f2a2f8a4a53ef50239f0c14de740b2fb9fea7 100644 --- a/model.py +++ b/model.py @@ -1125,19 +1125,19 @@ class Model(fluid.dygraph.Layer): if not isinstance(test_loader, Iterable): loader = test_loader() - outputs = None + outputs = [] for data in tqdm.tqdm(loader): if not fluid.in_dygraph_mode(): data = data[0] - outs = self.test(*data) + assert len(data) == len(self._inputs) + len(self._labels), \ + "data fileds number mismatch" + inputs_data = data[:len(self._inputs)] - if outputs is None: - outputs = outs - else: - outputs = [ - np.vstack([x, outs[i]]) for i, x in enumerate(outputs) - ] + outputs.append(self.test(inputs_data)) + + # sample list to batched data + outputs = list(zip(*outputs)) self._test_dataloader = None if test_loader is not None and self._adapter._nranks > 1 \ @@ -1180,11 +1180,16 @@ class Model(fluid.dygraph.Layer): else: batch_size = data[0].shape[0] + assert len(data) == len(self._inputs) + len(self._labels), \ + "data fileds number mismatch" + inputs_data = data[:len(self._inputs)] + labels_data = data[len(self._inputs):] + callbacks.on_batch_begin(mode, step, logs) if mode == 'train': - outs = self.train(*data) + outs = self.train(inputs_data, labels_data) else: - outs = self.eval(*data) + outs = self.eval(inputs_data, labels_data) # losses loss = outs[0] if self._metrics else outs diff --git a/yolov3.py b/yolov3.py index 6c609f24dce60293ee42a599324d595a6875a0f6..5fe8c5c4529b435c9ed3f1d951bf01206389c49a 100644 --- a/yolov3.py +++ b/yolov3.py @@ -18,233 +18,29 @@ from __future__ import print_function import argparse import contextlib import os -import random -import time -from functools import partial - -import cv2 import numpy as np -from pycocotools.coco import COCO - -import paddle -import paddle.fluid as fluid -from paddle.fluid.dygraph.nn import Conv2D -from paddle.fluid.param_attr import ParamAttr -from paddle.fluid.regularizer import L2Decay - -from model import Model, Loss, Input -from resnet import ResNet, ConvBNLayer - -import logging -FORMAT = '%(asctime)s-%(levelname)s: %(message)s' -logging.basicConfig(level=logging.INFO, format=FORMAT) -logger = logging.getLogger(__name__) - - -# XXX transfer learning -class ResNetBackBone(ResNet): - def __init__(self, depth=50): - super(ResNetBackBone, self).__init__(depth=depth) - delattr(self, 'fc') - - def forward(self, inputs): - x = self.conv(inputs) - x = self.pool(x) - outputs = [] - for layer in self.layers: - x = layer(x) - outputs.append(x) - return outputs - - -class YoloDetectionBlock(fluid.dygraph.Layer): - def __init__(self, num_channels, num_filters): - super(YoloDetectionBlock, self).__init__() - - assert num_filters % 2 == 0, \ - "num_filters {} cannot be divided by 2".format(num_filters) - - self.conv0 = ConvBNLayer( - num_channels=num_channels, - num_filters=num_filters, - filter_size=1, - act='leaky_relu') - self.conv1 = ConvBNLayer( - num_channels=num_filters, - num_filters=num_filters * 2, - filter_size=3, - act='leaky_relu') - self.conv2 = ConvBNLayer( - num_channels=num_filters * 2, - num_filters=num_filters, - filter_size=1, - act='leaky_relu') - self.conv3 = ConvBNLayer( - num_channels=num_filters, - num_filters=num_filters * 2, - filter_size=3, - act='leaky_relu') - self.route = ConvBNLayer( - num_channels=num_filters * 2, - num_filters=num_filters, - filter_size=1, - act='leaky_relu') - self.tip = ConvBNLayer( - num_channels=num_filters, - num_filters=num_filters * 2, - filter_size=3, - act='leaky_relu') - - def forward(self, inputs): - out = self.conv0(inputs) - out = self.conv1(out) - out = self.conv2(out) - out = self.conv3(out) - route = self.route(out) - tip = self.tip(route) - return route, tip - - -class YOLOv3(Model): - def __init__(self, num_classes=80): - super(YOLOv3, self).__init__() - self.num_classes = num_classes - self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, - 59, 119, 116, 90, 156, 198, 373, 326] - self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] - self.valid_thresh = 0.005 - self.nms_thresh = 0.45 - self.nms_topk = 400 - self.nms_posk = 100 - self.draw_thresh = 0.5 - - self.backbone = ResNetBackBone() - self.block_outputs = [] - self.yolo_blocks = [] - self.route_blocks = [] - - for idx, num_chan in enumerate([2048, 1280, 640]): - yolo_block = self.add_sublayer( - "detecton_block_{}".format(idx), - YoloDetectionBlock(num_chan, num_filters=512 // (2**idx))) - self.yolo_blocks.append(yolo_block) - - num_filters = len(self.anchor_masks[idx]) * (self.num_classes + 5) - - block_out = self.add_sublayer( - "block_out_{}".format(idx), - Conv2D(num_channels=1024 // (2**idx), - num_filters=num_filters, - filter_size=1, - param_attr=ParamAttr( - initializer=fluid.initializer.Normal(0., 0.02)), - bias_attr=ParamAttr( - initializer=fluid.initializer.Constant(0.0), - regularizer=L2Decay(0.)))) - self.block_outputs.append(block_out) - if idx < 2: - route = self.add_sublayer( - "route_{}".format(idx), - ConvBNLayer(num_channels=512 // (2**idx), - num_filters=256 // (2**idx), - filter_size=1, - act='leaky_relu')) - self.route_blocks.append(route) - - def forward(self, inputs, img_info): - outputs = [] - boxes = [] - scores = [] - downsample = 32 - - feats = self.backbone(inputs) - feats = feats[::-1][:len(self.anchor_masks)] - route = None - for idx, feat in enumerate(feats): - if idx > 0: - feat = fluid.layers.concat(input=[route, feat], axis=1) - route, tip = self.yolo_blocks[idx](feat) - block_out = self.block_outputs[idx](tip) - outputs.append(block_out) - if idx < 2: - route = self.route_blocks[idx](route) - route = fluid.layers.resize_nearest(route, scale=2) +from paddle import fluid +from paddle.fluid.optimizer import Momentum +from paddle.fluid.io import DataLoader - if self.mode == 'test': - anchor_mask = self.anchor_masks[idx] - mask_anchors = [] - for m in anchor_mask: - mask_anchors.append(self.anchors[2 * m]) - mask_anchors.append(self.anchors[2 * m + 1]) - img_shape = fluid.layers.slice(img_info, axes=[1], starts=[1], ends=[3]) - img_id = fluid.layers.slice(img_info, axes=[1], starts=[0], ends=[1]) - b, s = fluid.layers.yolo_box( - x=block_out, - img_size=img_shape, - anchors=mask_anchors, - class_num=self.num_classes, - conf_thresh=self.valid_thresh, - downsample_ratio=downsample) +from model import Model, Input, set_device +from distributed import DistributedBatchSampler +from yolov3.coco import * +from yolov3.transforms import * +from yolov3.modeling import * +from yolov3.coco_metric import * - boxes.append(b) - scores.append(fluid.layers.transpose(s, perm=[0, 2, 1])) +NUM_MAX_BOXES = 50 - downsample //= 2 - if self.mode != 'test': - return outputs - - return [img_id, fluid.layers.multiclass_nms( - bboxes=fluid.layers.concat(boxes, axis=1), - scores=fluid.layers.concat(scores, axis=2), - score_threshold=self.valid_thresh, - nms_top_k=self.nms_topk, - keep_top_k=self.nms_posk, - nms_threshold=self.nms_thresh, - background_label=-1)] - - -class YoloLoss(Loss): - def __init__(self, num_classes=80): - super(YoloLoss, self).__init__() - self.num_classes = num_classes - self.ignore_thresh = 0.7 - self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, - 59, 119, 116, 90, 156, 198, 373, 326] - self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] - - def forward(self, outputs, labels): - downsample = 32 - gt_box, gt_label, gt_score = labels - losses = [] - - for idx, out in enumerate(outputs): - anchor_mask = self.anchor_masks[idx] - loss = fluid.layers.yolov3_loss( - x=out, - gt_box=gt_box, - gt_label=gt_label, - gt_score=gt_score, - anchor_mask=anchor_mask, - downsample_ratio=downsample, - anchors=self.anchors, - class_num=self.num_classes, - ignore_thresh=self.ignore_thresh, - use_label_smooth=True) - loss = fluid.layers.reduce_mean(loss) - losses.append(loss) - downsample //= 2 - return losses - - -def make_optimizer(parameter_list=None): +def make_optimizer(step_per_epoch, parameter_list=None): base_lr = FLAGS.lr - warm_up_iter = 4000 + warm_up_iter = 1000 momentum = 0.9 weight_decay = 5e-4 - boundaries = [400000, 450000] + boundaries = [step_per_epoch * e for e in [200, 250]] values = [base_lr * (0.1 ** i) for i in range(len(boundaries) + 1)] learning_rate = fluid.layers.piecewise_decay( boundaries=boundaries, @@ -262,307 +58,151 @@ def make_optimizer(parameter_list=None): return optimizer -def _iou_matrix(a, b): - tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2]) - br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) - area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2) - area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) - area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) - area_o = (area_a[:, np.newaxis] + area_b - area_i) - return area_i / (area_o + 1e-10) - - -def _crop_box_with_center_constraint(box, crop): - cropped_box = box.copy() - cropped_box[:, :2] = np.maximum(box[:, :2], crop[:2]) - cropped_box[:, 2:] = np.minimum(box[:, 2:], crop[2:]) - cropped_box[:, :2] -= crop[:2] - cropped_box[:, 2:] -= crop[:2] - centers = (box[:, :2] + box[:, 2:]) / 2 - valid = np.logical_and( - crop[:2] <= centers, centers < crop[2:]).all(axis=1) - valid = np.logical_and( - valid, (cropped_box[:, :2] < cropped_box[:, 2:]).all(axis=1)) - return cropped_box, np.where(valid)[0] - - -def random_crop(inputs): - aspect_ratios = [.5, 2.] - thresholds = [.0, .1, .3, .5, .7, .9] - scaling = [.3, 1.] - - img, img_ids, gt_box, gt_label = inputs - h, w = img.shape[:2] - - if len(gt_box) == 0: - return inputs - - np.random.shuffle(thresholds) - for thresh in thresholds: - found = False - for i in range(50): - scale = np.random.uniform(*scaling) - min_ar, max_ar = aspect_ratios - ar = np.random.uniform(max(min_ar, scale**2), - min(max_ar, scale**-2)) - crop_h = int(h * scale / np.sqrt(ar)) - crop_w = int(w * scale * np.sqrt(ar)) - crop_y = np.random.randint(0, h - crop_h) - crop_x = np.random.randint(0, w - crop_w) - crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h] - iou = _iou_matrix(gt_box, np.array([crop_box], dtype=np.float32)) - if iou.max() < thresh: - continue - - cropped_box, valid_ids = _crop_box_with_center_constraint( - gt_box, np.array(crop_box, dtype=np.float32)) - if valid_ids.size > 0: - found = True - break - - if found: - x1, y1, x2, y2 = crop_box - img = img[y1:y2, x1:x2, :] - gt_box = np.take(cropped_box, valid_ids, axis=0) - gt_label = np.take(gt_label, valid_ids, axis=0) - return img, img_ids, gt_box, gt_label - - return inputs - - -# XXX mix up, color distort and random expand are skipped for simplicity -def sample_transform(inputs, mode='train', num_max_boxes=50): - if mode == 'train': - img, img_id, gt_box, gt_label = random_crop(inputs) - else: - img, img_id, gt_box, gt_label = inputs - - h, w = img.shape[:2] - # random flip - if mode == 'train' and np.random.uniform(0., 1.) > .5: - img = img[:, ::-1, :] - if len(gt_box) > 0: - swap = gt_box.copy() - gt_box[:, 0] = w - swap[:, 2] - 1 - gt_box[:, 2] = w - swap[:, 0] - 1 - - if len(gt_label) == 0: - gt_box = np.zeros([num_max_boxes, 4], dtype=np.float32) - gt_label = np.zeros([num_max_boxes], dtype=np.int32) - return img, gt_box, gt_label - - gt_box = gt_box[:num_max_boxes, :] - gt_label = gt_label[:num_max_boxes, 0] - # normalize boxes - gt_box /= np.array([w, h] * 2, dtype=np.float32) - gt_box[:, 2:] = gt_box[:, 2:] - gt_box[:, :2] - gt_box[:, :2] = gt_box[:, :2] + gt_box[:, 2:] / 2. - - pad = num_max_boxes - gt_label.size - gt_box = np.pad(gt_box, ((0, pad), (0, 0)), mode='constant') - gt_label = np.pad(gt_label, ((0, pad)), mode='constant') - - return img, img_id, gt_box, gt_label - - -def batch_transform(batch, mode='train'): - if mode == 'train': - d = np.random.choice( - [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]) - interp = np.random.choice(range(5)) - else: - d = 608 - interp = cv2.INTER_CUBIC - # transpose batch - imgs, img_ids, gt_boxes, gt_labels = list(zip(*batch)) - img_shapes = np.array([[im.shape[0], im.shape[1]] for im in imgs]).astype('int32') - imgs = np.array([cv2.resize( - img, (d, d), interpolation=interp) for img in imgs]) - - # transpose, permute and normalize - imgs = imgs.astype(np.float32)[..., ::-1] - mean = np.array([123.675, 116.28, 103.53], dtype=np.float32) - std = np.array([58.395, 57.120, 57.375], dtype=np.float32) - invstd = 1. / std - imgs -= mean - imgs *= invstd - imgs = imgs.transpose((0, 3, 1, 2)) - - img_ids = np.array(img_ids) - img_info = np.concatenate([img_ids, img_shapes], axis=1) - gt_boxes = np.array(gt_boxes) - gt_labels = np.array(gt_labels) - # XXX since mix up is not used, scores are all ones - gt_scores = np.ones_like(gt_labels, dtype=np.float32) - return [imgs, img_info], [gt_boxes, gt_labels, gt_scores] - - -def coco2017(root_dir, mode='train'): - json_path = os.path.join( - root_dir, 'annotations/instances_{}2017.json'.format(mode)) - coco = COCO(json_path) - img_ids = coco.getImgIds() - imgs = coco.loadImgs(img_ids) - class_map = {v: i + 1 for i, v in enumerate(coco.getCatIds())} - samples = [] - - for img in imgs: - img_path = os.path.join( - root_dir, '{}2017'.format(mode), img['file_name']) - file_path = img_path - width = img['width'] - height = img['height'] - ann_ids = coco.getAnnIds(imgIds=img['id'], iscrowd=False) - anns = coco.loadAnns(ann_ids) - - gt_box = [] - gt_label = [] - - for ann in anns: - x1, y1, w, h = ann['bbox'] - x2 = x1 + w - 1 - y2 = y1 + h - 1 - x1 = np.clip(x1, 0, width - 1) - x2 = np.clip(x2, 0, width - 1) - y1 = np.clip(y1, 0, height - 1) - y2 = np.clip(y2, 0, height - 1) - if ann['area'] <= 0 or x2 < x1 or y2 < y1: - continue - gt_label.append(ann['category_id']) - gt_box.append([x1, y1, x2, y2]) - - gt_box = np.array(gt_box, dtype=np.float32) - gt_label = np.array([class_map[cls] for cls in gt_label], - dtype=np.int32)[:, np.newaxis] - im_id = np.array([img['id']], dtype=np.int32) - - if gt_label.size == 0 and not mode == 'train': - continue - samples.append((file_path, im_id.copy(), gt_box.copy(), gt_label.copy())) - - def iterator(): - if mode == 'train': - np.random.shuffle(samples) - for file_path, im_id, gt_box, gt_label in samples: - img = cv2.imread(file_path) - yield img, im_id, gt_box, gt_label - - return iterator - - -# XXX coco metrics not included for simplicity -def run(model, loader, mode='train'): - total_loss = 0. - total_time = 0. - device_ids = list(range(FLAGS.num_devices)) - start = time.time() - - for idx, batch in enumerate(loader()): - losses = getattr(model, mode)(batch[0], batch[1]) - - total_loss += np.sum(losses) - if idx > 1: # skip first two steps - total_time += time.time() - start - if idx % 10 == 0: - logger.info("{:04d}: loss {:0.3f} time: {:0.3f}".format( - idx, total_loss / (idx + 1), total_time / max(1, (idx - 1)))) - start = time.time() - - def main(): - @contextlib.contextmanager - def null_guard(): - yield - - epoch = FLAGS.epoch - batch_size = FLAGS.batch_size - guard = fluid.dygraph.guard() if FLAGS.dynamic else null_guard() + device = set_device(FLAGS.device) + fluid.enable_dygraph(device) if FLAGS.dynamic else None + + inputs = [Input([None, 3], 'int32', name='img_info'), + Input([None, 3, None, None], 'float32', name='image')] + labels = [Input([None, NUM_MAX_BOXES, 4], 'float32', name='gt_bbox'), + Input([None, NUM_MAX_BOXES], 'int32', name='gt_label'), + Input([None, NUM_MAX_BOXES], 'float32', name='gt_score')] + + if not FLAGS.eval_only: # training mode + train_transform = Compose([ColorDistort(), + RandomExpand(), + RandomCrop(), + RandomFlip(), + NormalizeBox(), + PadBox(), + BboxXYXY2XYWH()]) + train_collate_fn = BatchCompose([RandomShape(), NormalizeImage()]) + dataset = COCODataset(dataset_dir=FLAGS.data, + anno_path='annotations/instances_train2017.json', + image_dir='train2017', + with_background=False, + mixup=True, + transform=train_transform) + batch_sampler = DistributedBatchSampler(dataset, + batch_size=FLAGS.batch_size, + shuffle=True, + drop_last=True) + loader = DataLoader(dataset, + batch_sampler=batch_sampler, + places=device, + feed_list=[i.forward() for i in inputs + labels] \ + if not FLAGS.dynamic else None, + num_workers=FLAGS.num_workers, + return_list=True, + collate_fn=train_collate_fn) + else: # evaluation mode + eval_transform = Compose([ResizeImage(target_size=608), + NormalizeBox(), + PadBox(), + BboxXYXY2XYWH()]) + eval_collate_fn = BatchCompose([NormalizeImage()]) + dataset = COCODataset(dataset_dir=FLAGS.data, + anno_path='annotations/instances_val2017.json', + image_dir='val2017', + with_background=False, + transform=eval_transform) + # batch_size can only be 1 in evaluation for YOLOv3 + # prediction bbox is LoDTensor + batch_sampler = DistributedBatchSampler(dataset, + batch_size=1, + shuffle=False, + drop_last=False) + loader = DataLoader(dataset, + batch_sampler=batch_sampler, + places=device, + feed_list=[i.forward() for i in inputs + labels] \ + if not FLAGS.dynamic else None, + num_workers=FLAGS.num_workers, + return_list=True, + collate_fn=eval_collate_fn) + + model = YOLOv3(num_classes=dataset.num_classes, + model_mode='eval' if FLAGS.eval_only else 'train') + if FLAGS.pretrain_weights is not None: + model.load(FLAGS.pretrain_weights, skip_mismatch=True, reset_optimizer=True) + + optim = make_optimizer(len(batch_sampler), parameter_list=model.parameters()) + + model.prepare(optim, + YoloLoss(num_classes=dataset.num_classes), + inputs=inputs, labels=labels, + device=FLAGS.device) + + # NOTE: we implement COCO metric of YOLOv3 model here, separately + # from 'prepare' and 'fit' framework for follwing reason: + # 1. YOLOv3 network structure is different between 'train' and + # 'eval' mode, in 'eval' mode, output prediction bbox is not the + # feature map used for YoloLoss calculating + # 2. COCO metric behavior is also different from defined Metric + # for COCO metric should not perform accumulate in each iteration + # but only accumulate at the end of an epoch + if FLAGS.eval_only: + if FLAGS.weights is not None: + model.load(FLAGS.weights) + preds = model.predict(loader) + _, _, _, img_ids, bboxes = preds - train_loader = fluid.io.xmap_readers( - batch_transform, - paddle.batch( - fluid.io.xmap_readers( - sample_transform, - coco2017(FLAGS.data, 'train'), - process_num=8, - buffer_size=4 * batch_size), - batch_size=batch_size, - drop_last=True), - process_num=2, buffer_size=4) + anno_path = os.path.join(FLAGS.data, 'annotations/instances_val2017.json') + coco_metric = COCOMetric(anno_path=anno_path, with_background=False) + for img_id, bbox in zip(img_ids, bboxes): + coco_metric.update(img_id, bbox) + coco_metric.accumulate() + coco_metric.reset() + return - val_sample_transform = partial(sample_transform, mode='val') - val_batch_transform = partial(batch_transform, mode='val') + if FLAGS.resume is not None: + model.load(FLAGS.resume) - val_loader = fluid.io.xmap_readers( - val_batch_transform, - paddle.batch( - fluid.io.xmap_readers( - val_sample_transform, - coco2017(FLAGS.data, 'val'), - process_num=8, - buffer_size=4 * batch_size), - batch_size=1), - process_num=2, buffer_size=4) + model.fit(train_data=loader, + epochs=FLAGS.epoch - FLAGS.no_mixup_epoch, + save_dir="yolo_checkpoint/mixup", + save_freq=10) - if not os.path.exists('yolo_checkpoints'): - os.mkdir('yolo_checkpoints') - - with guard: - NUM_CLASSES = 7 - NUM_MAX_BOXES = 50 - model = YOLOv3(num_classes=NUM_CLASSES) - # XXX transfer learning - if FLAGS.pretrain_weights is not None: - model.backbone.load(FLAGS.pretrain_weights) - if FLAGS.weights is not None: - model.load(FLAGS.weights) - optim = make_optimizer(parameter_list=model.parameters()) - anno_path = os.path.join(FLAGS.data, 'annotations', 'instances_val2017.json') - inputs = [Input([None, 3, None, None], 'float32', name='image'), - Input([None, 3], 'int32', name='img_info')] - labels = [Input([None, NUM_MAX_BOXES, 4], 'float32', name='gt_bbox'), - Input([None, NUM_MAX_BOXES], 'int32', name='gt_label'), - Input([None, NUM_MAX_BOXES], 'float32', name='gt_score')] - model.prepare(optim, - YoloLoss(num_classes=NUM_CLASSES), - # For YOLOv3, output variable in train/eval is different, - # which is not supported by metric, add by callback later? - # metrics=COCOMetric(anno_path, with_background=False) - inputs=inputs, - labels = labels) - - for e in range(epoch): - logger.info("======== train epoch {} ========".format(e)) - run(model, train_loader) - model.save('yolo_checkpoints/{:02d}'.format(e)) - logger.info("======== eval epoch {} ========".format(e)) - run(model, val_loader, mode='eval') - # should be called in fit() - for metric in model._metrics: - metric.accumulate() - metric.reset() + # do not use image mixup transfrom in laste FLAGS.no_mixup_epoch epoches + dataset.mixup = False + model.fit(train_data=loader, + epochs=FLAGS.no_mixup_epoch, + save_dir="yolo_checkpoint/no_mixup", + save_freq=5) if __name__ == '__main__': parser = argparse.ArgumentParser("Yolov3 Training on COCO") parser.add_argument('data', metavar='DIR', help='path to COCO dataset') + parser.add_argument( + "--device", type=str, default='gpu', help="device to use, gpu or cpu") parser.add_argument( "-d", "--dynamic", action='store_true', help="enable dygraph mode") + parser.add_argument( + "--eval_only", action='store_true', help="run evaluation only") parser.add_argument( "-e", "--epoch", default=300, type=int, help="number of epoch") + parser.add_argument( + "--no_mixup_epoch", default=30, type=int, + help="number of the last N epoch without image mixup") parser.add_argument( '--lr', '--learning-rate', default=0.001, type=float, metavar='LR', help='initial learning rate') parser.add_argument( - "-b", "--batch_size", default=64, type=int, help="batch size") + "-b", "--batch_size", default=8, type=int, help="batch size") parser.add_argument( - "-n", "--num_devices", default=8, type=int, help="number of devices") + "-n", "--num_devices", default=1, type=int, help="number of devices") + parser.add_argument( + "-j", "--num_workers", default=4, type=int, help="reader worker number") parser.add_argument( "-p", "--pretrain_weights", default=None, type=str, help="path to pretrained weights") parser.add_argument( - "-w", "--weights", default=None, type=str, + "-r", "--resume", default=None, type=str, help="path to model weights") + parser.add_argument( + "-w", "--weights", default=None, type=str, + help="path to weights for evaluation") FLAGS = parser.parse_args() assert FLAGS.data, "error: must provide data path" main() diff --git a/yolov3/coco.py b/yolov3/coco.py new file mode 100644 index 0000000000000000000000000000000000000000..34809246c1f90d3ad029842c19ae5f2c3eba08b0 --- /dev/null +++ b/yolov3/coco.py @@ -0,0 +1,275 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import os +import cv2 +import numpy as np +from pycocotools.coco import COCO + +from paddle.fluid.io import Dataset + +import logging +logger = logging.getLogger(__name__) + +__all__ = ['COCODataset'] + + +class COCODataset(Dataset): + """ + Load dataset with MS-COCO format. + + Args: + dataset_dir (str): root directory for dataset. + image_dir (str): directory for images. + anno_path (str): voc annotation file path. + sample_num (int): number of samples to load, -1 means all. + use_default_label (bool): whether use the default mapping of + label to integer index. Default True. + with_background (bool): whether load background as a class, + default True. + transform (callable): callable transform to perform on samples, + default None. + mixup (bool): whether return image mixup samples, default False. + alpha (float): alpha factor of beta distribution to generate + mixup score, used only when mixup is True, default 1.5 + beta (float): beta factor of beta distribution to generate + mixup score, used only when mixup is True, default 1.5 + """ + + def __init__(self, + dataset_dir='', + image_dir='', + anno_path='', + sample_num=-1, + with_background=True, + transform=None, + mixup=False, + alpha=1.5, + beta=1.5): + # roidbs is list of dict whose structure is: + # { + # 'im_file': im_fname, # image file name + # 'im_id': im_id, # image id + # 'h': im_h, # height of image + # 'w': im_w, # width + # 'is_crowd': is_crowd, + # 'gt_class': gt_class, + # 'gt_bbox': gt_bbox, + # 'gt_score': gt_score, + # 'difficult': difficult + # } + + self._anno_path = os.path.join(dataset_dir, anno_path) + self._image_dir = os.path.join(dataset_dir, image_dir) + assert os.path.exists(self._anno_path), \ + "anno_path {} not exists".format(anno_path) + assert os.path.exists(self._image_dir), \ + "image_dir {} not exists".format(image_dir) + + self._sample_num = sample_num + self._with_background = with_background + self._transform = transform + self._mixup = mixup + self._alpha = alpha + self._beta = beta + + # load in dataset roidbs + self._load_roidb_and_cname2cid() + + def _load_roidb_and_cname2cid(self): + assert self._anno_path.endswith('.json'), \ + 'invalid coco annotation file: ' + anno_path + coco = COCO(self._anno_path) + img_ids = coco.getImgIds() + cat_ids = coco.getCatIds() + records = [] + ct = 0 + + # when with_background = True, mapping category to classid, like: + # background:0, first_class:1, second_class:2, ... + catid2clsid = dict({ + catid: i + int(self._with_background) + for i, catid in enumerate(cat_ids) + }) + cname2cid = dict({ + coco.loadCats(catid)[0]['name']: clsid + for catid, clsid in catid2clsid.items() + }) + + for img_id in img_ids: + img_anno = coco.loadImgs(img_id)[0] + im_fname = img_anno['file_name'] + im_w = float(img_anno['width']) + im_h = float(img_anno['height']) + + ins_anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=False) + instances = coco.loadAnns(ins_anno_ids) + + bboxes = [] + for inst in instances: + x, y, box_w, box_h = inst['bbox'] + x1 = max(0, x) + y1 = max(0, y) + x2 = min(im_w - 1, x1 + max(0, box_w - 1)) + y2 = min(im_h - 1, y1 + max(0, box_h - 1)) + if inst['area'] > 0 and x2 >= x1 and y2 >= y1: + inst['clean_bbox'] = [x1, y1, x2, y2] + bboxes.append(inst) + else: + logger.warn( + 'Found an invalid bbox in annotations: im_id: {}, ' + 'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'.format( + img_id, float(inst['area']), x1, y1, x2, y2)) + num_bbox = len(bboxes) + + gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32) + gt_class = np.zeros((num_bbox, 1), dtype=np.int32) + gt_score = np.ones((num_bbox, 1), dtype=np.float32) + is_crowd = np.zeros((num_bbox, 1), dtype=np.int32) + difficult = np.zeros((num_bbox, 1), dtype=np.int32) + gt_poly = [None] * num_bbox + + for i, box in enumerate(bboxes): + catid = box['category_id'] + gt_class[i][0] = catid2clsid[catid] + gt_bbox[i, :] = box['clean_bbox'] + is_crowd[i][0] = box['iscrowd'] + if 'segmentation' in box: + gt_poly[i] = box['segmentation'] + + im_fname = os.path.join(self._image_dir, + im_fname) if self._image_dir else im_fname + coco_rec = { + 'im_file': im_fname, + 'im_id': np.array([img_id]), + 'h': im_h, + 'w': im_w, + 'is_crowd': is_crowd, + 'gt_class': gt_class, + 'gt_bbox': gt_bbox, + 'gt_score': gt_score, + 'gt_poly': gt_poly, + } + + records.append(coco_rec) + ct += 1 + if self._sample_num > 0 and ct >= self._sample_num: + break + assert len(records) > 0, 'not found any coco record in %s' % (self._anno_path) + logger.info('{} samples in file {}'.format(ct, self._anno_path)) + self._roidbs, self._cname2cid = records, cname2cid + + @property + def num_classes(self): + return len(self._cname2cid) + + def __len__(self): + return len(self._roidbs) + + def _getitem_by_index(self, idx): + roidb = self._roidbs[idx] + with open(roidb['im_file'], 'rb') as f: + data = np.frombuffer(f.read(), dtype='uint8') + im = cv2.imdecode(data, 1) + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + im_info = np.array([roidb['im_id'][0], roidb['h'], roidb['w']], dtype='int32') + gt_bbox = roidb['gt_bbox'] + gt_class = roidb['gt_class'] + gt_score = roidb['gt_score'] + return im_info, im, gt_bbox, gt_class, gt_score + + def __getitem__(self, idx): + im_info, im, gt_bbox, gt_class, gt_score = self._getitem_by_index(idx) + + if self._mixup: + mixup_idx = idx + np.random.randint(1, self.__len__()) + mixup_idx %= self.__len__() + _, mixup_im, mixup_bbox, mixup_class, _ = \ + self._getitem_by_index(mixup_idx) + + im, gt_bbox, gt_class, gt_score = \ + self._mixup_image(im, gt_bbox, gt_class, mixup_im, + mixup_bbox, mixup_class) + + if self._transform: + im_info, im, gt_bbox, gt_class, gt_score = \ + self._transform(im_info, im, gt_bbox, gt_class, gt_score) + + return [im_info, im, gt_bbox, gt_class, gt_score] + + def _mixup_image(self, img1, bbox1, class1, img2, bbox2, class2): + factor = np.random.beta(self._alpha, self._beta) + factor = max(0.0, min(1.0, factor)) + if factor >= 1.0: + return img1, bbox1, class1, np.ones_like(class1, dtype="float32") + if factor <= 0.0: + return img2, bbox2, class2, np.ones_like(class2, dtype="float32") + + h = max(img1.shape[0], img2.shape[0]) + w = max(img1.shape[1], img2.shape[1]) + img = np.zeros((h, w, img1.shape[2]), 'float32') + img[:img1.shape[0], :img1.shape[1], :] = \ + img1.astype('float32') * factor + img[:img2.shape[0], :img2.shape[1], :] += \ + img2.astype('float32') * (1.0 - factor) + + gt_bbox = np.concatenate((bbox1, bbox2), axis=0) + gt_class = np.concatenate((class1, class2), axis=0) + + score1 = np.ones_like(class1, dtype="float32") * factor + score2 = np.ones_like(class2, dtype="float32") * (1.0 - factor) + gt_score = np.concatenate((score1, score2), axis=0) + + return img, gt_bbox, gt_class, gt_score + + @property + def mixup(self): + return self._mixup + + @mixup.setter + def mixup(self, value): + if not isinstance(value, bool): + raise ValueError("mixup should be a boolean number") + logger.info("{} set mixup to {}".format(self, value)) + self._mixup = value + +def pascalvoc_label(with_background=True): + labels_map = { + 'aeroplane': 1, + 'bicycle': 2, + 'bird': 3, + 'boat': 4, + 'bottle': 5, + 'bus': 6, + 'car': 7, + 'cat': 8, + 'chair': 9, + 'cow': 10, + 'diningtable': 11, + 'dog': 12, + 'horse': 13, + 'motorbike': 14, + 'person': 15, + 'pottedplant': 16, + 'sheep': 17, + 'sofa': 18, + 'train': 19, + 'tvmonitor': 20 + } + if not with_background: + labels_map = {k: v - 1 for k, v in labels_map.items()} + return labels_map diff --git a/yolov3/coco_metric.py b/yolov3/coco_metric.py index ec7bcac24b3dde91d3ae85e39e7bf9e5151f43ec..4fb3103fc03cc07f4722e164fcdc02f7e259e1e0 100644 --- a/yolov3/coco_metric.py +++ b/yolov3/coco_metric.py @@ -17,8 +17,6 @@ import json from pycocotools.cocoeval import COCOeval from pycocotools.coco import COCO -from metrics import Metric - import logging FORMAT = '%(asctime)s-%(levelname)s: %(message)s' logging.basicConfig(level=logging.INFO, format=FORMAT) @@ -31,7 +29,7 @@ OUTFILE = './bbox.json' # considered to change to a callback later -class COCOMetric(Metric): +class COCOMetric(): """ Metrci for MS-COCO dataset, only support update with batch size as 1. @@ -43,26 +41,24 @@ class COCOMetric(Metric): """ def __init__(self, anno_path, with_background=True, **kwargs): - super(COCOMetric, self).__init__(**kwargs) self.anno_path = anno_path self.with_background = with_background self.bbox_results = [] self.coco_gt = COCO(anno_path) cat_ids = self.coco_gt.getCatIds() - self.clsid2catid = dict( - {i + int(with_background): catid - for i, catid in enumerate(cat_ids)}) + self.clsid2catid = dict( + {i + int(with_background): catid + for i, catid in enumerate(cat_ids)}) - def update(self, preds, *args, **kwargs): - im_ids, bboxes = preds - assert im_ids.shape[0] == 1, \ + def update(self, img_id, bboxes): + assert img_id.shape[0] == 1, \ "COCOMetric can only update with batch size = 1" if bboxes.shape[1] != 6: # no bbox detected in this batch return - im_id = int(im_ids) + img_id = int(img_id) for i in range(bboxes.shape[0]): dt = bboxes[i, :] clsid, score, xmin, ymin, xmax, ymax = dt.tolist() @@ -72,7 +68,7 @@ class COCOMetric(Metric): h = ymax - ymin + 1 bbox = [xmin, ymin, w, h] coco_res = { - 'image_id': im_id, + 'image_id': img_id, 'category_id': catid, 'bbox': bbox, 'score': score @@ -83,30 +79,30 @@ class COCOMetric(Metric): self.bbox_results = [] def accumulate(self): - if len(self.bbox_results) == 0: - logger.warning("The number of valid bbox detected is zero.\n \ - Please use reasonable model and check input data.\n \ - stop COCOMetric accumulate!") - return [0.0] - with open(OUTFILE, 'w') as f: - json.dump(self.bbox_results, f) - - map_stats = self.cocoapi_eval(OUTFILE, 'bbox', coco_gt=self.coco_gt) - # flush coco evaluation result - sys.stdout.flush() + if len(self.bbox_results) == 0: + logger.warning("The number of valid bbox detected is zero.\n \ + Please use reasonable model and check input data.\n \ + stop COCOMetric accumulate!") + return [0.0] + with open(OUTFILE, 'w') as f: + json.dump(self.bbox_results, f) + + map_stats = self.cocoapi_eval(OUTFILE, 'bbox', coco_gt=self.coco_gt) + # flush coco evaluation result + sys.stdout.flush() self.result = map_stats[0] - return self.result + return [self.result] def cocoapi_eval(self, jsonfile, style, coco_gt=None, anno_file=None): - assert coco_gt != None or anno_file != None - - if coco_gt == None: - coco_gt = COCO(anno_file) - logger.info("Start evaluate...") - coco_dt = coco_gt.loadRes(jsonfile) - coco_eval = COCOeval(coco_gt, coco_dt, style) - coco_eval.evaluate() - coco_eval.accumulate() - coco_eval.summarize() - return coco_eval.stats + assert coco_gt != None or anno_file != None + + if coco_gt == None: + coco_gt = COCO(anno_file) + logger.info("Start evaluate...") + coco_dt = coco_gt.loadRes(jsonfile) + coco_eval = COCOeval(coco_gt, coco_dt, style) + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + return coco_eval.stats diff --git a/yolov3/darknet.py b/yolov3/darknet.py new file mode 100755 index 0000000000000000000000000000000000000000..9220bca95c8e05f41531204503b2cee2355d9781 --- /dev/null +++ b/yolov3/darknet.py @@ -0,0 +1,177 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.regularizer import L2Decay + +from paddle.fluid.dygraph.nn import Conv2D, BatchNorm +from paddle.fluid.dygraph.base import to_variable + +__all__ = ['DarkNet53', 'ConvBNLayer'] + + +class ConvBNLayer(fluid.dygraph.Layer): + def __init__(self, + ch_in, + ch_out, + filter_size=3, + stride=1, + groups=1, + padding=0, + act="leaky"): + super(ConvBNLayer, self).__init__() + + self.conv = Conv2D( + num_channels=ch_in, + num_filters=ch_out, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=groups, + param_attr=ParamAttr( + initializer=fluid.initializer.Normal(0., 0.02)), + bias_attr=False, + act=None) + self.batch_norm = BatchNorm( + num_channels=ch_out, + param_attr=ParamAttr( + initializer=fluid.initializer.Normal(0., 0.02), + regularizer=L2Decay(0.)), + bias_attr=ParamAttr( + initializer=fluid.initializer.Constant(0.0), + regularizer=L2Decay(0.))) + + self.act = act + + def forward(self, inputs): + out = self.conv(inputs) + out = self.batch_norm(out) + if self.act == 'leaky': + out = fluid.layers.leaky_relu(x=out, alpha=0.1) + return out + +class DownSample(fluid.dygraph.Layer): + def __init__(self, + ch_in, + ch_out, + filter_size=3, + stride=2, + padding=1): + + super(DownSample, self).__init__() + + self.conv_bn_layer = ConvBNLayer( + ch_in=ch_in, + ch_out=ch_out, + filter_size=filter_size, + stride=stride, + padding=padding) + self.ch_out = ch_out + def forward(self, inputs): + out = self.conv_bn_layer(inputs) + return out + +class BasicBlock(fluid.dygraph.Layer): + def __init__(self, ch_in, ch_out): + super(BasicBlock, self).__init__() + + self.conv1 = ConvBNLayer( + ch_in=ch_in, + ch_out=ch_out, + filter_size=1, + stride=1, + padding=0) + self.conv2 = ConvBNLayer( + ch_in=ch_out, + ch_out=ch_out*2, + filter_size=3, + stride=1, + padding=1) + def forward(self, inputs): + conv1 = self.conv1(inputs) + conv2 = self.conv2(conv1) + out = fluid.layers.elementwise_add(x=inputs, y=conv2, act=None) + return out + +class LayerWarp(fluid.dygraph.Layer): + def __init__(self, ch_in, ch_out, count): + super(LayerWarp,self).__init__() + + self.basicblock0 = BasicBlock(ch_in, ch_out) + self.res_out_list = [] + for i in range(1,count): + res_out = self.add_sublayer("basic_block_%d" % (i), + BasicBlock( + ch_out*2, + ch_out)) + self.res_out_list.append(res_out) + self.ch_out = ch_out + def forward(self,inputs): + y = self.basicblock0(inputs) + for basic_block_i in self.res_out_list: + y = basic_block_i(y) + return y + + +DarkNet_cfg = {53: ([1, 2, 8, 8, 4])} + + +class DarkNet53(fluid.dygraph.Layer): + def __init__(self, ch_in=3): + super(DarkNet53, self).__init__() + self.stages = DarkNet_cfg[53] + self.stages = self.stages[0:5] + + self.conv0 = ConvBNLayer( + ch_in=ch_in, + ch_out=32, + filter_size=3, + stride=1, + padding=1) + + self.downsample0 = DownSample( + ch_in=32, + ch_out=32 * 2) + self.darknet53_conv_block_list = [] + self.downsample_list = [] + ch_in = [64,128,256,512,1024] + for i, stage in enumerate(self.stages): + conv_block = self.add_sublayer( + "stage_%d" % (i), + LayerWarp( + int(ch_in[i]), + 32*(2**i), + stage)) + self.darknet53_conv_block_list.append(conv_block) + for i in range(len(self.stages) - 1): + downsample = self.add_sublayer( + "stage_%d_downsample" % i, + DownSample( + ch_in = 32*(2**(i+1)), + ch_out = 32*(2**(i+2)))) + self.downsample_list.append(downsample) + + def forward(self,inputs): + + out = self.conv0(inputs) + out = self.downsample0(out) + blocks = [] + for i, conv_block_i in enumerate(self.darknet53_conv_block_list): + out = conv_block_i(out) + blocks.append(out) + if i < len(self.stages) - 1: + out = self.downsample_list[i](out) + return blocks[-1:-4:-1] + diff --git a/yolov3/modeling.py b/yolov3/modeling.py new file mode 100644 index 0000000000000000000000000000000000000000..810aa6df17bcc50c301f5e8c6a5b8a3e92c835bd --- /dev/null +++ b/yolov3/modeling.py @@ -0,0 +1,220 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import paddle.fluid as fluid +from paddle.fluid.dygraph.nn import Conv2D +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.regularizer import L2Decay + +from model import Model, Loss +from .darknet import DarkNet53, ConvBNLayer + +__all__ = ['YoloLoss', 'YOLOv3'] + + +class YoloDetectionBlock(fluid.dygraph.Layer): + def __init__(self, ch_in, channel): + super(YoloDetectionBlock, self).__init__() + + assert channel % 2 == 0, \ + "channel {} cannot be divided by 2".format(channel) + + self.conv0 = ConvBNLayer( + ch_in=ch_in, + ch_out=channel, + filter_size=1, + stride=1, + padding=0) + self.conv1 = ConvBNLayer( + ch_in=channel, + ch_out=channel*2, + filter_size=3, + stride=1, + padding=1) + self.conv2 = ConvBNLayer( + ch_in=channel*2, + ch_out=channel, + filter_size=1, + stride=1, + padding=0) + self.conv3 = ConvBNLayer( + ch_in=channel, + ch_out=channel*2, + filter_size=3, + stride=1, + padding=1) + self.route = ConvBNLayer( + ch_in=channel*2, + ch_out=channel, + filter_size=1, + stride=1, + padding=0) + self.tip = ConvBNLayer( + ch_in=channel, + ch_out=channel*2, + filter_size=3, + stride=1, + padding=1) + + def forward(self, inputs): + out = self.conv0(inputs) + out = self.conv1(out) + out = self.conv2(out) + out = self.conv3(out) + route = self.route(out) + tip = self.tip(route) + return route, tip + + +class YOLOv3(Model): + def __init__(self, num_classes=80, model_mode='train'): + super(YOLOv3, self).__init__() + self.num_classes = num_classes + assert str.lower(model_mode) in ['train', 'eval'], \ + "model_mode should be 'train' or 'val', but got " \ + "{}".format(model_mode) + self.model_mode = str.lower(model_mode) + self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, + 59, 119, 116, 90, 156, 198, 373, 326] + self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] + self.valid_thresh = 0.005 + self.nms_thresh = 0.45 + self.nms_topk = 400 + self.nms_posk = 100 + self.draw_thresh = 0.5 + + self.block = DarkNet53() + self.block_outputs = [] + self.yolo_blocks = [] + self.route_blocks = [] + + for idx, num_chan in enumerate([1024, 768, 384]): + yolo_block = self.add_sublayer( + "yolo_detecton_block_{}".format(idx), + YoloDetectionBlock(num_chan, 512 // (2**idx))) + self.yolo_blocks.append(yolo_block) + + num_filters = len(self.anchor_masks[idx]) * (self.num_classes + 5) + + block_out = self.add_sublayer( + "block_out_{}".format(idx), + Conv2D(num_channels=1024 // (2**idx), + num_filters=num_filters, + filter_size=1, + act=None, + param_attr=ParamAttr( + initializer=fluid.initializer.Normal(0., 0.02)), + bias_attr=ParamAttr( + initializer=fluid.initializer.Constant(0.0), + regularizer=L2Decay(0.)))) + self.block_outputs.append(block_out) + if idx < 2: + route = self.add_sublayer( + "route2_{}".format(idx), + ConvBNLayer(ch_in=512 // (2**idx), + ch_out=256 // (2**idx), + filter_size=1, + act='leaky_relu')) + self.route_blocks.append(route) + + def forward(self, img_info, inputs): + outputs = [] + boxes = [] + scores = [] + downsample = 32 + + feats = self.block(inputs) + route = None + for idx, feat in enumerate(feats): + if idx > 0: + feat = fluid.layers.concat(input=[route, feat], axis=1) + route, tip = self.yolo_blocks[idx](feat) + block_out = self.block_outputs[idx](tip) + outputs.append(block_out) + + if idx < 2: + route = self.route_blocks[idx](route) + route = fluid.layers.resize_nearest(route, scale=2) + + if self.model_mode == 'eval': + anchor_mask = self.anchor_masks[idx] + mask_anchors = [] + for m in anchor_mask: + mask_anchors.append(self.anchors[2 * m]) + mask_anchors.append(self.anchors[2 * m + 1]) + img_shape = fluid.layers.slice(img_info, axes=[1], starts=[1], ends=[3]) + img_id = fluid.layers.slice(img_info, axes=[1], starts=[0], ends=[1]) + b, s = fluid.layers.yolo_box( + x=block_out, + img_size=img_shape, + anchors=mask_anchors, + class_num=self.num_classes, + conf_thresh=self.valid_thresh, + downsample_ratio=downsample) + + boxes.append(b) + scores.append(fluid.layers.transpose(s, perm=[0, 2, 1])) + + downsample //= 2 + + if self.model_mode == 'train': + return outputs + + return outputs + [img_id[0, :], fluid.layers.multiclass_nms( + bboxes=fluid.layers.concat(boxes, axis=1), + scores=fluid.layers.concat(scores, axis=2), + score_threshold=self.valid_thresh, + nms_top_k=self.nms_topk, + keep_top_k=self.nms_posk, + nms_threshold=self.nms_thresh, + background_label=-1) +] + + +class YoloLoss(Loss): + def __init__(self, num_classes=80, num_max_boxes=50): + super(YoloLoss, self).__init__() + self.num_classes = num_classes + self.num_max_boxes = num_max_boxes + self.ignore_thresh = 0.7 + self.anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, + 59, 119, 116, 90, 156, 198, 373, 326] + self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] + + def forward(self, outputs, labels): + downsample = 32 + gt_box, gt_label, gt_score = labels + losses = [] + + for idx, out in enumerate(outputs): + if idx == 3: break # debug + anchor_mask = self.anchor_masks[idx] + loss = fluid.layers.yolov3_loss( + x=out, + gt_box=gt_box, + gt_label=gt_label, + gt_score=gt_score, + anchor_mask=anchor_mask, + downsample_ratio=downsample, + anchors=self.anchors, + class_num=self.num_classes, + ignore_thresh=self.ignore_thresh, + use_label_smooth=True) + loss = fluid.layers.reduce_mean(loss) + losses.append(loss) + downsample //= 2 + return losses diff --git a/yolov3/transforms.py b/yolov3/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..a5fbe46cbbfdb39efe3025a351b407b82dbf33c4 --- /dev/null +++ b/yolov3/transforms.py @@ -0,0 +1,620 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import cv2 +import traceback +import numpy as np + +import logging +logger = logging.getLogger(__name__) + +__all__ = ['ColorDistort', 'RandomExpand', 'RandomCrop', 'RandomFlip', + 'NormalizeBox', 'PadBox', 'RandomShape', 'NormalizeImage', + 'BboxXYXY2XYWH', 'ResizeImage', 'Compose', 'BatchCompose'] + + +class Compose(object): + def __init__(self, transforms=[]): + self.transforms = transforms + + def __call__(self, *data): + for f in self.transforms: + try: + data = f(*data) + except Exception as e: + stack_info = traceback.format_exc() + logger.info("fail to perform transform [{}] with error: " + "{} and stack:\n{}".format(f, e, str(stack_info))) + raise e + return data + + +class BatchCompose(object): + def __init__(self, transforms=[]): + self.transforms = transforms + + def __call__(self, data): + for f in self.transforms: + try: + data = f(data) + except Exception as e: + stack_info = traceback.format_exc() + logger.info("fail to perform batch transform [{}] with error: " + "{} and stack:\n{}".format(f, e, str(stack_info))) + raise e + + # sample list to batch data + batch = list(zip(*data)) + + return batch + + +class ColorDistort(object): + """Random color distortion. + + Args: + hue (list): hue settings. + in [lower, upper, probability] format. + saturation (list): saturation settings. + in [lower, upper, probability] format. + contrast (list): contrast settings. + in [lower, upper, probability] format. + brightness (list): brightness settings. + in [lower, upper, probability] format. + random_apply (bool): whether to apply in random (yolo) or fixed (SSD) + order. + """ + + def __init__(self, + hue=[-18, 18, 0.5], + saturation=[0.5, 1.5, 0.5], + contrast=[0.5, 1.5, 0.5], + brightness=[0.5, 1.5, 0.5], + random_apply=True): + self.hue = hue + self.saturation = saturation + self.contrast = contrast + self.brightness = brightness + self.random_apply = random_apply + + def apply_hue(self, img): + low, high, prob = self.hue + if np.random.uniform(0., 1.) < prob: + return img + + img = img.astype(np.float32) + + # XXX works, but result differ from HSV version + delta = np.random.uniform(low, high) + u = np.cos(delta * np.pi) + w = np.sin(delta * np.pi) + bt = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]]) + tyiq = np.array([[0.299, 0.587, 0.114], [0.596, -0.274, -0.321], + [0.211, -0.523, 0.311]]) + ityiq = np.array([[1.0, 0.956, 0.621], [1.0, -0.272, -0.647], + [1.0, -1.107, 1.705]]) + t = np.dot(np.dot(ityiq, bt), tyiq).T + img = np.dot(img, t) + return img + + def apply_saturation(self, img): + low, high, prob = self.saturation + if np.random.uniform(0., 1.) < prob: + return img + delta = np.random.uniform(low, high) + + img = img.astype(np.float32) + gray = img * np.array([[[0.299, 0.587, 0.114]]], dtype=np.float32) + gray = gray.sum(axis=2, keepdims=True) + gray *= (1.0 - delta) + img *= delta + img += gray + return img + + def apply_contrast(self, img): + low, high, prob = self.contrast + if np.random.uniform(0., 1.) < prob: + return img + delta = np.random.uniform(low, high) + + img = img.astype(np.float32) + img *= delta + return img + + def apply_brightness(self, img): + low, high, prob = self.brightness + if np.random.uniform(0., 1.) < prob: + return img + delta = np.random.uniform(low, high) + + img = img.astype(np.float32) + img += delta + return img + + def __call__(self, im_info, im, gt_bbox, gt_class, gt_score): + if self.random_apply: + distortions = np.random.permutation([ + self.apply_brightness, self.apply_contrast, + self.apply_saturation, self.apply_hue + ]) + for func in distortions: + im = func(im) + return [im_info, im, gt_bbox, gt_class, gt_score] + + im = self.apply_brightness(im) + + if np.random.randint(0, 2): + im = self.apply_contrast(im) + im = self.apply_saturation(im) + im = self.apply_hue(im) + else: + im = self.apply_saturation(im) + im = self.apply_hue(im) + im = self.apply_contrast(im) + return [im_info, im, gt_bbox, gt_class, gt_score] + + +class RandomExpand(object): + """Random expand the canvas. + + Args: + ratio (float): maximum expansion ratio. + prob (float): probability to expand. + fill_value (list): color value used to fill the canvas. in RGB order. + """ + + def __init__(self, ratio=4., prob=0.5, fill_value=[123.675, 116.28, 103.53]): + assert ratio > 1.01, "expand ratio must be larger than 1.01" + self.ratio = ratio + self.prob = prob + self.fill_value = fill_value + + def __call__(self, im_info, im, gt_bbox, gt_class, gt_score): + if np.random.uniform(0., 1.) < self.prob: + return [im_info, im, gt_bbox, gt_class, gt_score] + + height, width, _ = im.shape + expand_ratio = np.random.uniform(1., self.ratio) + h = int(height * expand_ratio) + w = int(width * expand_ratio) + if not h > height or not w > width: + return [im_info, im, gt_bbox, gt_class, gt_score] + y = np.random.randint(0, h - height) + x = np.random.randint(0, w - width) + canvas = np.ones((h, w, 3), dtype=np.uint8) + canvas *= np.array(self.fill_value, dtype=np.uint8) + canvas[y:y + height, x:x + width, :] = im.astype(np.uint8) + + gt_bbox += np.array([x, y, x, y], dtype=np.float32) + + return [im_info, canvas, gt_bbox, gt_class, gt_score] + + +class RandomCrop(): + """Random crop image and bboxes. + + Args: + aspect_ratio (list): aspect ratio of cropped region. + in [min, max] format. + thresholds (list): iou thresholds for decide a valid bbox crop. + scaling (list): ratio between a cropped region and the original image. + in [min, max] format. + num_attempts (int): number of tries before giving up. + allow_no_crop (bool): allow return without actually cropping them. + cover_all_box (bool): ensure all bboxes are covered in the final crop. + """ + + def __init__(self, + aspect_ratio=[.5, 2.], + thresholds=[.0, .1, .3, .5, .7, .9], + scaling=[.3, 1.], + num_attempts=50, + allow_no_crop=True, + cover_all_box=False): + self.aspect_ratio = aspect_ratio + self.thresholds = thresholds + self.scaling = scaling + self.num_attempts = num_attempts + self.allow_no_crop = allow_no_crop + self.cover_all_box = cover_all_box + + def __call__(self, im_info, im, gt_bbox, gt_class, gt_score): + if len(gt_bbox) == 0: + return [im_info, im, gt_bbox, gt_class, gt_score] + + # NOTE Original method attempts to generate one candidate for each + # threshold then randomly sample one from the resulting list. + # Here a short circuit approach is taken, i.e., randomly choose a + # threshold and attempt to find a valid crop, and simply return the + # first one found. + # The probability is not exactly the same, kinda resembling the + # "Monty Hall" problem. Actually carrying out the attempts will affect + # observability (just like opening doors in the "Monty Hall" game). + thresholds = list(self.thresholds) + if self.allow_no_crop: + thresholds.append('no_crop') + np.random.shuffle(thresholds) + + for thresh in thresholds: + if thresh == 'no_crop': + return [im_info, im, gt_bbox, gt_class, gt_score] + + h, w, _ = im.shape + found = False + for i in range(self.num_attempts): + scale = np.random.uniform(*self.scaling) + min_ar, max_ar = self.aspect_ratio + aspect_ratio = np.random.uniform( + max(min_ar, scale**2), min(max_ar, scale**-2)) + crop_h = int(h * scale / np.sqrt(aspect_ratio)) + crop_w = int(w * scale * np.sqrt(aspect_ratio)) + crop_y = np.random.randint(0, h - crop_h) + crop_x = np.random.randint(0, w - crop_w) + crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h] + iou = self._iou_matrix( + gt_bbox, np.array( + [crop_box], dtype=np.float32)) + if iou.max() < thresh: + continue + + if self.cover_all_box and iou.min() < thresh: + continue + + cropped_box, valid_ids = self._crop_box_with_center_constraint( + gt_bbox, np.array( + crop_box, dtype=np.float32)) + if valid_ids.size > 0: + found = True + break + + if found: + im = self._crop_image(im, crop_box) + gt_bbox = np.take(cropped_box, valid_ids, axis=0) + gt_class = np.take(gt_class, valid_ids, axis=0) + gt_score = np.take(gt_score, valid_ids, axis=0) + return [im_info, im, gt_bbox, gt_class, gt_score] + + return [im_info, im, gt_bbox, gt_class, gt_score] + + def _iou_matrix(self, a, b): + tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) + area_o = (area_a[:, np.newaxis] + area_b - area_i) + return area_i / (area_o + 1e-10) + + def _crop_box_with_center_constraint(self, box, crop): + cropped_box = box.copy() + + cropped_box[:, :2] = np.maximum(box[:, :2], crop[:2]) + cropped_box[:, 2:] = np.minimum(box[:, 2:], crop[2:]) + cropped_box[:, :2] -= crop[:2] + cropped_box[:, 2:] -= crop[:2] + + centers = (box[:, :2] + box[:, 2:]) / 2 + valid = np.logical_and(crop[:2] <= centers, + centers < crop[2:]).all(axis=1) + valid = np.logical_and( + valid, (cropped_box[:, :2] < cropped_box[:, 2:]).all(axis=1)) + + return cropped_box, np.where(valid)[0] + + def _crop_image(self, img, crop): + x1, y1, x2, y2 = crop + return img[y1:y2, x1:x2, :] + + +class RandomFlip(): + def __init__(self, prob=0.5, is_normalized=False): + """ + Args: + prob (float): the probability of flipping image + is_normalized (bool): whether the bbox scale to [0,1] + """ + self.prob = prob + self.is_normalized = is_normalized + if not (isinstance(self.prob, float) and + isinstance(self.is_normalized, bool)): + raise TypeError("{}: input type is invalid.".format(self)) + + def __call__(self, im_info, im, gt_bbox, gt_class, gt_score): + """Filp the image and bounding box. + Operators: + 1. Flip the image numpy. + 2. Transform the bboxes' x coordinates. + (Must judge whether the coordinates are normalized!) + """ + + if not isinstance(im, np.ndarray): + raise TypeError("{}: image is not a numpy array.".format(self)) + if len(im.shape) != 3: + raise ImageError("{}: image is not 3-dimensional.".format(self)) + height, width, _ = im.shape + if np.random.uniform(0, 1) < self.prob: + im = im[:, ::-1, :] + if gt_bbox.shape[0] > 0: + oldx1 = gt_bbox[:, 0].copy() + oldx2 = gt_bbox[:, 2].copy() + if self.is_normalized: + gt_bbox[:, 0] = 1 - oldx2 + gt_bbox[:, 2] = 1 - oldx1 + else: + gt_bbox[:, 0] = width - oldx2 - 1 + gt_bbox[:, 2] = width - oldx1 - 1 + if gt_bbox.shape[0] != 0 and ( + gt_bbox[:, 2] < gt_bbox[:, 0]).all(): + m = "{}: invalid box, x2 should be greater than x1".format( + self) + raise ValueError(m) + return [im_info, im, gt_bbox, gt_class, gt_score] + + +class NormalizeBox(object): + """Transform the bounding box's coornidates to [0,1].""" + + def __call__(self, im_info, im, gt_bbox, gt_class, gt_score): + height, width, _ = im.shape + for i in range(gt_bbox.shape[0]): + gt_bbox[i][0] = gt_bbox[i][0] / width + gt_bbox[i][1] = gt_bbox[i][1] / height + gt_bbox[i][2] = gt_bbox[i][2] / width + gt_bbox[i][3] = gt_bbox[i][3] / height + return [im_info, im, gt_bbox, gt_class, gt_score] + + +class PadBox(object): + def __init__(self, num_max_boxes=50): + """ + Pad zeros to bboxes if number of bboxes is less than num_max_boxes. + Args: + num_max_boxes (int): the max number of bboxes + """ + self.num_max_boxes = num_max_boxes + + def __call__(self, im_info, im, gt_bbox, gt_class, gt_score): + gt_num = min(self.num_max_boxes, len(gt_bbox)) + num_max = self.num_max_boxes + + pad_bbox = np.zeros((num_max, 4), dtype=np.float32) + if gt_num > 0: + pad_bbox[:gt_num, :] = gt_bbox[:gt_num, :] + gt_bbox = pad_bbox + + pad_class = np.zeros((num_max), dtype=np.int32) + if gt_num > 0: + pad_class[:gt_num] = gt_class[:gt_num, 0] + gt_class = pad_class + + pad_score = np.zeros((num_max), dtype=np.float32) + if gt_num > 0: + pad_score[:gt_num] = gt_score[:gt_num, 0] + gt_score = pad_score + return [im_info, im, gt_bbox, gt_class, gt_score] + + +class BboxXYXY2XYWH(object): + """ + Convert bbox XYXY format to XYWH format. + """ + + def __call__(self, im_info, im, gt_bbox, gt_class, gt_score): + gt_bbox[:, 2:4] = gt_bbox[:, 2:4] - gt_bbox[:, :2] + gt_bbox[:, :2] = gt_bbox[:, :2] + gt_bbox[:, 2:4] / 2. + return [im_info, im, gt_bbox, gt_class, gt_score] + + +class RandomShape(object): + """ + Randomly reshape a batch. If random_inter is True, also randomly + select one an interpolation algorithm [cv2.INTER_NEAREST, cv2.INTER_LINEAR, + cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]. If random_inter is + False, use cv2.INTER_NEAREST. + + Args: + sizes (list): list of int, random choose a size from these + random_inter (bool): whether to randomly interpolation, defalut true. + """ + + def __init__(self, + sizes=[320, 352, 384, 416, 448, 480, 512, 544, 576, 608], + random_inter=True): + self.sizes = sizes + self.random_inter = random_inter + self.interps = [ + cv2.INTER_NEAREST, + cv2.INTER_LINEAR, + cv2.INTER_AREA, + cv2.INTER_CUBIC, + cv2.INTER_LANCZOS4, + ] if random_inter else [] + + def __call__(self, samples): + shape = np.random.choice(self.sizes) + method = np.random.choice(self.interps) if self.random_inter \ + else cv2.INTER_NEAREST + for i in range(len(samples)): + im = samples[i][1] + h, w = im.shape[:2] + scale_x = float(shape) / w + scale_y = float(shape) / h + im = cv2.resize( + im, None, None, fx=scale_x, fy=scale_y, interpolation=method) + samples[i][1] = im + return samples + + +class NormalizeImage(object): + def __init__(self, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225], + scale=True, + channel_first=True): + """ + Args: + mean (list): the pixel mean + std (list): the pixel variance + scale (bool): whether scale image to [0, 1] + channel_first (bool): whehter change [h, w, c] to [c, h, w] + """ + self.mean = mean + self.std = std + self.scale = scale + self.channel_first = channel_first + if not (isinstance(self.mean, list) and isinstance(self.std, list) and + isinstance(self.scale, bool)): + raise TypeError("{}: input type is invalid.".format(self)) + from functools import reduce + if reduce(lambda x, y: x * y, self.std) == 0: + raise ValueError('{}: std is invalid!'.format(self)) + + def __call__(self, samples): + """Normalize the image. + Operators: + 1. (optional) Scale the image to [0,1] + 2. Each pixel minus mean and is divided by std + 3. (optional) permute channel + """ + for i in range(len(samples)): + im = samples[i][1] + im = im.astype(np.float32, copy=False) + mean = np.array(self.mean)[np.newaxis, np.newaxis, :] + std = np.array(self.std)[np.newaxis, np.newaxis, :] + if self.scale: + im = im / 255.0 + im -= mean + im /= std + if self.channel_first: + im = im.transpose((2, 0, 1)) + samples[i][1] = im + return samples + + +def _iou_matrix(a, b): + tl_i = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + br_i = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + area_i = np.prod(br_i - tl_i, axis=2) * (tl_i < br_i).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) + area_o = (area_a[:, np.newaxis] + area_b - area_i) + return area_i / (area_o + 1e-10) + + +def _crop_box_with_center_constraint(box, crop): + cropped_box = box.copy() + cropped_box[:, :2] = np.maximum(box[:, :2], crop[:2]) + cropped_box[:, 2:] = np.minimum(box[:, 2:], crop[2:]) + cropped_box[:, :2] -= crop[:2] + cropped_box[:, 2:] -= crop[:2] + centers = (box[:, :2] + box[:, 2:]) / 2 + valid = np.logical_and( + crop[:2] <= centers, centers < crop[2:]).all(axis=1) + valid = np.logical_and( + valid, (cropped_box[:, :2] < cropped_box[:, 2:]).all(axis=1)) + return cropped_box, np.where(valid)[0] + + +def random_crop(inputs): + aspect_ratios = [.5, 2.] + thresholds = [.0, .1, .3, .5, .7, .9] + scaling = [.3, 1.] + + img, img_ids, gt_box, gt_label = inputs + h, w = img.shape[:2] + + if len(gt_box) == 0: + return inputs + + np.random.shuffle(thresholds) + for thresh in thresholds: + found = False + for i in range(50): + scale = np.random.uniform(*scaling) + min_ar, max_ar = aspect_ratios + ar = np.random.uniform(max(min_ar, scale**2), + min(max_ar, scale**-2)) + crop_h = int(h * scale / np.sqrt(ar)) + crop_w = int(w * scale * np.sqrt(ar)) + crop_y = np.random.randint(0, h - crop_h) + crop_x = np.random.randint(0, w - crop_w) + crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h] + iou = _iou_matrix(gt_box, np.array([crop_box], dtype=np.float32)) + if iou.max() < thresh: + continue + + cropped_box, valid_ids = _crop_box_with_center_constraint( + gt_box, np.array(crop_box, dtype=np.float32)) + if valid_ids.size > 0: + found = True + break + + if found: + x1, y1, x2, y2 = crop_box + img = img[y1:y2, x1:x2, :] + gt_box = np.take(cropped_box, valid_ids, axis=0) + gt_label = np.take(gt_label, valid_ids, axis=0) + return img, img_ids, gt_box, gt_label + + return inputs + + +class ResizeImage(object): + def __init__(self, + target_size=0, + interp=cv2.INTER_CUBIC): + """ + Rescale image to the specified target size. + If target_size is list, selected a scale randomly as the specified + target size. + + Args: + target_size (int|list): the target size of image's short side, + multi-scale training is adopted when type is list. + interp (int): the interpolation method + """ + self.interp = int(interp) + if not (isinstance(target_size, int) or isinstance(target_size, list)): + raise TypeError( + "Type of target_size is invalid. Must be Integer or List, now is {}". + format(type(target_size))) + self.target_size = target_size + + def __call__(self, im_info, im, gt_bbox, gt_class, gt_score): + """ Resize the image numpy. + """ + if not isinstance(im, np.ndarray): + raise TypeError("{}: image type is not numpy.".format(self)) + if len(im.shape) != 3: + raise ImageError('{}: image is not 3-dimensional.'.format(self)) + im_shape = im.shape + im_scale_x = float(self.target_size) / float(im_shape[1]) + im_scale_y = float(self.target_size) / float(im_shape[0]) + resize_w = self.target_size + resize_h = self.target_size + + im = cv2.resize( + im, + None, + None, + fx=im_scale_x, + fy=im_scale_y, + interpolation=self.interp) + + return [im_info, im, gt_bbox, gt_class, gt_score] +