未验证 提交 19eb7f47 编写于 作者: W wangxinxin08 提交者: GitHub

[Dygraph] train and eval yolov3 successfully (#1753)

* split yolov3 loss using paddle op

* add sync_bn and modify list to LayerList to use sync_bn, add missing op, support cutmix op in dataset, modify reader to use new ops

* modify code according to review

* modify code to run eval.py successfully

* modify code to run eval and train successfully

* modify code to use mixup

* rebase code on lastest dygraph

* modify code to run in sync_bn mode

* modify code according to review

* modify target size of ResizeOp
上级 ee911f6b
......@@ -6,6 +6,7 @@ TrainDataset:
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: dataset/coco
mixup_epoch: 250
EvalDataset:
!COCODataSet
......
......@@ -13,18 +13,21 @@ YOLOv3:
DarkNet:
depth: 53
return_idx: [2, 3, 4]
norm_type: sync_bn
YOLOv3FPN:
feat_channels: [1024, 768, 384]
YOLOv3Head:
anchors: [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326]
anchors: [[10, 13], [16, 30], [33, 23],
[30, 61], [62, 45], [59, 119],
[116, 90], [156, 198], [373, 326]]
anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
loss: YOLOv3Loss
YOLOv3Loss:
ignore_thresh: 0.7
downsample: 32
downsample: [32, 16, 8]
label_smooth: true
BBoxPostProcess:
......
worker_num: 2
TrainReader:
inputs_def:
fields: ['image', 'gt_bbox', 'gt_class', 'gt_score']
fields: ['image', 'gt_bbox', 'gt_class', 'gt_score', 'im_shape', 'scale_factor']
num_max_boxes: 50
sample_transforms:
- DecodeImage: {to_rgb: True, with_mixup: True}
- MixupImage: {alpha: 1.5, beta: 1.5}
- ColorDistort: {}
- RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
- RandomCrop: {}
- RandomFlipImage: {is_normalized: false}
- NormalizeBox: {}
- PadBox: {num_max_boxes: 50}
- BboxXYXY2XYWH: {}
- DecodeOp: {}
- MixupOp: {alpha: 1.5, beta: 1.5}
- RandomDistortOp: {}
- RandomExpandOp: {fill_value: [123.675, 116.28, 103.53]}
- RandomCropOp: {}
- RandomFlipOp: {}
batch_transforms:
- RandomShape: {sizes: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608], random_inter: True}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True, is_channel_first: false}
- Permute: {to_bgr: false, channel_first: True}
# Gt2YoloTarget is only used when use_fine_grained_loss set as true,
# this operator will be deleted automatically if use_fine_grained_loss
# is set as false
- Gt2YoloTarget: {anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]], anchors: [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326]], downsample_ratios: [32, 16, 8]}
- BatchRandomResizeOp: {target_size: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608], random_size: True, random_interp: True, keep_ratio: False}
- NormalizeBoxOp: {}
- PadBoxOp: {num_max_boxes: 50}
- BboxXYXY2XYWHOp: {}
- NormalizeImageOp: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- PermuteOp: {}
- Gt2YoloTargetOp: {anchor_masks: [[6, 7, 8], [3, 4, 5], [0, 1, 2]], anchors: [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45], [59, 119], [116, 90], [156, 198], [373, 326]], downsample_ratios: [32, 16, 8]}
batch_size: 8
shuffle: true
drop_last: true
......@@ -28,24 +25,24 @@ TrainReader:
EvalReader:
inputs_def:
fields: ['image', 'im_size', 'im_id']
fields: ['image', 'im_shape', 'scale_factor', 'im_id']
num_max_boxes: 50
sample_transforms:
- DecodeImage: {to_rgb: True}
- ResizeImage: {target_size: 608, interp: 2}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True, is_channel_first: false}
- PadBox: {num_max_boxes: 50}
- Permute: {to_bgr: false, channel_first: True}
- DecodeOp: {}
- ResizeOp: {target_size: [608, 608], keep_ratio: False, interp: 2}
- NormalizeImageOp: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- PadBoxOp: {num_max_boxes: 50}
- PermuteOp: {}
batch_size: 1
drop_empty: false
TestReader:
inputs_def:
image_shape: [3, 608, 608]
fields: ['image', 'im_size', 'im_id']
fields: ['image', 'im_shape', 'scale_factor', 'im_id']
sample_transforms:
- DecodeImage: {to_rgb: True}
- ResizeImage: {target_size: 608, interp: 2}
- NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True, is_channel_first: false}
- Permute: {to_bgr: false, channel_first: True}
- DecodeOp: {}
- ResizeOp: {target_size: [608, 608], interp: 2}
- NormalizeImageOp: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
- PermuteOp: {}
batch_size: 1
......@@ -11,17 +11,18 @@ import numpy as np
from paddle.io import DataLoader
from ppdet.core.workspace import register, serializable, create
from .sampler import DistributedBatchSampler
from .transform import operators
from .transform import batch_operators
from . import transform
from .transform import operator, batch_operator
logger = logging.getLogger(__name__)
class Compose(object):
def __init__(self, transforms, fields=None, from_=operators,
def __init__(self, transforms, fields=None, from_=transform,
num_classes=81):
self.transforms = transforms
self.transforms_cls = []
output_fields = None
for t in self.transforms:
for k, v in t.items():
op_cls = getattr(from_, k)
......@@ -29,7 +30,17 @@ class Compose(object):
if hasattr(op_cls, 'num_classes'):
op_cls.num_classes = num_classes
if op_cls in [
transform.Gt2YoloTargetOp, transform.Gt2YoloTarget
]:
output_fields = ['image', 'gt_bbox']
output_fields.extend([
'target{}'.format(i)
for i in range(len(v['anchor_masks']))
])
self.fields = fields
self.output_fields = output_fields if output_fields else fields
def __call__(self, data):
if self.fields is not None:
......@@ -47,11 +58,11 @@ class Compose(object):
format(f, e, str(stack_info)))
raise e
if self.fields is not None:
if self.output_fields is not None:
data_new = []
for item in data:
batch = []
for k in self.fields:
for k in self.output_fields:
batch.append(item[k])
data_new.append(batch)
batch_size = len(data_new)
......@@ -80,8 +91,7 @@ class BaseDataLoader(object):
num_classes=81,
with_background=True):
# out fields
self._fields = copy.deepcopy(inputs_def[
'fields']) if inputs_def else None
self._fields = inputs_def['fields'] if inputs_def else None
# sample transform
self._sample_transforms = Compose(
sample_transforms, num_classes=num_classes)
......@@ -89,8 +99,9 @@ class BaseDataLoader(object):
# batch transfrom
self._batch_transforms = None
if batch_transforms:
self._batch_transforms = Compose(batch_transforms, self._fields,
batch_operators, num_classes)
self._batch_transforms = Compose(batch_transforms,
copy.deepcopy(self._fields),
transform, num_classes)
self.batch_size = batch_size
self.shuffle = shuffle
......@@ -100,19 +111,24 @@ class BaseDataLoader(object):
def __call__(self,
dataset,
worker_num,
device,
device=None,
batch_sampler=None,
return_list=False,
use_prefetch=True):
self._dataset = dataset
self._dataset.parse_dataset(self.with_background)
# get data
self._dataset.set_out(self._sample_transforms, self._fields)
self._dataset.set_out(self._sample_transforms,
copy.deepcopy(self._fields))
# batch sampler
self._batch_sampler = DistributedBatchSampler(
self._dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
drop_last=self.drop_last)
if batch_sampler is None:
self._batch_sampler = DistributedBatchSampler(
self._dataset,
batch_size=self.batch_size,
shuffle=self.shuffle,
drop_last=self.drop_last)
else:
self._batch_sampler = batch_sampler
loader = DataLoader(
dataset=self._dataset,
......@@ -152,7 +168,7 @@ class EvalReader(BaseDataLoader):
batch_transforms=None,
batch_size=1,
shuffle=False,
drop_last=False,
drop_last=True,
drop_empty=True,
num_classes=81,
with_background=True):
......
......@@ -28,9 +28,18 @@ class COCODataSet(DetDataset):
dataset_dir=None,
image_dir=None,
anno_path=None,
mixup_epoch=-1,
cutmix_epoch=-1,
mosaic_epoch=-1,
sample_num=-1):
super(COCODataSet, self).__init__(dataset_dir, image_dir, anno_path,
sample_num)
super(COCODataSet, self).__init__(
dataset_dir,
image_dir,
anno_path,
sample_num,
mixup_epoch=mixup_epoch,
cutmix_epoch=cutmix_epoch,
mosaic_epoch=mosaic_epoch)
self.load_image_only = False
self.load_semantic = False
......
......@@ -33,6 +33,9 @@ class DetDataset(Dataset):
anno_path=None,
sample_num=-1,
use_default_label=None,
mixup_epoch=-1,
cutmix_epoch=-1,
mosaic_epoch=-1,
**kwargs):
super(DetDataset, self).__init__()
self.dataset_dir = dataset_dir if dataset_dir is not None else ''
......@@ -40,6 +43,10 @@ class DetDataset(Dataset):
self.image_dir = image_dir if image_dir is not None else ''
self.sample_num = sample_num
self.use_default_label = use_default_label
self.epoch = 0
self.mixup_epoch = mixup_epoch
self.cutmix_epoch = cutmix_epoch
self.mosaic_epoch = mosaic_epoch
def __len__(self, ):
return len(self.roidbs)
......@@ -47,6 +54,21 @@ class DetDataset(Dataset):
def __getitem__(self, idx):
# data batch
roidb = copy.deepcopy(self.roidbs[idx])
if self.mixup_epoch == 0 or self.epoch < self.mixup_epoch:
n = len(self.roidbs)
idx = np.random.randint(n)
roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
elif self.cutmix_epoch == 0 or self.epoch < self.cutmix_epoch:
n = len(self.roidbs)
idx = np.random.randint(n)
roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
elif self.mosaic_epoch == 0 or self.epoch < self.mosaic_epoch:
n = len(self.roidbs)
roidb = [roidb, ] + [
copy.deepcopy(self.roidbs[np.random.randint(n)])
for _ in range(3)
]
# data augment
roidb = self.transform(roidb)
# data item
......
......@@ -14,6 +14,8 @@
from . import operators
from . import batch_operators
from . import operator
from . import batch_operator
# TODO: operators and batch_operators will be replaced by operator and batch_operator
from .operators import *
......
......@@ -156,7 +156,7 @@ class BatchRandomResizeOp(BaseOperator):
def __init__(self,
target_size,
keep_ratio=True,
interp=cv2.INTER_LINEAR,
interp=cv2.INTER_NEAREST,
random_size=True,
random_interp=False):
super(BatchRandomResizeOp, self).__init__()
......
......@@ -25,7 +25,7 @@ try:
except Exception:
from collections import Sequence
from numbers import Number
from numbers import Number, Integral
import uuid
import logging
......@@ -33,6 +33,7 @@ import random
import math
import numpy as np
import os
import copy
import cv2
from PIL import Image, ImageEnhance, ImageDraw
......@@ -95,8 +96,8 @@ class BaseOperator(object):
if isinstance(sample, Sequence):
for i in range(len(sample)):
sample[i] = self.apply(sample[i], context)
sample = self.apply(sample, context)
else:
sample = self.apply(sample, context)
return sample
def __str__(self):
......@@ -140,8 +141,8 @@ class DecodeOp(BaseOperator):
"image width.".format(im.shape[1], sample['w']))
sample['w'] = im.shape[1]
sample['im_shape'] = im.shape[:2]
sample['scale_factor'] = [1., 1.]
sample['im_shape'] = np.array(im.shape[:2], dtype=np.int32)
sample['scale_factor'] = np.array([1., 1.], dtype=np.float32)
return sample
......@@ -182,6 +183,56 @@ class LightingOp(BaseOperator):
return sample
@register_op
class RandomErasingImageOp(BaseOperator):
def __init__(self, prob=0.5, lower=0.02, higher=0.4, aspect_ratio=0.3):
"""
Random Erasing Data Augmentation, see https://arxiv.org/abs/1708.04896
Args:
prob (float): probability to carry out random erasing
lower (float): lower limit of the erasing area ratio
heigher (float): upper limit of the erasing area ratio
aspect_ratio (float): aspect ratio of the erasing region
"""
super(RandomErasingImageOp, self).__init__()
self.prob = prob
self.lower = lower
self.heigher = heigher
self.aspect_ratio = aspect_ratio
def apply(self, sample):
gt_bbox = sample['gt_bbox']
im = sample['image']
if not isinstance(im, np.ndarray):
raise TypeError("{}: image is not a numpy array.".format(self))
if len(im.shape) != 3:
raise ImageError("{}: image is not 3-dimensional.".format(self))
for idx in range(gt_bbox.shape[0]):
if self.prob <= np.random.rand():
continue
x1, y1, x2, y2 = gt_bbox[idx, :]
w_bbox = x2 - x1 + 1
h_bbox = y2 - y1 + 1
area = w_bbox * h_bbox
target_area = random.uniform(self.lower, self.higher) * area
aspect_ratio = random.uniform(self.aspect_ratio,
1 / self.aspect_ratio)
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < w_bbox and h < h_bbox:
off_y1 = random.randint(0, int(h_bbox - h))
off_x1 = random.randint(0, int(w_bbox - w))
im[int(y1 + off_y1):int(y1 + off_y1 + h), int(x1 + off_x1):int(
x1 + off_x1 + w), :] = 0
sample['image'] = im
return sample
@register_op
class NormalizeImageOp(BaseOperator):
def __init__(self, mean=[0.485, 0.456, 0.406], std=[1, 1, 1],
......@@ -350,7 +401,7 @@ class RandomDistortOp(BaseOperator):
lambda img: cv2.cvtColor(self.apply_saturation(cv2.cvtColor(img, cv2.COLOR_RGB2HSV)), cv2.COLOR_HSV2RGB),
lambda img: cv2.cvtColor(self.apply_hue(cv2.cvtColor(img, cv2.COLOR_RGB2HSV)), cv2.COLOR_HSV2RGB),
]
distortions = np.random.permutation(functions)[:count]
distortions = np.random.permutation(functions)[:self.count]
for func in distortions:
img = func(img)
sample['image'] = img
......@@ -527,11 +578,11 @@ class ResizeOp(BaseOperator):
super(ResizeOp, self).__init__()
self.keep_ratio = keep_ratio
self.interp = interp
if not isinstance(target_size, (int, list, tuple)):
if not isinstance(target_size, (Integral, Sequence)):
raise TypeError(
"Type of target_size is invalid. Must be Integer or List or Tuple, now is {}".
format(type(target_size)))
if isinstance(target_size, int):
if isinstance(target_size, Integral):
target_size = [target_size, target_size]
self.target_size = target_size
......@@ -627,11 +678,15 @@ class ResizeOp(BaseOperator):
im = self.apply_image(sample['image'], [im_scale_x, im_scale_y])
sample['image'] = im
sample['im_shape'] = [resize_h, resize_w]
scale_factor = sample['scale_factor']
sample['scale_factor'] = [
scale_factor[0] * im_scale_y, scale_factor[1] * im_scale_x
]
sample['im_shape'] = np.array([resize_h, resize_w], dtype=np.int32)
if 'scale_factor' in sample:
scale_factor = sample['scale_factor']
sample['scale_factor'] = np.array(
[scale_factor[0] * im_scale_y, scale_factor[1] * im_scale_x],
dtype=np.float32)
else:
sample['scale_factor'] = np.array(
[im_scale_y, im_scale_x], dtype=np.float32)
# apply bbox
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
......@@ -641,7 +696,7 @@ class ResizeOp(BaseOperator):
# apply polygon
if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_shape,
sample['gt_poly'] = self.apply_segm(sample['gt_poly'], im_shape[:2],
[im_scale_x, im_scale_y])
# apply semantic
......@@ -694,16 +749,16 @@ class MultiscaleTestResizeOp(BaseOperator):
self.interp = interp
self.use_flip = use_flip
if not isinstance(target_size, list):
if not isinstance(target_size, Sequence):
raise TypeError(
"Type of target_size is invalid. Must be List, now is {}".
"Type of target_size is invalid. Must be List or Tuple, now is {}".
format(type(target_size)))
self.target_size = target_size
if not isinstance(origin_target_size, list):
if not isinstance(origin_target_size, Sequence):
raise TypeError(
"Type of target_size is invalid. Must be List, now is {}".
format(type(target_size)))
"Type of origin_target_size is invalid. Must be List or Tuple, now is {}".
format(type(origin_target_size)))
self.origin_target_size = origin_target_size
......@@ -753,10 +808,10 @@ class RandomResizeOp(BaseOperator):
cv2.INTER_LANCZOS4,
]
assert isinstance(target_size, (
int, Sequence)), "target_size must be int, list or tuple"
if random_size and not isinstance(target_size, list):
Integral, Sequence)), "target_size must be Integer, List or Tuple"
if random_size and not isinstance(target_size, Sequence):
raise TypeError(
"Type of target_size is invalid when random_size is True. Must be List, now is {}".
"Type of target_size is invalid when random_size is True. Must be List or Tuple, now is {}".
format(type(target_size)))
self.target_size = target_size
self.random_size = random_size
......@@ -816,7 +871,10 @@ class RandomExpandOp(BaseOperator):
x = np.random.randint(0, w - width)
offsets, size = [x, y], [h, w]
pad = Pad(size, pad_mode=-1, offsets=offsets)
pad = Pad(size,
pad_mode=-1,
offsets=offsets,
fill_value=self.fill_value)
return pad(sample, context=context)
......@@ -1350,11 +1408,11 @@ class RandomScaledCropOp(BaseOperator):
canvas[:min(dim, resize_h), :min(dim, resize_w), :] = img[
offset_y:offset_y + dim, offset_x:offset_x + dim, :]
sample['image'] = canvas
sample['im_shape'] = [resize_h, resize_w]
sample['im_shape'] = np.array([resize_h, resize_w], dtype=np.int32)
scale_factor = sample['sacle_factor']
sample['scale_factor'] = [
scale_factor[0] * scale, scale_factor[1] * scale
]
sample['scale_factor'] = np.array(
[scale_factor[0] * scale, scale_factor[1] * scale],
dtype=np.float32)
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
scale_array = np.array([scale, scale] * 2, dtype=np.float32)
......@@ -1487,30 +1545,32 @@ class MixupOp(BaseOperator):
if factor <= 0.0:
return sample[1]
im = self.apply_image(sample[0]['image'], sample[1]['image'], factor)
result = copy.deepcopy(sample[0])
result['image'] = im
# apply bbox and score
gt_bbox1 = sample[0]['gt_bbox']
gt_bbox2 = sample[1]['gt_bbox']
gt_bbox = np.concatenate((gt_bbox1, gt_bbox2), axis=0)
gt_class1 = sample[0]['gt_class']
gt_class2 = sample[1]['gt_class']
gt_class = np.concatenate((gt_class1, gt_class2), axis=0)
gt_score1 = sample[0]['gt_score']
gt_score2 = sample[1]['gt_score']
gt_score = np.concatenate(
(gt_score1 * factor, gt_score2 * (1. - factor)), axis=0)
is_crowd1 = sample[0]['is_crowd']
is_crowd2 = sample[1]['is_crowd']
is_crowd = np.concatenate((is_crowd1, is_crowd2), axis=0)
sample = sample[0]
sample['image'] = im
sample['gt_bbox'] = gt_bbox
sample['gt_score'] = gt_score
sample['gt_class'] = gt_class
sample['is_crowd'] = is_crowd
return sample
if 'gt_bbox' in sample[0]:
gt_bbox1 = sample[0]['gt_bbox']
gt_bbox2 = sample[1]['gt_bbox']
gt_bbox = np.concatenate((gt_bbox1, gt_bbox2), axis=0)
result['gt_bbox'] = gt_bbox
if 'gt_class' in sample[0]:
gt_class1 = sample[0]['gt_class']
gt_class2 = sample[1]['gt_class']
gt_class = np.concatenate((gt_class1, gt_class2), axis=0)
result['gt_class'] = gt_class
if 'gt_score' in sample[0]:
gt_score1 = sample[0]['gt_score']
gt_score2 = sample[1]['gt_score']
gt_score = np.concatenate(
(gt_score1 * factor, gt_score2 * (1. - factor)), axis=0)
result['gt_score'] = gt_score
if 'is_crowd' in sample[0]:
is_crowd1 = sample[0]['is_crowd']
is_crowd2 = sample[1]['is_crowd']
is_crowd = np.concatenate((is_crowd1, is_crowd2), axis=0)
result['is_crowd'] = is_crowd
return result
@register_op
......@@ -1578,26 +1638,26 @@ class PadBoxOp(BaseOperator):
bbox = sample['gt_bbox']
gt_num = min(self.num_max_boxes, len(bbox))
num_max = self.num_max_boxes
fields = context['fields'] if context else []
# fields = context['fields'] if context else []
pad_bbox = np.zeros((num_max, 4), dtype=np.float32)
if gt_num > 0:
pad_bbox[:gt_num, :] = bbox[:gt_num, :]
sample['gt_bbox'] = pad_bbox
if 'gt_class' in fields:
pad_class = np.zeros((num_max), dtype=np.int32)
if 'gt_class' in sample:
pad_class = np.zeros((num_max, ), dtype=np.int32)
if gt_num > 0:
pad_class[:gt_num] = sample['gt_class'][:gt_num, 0]
sample['gt_class'] = pad_class
if 'gt_score' in fields:
pad_score = np.zeros((num_max), dtype=np.float32)
if 'gt_score' in sample:
pad_score = np.zeros((num_max, ), dtype=np.float32)
if gt_num > 0:
pad_score[:gt_num] = sample['gt_score'][:gt_num, 0]
sample['gt_score'] = pad_score
# in training, for example in op ExpandImage,
# the bbox and gt_class is expandded, but the difficult is not,
# so, judging by it's length
if 'is_difficult' in fields:
pad_diff = np.zeros((num_max), dtype=np.int32)
if 'is_difficult' in sample:
pad_diff = np.zeros((num_max, ), dtype=np.int32)
if gt_num > 0:
pad_diff[:gt_num] = sample['difficult'][:gt_num, 0]
sample['difficult'] = pad_diff
......@@ -1751,9 +1811,9 @@ class Pad(BaseOperator):
x, y = offsets
im_h, im_w = im_size
h, w = size
canvas = np.ones((h, w, 3), dtype=np.uint8)
canvas *= np.array(self.fill_value, dtype=np.uint8)
canvas[y:y + im_h, x:x + im_w, :] = image.astype(np.uint8)
canvas = np.ones((h, w, 3), dtype=np.float32)
canvas *= np.array(self.fill_value, dtype=np.float32)
canvas[y:y + im_h, x:x + im_w, :] = image.astype(np.float32)
return canvas
def apply(self, sample, context=None):
......
......@@ -8,6 +8,7 @@ from . import loss
from . import architecture
from . import post_process
from . import layers
from . import utils
from .ops import *
from .bbox import *
......@@ -19,3 +20,4 @@ from .loss import *
from .architecture import *
from .post_process import *
from .layers import *
from .utils import *
......@@ -40,13 +40,13 @@ class YOLOv3(BaseArch):
self.yolo_head_outs = self.yolo_head(body_feats)
def get_loss(self, ):
loss = self.yolo_head.get_loss(self.inputs, self.yolo_head_outs)
loss = self.yolo_head.get_loss(self.yolo_head_outs, self.inputs)
return loss
def get_pred(self, ):
bbox, bbox_num = self.post_process(self.yolo_head_outs,
self.yolo_head.mask_anchors,
self.inputs['im_size'])
bbox, bbox_num = self.post_process(
self.yolo_head_outs, self.yolo_head.mask_anchors,
self.inputs['im_shape'], self.inputs['scale_factor'])
outs = {
"bbox": bbox.numpy(),
"bbox_num": bbox_num.numpy(),
......
......@@ -2,8 +2,9 @@ import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.fluid.regularizer import L2Decay
from paddle.regularizer import L2Decay
from ppdet.core.workspace import register, serializable
from ppdet.modeling.ops import BatchNorm
__all__ = ['DarkNet', 'ConvBNLayer']
......@@ -16,6 +17,7 @@ class ConvBNLayer(nn.Layer):
stride=1,
groups=1,
padding=0,
norm_type='bn',
act="leaky",
name=None):
super(ConvBNLayer, self).__init__()
......@@ -29,14 +31,7 @@ class ConvBNLayer(nn.Layer):
groups=groups,
weight_attr=ParamAttr(name=name + '.conv.weights'),
bias_attr=False)
bn_name = name + '.bn'
self.batch_norm = nn.BatchNorm2D(
ch_out,
weight_attr=ParamAttr(
name=bn_name + '.scale', regularizer=L2Decay(0.)),
bias_attr=ParamAttr(
name=bn_name + '.offset', regularizer=L2Decay(0.)))
self.batch_norm = BatchNorm(ch_out, norm_type=norm_type, name=name)
self.act = act
def forward(self, inputs):
......@@ -54,6 +49,7 @@ class DownSample(nn.Layer):
filter_size=3,
stride=2,
padding=1,
norm_type='bn',
name=None):
super(DownSample, self).__init__()
......@@ -64,6 +60,7 @@ class DownSample(nn.Layer):
filter_size=filter_size,
stride=stride,
padding=padding,
norm_type=norm_type,
name=name)
self.ch_out = ch_out
......@@ -73,7 +70,7 @@ class DownSample(nn.Layer):
class BasicBlock(nn.Layer):
def __init__(self, ch_in, ch_out, name=None):
def __init__(self, ch_in, ch_out, norm_type='bn', name=None):
super(BasicBlock, self).__init__()
self.conv1 = ConvBNLayer(
......@@ -82,6 +79,7 @@ class BasicBlock(nn.Layer):
filter_size=1,
stride=1,
padding=0,
norm_type=norm_type,
name=name + '.0')
self.conv2 = ConvBNLayer(
ch_in=ch_out,
......@@ -89,6 +87,7 @@ class BasicBlock(nn.Layer):
filter_size=3,
stride=1,
padding=1,
norm_type=norm_type,
name=name + '.1')
def forward(self, inputs):
......@@ -99,16 +98,18 @@ class BasicBlock(nn.Layer):
class Blocks(nn.Layer):
def __init__(self, ch_in, ch_out, count, name=None):
def __init__(self, ch_in, ch_out, count, norm_type='bn', name=None):
super(Blocks, self).__init__()
self.basicblock0 = BasicBlock(ch_in, ch_out, name=name + '.0')
self.basicblock0 = BasicBlock(
ch_in, ch_out, norm_type=norm_type, name=name + '.0')
self.res_out_list = []
for i in range(1, count):
block_name = '{}.{}'.format(name, i)
res_out = self.add_sublayer(
block_name, BasicBlock(
ch_out * 2, ch_out, name=block_name))
block_name,
BasicBlock(
ch_out * 2, ch_out, norm_type=norm_type, name=block_name))
self.res_out_list.append(res_out)
self.ch_out = ch_out
......@@ -125,11 +126,14 @@ DarkNet_cfg = {53: ([1, 2, 8, 8, 4])}
@register
@serializable
class DarkNet(nn.Layer):
__shared__ = ['norm_type']
def __init__(self,
depth=53,
freeze_at=-1,
return_idx=[2, 3, 4],
num_stages=5):
num_stages=5,
norm_type='bn'):
super(DarkNet, self).__init__()
self.depth = depth
self.freeze_at = freeze_at
......@@ -143,10 +147,14 @@ class DarkNet(nn.Layer):
filter_size=3,
stride=1,
padding=1,
norm_type=norm_type,
name='yolo_input')
self.downsample0 = DownSample(
ch_in=32, ch_out=32 * 2, name='yolo_input.downsample')
ch_in=32,
ch_out=32 * 2,
norm_type=norm_type,
name='yolo_input.downsample')
self.darknet_conv_block_list = []
self.downsample_list = []
......@@ -154,8 +162,13 @@ class DarkNet(nn.Layer):
for i, stage in enumerate(self.stages):
name = 'stage.{}'.format(i)
conv_block = self.add_sublayer(
name, Blocks(
int(ch_in[i]), 32 * (2**i), stage, name=name))
name,
Blocks(
int(ch_in[i]),
32 * (2**i),
stage,
norm_type=norm_type,
name=name))
self.darknet_conv_block_list.append(conv_block)
for i in range(num_stages - 1):
down_name = 'stage.{}.downsample'.format(i)
......@@ -164,6 +177,7 @@ class DarkNet(nn.Layer):
DownSample(
ch_in=32 * (2**(i + 1)),
ch_out=32 * (2**(i + 2)),
norm_type=norm_type,
name=down_name))
self.downsample_list.append(downsample)
......
import paddle.fluid as fluid
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.fluid.regularizer import L2Decay
from paddle.regularizer import L2Decay
from ppdet.core.workspace import register
from ..backbone.darknet import ConvBNLayer
......@@ -14,24 +13,20 @@ class YOLOv3Head(nn.Layer):
__inject__ = ['loss']
def __init__(self,
anchors=[
10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90,
156, 198, 373, 326
],
anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
[59, 119], [116, 90], [156, 198], [373, 326]],
anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
num_classes=80,
loss='YOLOv3Loss'):
super(YOLOv3Head, self).__init__()
self.anchors = anchors
self.anchor_masks = anchor_masks
self.num_classes = num_classes
self.loss = loss
self.mask_anchors = self.parse_anchor(self.anchors, self.anchor_masks)
self.num_outputs = len(self.mask_anchors)
self.parse_anchor(anchors, anchor_masks)
self.num_outputs = len(self.anchors)
self.yolo_outputs = []
for i in range(len(self.mask_anchors)):
for i in range(len(self.anchors)):
num_filters = self.num_outputs * (self.num_classes + 5)
name = 'yolo_output.{}'.format(i)
yolo_output = self.add_sublayer(
......@@ -48,24 +43,22 @@ class YOLOv3Head(nn.Layer):
self.yolo_outputs.append(yolo_output)
def parse_anchor(self, anchors, anchor_masks):
anchor_num = len(self.anchors)
mask_anchors = []
for i in range(len(self.anchor_masks)):
mask_anchor = []
for m in self.anchor_masks[i]:
assert m < anchor_num, "anchor mask index overflow"
mask_anchor.extend(self.anchors[2 * m:2 * m + 2])
mask_anchors.append(mask_anchor)
return mask_anchors
self.anchors = [[anchors[i] for i in mask] for mask in anchor_masks]
self.mask_anchors = []
anchor_num = len(anchors)
for masks in anchor_masks:
self.mask_anchors.append([])
for mask in masks:
assert mask < anchor_num, "anchor mask index overflow"
self.mask_anchors[-1].extend(anchors[mask])
def forward(self, feats):
assert len(feats) == len(self.mask_anchors)
assert len(feats) == len(self.anchors)
yolo_outputs = []
for i, feat in enumerate(feats):
yolo_output = self.yolo_outputs[i](feat)
yolo_outputs.append(yolo_output)
return yolo_outputs
def get_loss(self, inputs, head_outputs):
return self.loss(inputs, head_outputs, self.anchors, self.anchor_masks)
def get_loss(self, inputs, targets):
return self.loss(inputs, targets, self.anchors)
......@@ -407,10 +407,13 @@ class YOLOBox(object):
def __call__(self, yolo_head_out, anchors, im_shape, scale_factor=None):
boxes_list = []
scores_list = []
im_shape = paddle.cast(im_shape, 'float32')
if scale_factor is not None:
origin_shape = im_shape / scale_factor
else:
origin_shape = im_shape
origin_shape = paddle.cast(origin_shape, 'int32')
for i, head_out in enumerate(yolo_head_out):
boxes, scores = ops.yolo_box(head_out, origin_shape, anchors[i],
self.num_classes, self.conf_thresh,
......
......@@ -13,5 +13,9 @@
# limitations under the License.
from . import yolo_loss
from . import iou_aware_loss
from . import iou_loss
from .yolo_loss import *
from .iou_aware_loss import *
from .iou_loss import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.nn.functional as F
from ppdet.core.workspace import register, serializable
from .iou_loss import IouLoss
from ..utils import xywh2xyxy, bbox_iou, decode_yolo
@register
@serializable
class IouAwareLoss(IouLoss):
"""
iou aware loss, see https://arxiv.org/abs/1912.05992
Args:
loss_weight (float): iou aware loss weight, default is 1.0
max_height (int): max height of input to support random shape input
max_width (int): max width of input to support random shape input
"""
def __init__(
self,
loss_weight=1.0,
giou=False,
diou=False,
ciou=False, ):
super(IouAwareLoss, self).__init__(
loss_weight=loss_weight, giou=giou, diou=diou, ciou=ciou)
def __call__(self, ioup, pbox, gbox, anchor, downsample, scale=1.):
b = pbox.shape[0]
ioup = ioup.reshape((b, -1))
pbox = decode_yolo(pbox, anchor, downsample)
gbox = decode_yolo(gbox, anchor, downsample)
pbox = xywh2xyxy(pbox).reshape((b, -1, 4))
gbox = xywh2xyxy(gbox).reshape((b, -1, 4))
iou = bbox_iou(
pbox, gbox, giou=self.giou, diou=self.diou, ciou=self.ciou)
iou.stop_gradient = True
loss_iou_aware = F.binary_cross_entropy_with_logits(
ioup, iou, reduction='none')
loss_iou_aware = loss_iou_aware * self.loss_weight
return loss_iou_aware
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn.functional as F
from ppdet.core.workspace import register, serializable
from ..utils import xywh2xyxy, bbox_iou, decode_yolo
__all__ = ['IouLoss']
@register
@serializable
class IouLoss(object):
"""
iou loss, see https://arxiv.org/abs/1908.03851
loss = 1.0 - iou * iou
Args:
loss_weight (float): iou loss weight, default is 2.5
max_height (int): max height of input to support random shape input
max_width (int): max width of input to support random shape input
ciou_term (bool): whether to add ciou_term
loss_square (bool): whether to square the iou term
"""
def __init__(self,
loss_weight=2.5,
giou=False,
diou=False,
ciou=False,
loss_square=True):
self.loss_weight = loss_weight
self.giou = giou
self.diou = diou
self.ciou = ciou
self.loss_square = loss_square
def __call__(self, pbox, gbox, anchor, downsample, scale=1.):
b = pbox.shape[0]
pbox = decode_yolo(pbox, anchor, downsample)
gbox = decode_yolo(gbox, anchor, downsample)
pbox = xywh2xyxy(pbox).reshape((b, -1, 4))
gbox = xywh2xyxy(gbox).reshape((b, -1, 4))
iou = bbox_iou(
pbox, gbox, giou=self.giou, diou=self.diou, ciou=self.ciou)
if self.loss_square:
loss_iou = 1 - iou * iou
else:
loss_iou = 1 - iou
loss_iou = loss_iou * self.loss_weight
return loss_iou
......@@ -12,54 +12,156 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.fluid.regularizer import L2Decay
from ppdet.core.workspace import register
from ..backbone.darknet import ConvBNLayer
from ..utils import decode_yolo, xywh2xyxy, iou_similarity
__all__ = ['YOLOv3Loss']
@register
class YOLOv3Loss(nn.Layer):
__inject__ = ['iou_loss', 'iou_aware_loss']
__shared__ = ['num_classes']
def __init__(self,
num_classes=80,
ignore_thresh=0.7,
label_smooth=False,
downsample=32,
use_fine_grained_loss=False):
downsample=[32, 16, 8],
scale_x_y=1.,
iou_loss=None,
iou_aware_loss=None):
super(YOLOv3Loss, self).__init__()
self.num_classes = num_classes
self.ignore_thresh = ignore_thresh
self.label_smooth = label_smooth
self.downsample = downsample
self.use_fine_grained_loss = use_fine_grained_loss
def forward(self, inputs, head_outputs, anchors, anchor_masks):
if self.use_fine_grained_loss:
raise NotImplementedError(
"fine grained loss not implement currently")
yolo_losses = []
for i, out in enumerate(head_outputs):
loss = fluid.layers.yolov3_loss(
x=out,
gt_box=inputs['gt_bbox'],
gt_label=inputs['gt_class'],
gt_score=inputs['gt_score'],
anchors=anchors,
anchor_mask=anchor_masks[i],
class_num=self.num_classes,
ignore_thresh=self.ignore_thresh,
downsample_ratio=self.downsample // 2**i,
use_label_smooth=self.label_smooth,
name='yolo_loss_' + str(i))
loss = paddle.mean(loss)
yolo_losses.append(loss)
return {'loss': sum(yolo_losses)}
self.scale_x_y = scale_x_y
self.iou_loss = iou_loss
self.iou_aware_loss = iou_aware_loss
def obj_loss(self, pbox, gbox, pobj, tobj, anchor, downsample):
b, h, w, na = pbox.shape[:4]
pbox = decode_yolo(pbox, anchor, downsample)
pbox = pbox.reshape((b, -1, 4))
pbox = xywh2xyxy(pbox)
gbox = xywh2xyxy(gbox)
iou = iou_similarity(pbox, gbox)
iou.stop_gradient = True
iou_max = iou.max(2) # [N, M1]
iou_mask = paddle.cast(iou_max <= self.ignore_thresh, dtype=pbox.dtype)
iou_mask.stop_gradient = True
pobj = pobj.reshape((b, -1))
tobj = tobj.reshape((b, -1))
obj_mask = paddle.cast(tobj > 0, dtype=pbox.dtype)
obj_mask.stop_gradient = True
loss_obj = F.binary_cross_entropy_with_logits(
pobj, obj_mask, reduction='none')
loss_obj_pos = (loss_obj * tobj)
loss_obj_neg = (loss_obj * (1 - obj_mask) * iou_mask)
return loss_obj_pos + loss_obj_neg
def cls_loss(self, pcls, tcls):
if self.label_smooth:
delta = min(1. / self.num_classes, 1. / 40)
pos, neg = 1 - delta, delta
# 1 for positive, 0 for negative
tcls = pos * paddle.cast(
tcls > 0., dtype=tcls.dtype) + neg * paddle.cast(
tcls <= 0., dtype=tcls.dtype)
loss_cls = F.binary_cross_entropy_with_logits(
pcls, tcls, reduction='none')
return loss_cls
def yolov3_loss(self, x, t, gt_box, anchor, downsample, scale=1.,
eps=1e-10):
na = len(anchor)
b, c, h, w = x.shape
no = c // na
x = x.reshape((b, na, no, h, w)).transpose((0, 3, 4, 1, 2))
xy, wh, obj = x[:, :, :, :, 0:2], x[:, :, :, :, 2:4], x[:, :, :, :, 4:5]
if self.iou_aware_loss:
ioup, pcls = x[:, :, :, :, 5:6], x[:, :, :, :, 6:]
else:
pcls = x[:, :, :, :, 5:]
t = t.transpose((0, 3, 4, 1, 2))
txy, twh, tscale = t[:, :, :, :, 0:2], t[:, :, :, :, 2:4], t[:, :, :, :,
4:5]
tobj, tcls = t[:, :, :, :, 5:6], t[:, :, :, :, 6:]
tscale_obj = tscale * tobj
loss = dict()
if abs(scale - 1.) < eps:
loss_xy = tscale_obj * F.binary_cross_entropy_with_logits(
xy, txy, reduction='none')
else:
xy = scale * F.sigmoid(xy) - 0.5 * (scale - 1.)
loss_xy = tscale_obj * paddle.abs(xy - txy)
loss_xy = loss_xy.sum([1, 2, 3, 4]).mean()
loss_wh = tscale_obj * paddle.abs(wh - twh)
loss_wh = loss_wh.sum([1, 2, 3, 4]).mean()
loss['loss_loc'] = loss_xy + loss_wh
x[:, :, :, :, 0:2] = scale * F.sigmoid(x[:, :, :, :, 0:2]) - 0.5 * (
scale - 1.)
box, tbox = x[:, :, :, :, 0:4], t[:, :, :, :, 0:4]
if self.iou_loss is not None:
# box and tbox will not change though they are modified in self.iou_loss function, so no need to clone
loss_iou = self.iou_loss(box, tbox, anchor, downsample, scale)
loss_iou = loss_iou * tscale_obj.reshape((b, -1))
loss_iou = loss_iou.sum(-1).mean()
loss['loss_iou'] = loss_iou
if self.iou_aware_loss is not None:
# box and tbox will not change though they are modified in self.iou_aware_loss function, so no need to clone
loss_iou_aware = self.iou_aware_loss(ioup, box, tbox, anchor,
downsample, scale)
loss_iou_aware = loss_iou_aware * tobj.reshape((b, -1))
loss_iou_aware = loss_iou_aware.sum(-1).mean()
loss['loss_iou_aware'] = loss_iou_aware
loss_obj = self.obj_loss(box, gt_box, obj, tobj, anchor, downsample)
loss_obj = loss_obj.sum(-1).mean()
loss['loss_obj'] = loss_obj
loss_cls = self.cls_loss(pcls, tcls) * tobj
loss_cls = loss_cls.sum([1, 2, 3, 4]).mean()
loss['loss_cls'] = loss_cls
return loss
def forward(self, inputs, targets, anchors):
np = len(inputs)
gt_targets = [targets['target{}'.format(i)] for i in range(np)]
gt_box = targets['gt_bbox']
yolo_losses = dict()
for x, t, anchor, downsample in zip(inputs, gt_targets, anchors,
self.downsample):
yolo_loss = self.yolov3_loss(x, t, gt_box, anchor, downsample)
for k, v in yolo_loss.items():
if k in yolo_losses:
yolo_losses[k] += v
else:
yolo_losses[k] = v
loss = 0
for k, v in yolo_losses.items():
loss += v
yolo_losses['loss'] = loss
return yolo_losses
......@@ -21,7 +21,7 @@ from ..backbone.darknet import ConvBNLayer
class YoloDetBlock(nn.Layer):
def __init__(self, ch_in, channel, name):
def __init__(self, ch_in, channel, norm_type, name):
super(YoloDetBlock, self).__init__()
self.ch_in = ch_in
self.channel = channel
......@@ -45,6 +45,7 @@ class YoloDetBlock(nn.Layer):
ch_out=ch_out,
filter_size=filter_size,
padding=(filter_size - 1) // 2,
norm_type=norm_type,
name=name + post_name))
self.tip = ConvBNLayer(
......@@ -52,6 +53,7 @@ class YoloDetBlock(nn.Layer):
ch_out=channel * 2,
filter_size=3,
padding=1,
norm_type=norm_type,
name=name + '.tip')
def forward(self, inputs):
......@@ -63,7 +65,9 @@ class YoloDetBlock(nn.Layer):
@register
@serializable
class YOLOv3FPN(nn.Layer):
def __init__(self, feat_channels=[1024, 768, 384]):
__shared__ = ['norm_type']
def __init__(self, feat_channels=[1024, 768, 384], norm_type='bn'):
super(YOLOv3FPN, self).__init__()
assert len(feat_channels) > 0, "feat_channels length should > 0"
self.feat_channels = feat_channels
......@@ -75,7 +79,10 @@ class YOLOv3FPN(nn.Layer):
yolo_block = self.add_sublayer(
name,
YoloDetBlock(
feat_channels[i], channel=512 // (2**i), name=name))
feat_channels[i],
channel=512 // (2**i),
norm_type=norm_type,
name=name))
self.yolo_blocks.append(yolo_block)
if i < self.num_blocks - 1:
......@@ -88,6 +95,7 @@ class YOLOv3FPN(nn.Layer):
filter_size=1,
stride=1,
padding=0,
norm_type=norm_type,
name=name))
self.routes.append(route)
......
......@@ -14,6 +14,9 @@
import paddle
import paddle.nn.functional as F
import paddle.nn as nn
from paddle import ParamAttr
from paddle.regularizer import L2Decay
from paddle.fluid.framework import Variable, in_dygraph_mode
from paddle.fluid import core
......@@ -26,21 +29,33 @@ import numpy as np
from functools import reduce
__all__ = [
'roi_pool',
'roi_align',
'prior_box',
'anchor_generator',
'generate_proposals',
'iou_similarity',
'box_coder',
'yolo_box',
'multiclass_nms',
'distribute_fpn_proposals',
'collect_fpn_proposals',
'matrix_nms',
'roi_pool', 'roi_align', 'prior_box', 'anchor_generator',
'generate_proposals', 'iou_similarity', 'box_coder', 'yolo_box',
'multiclass_nms', 'distribute_fpn_proposals', 'collect_fpn_proposals',
'matrix_nms', 'BatchNorm'
]
class BatchNorm(nn.Layer):
def __init__(self, ch, norm_type='bn', name=None):
super(BatchNorm, self).__init__()
bn_name = name + '.bn'
if norm_type == 'sync_bn':
batch_norm = nn.SyncBatchNorm
else:
batch_norm = nn.BatchNorm2D
self.batch_norm = batch_norm(
ch,
weight_attr=ParamAttr(
name=bn_name + '.scale', regularizer=L2Decay(0.)),
bias_attr=ParamAttr(
name=bn_name + '.offset', regularizer=L2Decay(0.)))
def forward(self, x):
return self.batch_norm(x)
def roi_pool(input,
rois,
output_size,
......
import numpy as np
import paddle.fluid as fluid
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
......
from __future__ import print_function
import random
import unittest
import numpy as np
import copy
# add python path of PadleDetection to sys.path
import os
import sys
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 4)))
if parent_path not in sys.path:
sys.path.append(parent_path)
from ppdet.data.transform import *
def gen_sample(h, w, nt, nc, random_score=True, channel_first=False):
im = np.random.randint(0, 256, size=(h, w, 3)).astype('float32')
if channel_first:
im = im.transpose((2, 0, 1))
gt_bbox = np.random.random(size=(nt, 4)).astype('float32')
gt_class = np.random.randint(0, nc, size=(nt, 1)).astype('int32')
if random_score:
gt_score = np.random.random(size=(nt, 1))
else:
gt_score = np.ones(shape=(nt, 1)).astype('float32')
is_crowd = np.zeros_like(gt_class)
sample = {
'image': im,
'gt_bbox': gt_bbox,
'gt_class': gt_class,
'gt_score': gt_score,
'is_crowd': is_crowd,
'h': h,
'w': w
}
return sample
class TestTransformOp(unittest.TestCase):
def setUp(self):
self.h, self.w = np.random.randint(1, 1024, size=2)
self.nt = np.random.randint(1, 50)
self.nc = 80
def assertAllClose(self, x, y, msg, atol=1e-5, rtol=1e-3):
self.assertTrue(np.allclose(x, y, atol=atol, rtol=rtol), msg=msg)
class TestResizeOp(TestTransformOp):
def test_resize(self):
sample = gen_sample(self.h, self.w, self.nt, self.nc)
orig_op = Resize(target_dim=608, interp=2)
curr_op = ResizeOp(target_size=608, keep_ratio=False, interp=2)
orig_res = orig_op(copy.deepcopy(sample))
curr_res = curr_op(copy.deepcopy(sample))
fields = ['image', 'gt_bbox', 'gt_class', 'gt_score']
for k in fields:
self.assertAllClose(orig_res[k], curr_res[k], msg=k)
# only for specified random seed
# class TestMixupOp(TestTransformOp):
# def setUp(self):
# self.h, self.w = np.random.randint(1024, size=2)
# self.nt = np.random.randint(50)
# self.nc = 80
# def test_mixup(self):
# curr_sample = [gen_sample(self.h, self.w, self.nt, self.nc) for _ in range(2)]
# orig_sample = copy.deepcopy(curr_sample[0])
# orig_sample['mixup'] = copy.deepcopy(curr_sample[1])
# orig_op = MixupImage(alpha=1.5, beta=1.5)
# curr_op = MixupOp(alpha=1.5, beta=1.5)
# orig_res = orig_op(orig_sample)
# curr_res = curr_op(curr_sample)
# fields = ['image', 'gt_bbox', 'gt_class', 'gt_score']
# for k in fields:
# self.assertAllClose(orig_res[k], curr_res[k], msg=k)
# only for specified random seed
# class TestRandomDistortOp(TestTransformOp):
# def test_random_distort(self):
# sample = gen_sample(self.h, self.w, self.nt, self.nc)
# orig_op = ColorDistort(hsv_format=True, random_apply=False)
# curr_op = RandomDistortOp(random_apply=False)
# orig_res = orig_op(copy.deepcopy(sample))
# curr_res = curr_op(copy.deepcopy(sample))
# fields = ['image', 'gt_bbox', 'gt_class', 'gt_score']
# for k in fields:
# self.assertAllClose(orig_res[k], curr_res[k], msg=k)
# only for specified random seed
# class TestRandomExpandOp(TestTransformOp):
# def test_random_expand(self):
# sample = gen_sample(self.h, self.w, self.nt, self.nc)
# orig_op = RandomExpand(fill_value=[123.675, 116.28, 103.53])
# curr_op = RandomExpandOp(fill_value=[123.675, 116.28, 103.53])
# orig_res = orig_op(copy.deepcopy(sample))
# curr_res = curr_op(copy.deepcopy(sample))
# fields = ['image', 'gt_bbox', 'gt_class', 'gt_score']
# for k in fields:
# self.assertAllClose(orig_res[k], curr_res[k], msg=k)
# only for specified random seed
# class TestRandomCropOp(TestTransformOp):
# def test_random_crop(self):
# sample = gen_sample(self.h, self.w, self.nt, self.nc)
# orig_op = RandomCrop()
# curr_op = RandomCropOp()
# orig_res = orig_op(copy.deepcopy(sample))
# curr_res = curr_op(copy.deepcopy(sample))
# fields = ['image', 'gt_bbox', 'gt_class', 'gt_score']
# for k in fields:
# self.assertAllClose(orig_res[k], curr_res[k], msg=k)
# only for specified random seed
# class TestRandomFlipOp(TestTransformOp):
# def test_random_flip(self):
# sample = gen_sample(self.h, self.w, self.nt, self.nc)
# orig_op = RandomFlipImage(is_normalized=False)
# curr_op = RandomFlipOp()
# orig_res = orig_op(copy.deepcopy(sample))
# curr_res = curr_op(copy.deepcopy(sample))
# fields = ['image', 'gt_bbox', 'gt_class', 'gt_score']
# for k in fields:
# self.assertAllClose(orig_res[k], curr_res[k], msg=k)
# only for specified random seed
# class TestBatchRandomResizeOp(TestTransformOp):
# def test_batch_random_resize(self):
# sample = [gen_sample(self.h, self.w, self.nt, self.nc) for _ in range(10)]
# orig_op = RandomShape(sizes=[320, 352, 384, 416, 448, 480, 512, 544, 576, 608], random_inter=True, resize_box=True)
# curr_op = BatchRandomResizeOp(target_size=[320, 352, 384, 416, 448, 480, 512, 544, 576, 608], random_size=True, random_interp=True, keep_ratio=False)
# orig_ress = orig_op(copy.deepcopy(sample))
# curr_ress = curr_op(copy.deepcopy(sample))
# fields = ['image', 'gt_bbox', 'gt_class', 'gt_score']
# for orig_res, curr_res in zip(orig_ress, curr_ress):
# for k in fields:
# self.assertAllClose(orig_res[k], curr_res[k], msg=k)
class TestNormalizeBoxOp(TestTransformOp):
def test_normalize_box(self):
sample = gen_sample(self.h, self.w, self.nt, self.nc)
orig_op = NormalizeBox()
curr_op = NormalizeBoxOp()
orig_res = orig_op(copy.deepcopy(sample))
curr_res = curr_op(copy.deepcopy(sample))
fields = ['image', 'gt_bbox', 'gt_class', 'gt_score']
for k in fields:
self.assertAllClose(orig_res[k], curr_res[k], msg=k)
class TestPadBoxOp(TestTransformOp):
def test_pad_box(self):
sample = gen_sample(self.h, self.w, self.nt, self.nc)
orig_op = PadBox(num_max_boxes=50)
curr_op = PadBoxOp(num_max_boxes=50)
orig_res = orig_op(copy.deepcopy(sample))
curr_res = curr_op(copy.deepcopy(sample))
fields = ['image', 'gt_bbox', 'gt_class', 'gt_score']
for k in fields:
self.assertAllClose(orig_res[k], curr_res[k], msg=k)
class TestBboxXYXY2XYWHOp(TestTransformOp):
def test_bbox_xyxy2xywh(self):
sample = gen_sample(self.h, self.w, self.nt, self.nc)
orig_op = BboxXYXY2XYWH()
curr_op = BboxXYXY2XYWHOp()
orig_res = orig_op(copy.deepcopy(sample))
curr_res = curr_op(copy.deepcopy(sample))
fields = ['image', 'gt_bbox', 'gt_class', 'gt_score']
for k in fields:
self.assertAllClose(orig_res[k], curr_res[k], msg=k)
class TestNormalizeImageOp(TestTransformOp):
def test_normalize_image(self):
sample = gen_sample(self.h, self.w, self.nt, self.nc)
orig_op = NormalizeImage(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
is_scale=True,
is_channel_first=False)
curr_op = NormalizeImageOp(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
is_scale=True)
orig_res = orig_op(copy.deepcopy(sample))
curr_res = curr_op(copy.deepcopy(sample))
fields = ['image', 'gt_bbox', 'gt_class', 'gt_score']
for k in fields:
self.assertAllClose(orig_res[k], curr_res[k], msg=k)
class TestPermuteOp(TestTransformOp):
def test_permute(self):
sample = gen_sample(self.h, self.w, self.nt, self.nc)
orig_op = Permute(to_bgr=False, channel_first=True)
curr_op = PermuteOp()
orig_res = orig_op(copy.deepcopy(sample))
curr_res = curr_op(copy.deepcopy(sample))
fields = ['image', 'gt_bbox', 'gt_class', 'gt_score']
for k in fields:
self.assertAllClose(orig_res[k], curr_res[k], msg=k)
class TestGt2YoloTargetOp(TestTransformOp):
def test_gt2yolotarget(self):
sample = [
gen_sample(
self.h, self.w, self.nt, self.nc, channel_first=True)
for _ in range(10)
]
orig_op = Gt2YoloTarget(
anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
[59, 119], [116, 90], [156, 198], [373, 326]],
downsample_ratios=[32, 16, 8])
curr_op = Gt2YoloTargetOp(
anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
[59, 119], [116, 90], [156, 198], [373, 326]],
downsample_ratios=[32, 16, 8])
orig_ress = orig_op(copy.deepcopy(sample))
curr_ress = curr_op(copy.deepcopy(sample))
fields = ['image', 'gt_bbox', 'gt_class', 'gt_score']
for orig_res, curr_res in zip(orig_ress, curr_ress):
for k in fields:
self.assertAllClose(orig_res[k], curr_res[k], msg=k)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import unittest
import numpy as np
from scipy.special import logit
from scipy.special import expit
import paddle
from paddle import fluid
from paddle.fluid import core
# add python path of PadleDetection to sys.path
import os
import sys
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 4)))
if parent_path not in sys.path:
sys.path.append(parent_path)
from ppdet.modeling.loss import YOLOv3Loss
from ppdet.data.transform.op_helper import jaccard_overlap
import random
import numpy as np
def _split_ioup(output, an_num, num_classes):
"""
Split output feature map to output, predicted iou
along channel dimension
"""
ioup = fluid.layers.slice(output, axes=[1], starts=[0], ends=[an_num])
ioup = fluid.layers.sigmoid(ioup)
oriout = fluid.layers.slice(
output, axes=[1], starts=[an_num], ends=[an_num * (num_classes + 6)])
return (ioup, oriout)
def _split_output(output, an_num, num_classes):
"""
Split output feature map to x, y, w, h, objectness, classification
along channel dimension
"""
x = fluid.layers.strided_slice(
output,
axes=[1],
starts=[0],
ends=[output.shape[1]],
strides=[5 + num_classes])
y = fluid.layers.strided_slice(
output,
axes=[1],
starts=[1],
ends=[output.shape[1]],
strides=[5 + num_classes])
w = fluid.layers.strided_slice(
output,
axes=[1],
starts=[2],
ends=[output.shape[1]],
strides=[5 + num_classes])
h = fluid.layers.strided_slice(
output,
axes=[1],
starts=[3],
ends=[output.shape[1]],
strides=[5 + num_classes])
obj = fluid.layers.strided_slice(
output,
axes=[1],
starts=[4],
ends=[output.shape[1]],
strides=[5 + num_classes])
clss = []
stride = output.shape[1] // an_num
for m in range(an_num):
clss.append(
fluid.layers.slice(
output,
axes=[1],
starts=[stride * m + 5],
ends=[stride * m + 5 + num_classes]))
cls = fluid.layers.transpose(
fluid.layers.stack(
clss, axis=1), perm=[0, 1, 3, 4, 2])
return (x, y, w, h, obj, cls)
def _split_target(target):
"""
split target to x, y, w, h, objectness, classification
along dimension 2
target is in shape [N, an_num, 6 + class_num, H, W]
"""
tx = target[:, :, 0, :, :]
ty = target[:, :, 1, :, :]
tw = target[:, :, 2, :, :]
th = target[:, :, 3, :, :]
tscale = target[:, :, 4, :, :]
tobj = target[:, :, 5, :, :]
tcls = fluid.layers.transpose(target[:, :, 6:, :, :], perm=[0, 1, 3, 4, 2])
tcls.stop_gradient = True
return (tx, ty, tw, th, tscale, tobj, tcls)
def _calc_obj_loss(output, obj, tobj, gt_box, batch_size, anchors, num_classes,
downsample, ignore_thresh, scale_x_y):
# A prediction bbox overlap any gt_bbox over ignore_thresh,
# objectness loss will be ignored, process as follows:
# 1. get pred bbox, which is same with YOLOv3 infer mode, use yolo_box here
# NOTE: img_size is set as 1.0 to get noramlized pred bbox
bbox, prob = fluid.layers.yolo_box(
x=output,
img_size=fluid.layers.ones(
shape=[batch_size, 2], dtype="int32"),
anchors=anchors,
class_num=num_classes,
conf_thresh=0.,
downsample_ratio=downsample,
clip_bbox=False,
scale_x_y=scale_x_y)
# 2. split pred bbox and gt bbox by sample, calculate IoU between pred bbox
# and gt bbox in each sample
if batch_size > 1:
preds = fluid.layers.split(bbox, batch_size, dim=0)
gts = fluid.layers.split(gt_box, batch_size, dim=0)
else:
preds = [bbox]
gts = [gt_box]
probs = [prob]
ious = []
for pred, gt in zip(preds, gts):
def box_xywh2xyxy(box):
x = box[:, 0]
y = box[:, 1]
w = box[:, 2]
h = box[:, 3]
return fluid.layers.stack(
[
x - w / 2.,
y - h / 2.,
x + w / 2.,
y + h / 2.,
], axis=1)
pred = fluid.layers.squeeze(pred, axes=[0])
gt = box_xywh2xyxy(fluid.layers.squeeze(gt, axes=[0]))
ious.append(fluid.layers.iou_similarity(pred, gt))
iou = fluid.layers.stack(ious, axis=0)
# 3. Get iou_mask by IoU between gt bbox and prediction bbox,
# Get obj_mask by tobj(holds gt_score), calculate objectness loss
max_iou = fluid.layers.reduce_max(iou, dim=-1)
iou_mask = fluid.layers.cast(max_iou <= ignore_thresh, dtype="float32")
output_shape = fluid.layers.shape(output)
an_num = len(anchors) // 2
iou_mask = fluid.layers.reshape(iou_mask, (-1, an_num, output_shape[2],
output_shape[3]))
iou_mask.stop_gradient = True
# NOTE: tobj holds gt_score, obj_mask holds object existence mask
obj_mask = fluid.layers.cast(tobj > 0., dtype="float32")
obj_mask.stop_gradient = True
# For positive objectness grids, objectness loss should be calculated
# For negative objectness grids, objectness loss is calculated only iou_mask == 1.0
loss_obj = fluid.layers.sigmoid_cross_entropy_with_logits(obj, obj_mask)
loss_obj_pos = fluid.layers.reduce_sum(loss_obj * tobj, dim=[1, 2, 3])
loss_obj_neg = fluid.layers.reduce_sum(
loss_obj * (1.0 - obj_mask) * iou_mask, dim=[1, 2, 3])
return loss_obj_pos, loss_obj_neg
def fine_grained_loss(output,
target,
gt_box,
batch_size,
num_classes,
anchors,
ignore_thresh,
downsample,
scale_x_y=1.,
eps=1e-10):
an_num = len(anchors) // 2
x, y, w, h, obj, cls = _split_output(output, an_num, num_classes)
tx, ty, tw, th, tscale, tobj, tcls = _split_target(target)
tscale_tobj = tscale * tobj
scale_x_y = scale_x_y
if (abs(scale_x_y - 1.0) < eps):
loss_x = fluid.layers.sigmoid_cross_entropy_with_logits(
x, tx) * tscale_tobj
loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3])
loss_y = fluid.layers.sigmoid_cross_entropy_with_logits(
y, ty) * tscale_tobj
loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3])
else:
dx = scale_x_y * fluid.layers.sigmoid(x) - 0.5 * (scale_x_y - 1.0)
dy = scale_x_y * fluid.layers.sigmoid(y) - 0.5 * (scale_x_y - 1.0)
loss_x = fluid.layers.abs(dx - tx) * tscale_tobj
loss_x = fluid.layers.reduce_sum(loss_x, dim=[1, 2, 3])
loss_y = fluid.layers.abs(dy - ty) * tscale_tobj
loss_y = fluid.layers.reduce_sum(loss_y, dim=[1, 2, 3])
# NOTE: we refined loss function of (w, h) as L1Loss
loss_w = fluid.layers.abs(w - tw) * tscale_tobj
loss_w = fluid.layers.reduce_sum(loss_w, dim=[1, 2, 3])
loss_h = fluid.layers.abs(h - th) * tscale_tobj
loss_h = fluid.layers.reduce_sum(loss_h, dim=[1, 2, 3])
loss_obj_pos, loss_obj_neg = _calc_obj_loss(
output, obj, tobj, gt_box, batch_size, anchors, num_classes, downsample,
ignore_thresh, scale_x_y)
loss_cls = fluid.layers.sigmoid_cross_entropy_with_logits(cls, tcls)
loss_cls = fluid.layers.elementwise_mul(loss_cls, tobj, axis=0)
loss_cls = fluid.layers.reduce_sum(loss_cls, dim=[1, 2, 3, 4])
loss_xys = fluid.layers.reduce_mean(loss_x + loss_y)
loss_whs = fluid.layers.reduce_mean(loss_w + loss_h)
loss_objs = fluid.layers.reduce_mean(loss_obj_pos + loss_obj_neg)
loss_clss = fluid.layers.reduce_mean(loss_cls)
losses_all = {
"loss_xy": fluid.layers.sum(loss_xys),
"loss_wh": fluid.layers.sum(loss_whs),
"loss_loc": fluid.layers.sum(loss_xys) + fluid.layers.sum(loss_whs),
"loss_obj": fluid.layers.sum(loss_objs),
"loss_cls": fluid.layers.sum(loss_clss),
}
return losses_all, x, y, tx, ty
def gt2yolotarget(gt_bbox, gt_class, gt_score, anchors, mask, num_classes, size,
stride):
grid_h, grid_w = size
h, w = grid_h * stride, grid_w * stride
an_hw = np.array(anchors) / np.array([[w, h]])
target = np.zeros(
(len(mask), 6 + num_classes, grid_h, grid_w), dtype=np.float32)
for b in range(gt_bbox.shape[0]):
gx, gy, gw, gh = gt_bbox[b, :]
cls = gt_class[b]
score = gt_score[b]
if gw <= 0. or gh <= 0. or score <= 0.:
continue
# find best match anchor index
best_iou = 0.
best_idx = -1
for an_idx in range(an_hw.shape[0]):
iou = jaccard_overlap([0., 0., gw, gh],
[0., 0., an_hw[an_idx, 0], an_hw[an_idx, 1]])
if iou > best_iou:
best_iou = iou
best_idx = an_idx
gi = int(gx * grid_w)
gj = int(gy * grid_h)
# gtbox should be regresed in this layes if best match
# anchor index in anchor mask of this layer
if best_idx in mask:
best_n = mask.index(best_idx)
# x, y, w, h, scale
target[best_n, 0, gj, gi] = gx * grid_w - gi
target[best_n, 1, gj, gi] = gy * grid_h - gj
target[best_n, 2, gj, gi] = np.log(gw * w / anchors[best_idx][0])
target[best_n, 3, gj, gi] = np.log(gh * h / anchors[best_idx][1])
target[best_n, 4, gj, gi] = 2.0 - gw * gh
# objectness record gt_score
# if target[best_n, 5, gj, gi] > 0:
# print('find 1 duplicate')
target[best_n, 5, gj, gi] = score
# classification
target[best_n, 6 + cls, gj, gi] = 1.
return target
class TestYolov3LossOp(unittest.TestCase):
def setUp(self):
self.initTestCase()
x = np.random.uniform(0, 1, self.x_shape).astype('float64')
gtbox = np.random.random(size=self.gtbox_shape).astype('float64')
gtlabel = np.random.randint(0, self.class_num, self.gtbox_shape[:2])
gtmask = np.random.randint(0, 2, self.gtbox_shape[:2])
gtbox = gtbox * gtmask[:, :, np.newaxis]
gtlabel = gtlabel * gtmask
gtscore = np.ones(self.gtbox_shape[:2]).astype('float64')
if self.gtscore:
gtscore = np.random.random(self.gtbox_shape[:2]).astype('float64')
target = []
for box, label, score in zip(gtbox, gtlabel, gtscore):
target.append(
gt2yolotarget(box, label, score, self.anchors, self.anchor_mask,
self.class_num, (self.h, self.w
), self.downsample_ratio))
self.target = np.array(target).astype('float64')
self.mask_anchors = []
for i in self.anchor_mask:
self.mask_anchors.extend(self.anchors[i])
self.x = x
self.gtbox = gtbox
self.gtlabel = gtlabel
self.gtscore = gtscore
def initTestCase(self):
self.b = 8
self.h = 19
self.w = 19
self.anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
[59, 119], [116, 90], [156, 198], [373, 326]]
self.anchor_mask = [6, 7, 8]
self.na = len(self.anchor_mask)
self.class_num = 80
self.ignore_thresh = 0.7
self.downsample_ratio = 32
self.x_shape = (self.b, len(self.anchor_mask) * (5 + self.class_num),
self.h, self.w)
self.gtbox_shape = (self.b, 40, 4)
self.gtscore = True
self.use_label_smooth = False
self.scale_x_y = 1.
def test_loss(self):
x, gtbox, gtlabel, gtscore, target = self.x, self.gtbox, self.gtlabel, self.gtscore, self.target
yolo_loss = YOLOv3Loss(
ignore_thresh=self.ignore_thresh,
label_smooth=self.use_label_smooth,
num_classes=self.class_num,
downsample=self.downsample_ratio,
scale_x_y=self.scale_x_y)
x = paddle.to_tensor(x.astype(np.float32))
gtbox = paddle.to_tensor(gtbox.astype(np.float32))
gtlabel = paddle.to_tensor(gtlabel.astype(np.float32))
gtscore = paddle.to_tensor(gtscore.astype(np.float32))
t = paddle.to_tensor(target.astype(np.float32))
anchor = [self.anchors[i] for i in self.anchor_mask]
(yolo_loss1, px, py, tx, ty) = fine_grained_loss(
output=x,
target=t,
gt_box=gtbox,
batch_size=self.b,
num_classes=self.class_num,
anchors=self.mask_anchors,
ignore_thresh=self.ignore_thresh,
downsample=self.downsample_ratio,
scale_x_y=self.scale_x_y)
yolo_loss2 = yolo_loss.yolov3_loss(
x, t, gtbox, anchor, self.downsample_ratio, self.scale_x_y)
for k in yolo_loss2:
self.assertAlmostEqual(
yolo_loss1[k].numpy()[0],
yolo_loss2[k].numpy()[0],
delta=1e-2,
msg=k)
class TestYolov3LossNoGTScore(TestYolov3LossOp):
def initTestCase(self):
self.b = 1
self.h = 76
self.w = 76
self.anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
[59, 119], [116, 90], [156, 198], [373, 326]]
self.anchor_mask = [0, 1, 2]
self.na = len(self.anchor_mask)
self.class_num = 80
self.ignore_thresh = 0.7
self.downsample_ratio = 8
self.x_shape = (self.b, len(self.anchor_mask) * (5 + self.class_num),
self.h, self.w)
self.gtbox_shape = (self.b, 40, 4)
self.gtscore = False
self.use_label_smooth = False
self.scale_x_y = 1.
class TestYolov3LossWithScaleXY(TestYolov3LossOp):
def initTestCase(self):
self.b = 5
self.h = 38
self.w = 38
self.anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
[59, 119], [116, 90], [156, 198], [373, 326]]
self.anchor_mask = [3, 4, 5]
self.na = len(self.anchor_mask)
self.class_num = 80
self.ignore_thresh = 0.7
self.downsample_ratio = 16
self.x_shape = (self.b, len(self.anchor_mask) * (5 + self.class_num),
self.h, self.w)
self.gtbox_shape = (self.b, 40, 4)
self.gtscore = True
self.use_label_smooth = False
self.scale_x_y = 1.2
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import bbox_util
from .bbox_util import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn.functional as F
import math
def xywh2xyxy(box):
out = paddle.zeros_like(box)
out[:, :, 0:2] = box[:, :, 0:2] - box[:, :, 2:4] / 2
out[:, :, 2:4] = box[:, :, 0:2] + box[:, :, 2:4] / 2
return out
def make_grid(h, w, dtype):
yv, xv = paddle.meshgrid([paddle.arange(h), paddle.arange(w)])
return paddle.stack((xv, yv), 2).cast(dtype=dtype)
def decode_yolo(box, anchor, downsample_ratio):
"""decode yolo box
Args:
box (Tensor): pred with the shape [b, h, w, na, 4]
anchor (list): anchor with the shape [na, 2]
downsample_ratio (int): downsample ratio, default 32
scale (float): scale, default 1.
Return:
box (Tensor): decoded box, with the shape [b, h, w, na, 4]
"""
h, w, na = box.shape[1:4]
grid = make_grid(h, w, box.dtype).reshape((1, h, w, 1, 2))
box[:, :, :, :, 0:2] = box[:, :, :, :, :2] + grid
box[:, :, :, :, 0] = box[:, :, :, :, 0] / w
box[:, :, :, :, 1] = box[:, :, :, :, 1] / h
anchor = paddle.to_tensor(anchor)
anchor = paddle.cast(anchor, box.dtype)
anchor = anchor.reshape((1, 1, 1, na, 2))
box[:, :, :, :, 2:4] = paddle.exp(box[:, :, :, :, 2:4]) * anchor
box[:, :, :, :, 2] = box[:, :, :, :, 2] / (downsample_ratio * w)
box[:, :, :, :, 3] = box[:, :, :, :, 3] / (downsample_ratio * h)
return box
def iou_similarity(box1, box2, eps=1e-9):
"""Calculate iou of box1 and box2
Args:
box1 (Tensor): box with the shape [N, M1, 4]
box2 (Tensor): box with the shape [N, M2, 4]
Return:
iou (Tensor): iou between box1 and box2 with the shape [N, M1, M2]
"""
box1 = box1.unsqueeze(2) # [N, M1, 4] -> [N, M1, 1, 4]
box2 = box2.unsqueeze(1) # [N, M2, 4] -> [N, 1, M2, 4]
px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4]
gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4]
x1y1 = paddle.maximum(px1y1, gx1y1)
x2y2 = paddle.minimum(px2y2, gx2y2)
overlap = (x2y2 - x1y1).clip(0).prod(-1)
area1 = (px2y2 - px1y1).clip(0).prod(-1)
area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
union = area1 + area2 - overlap + eps
return overlap / union
def bbox_iou(box1, box2, giou=False, diou=False, ciou=False, eps=1e-9):
"""calculate the iou of box1 and box2
Args:
box1 (Tensor): box1 with the shape (N, M, 4)
box2 (Tensor): box1 with the shape (N, M, 4)
giou (bool): whether use giou or not, default False
diou (bool): whether use diou or not, default False
ciou (bool): whether use ciou or not, default False
eps (float): epsilon to avoid divide by zero
Return:
iou (Tensor): iou of box1 and box1, with the shape (N, M)
"""
px1y1, px2y2 = box1[:, :, 0:2], box1[:, :, 2:4]
gx1y1, gx2y2 = box2[:, :, 0:2], box2[:, :, 2:4]
x1y1 = paddle.maximum(px1y1, gx1y1)
x2y2 = paddle.minimum(px2y2, gx2y2)
overlap = (x2y2 - x1y1).clip(0).prod(-1)
area1 = (px2y2 - px1y1).clip(0).prod(-1)
area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
union = area1 + area2 - overlap + eps
iou = overlap / union
if giou or ciou or diou:
# convex w, h
cwh = paddle.maximum(px2y2, gx2y2) - paddle.minimum(px1y1, gx1y1)
if ciou or diou:
# convex diagonal squared
c2 = (cwh**2).sum(2) + eps
# center distance
rho2 = ((px1y1 + px2y2 - gx1y1 - gx2y2)**2).sum(2) / 4
if diou:
return iou - rho2 / c2
elif ciou:
wh1 = px2y2 - px1y1
wh2 = gx2y2 - gx1y1
w1, h1 = wh1[:, :, 0], wh1[:, :, 1] + eps
w2, h2 = wh2[:, :, 0], wh2[:, :, 1] + eps
v = (4 / math.pi**2) * paddle.pow(
paddle.atan(w1 / h1) - paddle.atan(w2 / h2), 2)
alpha = v / (1 + eps - iou + v)
alpha.stop_gradient = True
return iou - (rho2 / c2 + v * alpha)
else:
c_area = cwh.prod(2) + eps
return iou - (c_area - union) / c_area
else:
return iou
......@@ -30,7 +30,6 @@ import datetime
import numpy as np
from collections import deque
import paddle
from paddle import fluid
from ppdet.core.workspace import load_config, merge_config, create
from ppdet.utils.stats import TrainingStats
from ppdet.utils.check import check_gpu, check_version, check_config
......@@ -122,7 +121,6 @@ def run(FLAGS, cfg, place):
dataset, cfg['worker_num'], place)
# Model
main_arch = cfg.architecture
model = create(cfg.architecture)
# Optimizer
......@@ -137,19 +135,28 @@ def run(FLAGS, cfg, place):
cfg.get('load_static_weights', False),
FLAGS.weight_type)
if getattr(model.backbone, 'norm_type', None) == 'sync_bn':
assert cfg.use_gpu and ParallelEnv(
).nranks > 1, 'you should use bn rather than sync_bn while using a single gpu'
# sync_bn = (getattr(model.backbone, 'norm_type', None) == 'sync_bn' and
# cfg.use_gpu and ParallelEnv().nranks > 1)
# if sync_bn:
# model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
# Parallel Model
if ParallelEnv().nranks > 1:
model = paddle.DataParallel(model)
fields = train_loader.collate_fn.output_fields
# Run Train
start_iter = 0
time_stat = deque(maxlen=cfg.log_iter)
start_time = time.time()
end_time = time.time()
# Run Train
start_epoch = optimizer.state_dict()['LR_Scheduler']['last_epoch']
for e_id in range(int(cfg.epoch)):
cur_eid = e_id + start_epoch
for epoch_id in range(int(cfg.epoch)):
cur_eid = epoch_id + start_epoch
train_loader.dataset.epoch = epoch_id
for iter_id, data in enumerate(train_loader):
start_time = end_time
end_time = time.time()
......@@ -161,8 +168,7 @@ def run(FLAGS, cfg, place):
# Model Forward
model.train()
outputs = model(data, cfg['TrainReader']['inputs_def']['fields'],
'train')
outputs = model(data, fields, 'train')
# Model Backward
loss = outputs['loss']
......@@ -179,7 +185,7 @@ def run(FLAGS, cfg, place):
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
# Log state
if e_id == 0 and iter_id == 0:
if epoch_id == 0 and iter_id == 0:
train_stats = TrainingStats(cfg.log_iter, outputs.keys())
train_stats.update(outputs)
logs = train_stats.log()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册