未验证 提交 10bf8de7 编写于 作者: F Feng Ni 提交者: GitHub

add YOLOX codes (#5740)

上级 fecae1ee
......@@ -31,16 +31,32 @@ sys.path.insert(0, parent_path)
from benchmark_utils import PaddleInferBenchmark
from picodet_postprocess import PicoDetPostProcess
from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, decode_image
from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image
from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop
from visualize import visualize_box_mask
from utils import argsparser, Timer, get_current_memory_mb
# Global dictionary
SUPPORT_MODELS = {
'YOLO', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet', 'S2ANet', 'JDE',
'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet', 'TOOD', 'RetinaNet',
'StrongBaseline', 'STGCN'
'YOLO',
'RCNN',
'SSD',
'Face',
'FCOS',
'SOLOv2',
'TTFNet',
'S2ANet',
'JDE',
'FairMOT',
'DeepSORT',
'GFL',
'PicoDet',
'CenterNet',
'TOOD',
'RetinaNet',
'StrongBaseline',
'STGCN',
'YOLOX',
}
......
......@@ -246,6 +246,81 @@ class LetterBoxResize(object):
return im, im_info
class Pad(object):
def __init__(self,
size=None,
size_divisor=32,
pad_mode=0,
offsets=None,
fill_value=(127.5, 127.5, 127.5)):
"""
Pad image to a specified size or multiple of size_divisor.
Args:
size (int, Sequence): image target size, if None, pad to multiple of size_divisor, default None
size_divisor (int): size divisor, default 32
pad_mode (int): pad mode, currently only supports four modes [-1, 0, 1, 2]. if -1, use specified offsets
if 0, only pad to right and bottom. if 1, pad according to center. if 2, only pad left and top
offsets (list): [offset_x, offset_y], specify offset while padding, only supported pad_mode=-1
fill_value (bool): rgb value of pad area, default (127.5, 127.5, 127.5)
"""
super(Pad, self).__init__()
if isinstance(size, int):
size = [size, size]
assert pad_mode in [
-1, 0, 1, 2
], 'currently only supports four modes [-1, 0, 1, 2]'
if pad_mode == -1:
assert offsets, 'if pad_mode is -1, offsets should not be None'
self.size = size
self.size_divisor = size_divisor
self.pad_mode = pad_mode
self.fill_value = fill_value
self.offsets = offsets
def apply_image(self, image, offsets, im_size, size):
x, y = offsets
im_h, im_w = im_size
h, w = size
canvas = np.ones((h, w, 3), dtype=np.float32)
canvas *= np.array(self.fill_value, dtype=np.float32)
canvas[y:y + im_h, x:x + im_w, :] = image.astype(np.float32)
return canvas
def __call__(self, im, im_info):
im_h, im_w = im.shape[:2]
if self.size:
h, w = self.size
assert (
im_h <= h and im_w <= w
), '(h, w) of target size should be greater than (im_h, im_w)'
else:
h = int(np.ceil(im_h / self.size_divisor) * self.size_divisor)
w = int(np.ceil(im_w / self.size_divisor) * self.size_divisor)
if h == im_h and w == im_w:
im = im.astype(np.float32)
return im, im_info
if self.pad_mode == -1:
offset_x, offset_y = self.offsets
elif self.pad_mode == 0:
offset_y, offset_x = 0, 0
elif self.pad_mode == 1:
offset_y, offset_x = (h - im_h) // 2, (w - im_w) // 2
else:
offset_y, offset_x = h - im_h, w - im_w
offsets, im_size, size = [offset_x, offset_y], [im_h, im_w], [h, w]
im = self.apply_image(im, offsets, im_size, size)
if self.pad_mode == 0:
return im, im_info
return im, im_info
class WarpAffine(object):
"""Warp affine the image
"""
......
......@@ -5,7 +5,7 @@
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -75,7 +75,7 @@ class DetDataset(Dataset):
n = len(self.roidbs)
roidb = [roidb, ] + [
copy.deepcopy(self.roidbs[np.random.randint(n)])
for _ in range(3)
for _ in range(4)
]
if isinstance(roidb, Sequence):
for r in roidb:
......
......@@ -2034,13 +2034,14 @@ class Pad(BaseOperator):
if self.size:
h, w = self.size
assert (
im_h < h and im_w < w
im_h <= h and im_w <= w
), '(h, w) of target size should be greater than (im_h, im_w)'
else:
h = int(np.ceil(im_h / self.size_divisor) * self.size_divisor)
w = int(np.ceil(im_w / self.size_divisor) * self.size_divisor)
if h == im_h and w == im_w:
sample['image'] = im.astype(np.float32)
return sample
if self.pad_mode == -1:
......@@ -2139,16 +2140,29 @@ class Rbox2Poly(BaseOperator):
@register_op
class AugmentHSV(BaseOperator):
def __init__(self, fraction=0.50, is_bgr=True):
"""
Augment the SV channel of image data.
Args:
fraction (float): the fraction for augment. Default: 0.5.
is_bgr (bool): whether the image is BGR mode. Default: True.
"""
"""
Augment the SV channel of image data.
Args:
fraction (float): the fraction for augment. Default: 0.5.
is_bgr (bool): whether the image is BGR mode. Default: True.
hgain (float): H channel gains
sgain (float): S channel gains
vgain (float): V channel gains
"""
def __init__(self,
fraction=0.50,
is_bgr=True,
hgain=None,
sgain=None,
vgain=None):
super(AugmentHSV, self).__init__()
self.fraction = fraction
self.is_bgr = is_bgr
self.hgain = hgain
self.sgain = sgain
self.vgain = vgain
self.use_hsvgain = False if hgain is None else True
def apply(self, sample, context=None):
img = sample['image']
......@@ -2156,21 +2170,33 @@ class AugmentHSV(BaseOperator):
img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
else:
img_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
S = img_hsv[:, :, 1].astype(np.float32)
V = img_hsv[:, :, 2].astype(np.float32)
a = (random.random() * 2 - 1) * self.fraction + 1
S *= a
if a > 1:
np.clip(S, a_min=0, a_max=255, out=S)
if self.use_hsvgain:
hsv_augs = np.random.uniform(
-1, 1, 3) * [self.hgain, self.sgain, self.vgain]
# random selection of h, s, v
hsv_augs *= np.random.randint(0, 2, 3)
img_hsv[..., 0] = (img_hsv[..., 0] + hsv_augs[0]) % 180
img_hsv[..., 1] = np.clip(img_hsv[..., 1] + hsv_augs[1], 0, 255)
img_hsv[..., 2] = np.clip(img_hsv[..., 2] + hsv_augs[2], 0, 255)
a = (random.random() * 2 - 1) * self.fraction + 1
V *= a
if a > 1:
np.clip(V, a_min=0, a_max=255, out=V)
else:
S = img_hsv[:, :, 1].astype(np.float32)
V = img_hsv[:, :, 2].astype(np.float32)
a = (random.random() * 2 - 1) * self.fraction + 1
S *= a
if a > 1:
np.clip(S, a_min=0, a_max=255, out=S)
a = (random.random() * 2 - 1) * self.fraction + 1
V *= a
if a > 1:
np.clip(V, a_min=0, a_max=255, out=V)
img_hsv[:, :, 1] = S.astype(np.uint8)
img_hsv[:, :, 2] = V.astype(np.uint8)
img_hsv[:, :, 1] = S.astype(np.uint8)
img_hsv[:, :, 2] = V.astype(np.uint8)
if self.is_bgr:
cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img)
else:
......@@ -3018,3 +3044,373 @@ class CenterRandColor(BaseOperator):
img = func(img, img_gray)
sample['image'] = img
return sample
@register_op
class Mosaic(BaseOperator):
""" Mosaic operator for image and gt_bboxes
The code is based on https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/data/datasets/mosaicdetection.py
1. get mosaic coords
2. clip bbox and get mosaic_labels
3. random_affine augment
4. Mixup augment as copypaste (optinal), not used in tiny/nano
Args:
prob (float): probability of using Mosaic, 1.0 as default
input_dim (list[int]): input shape
degrees (list[2]): the rotate range to apply, transform range is [min, max]
translate (list[2]): the translate range to apply, transform range is [min, max]
scale (list[2]): the scale range to apply, transform range is [min, max]
shear (list[2]): the shear range to apply, transform range is [min, max]
enable_mixup (bool): whether to enable Mixup or not
mixup_prob (float): probability of using Mixup, 1.0 as default
mixup_scale (list[int]): scale range of Mixup
remove_outside_box (bool): whether remove outside boxes, False as
default in COCO dataset, True in MOT dataset
"""
def __init__(self,
prob=1.0,
input_dim=[640, 640],
degrees=[-10, 10],
translate=[-0.1, 0.1],
scale=[0.1, 2],
shear=[-2, 2],
enable_mixup=True,
mixup_prob=1.0,
mixup_scale=[0.5, 1.5],
remove_outside_box=False):
super(Mosaic, self).__init__()
self.prob = prob
self.input_dim = input_dim
self.degrees = degrees
self.translate = translate
self.scale = scale
self.shear = shear
self.enable_mixup = enable_mixup
self.mixup_prob = mixup_prob
self.mixup_scale = mixup_scale
self.remove_outside_box = remove_outside_box
def get_mosaic_coords(self, mosaic_idx, xc, yc, w, h, input_h, input_w):
# (x1, y1, x2, y2) means coords in large image,
# small_coords means coords in small image in mosaic aug.
if mosaic_idx == 0:
# top left
x1, y1, x2, y2 = max(xc - w, 0), max(yc - h, 0), xc, yc
small_coords = w - (x2 - x1), h - (y2 - y1), w, h
elif mosaic_idx == 1:
# top right
x1, y1, x2, y2 = xc, max(yc - h, 0), min(xc + w, input_w * 2), yc
small_coords = 0, h - (y2 - y1), min(w, x2 - x1), h
elif mosaic_idx == 2:
# bottom left
x1, y1, x2, y2 = max(xc - w, 0), yc, xc, min(input_h * 2, yc + h)
small_coords = w - (x2 - x1), 0, w, min(y2 - y1, h)
elif mosaic_idx == 3:
# bottom right
x1, y1, x2, y2 = xc, yc, min(xc + w, input_w * 2), min(input_h * 2,
yc + h)
small_coords = 0, 0, min(w, x2 - x1), min(y2 - y1, h)
return (x1, y1, x2, y2), small_coords
def random_affine_augment(self,
img,
labels=[],
input_dim=[640, 640],
degrees=[-10, 10],
scales=[0.1, 2],
shears=[-2, 2],
translates=[-0.1, 0.1]):
# random rotation and scale
degree = random.uniform(degrees[0], degrees[1])
scale = random.uniform(scales[0], scales[1])
assert scale > 0, "Argument scale should be positive."
R = cv2.getRotationMatrix2D(angle=degree, center=(0, 0), scale=scale)
M = np.ones([2, 3])
# random shear
shear = random.uniform(shears[0], shears[1])
shear_x = math.tan(shear * math.pi / 180)
shear_y = math.tan(shear * math.pi / 180)
M[0] = R[0] + shear_y * R[1]
M[1] = R[1] + shear_x * R[0]
# random translation
translate = random.uniform(translates[0], translates[1])
translation_x = translate * input_dim[0]
translation_y = translate * input_dim[1]
M[0, 2] = translation_x
M[1, 2] = translation_y
# warpAffine
img = cv2.warpAffine(
img, M, dsize=input_dim, borderValue=(114, 114, 114))
num_gts = len(labels)
if num_gts > 0:
# warp corner points
corner_points = np.ones((4 * num_gts, 3))
corner_points[:, :2] = labels[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(
4 * num_gts, 2) # x1y1, x2y2, x1y2, x2y1
# apply affine transform
corner_points = corner_points @M.T
corner_points = corner_points.reshape(num_gts, 8)
# create new boxes
corner_xs = corner_points[:, 0::2]
corner_ys = corner_points[:, 1::2]
new_bboxes = np.concatenate((corner_xs.min(1), corner_ys.min(1),
corner_xs.max(1), corner_ys.max(1)))
new_bboxes = new_bboxes.reshape(4, num_gts).T
# clip boxes
new_bboxes[:, 0::2] = np.clip(new_bboxes[:, 0::2], 0, input_dim[0])
new_bboxes[:, 1::2] = np.clip(new_bboxes[:, 1::2], 0, input_dim[1])
labels[:, :4] = new_bboxes
return img, labels
def __call__(self, sample, context=None):
if not isinstance(sample, Sequence):
return sample
assert len(
sample) == 5, "Mosaic needs 5 samples, 4 for mosaic and 1 for mixup."
if np.random.uniform(0., 1.) > self.prob:
return sample[0]
mosaic_gt_bbox, mosaic_gt_class, mosaic_is_crowd = [], [], []
input_h, input_w = self.input_dim
yc = int(random.uniform(0.5 * input_h, 1.5 * input_h))
xc = int(random.uniform(0.5 * input_w, 1.5 * input_w))
mosaic_img = np.full((input_h * 2, input_w * 2, 3), 114, dtype=np.uint8)
# 1. get mosaic coords
for mosaic_idx, sp in enumerate(sample[:4]):
img = sp['image']
gt_bbox = sp['gt_bbox']
h0, w0 = img.shape[:2]
scale = min(1. * input_h / h0, 1. * input_w / w0)
img = cv2.resize(
img, (int(w0 * scale), int(h0 * scale)),
interpolation=cv2.INTER_LINEAR)
(h, w, c) = img.shape[:3]
# suffix l means large image, while s means small image in mosaic aug.
(l_x1, l_y1, l_x2, l_y2), (
s_x1, s_y1, s_x2, s_y2) = self.get_mosaic_coords(
mosaic_idx, xc, yc, w, h, input_h, input_w)
mosaic_img[l_y1:l_y2, l_x1:l_x2] = img[s_y1:s_y2, s_x1:s_x2]
padw, padh = l_x1 - s_x1, l_y1 - s_y1
# Normalized xywh to pixel xyxy format
_gt_bbox = gt_bbox.copy()
if len(gt_bbox) > 0:
_gt_bbox[:, 0] = scale * gt_bbox[:, 0] + padw
_gt_bbox[:, 1] = scale * gt_bbox[:, 1] + padh
_gt_bbox[:, 2] = scale * gt_bbox[:, 2] + padw
_gt_bbox[:, 3] = scale * gt_bbox[:, 3] + padh
mosaic_gt_bbox.append(_gt_bbox)
mosaic_gt_class.append(sp['gt_class'])
mosaic_is_crowd.append(sp['is_crowd'])
# 2. clip bbox and get mosaic_labels([gt_bbox, gt_class, is_crowd])
if len(mosaic_gt_bbox):
mosaic_gt_bbox = np.concatenate(mosaic_gt_bbox, 0)
mosaic_gt_class = np.concatenate(mosaic_gt_class, 0)
mosaic_is_crowd = np.concatenate(mosaic_is_crowd, 0)
mosaic_labels = np.concatenate([
mosaic_gt_bbox, mosaic_gt_class.astype(mosaic_gt_bbox.dtype),
mosaic_is_crowd.astype(mosaic_gt_bbox.dtype)
], 1)
if self.remove_outside_box:
# for MOT dataset
flag1 = mosaic_gt_bbox[:, 0] < 2 * input_w
flag2 = mosaic_gt_bbox[:, 2] > 0
flag3 = mosaic_gt_bbox[:, 1] < 2 * input_h
flag4 = mosaic_gt_bbox[:, 3] > 0
flag_all = flag1 * flag2 * flag3 * flag4
mosaic_labels = mosaic_labels[flag_all]
else:
mosaic_labels[:, 0] = np.clip(mosaic_labels[:, 0], 0,
2 * input_w)
mosaic_labels[:, 1] = np.clip(mosaic_labels[:, 1], 0,
2 * input_h)
mosaic_labels[:, 2] = np.clip(mosaic_labels[:, 2], 0,
2 * input_w)
mosaic_labels[:, 3] = np.clip(mosaic_labels[:, 3], 0,
2 * input_h)
else:
mosaic_labels = np.zeros((1, 6))
# 3. random_affine augment
mosaic_img, mosaic_labels = self.random_affine_augment(
mosaic_img,
mosaic_labels,
input_dim=self.input_dim,
degrees=self.degrees,
translates=self.translate,
scales=self.scale,
shears=self.shear)
# 4. Mixup augment as copypaste, https://arxiv.org/abs/2012.07177
# optinal, not used(enable_mixup=False) in tiny/nano
if (self.enable_mixup and not len(mosaic_labels) == 0 and
random.random() < self.mixup_prob):
sample_mixup = sample[4]
mixup_img = sample_mixup['image']
cp_labels = np.concatenate([
sample_mixup['gt_bbox'],
sample_mixup['gt_class'].astype(mosaic_labels.dtype),
sample_mixup['is_crowd'].astype(mosaic_labels.dtype)
], 1)
mosaic_img, mosaic_labels = self.mixup_augment(
mosaic_img, mosaic_labels, self.input_dim, cp_labels, mixup_img)
sample0 = sample[0]
sample0['image'] = mosaic_img.astype(np.uint8) # can not be float32
sample0['h'] = float(mosaic_img.shape[0])
sample0['w'] = float(mosaic_img.shape[1])
sample0['im_shape'][0] = sample0['h']
sample0['im_shape'][1] = sample0['w']
sample0['gt_bbox'] = mosaic_labels[:, :4].astype(np.float32)
sample0['gt_class'] = mosaic_labels[:, 4:5].astype(np.float32)
sample0['is_crowd'] = mosaic_labels[:, 5:6].astype(np.float32)
return sample0
def mixup_augment(self, origin_img, origin_labels, input_dim, cp_labels,
img):
jit_factor = random.uniform(*self.mixup_scale)
FLIP = random.uniform(0, 1) > 0.5
if len(img.shape) == 3:
cp_img = np.ones(
(input_dim[0], input_dim[1], 3), dtype=np.uint8) * 114
else:
cp_img = np.ones(input_dim, dtype=np.uint8) * 114
cp_scale_ratio = min(input_dim[0] / img.shape[0],
input_dim[1] / img.shape[1])
resized_img = cv2.resize(
img, (int(img.shape[1] * cp_scale_ratio),
int(img.shape[0] * cp_scale_ratio)),
interpolation=cv2.INTER_LINEAR)
cp_img[:int(img.shape[0] * cp_scale_ratio), :int(img.shape[
1] * cp_scale_ratio)] = resized_img
cp_img = cv2.resize(cp_img, (int(cp_img.shape[1] * jit_factor),
int(cp_img.shape[0] * jit_factor)))
cp_scale_ratio *= jit_factor
if FLIP:
cp_img = cp_img[:, ::-1, :]
origin_h, origin_w = cp_img.shape[:2]
target_h, target_w = origin_img.shape[:2]
padded_img = np.zeros(
(max(origin_h, target_h), max(origin_w, target_w), 3),
dtype=np.uint8)
padded_img[:origin_h, :origin_w] = cp_img
x_offset, y_offset = 0, 0
if padded_img.shape[0] > target_h:
y_offset = random.randint(0, padded_img.shape[0] - target_h - 1)
if padded_img.shape[1] > target_w:
x_offset = random.randint(0, padded_img.shape[1] - target_w - 1)
padded_cropped_img = padded_img[y_offset:y_offset + target_h, x_offset:
x_offset + target_w]
# adjust boxes
cp_bboxes_origin_np = cp_labels[:, :4].copy()
cp_bboxes_origin_np[:, 0::2] = np.clip(cp_bboxes_origin_np[:, 0::2] *
cp_scale_ratio, 0, origin_w)
cp_bboxes_origin_np[:, 1::2] = np.clip(cp_bboxes_origin_np[:, 1::2] *
cp_scale_ratio, 0, origin_h)
if FLIP:
cp_bboxes_origin_np[:, 0::2] = (
origin_w - cp_bboxes_origin_np[:, 0::2][:, ::-1])
cp_bboxes_transformed_np = cp_bboxes_origin_np.copy()
if self.remove_outside_box:
# for MOT dataset
cp_bboxes_transformed_np[:, 0::2] -= x_offset
cp_bboxes_transformed_np[:, 1::2] -= y_offset
else:
cp_bboxes_transformed_np[:, 0::2] = np.clip(
cp_bboxes_transformed_np[:, 0::2] - x_offset, 0, target_w)
cp_bboxes_transformed_np[:, 1::2] = np.clip(
cp_bboxes_transformed_np[:, 1::2] - y_offset, 0, target_h)
cls_labels = cp_labels[:, 4:5].copy()
crd_labels = cp_labels[:, 5:6].copy()
box_labels = cp_bboxes_transformed_np
labels = np.hstack((box_labels, cls_labels, crd_labels))
if self.remove_outside_box:
labels = labels[labels[:, 0] < target_w]
labels = labels[labels[:, 2] > 0]
labels = labels[labels[:, 1] < target_h]
labels = labels[labels[:, 3] > 0]
origin_labels = np.vstack((origin_labels, labels))
origin_img = origin_img.astype(np.float32)
origin_img = 0.5 * origin_img + 0.5 * padded_cropped_img.astype(
np.float32)
return origin_img.astype(np.uint8), origin_labels
@register_op
class PadResize(BaseOperator):
""" PadResize for image and gt_bbbox
Args:
target_size (list[int]): input shape
fill_value (float): pixel value of padded image
"""
def __init__(self, target_size, fill_value=114):
super(PadResize, self).__init__()
if isinstance(target_size, Integral):
target_size = [target_size, target_size]
self.target_size = target_size
self.fill_value = fill_value
def _resize(self, img, bboxes, labels):
ratio = min(self.target_size[0] / img.shape[0],
self.target_size[1] / img.shape[1])
w, h = int(img.shape[1] * ratio), int(img.shape[0] * ratio)
resized_img = cv2.resize(img, (w, h), interpolation=cv2.INTER_LINEAR)
if len(bboxes) > 0:
bboxes *= ratio
mask = np.minimum(bboxes[:, 2] - bboxes[:, 0],
bboxes[:, 3] - bboxes[:, 1]) > 1
bboxes = bboxes[mask]
labels = labels[mask]
return resized_img, bboxes, labels
def _pad(self, img):
h, w, _ = img.shape
if h == self.target_size[0] and w == self.target_size[1]:
return img
padded_img = np.full(
(self.target_size[0], self.target_size[1], 3),
self.fill_value,
dtype=np.uint8)
padded_img[:h, :w] = img
return padded_img
def apply(self, sample, context=None):
image = sample['image']
bboxes = sample['gt_bbox']
labels = sample['gt_class']
image, bboxes, labels = self._resize(image, bboxes, labels)
sample['image'] = self._pad(image).astype(np.float32)
sample['gt_bbox'] = bboxes
sample['gt_class'] = labels
return sample
......@@ -48,6 +48,7 @@ TRT_MIN_SUBGRAPH = {
'PicoDet': 3,
'CenterNet': 5,
'TOOD': 5,
'YOLOX': 8,
}
KEYPOINT_ARCH = ['HigherHRNet', 'TopDownHRNet']
......@@ -147,6 +148,12 @@ def _dump_infer_config(config, path, image_shape, model):
infer_cfg['min_subgraph_size'] = min_subgraph_size
arch_state = True
break
if infer_arch == 'YOLOX':
infer_cfg['arch'] = infer_arch
infer_cfg['min_subgraph_size'] = TRT_MIN_SUBGRAPH[infer_arch]
arch_state = True
if not arch_state:
logger.error(
'Architecture: {} is not supported for exporting model now.\n'.
......
......@@ -28,6 +28,7 @@ from PIL import Image, ImageOps, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import paddle
import paddle.nn as nn
import paddle.distributed as dist
from paddle.distributed import fleet
from paddle import amp
......@@ -99,6 +100,12 @@ class Trainer(object):
self.model = self.cfg.model
self.is_loaded_weights = True
if cfg.architecture == 'YOLOX':
for k, m in self.model.named_sublayers():
if isinstance(m, nn.BatchNorm2D):
m.epsilon = 1e-3 # for amp(fp16)
m.momentum = 0.97 # 0.03 in pytorch
#normalize params for deploy
if 'slim' in cfg and cfg['slim_type'] == 'OFA':
self.model.model.load_meanstd(cfg['TestReader'][
......@@ -117,10 +124,11 @@ class Trainer(object):
if self.use_ema:
ema_decay = self.cfg.get('ema_decay', 0.9998)
cycle_epoch = self.cfg.get('cycle_epoch', -1)
ema_decay_type = self.cfg.get('ema_decay_type', 'threshold')
self.ema = ModelEMA(
self.model,
decay=ema_decay,
use_thres_step=True,
ema_decay_type=ema_decay_type,
cycle_epoch=cycle_epoch)
# EvalDataset build with BatchSampler to evaluate in single device
......
......@@ -5,6 +5,13 @@
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import meta_arch
from . import faster_rcnn
from . import mask_rcnn
......@@ -28,6 +35,7 @@ from . import sparse_rcnn
from . import tood
from . import retinanet
from . import bytetrack
from . import yolox
from .meta_arch import *
from .faster_rcnn import *
......@@ -53,3 +61,4 @@ from .sparse_rcnn import *
from .tood import *
from .retinanet import *
from .bytetrack import *
from .yolox import *
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ppdet.core.workspace import register, create
from .meta_arch import BaseArch
import random
import paddle
import paddle.nn.functional as F
import paddle.distributed as dist
from ppdet.modeling.ops import paddle_distributed_is_initialized
__all__ = ['YOLOX']
@register
class YOLOX(BaseArch):
"""
YOLOX network, see https://arxiv.org/abs/2107.08430
Args:
backbone (nn.Layer): backbone instance
neck (nn.Layer): neck instance
head (nn.Layer): head instance
for_mot (bool): whether used for MOT or not
input_size (list[int]): initial scale, will be reset by self._preprocess()
size_stride (int): stride of the size range
size_range (list[int]): multi-scale range for training
random_interval (int): interval of iter to change self._input_size
"""
__category__ = 'architecture'
def __init__(self,
backbone='CSPDarkNet',
neck='YOLOCSPPAN',
head='YOLOXHead',
for_mot=False,
input_size=[640, 640],
size_stride=32,
size_range=[15, 25],
random_interval=10):
super(YOLOX, self).__init__()
self.backbone = backbone
self.neck = neck
self.head = head
self.for_mot = for_mot
self.input_size = input_size
self._input_size = paddle.to_tensor(input_size)
self.size_stride = size_stride
self.size_range = size_range
self.random_interval = random_interval
self._step = 0
@classmethod
def from_config(cls, cfg, *args, **kwargs):
# backbone
backbone = create(cfg['backbone'])
# fpn
kwargs = {'input_shape': backbone.out_shape}
neck = create(cfg['neck'], **kwargs)
# head
kwargs = {'input_shape': neck.out_shape}
head = create(cfg['head'], **kwargs)
return {
'backbone': backbone,
'neck': neck,
"head": head,
}
def _forward(self):
if self.training:
self._preprocess()
body_feats = self.backbone(self.inputs)
neck_feats = self.neck(body_feats, self.for_mot)
if self.training:
yolox_losses = self.head(neck_feats, self.inputs)
yolox_losses.update({'size': self._input_size[0]})
return yolox_losses
else:
head_outs = self.head(neck_feats)
bbox, bbox_num = self.head.post_process(
head_outs, self.inputs['im_shape'], self.inputs['scale_factor'])
return {'bbox': bbox, 'bbox_num': bbox_num}
def get_loss(self):
return self._forward()
def get_pred(self):
return self._forward()
def _preprocess(self):
# YOLOX multi-scale training, interpolate resize before inputs of the network.
self._get_size()
scale_y = self._input_size[0] / self.input_size[0]
scale_x = self._input_size[1] / self.input_size[1]
if scale_x != 1 or scale_y != 1:
self.inputs['image'] = F.interpolate(
self.inputs['image'],
size=self._input_size,
mode='bilinear',
align_corners=False)
gt_bboxes = self.inputs['gt_bbox']
for i in range(len(gt_bboxes)):
if len(gt_bboxes[i]) > 0:
gt_bboxes[i][:, 0::2] = gt_bboxes[i][:, 0::2] * scale_x
gt_bboxes[i][:, 1::2] = gt_bboxes[i][:, 1::2] * scale_y
self.inputs['gt_bbox'] = gt_bboxes
def _get_size(self):
# random_interval = 10 as default, every 10 iters to change self._input_size
image_ratio = self.input_size[1] * 1.0 / self.input_size[0]
if self._step % self.random_interval == 0:
size_factor = random.randint(*self.size_range)
size = [
self.size_stride * size_factor,
self.size_stride * int(size_factor * image_ratio)
]
size = paddle.to_tensor(size)
if dist.get_world_size() > 1 and paddle_distributed_is_initialized(
):
dist.barrier()
dist.broadcast(size, 0)
self._input_size = size
self._step += 1
......@@ -115,7 +115,7 @@ def check_points_inside_bboxes(points,
Args:
points (Tensor, float32): shape[L, 2], "xy" format, L: num_anchors
bboxes (Tensor, float32): shape[B, n, 4], "xmin, ymin, xmax, ymax" format
center_radius_tensor (Tensor, float32): shape [L, 1] Default: None.
center_radius_tensor (Tensor, float32): shape [L, 1]. Default: None.
eps (float): Default: 1e-9
Returns:
is_in_bboxes (Tensor, float32): shape[B, n, L], value=1. means selected
......@@ -123,25 +123,28 @@ def check_points_inside_bboxes(points,
points = points.unsqueeze([0, 1])
x, y = points.chunk(2, axis=-1)
xmin, ymin, xmax, ymax = bboxes.unsqueeze(2).chunk(4, axis=-1)
if center_radius_tensor is not None:
center_radius_tensor = center_radius_tensor.unsqueeze([0, 1])
bboxes_cx = (xmin + xmax) / 2.
bboxes_cy = (ymin + ymax) / 2.
xmin_sampling = bboxes_cx - center_radius_tensor
ymin_sampling = bboxes_cy - center_radius_tensor
xmax_sampling = bboxes_cx + center_radius_tensor
ymax_sampling = bboxes_cy + center_radius_tensor
xmin = paddle.maximum(xmin, xmin_sampling)
ymin = paddle.maximum(ymin, ymin_sampling)
xmax = paddle.minimum(xmax, xmax_sampling)
ymax = paddle.minimum(ymax, ymax_sampling)
# check whether `points` is in `bboxes`
l = x - xmin
t = y - ymin
r = xmax - x
b = ymax - y
bbox_ltrb = paddle.concat([l, t, r, b], axis=-1)
return (bbox_ltrb.min(axis=-1) > eps).astype(bboxes.dtype)
delta_ltrb = paddle.concat([l, t, r, b], axis=-1)
is_in_bboxes = (delta_ltrb.min(axis=-1) > eps)
if center_radius_tensor is not None:
# check whether `points` is in `center_radius`
center_radius_tensor = center_radius_tensor.unsqueeze([0, 1])
cx = (xmin + xmax) * 0.5
cy = (ymin + ymax) * 0.5
l = x - (cx - center_radius_tensor)
t = y - (cy - center_radius_tensor)
r = (cx + center_radius_tensor) - x
b = (cy + center_radius_tensor) - y
delta_ltrb_c = paddle.concat([l, t, r, b], axis=-1)
is_in_center = (delta_ltrb_c.min(axis=-1) > eps)
return (paddle.logical_and(is_in_bboxes, is_in_center),
paddle.logical_or(is_in_bboxes, is_in_center))
return is_in_bboxes.astype(bboxes.dtype)
def compute_max_iou_anchor(ious):
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import vgg
......@@ -30,6 +30,7 @@ from . import lcnet
from . import hardnet
from . import esnet
from . import cspresnet
from . import csp_darknet
from .vgg import *
from .resnet import *
......@@ -49,3 +50,4 @@ from .lcnet import *
from .hardnet import *
from .esnet import *
from .cspresnet import *
from .csp_darknet import *
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.regularizer import L2Decay
from ppdet.core.workspace import register, serializable
from ppdet.modeling.ops import get_activation
from ppdet.modeling.initializer import conv_init_
from ..shape_spec import ShapeSpec
__all__ = [
'CSPDarkNet', 'BaseConv', 'DWConv', 'BottleNeck', 'SPPLayer', 'SPPFLayer'
]
class BaseConv(nn.Layer):
def __init__(self,
in_channels,
out_channels,
ksize,
stride,
groups=1,
bias=False,
act="silu"):
super(BaseConv, self).__init__()
self.conv = nn.Conv2D(
in_channels,
out_channels,
kernel_size=ksize,
stride=stride,
padding=(ksize - 1) // 2,
groups=groups,
bias_attr=bias)
self.bn = nn.BatchNorm2D(
out_channels,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
self.act = get_activation(act)
self._init_weights()
def _init_weights(self):
conv_init_(self.conv)
def forward(self, x):
return self.act(self.bn(self.conv(x)))
class DWConv(nn.Layer):
"""Depthwise Conv"""
def __init__(self,
in_channels,
out_channels,
ksize,
stride=1,
bias=False,
act="silu"):
super(DWConv, self).__init__()
self.dw_conv = BaseConv(
in_channels,
in_channels,
ksize=ksize,
stride=stride,
groups=in_channels,
bias=bias,
act=act, )
self.pw_conv = BaseConv(
in_channels,
out_channels,
ksize=1,
stride=1,
groups=1,
bias=bias,
act=act)
def forward(self, x):
return self.pw_conv(self.dw_conv(x))
class Focus(nn.Layer):
"""Focus width and height information into channel space, used in YOLOX."""
def __init__(self,
in_channels,
out_channels,
ksize=3,
stride=1,
bias=False,
act="silu"):
super(Focus, self).__init__()
self.conv = BaseConv(
in_channels * 4,
out_channels,
ksize=ksize,
stride=stride,
bias=bias,
act=act)
def forward(self, inputs):
# inputs [bs, C, H, W] -> outputs [bs, 4C, W/2, H/2]
top_left = inputs[:, :, 0::2, 0::2]
top_right = inputs[:, :, 0::2, 1::2]
bottom_left = inputs[:, :, 1::2, 0::2]
bottom_right = inputs[:, :, 1::2, 1::2]
outputs = paddle.concat(
[top_left, bottom_left, top_right, bottom_right], 1)
return self.conv(outputs)
class BottleNeck(nn.Layer):
def __init__(self,
in_channels,
out_channels,
shortcut=True,
expansion=0.5,
depthwise=False,
bias=False,
act="silu"):
super(BottleNeck, self).__init__()
hidden_channels = int(out_channels * expansion)
Conv = DWConv if depthwise else BaseConv
self.conv1 = BaseConv(
in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act)
self.conv2 = Conv(
hidden_channels,
out_channels,
ksize=3,
stride=1,
bias=bias,
act=act)
self.add_shortcut = shortcut and in_channels == out_channels
def forward(self, x):
y = self.conv2(self.conv1(x))
if self.add_shortcut:
y = y + x
return y
class SPPLayer(nn.Layer):
"""Spatial Pyramid Pooling (SPP) layer used in YOLOv3-SPP and YOLOX"""
def __init__(self,
in_channels,
out_channels,
kernel_sizes=(5, 9, 13),
bias=False,
act="silu"):
super(SPPLayer, self).__init__()
hidden_channels = in_channels // 2
self.conv1 = BaseConv(
in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act)
self.maxpoolings = nn.LayerList([
nn.MaxPool2D(
kernel_size=ks, stride=1, padding=ks // 2)
for ks in kernel_sizes
])
conv2_channels = hidden_channels * (len(kernel_sizes) + 1)
self.conv2 = BaseConv(
conv2_channels, out_channels, ksize=1, stride=1, bias=bias, act=act)
def forward(self, x):
x = self.conv1(x)
x = paddle.concat([x] + [mp(x) for mp in self.maxpoolings], axis=1)
x = self.conv2(x)
return x
class SPPFLayer(nn.Layer):
""" Spatial Pyramid Pooling - Fast (SPPF) layer used in YOLOv5 by Glenn Jocher,
equivalent to SPP(k=(5, 9, 13))
"""
def __init__(self,
in_channels,
out_channels,
ksize=5,
bias=False,
act='silu'):
super(SPPFLayer, self).__init__()
hidden_channels = in_channels // 2
self.conv1 = BaseConv(
in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act)
self.maxpooling = nn.MaxPool2D(
kernel_size=ksize, stride=1, padding=ksize // 2)
conv2_channels = hidden_channels * 4
self.conv2 = BaseConv(
conv2_channels, out_channels, ksize=1, stride=1, bias=bias, act=act)
def forward(self, x):
x = self.conv1(x)
y1 = self.maxpooling(x)
y2 = self.maxpooling(y1)
y3 = self.maxpooling(y2)
concats = paddle.concat([x, y1, y2, y3], axis=1)
out = self.conv2(concats)
return out
class CSPLayer(nn.Layer):
"""CSP (Cross Stage Partial) layer with 3 convs, named C3 in YOLOv5"""
def __init__(self,
in_channels,
out_channels,
num_blocks=1,
shortcut=True,
expansion=0.5,
depthwise=False,
bias=False,
act="silu"):
super(CSPLayer, self).__init__()
hidden_channels = int(out_channels * expansion)
self.conv1 = BaseConv(
in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act)
self.conv2 = BaseConv(
in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act)
self.bottlenecks = nn.Sequential(* [
BottleNeck(
hidden_channels,
hidden_channels,
shortcut=shortcut,
expansion=1.0,
depthwise=depthwise,
bias=bias,
act=act) for _ in range(num_blocks)
])
self.conv3 = BaseConv(
hidden_channels * 2,
out_channels,
ksize=1,
stride=1,
bias=bias,
act=act)
def forward(self, x):
x_1 = self.conv1(x)
x_1 = self.bottlenecks(x_1)
x_2 = self.conv2(x)
x = paddle.concat([x_1, x_2], axis=1)
x = self.conv3(x)
return x
@register
@serializable
class CSPDarkNet(nn.Layer):
"""
CSPDarkNet backbone.
Args:
arch (str): Architecture of CSPDarkNet, from {P5, P6, X}, default as X,
and 'X' means used in YOLOX, 'P5/P6' means used in YOLOv5.
depth_mult (float): Depth multiplier, multiply number of channels in
each layer, default as 1.0.
width_mult (float): Width multiplier, multiply number of blocks in
CSPLayer, default as 1.0.
depthwise (bool): Whether to use depth-wise conv layer.
act (str): Activation function type, default as 'silu'.
return_idx (list): Index of stages whose feature maps are returned.
"""
__shared__ = ['depth_mult', 'width_mult', 'act']
# in_channels, out_channels, num_blocks, add_shortcut, use_spp(use_sppf)
# 'X' means setting used in YOLOX, 'P5/P6' means setting used in YOLOv5.
arch_settings = {
'X': [[64, 128, 3, True, False], [128, 256, 9, True, False],
[256, 512, 9, True, False], [512, 1024, 3, False, True]],
'P5': [[64, 128, 3, True, False], [128, 256, 6, True, False],
[256, 512, 9, True, False], [512, 1024, 3, True, True]],
'P6': [[64, 128, 3, True, False], [128, 256, 6, True, False],
[256, 512, 9, True, False], [512, 768, 3, True, False],
[768, 1024, 3, True, True]],
}
def __init__(self,
arch='X',
depth_mult=1.0,
width_mult=1.0,
depthwise=False,
act='silu',
return_idx=[2, 3, 4]):
super(CSPDarkNet, self).__init__()
self.arch = arch
self.return_idx = return_idx
Conv = DWConv if depthwise else BaseConv
arch_setting = self.arch_settings[arch]
base_channels = int(arch_setting[0][0] * width_mult)
# Note: differences between the latest YOLOv5 and the original YOLOX
# 1. self.stem, use SPPF(in YOLOv5) or SPP(in YOLOX)
# 2. use SPPF(in YOLOv5) or SPP(in YOLOX)
# 3. put SPPF before(YOLOv5) or SPP after(YOLOX) the last cspdark block's CSPLayer
# 4. whether SPPF(SPP)'CSPLayer add shortcut, True in YOLOv5, False in YOLOX
if arch in ['P5', 'P6']:
# in the latest YOLOv5, use Conv stem, and SPPF (fast, only single spp kernal size)
self.stem = Conv(
3, base_channels, ksize=6, stride=2, bias=False, act=act)
spp_kernal_sizes = 5
elif arch in ['X']:
# in the original YOLOX, use Focus stem, and SPP (three spp kernal sizes)
self.stem = Focus(
3, base_channels, ksize=3, stride=1, bias=False, act=act)
spp_kernal_sizes = (5, 9, 13)
else:
raise AttributeError("Unsupported arch type: {}".format(arch))
_out_channels = [base_channels]
layers_num = 1
self.csp_dark_blocks = []
for i, (in_channels, out_channels, num_blocks, shortcut,
use_spp) in enumerate(arch_setting):
in_channels = int(in_channels * width_mult)
out_channels = int(out_channels * width_mult)
_out_channels.append(out_channels)
num_blocks = max(round(num_blocks * depth_mult), 1)
stage = []
conv_layer = self.add_sublayer(
'layers{}.stage{}.conv_layer'.format(layers_num, i + 1),
Conv(
in_channels, out_channels, 3, 2, bias=False, act=act))
stage.append(conv_layer)
layers_num += 1
if use_spp and arch in ['X']:
# in YOLOX use SPPLayer
spp_layer = self.add_sublayer(
'layers{}.stage{}.spp_layer'.format(layers_num, i + 1),
SPPLayer(
out_channels,
out_channels,
kernel_sizes=spp_kernal_sizes,
bias=False,
act=act))
stage.append(spp_layer)
layers_num += 1
csp_layer = self.add_sublayer(
'layers{}.stage{}.csp_layer'.format(layers_num, i + 1),
CSPLayer(
out_channels,
out_channels,
num_blocks=num_blocks,
shortcut=shortcut,
depthwise=depthwise,
bias=False,
act=act))
stage.append(csp_layer)
layers_num += 1
if use_spp and arch in ['P5', 'P6']:
# in latest YOLOv5 use SPPFLayer instead of SPPLayer
sppf_layer = self.add_sublayer(
'layers{}.stage{}.sppf_layer'.format(layers_num, i + 1),
SPPFLayer(
out_channels,
out_channels,
ksize=5,
bias=False,
act=act))
stage.append(sppf_layer)
layers_num += 1
self.csp_dark_blocks.append(nn.Sequential(*stage))
self._out_channels = [_out_channels[i] for i in self.return_idx]
self.strides = [[2, 4, 8, 16, 32, 64][i] for i in self.return_idx]
def forward(self, inputs):
x = inputs['image']
outputs = []
x = self.stem(x)
for i, layer in enumerate(self.csp_dark_blocks):
x = layer(x)
if i + 1 in self.return_idx:
outputs.append(x)
return outputs
@property
def out_shape(self):
return [
ShapeSpec(
channels=c, stride=s)
for c, s in zip(self._out_channels, self.strides)
]
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
......@@ -5,6 +19,16 @@ from paddle import ParamAttr
from paddle.regularizer import L2Decay
from ppdet.core.workspace import register
import math
import numpy as np
from ..initializer import bias_init_with_prob, constant_
from ..backbones.csp_darknet import BaseConv, DWConv
from ..losses import IouLoss
from ppdet.modeling.assigners.simota_assigner import SimOTAAssigner
from ppdet.modeling.bbox_utils import bbox_overlaps
__all__ = ['YOLOv3Head', 'YOLOXHead']
def _de_sigmoid(x, eps=1e-7):
x = paddle.clip(x, eps, 1. / eps)
......@@ -122,3 +146,259 @@ class YOLOv3Head(nn.Layer):
@classmethod
def from_config(cls, cfg, input_shape):
return {'in_channels': [i.channels for i in input_shape], }
@register
class YOLOXHead(nn.Layer):
__shared__ = ['num_classes', 'width_mult', 'act']
__inject__ = ['assigner', 'nms']
def __init__(self,
num_classes=80,
width_mult=1.0,
depthwise=False,
in_channels=[256, 512, 1024],
feat_channels=256,
fpn_strides=(8, 16, 32),
l1_epoch=285,
act='silu',
assigner=SimOTAAssigner(use_vfl=False),
nms='MultiClassNMS',
loss_weight={'cls': 1.0,
'obj': 1.0,
'iou': 5.0,
'l1': 1.0}):
super(YOLOXHead, self).__init__()
self._dtype = paddle.framework.get_default_dtype()
self.num_classes = num_classes
assert len(in_channels) > 0, "in_channels length should > 0"
self.in_channels = in_channels
feat_channels = int(feat_channels * width_mult)
self.fpn_strides = fpn_strides
self.l1_epoch = l1_epoch
self.assigner = assigner
self.nms = nms
self.loss_weight = loss_weight
self.iou_loss = IouLoss(loss_weight=1.0) # default loss_weight 2.5
ConvBlock = DWConv if depthwise else BaseConv
self.stem_conv = nn.LayerList()
self.conv_cls = nn.LayerList()
self.conv_reg = nn.LayerList() # reg [x,y,w,h] + obj
for in_c in self.in_channels:
self.stem_conv.append(BaseConv(in_c, feat_channels, 1, 1, act=act))
self.conv_cls.append(
nn.Sequential(* [
ConvBlock(
feat_channels, feat_channels, 3, 1, act=act), ConvBlock(
feat_channels, feat_channels, 3, 1, act=act),
nn.Conv2D(
feat_channels,
self.num_classes,
1,
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
]))
self.conv_reg.append(
nn.Sequential(* [
ConvBlock(
feat_channels, feat_channels, 3, 1, act=act),
ConvBlock(
feat_channels, feat_channels, 3, 1, act=act),
nn.Conv2D(
feat_channels,
4 + 1, # reg [x,y,w,h] + obj
1,
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
]))
self._init_weights()
@classmethod
def from_config(cls, cfg, input_shape):
return {'in_channels': [i.channels for i in input_shape], }
def _init_weights(self):
bias_cls = bias_init_with_prob(0.01)
bias_reg = paddle.full([5], math.log(5.), dtype=self._dtype)
bias_reg[:2] = 0.
bias_reg[-1] = bias_cls
for cls_, reg_ in zip(self.conv_cls, self.conv_reg):
constant_(cls_[-1].weight)
constant_(cls_[-1].bias, bias_cls)
constant_(reg_[-1].weight)
reg_[-1].bias.set_value(bias_reg)
def _generate_anchor_point(self, feat_sizes, strides, offset=0.):
anchor_points, stride_tensor = [], []
num_anchors_list = []
for feat_size, stride in zip(feat_sizes, strides):
h, w = feat_size
x = (paddle.arange(w) + offset) * stride
y = (paddle.arange(h) + offset) * stride
y, x = paddle.meshgrid(y, x)
anchor_points.append(paddle.stack([x, y], axis=-1).reshape([-1, 2]))
stride_tensor.append(
paddle.full(
[len(anchor_points[-1]), 1], stride, dtype=self._dtype))
num_anchors_list.append(len(anchor_points[-1]))
anchor_points = paddle.concat(anchor_points).astype(self._dtype)
anchor_points.stop_gradient = True
stride_tensor = paddle.concat(stride_tensor)
stride_tensor.stop_gradient = True
return anchor_points, stride_tensor, num_anchors_list
def forward(self, feats, targets=None):
assert len(feats) == len(self.fpn_strides), \
"The size of feats is not equal to size of fpn_strides"
feat_sizes = [[f.shape[-2], f.shape[-1]] for f in feats]
cls_score_list, reg_pred_list = [], []
obj_score_list = []
for i, feat in enumerate(feats):
feat = self.stem_conv[i](feat)
cls_logit = self.conv_cls[i](feat)
reg_pred = self.conv_reg[i](feat)
# cls prediction
cls_score = F.sigmoid(cls_logit)
cls_score_list.append(cls_score.flatten(2).transpose([0, 2, 1]))
# reg prediction
reg_xywh, obj_logit = paddle.split(reg_pred, [4, 1], axis=1)
reg_xywh = reg_xywh.flatten(2).transpose([0, 2, 1])
reg_pred_list.append(reg_xywh)
# obj prediction
obj_score = F.sigmoid(obj_logit)
obj_score_list.append(obj_score.flatten(2).transpose([0, 2, 1]))
cls_score_list = paddle.concat(cls_score_list, axis=1)
reg_pred_list = paddle.concat(reg_pred_list, axis=1)
obj_score_list = paddle.concat(obj_score_list, axis=1)
# bbox decode
anchor_points, stride_tensor, _ =\
self._generate_anchor_point(feat_sizes, self.fpn_strides)
reg_xy, reg_wh = paddle.split(reg_pred_list, 2, axis=-1)
reg_xy += (anchor_points / stride_tensor)
reg_wh = paddle.exp(reg_wh) * 0.5
bbox_pred_list = paddle.concat(
[reg_xy - reg_wh, reg_xy + reg_wh], axis=-1)
if self.training:
anchor_points, stride_tensor, num_anchors_list =\
self._generate_anchor_point(feat_sizes, self.fpn_strides, 0.5)
yolox_losses = self.get_loss([
cls_score_list, bbox_pred_list, obj_score_list, anchor_points,
stride_tensor, num_anchors_list
], targets)
return yolox_losses
else:
pred_scores = (cls_score_list * obj_score_list).sqrt()
return pred_scores, bbox_pred_list, stride_tensor
def get_loss(self, head_outs, targets):
pred_cls, pred_bboxes, pred_obj,\
anchor_points, stride_tensor, num_anchors_list = head_outs
gt_labels = targets['gt_class']
gt_bboxes = targets['gt_bbox']
pred_scores = (pred_cls * pred_obj).sqrt()
# label assignment
center_and_strides = paddle.concat(
[anchor_points, stride_tensor, stride_tensor], axis=-1)
pos_num_list, label_list, bbox_target_list = [], [], []
for pred_score, pred_bbox, gt_box, gt_label in zip(
pred_scores.detach(),
pred_bboxes.detach() * stride_tensor, gt_bboxes, gt_labels):
pos_num, label, _, bbox_target = self.assigner(
pred_score, center_and_strides, pred_bbox, gt_box, gt_label)
pos_num_list.append(pos_num)
label_list.append(label)
bbox_target_list.append(bbox_target)
labels = paddle.to_tensor(np.stack(label_list, axis=0))
bbox_targets = paddle.to_tensor(np.stack(bbox_target_list, axis=0))
bbox_targets /= stride_tensor # rescale bbox
# 1. obj score loss
mask_positive = (labels != self.num_classes)
loss_obj = F.binary_cross_entropy(
pred_obj,
mask_positive.astype(pred_obj.dtype).unsqueeze(-1),
reduction='sum')
num_pos = sum(pos_num_list)
if num_pos > 0:
num_pos = paddle.to_tensor(num_pos, dtype=self._dtype).clip(min=1)
loss_obj /= num_pos
# 2. iou loss
bbox_mask = mask_positive.unsqueeze(-1).tile([1, 1, 4])
pred_bboxes_pos = paddle.masked_select(pred_bboxes,
bbox_mask).reshape([-1, 4])
assigned_bboxes_pos = paddle.masked_select(
bbox_targets, bbox_mask).reshape([-1, 4])
bbox_iou = bbox_overlaps(pred_bboxes_pos, assigned_bboxes_pos)
bbox_iou = paddle.diag(bbox_iou)
loss_iou = self.iou_loss(
pred_bboxes_pos.split(
4, axis=-1),
assigned_bboxes_pos.split(
4, axis=-1))
loss_iou = loss_iou.sum() / num_pos
# 3. cls loss
cls_mask = mask_positive.unsqueeze(-1).tile(
[1, 1, self.num_classes])
pred_cls_pos = paddle.masked_select(
pred_cls, cls_mask).reshape([-1, self.num_classes])
assigned_cls_pos = paddle.masked_select(labels, mask_positive)
assigned_cls_pos = F.one_hot(assigned_cls_pos,
self.num_classes + 1)[..., :-1]
assigned_cls_pos *= bbox_iou.unsqueeze(-1)
loss_cls = F.binary_cross_entropy(
pred_cls_pos, assigned_cls_pos, reduction='sum')
loss_cls /= num_pos
# 4. l1 loss
if targets['epoch_id'] >= self.l1_epoch:
loss_l1 = F.l1_loss(
pred_bboxes_pos, assigned_bboxes_pos, reduction='sum')
loss_l1 /= num_pos
else:
loss_l1 = paddle.zeros([1])
loss_l1.stop_gradient = False
else:
loss_cls = paddle.zeros([1])
loss_iou = paddle.zeros([1])
loss_l1 = paddle.zeros([1])
loss_cls.stop_gradient = False
loss_iou.stop_gradient = False
loss_l1.stop_gradient = False
loss = self.loss_weight['obj'] * loss_obj + \
self.loss_weight['cls'] * loss_cls + \
self.loss_weight['iou'] * loss_iou
if targets['epoch_id'] >= self.l1_epoch:
loss += (self.loss_weight['l1'] * loss_l1)
yolox_losses = {
'loss': loss,
'loss_cls': loss_cls,
'loss_obj': loss_obj,
'loss_iou': loss_iou,
'loss_l1': loss_l1,
}
return yolox_losses
def post_process(self, head_outs, img_shape, scale_factor):
pred_scores, pred_bboxes, stride_tensor = head_outs
pred_scores = pred_scores.transpose([0, 2, 1])
pred_bboxes *= stride_tensor
# scale bbox to origin image
scale_factor = scale_factor.flip(-1).tile([1, 2]).unsqueeze(1)
pred_bboxes /= scale_factor
bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
return bbox_pred, bbox_num
......@@ -273,7 +273,8 @@ def linear_init_(module):
def conv_init_(module):
bound = 1 / np.sqrt(np.prod(module.weight.shape[1:]))
uniform_(module.weight, -bound, bound)
uniform_(module.bias, -bound, bound)
if module.bias is not None:
uniform_(module.bias, -bound, bound)
def bias_init_with_prob(prior_prob=0.01):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
......@@ -19,8 +19,9 @@ from ppdet.core.workspace import register, serializable
from ppdet.modeling.layers import DropBlock
from ..backbones.darknet import ConvBNLayer
from ..shape_spec import ShapeSpec
from ..backbones.csp_darknet import BaseConv, DWConv, CSPLayer
__all__ = ['YOLOv3FPN', 'PPYOLOFPN', 'PPYOLOTinyFPN', 'PPYOLOPAN']
__all__ = ['YOLOv3FPN', 'PPYOLOFPN', 'PPYOLOTinyFPN', 'PPYOLOPAN', 'YOLOCSPPAN']
def add_coord(x, data_format):
......@@ -986,3 +987,102 @@ class PPYOLOPAN(nn.Layer):
@property
def out_shape(self):
return [ShapeSpec(channels=c) for c in self._out_channels]
@register
@serializable
class YOLOCSPPAN(nn.Layer):
"""
YOLO CSP-PAN, used in YOLOv5 and YOLOX.
"""
__shared__ = ['depth_mult', 'act']
def __init__(self,
depth_mult=1.0,
in_channels=[256, 512, 1024],
depthwise=False,
act='silu'):
super(YOLOCSPPAN, self).__init__()
self.in_channels = in_channels
self._out_channels = in_channels
Conv = DWConv if depthwise else BaseConv
self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
# top-down fpn
self.lateral_convs = nn.LayerList()
self.fpn_blocks = nn.LayerList()
for idx in range(len(in_channels) - 1, 0, -1):
self.lateral_convs.append(
BaseConv(
int(in_channels[idx]),
int(in_channels[idx - 1]),
1,
1,
act=act))
self.fpn_blocks.append(
CSPLayer(
int(in_channels[idx - 1] * 2),
int(in_channels[idx - 1]),
round(3 * depth_mult),
shortcut=False,
depthwise=depthwise,
act=act))
# bottom-up pan
self.downsample_convs = nn.LayerList()
self.pan_blocks = nn.LayerList()
for idx in range(len(in_channels) - 1):
self.downsample_convs.append(
Conv(
int(in_channels[idx]),
int(in_channels[idx]),
3,
stride=2,
act=act))
self.pan_blocks.append(
CSPLayer(
int(in_channels[idx] * 2),
int(in_channels[idx + 1]),
round(3 * depth_mult),
shortcut=False,
depthwise=depthwise,
act=act))
def forward(self, feats, for_mot=False):
assert len(feats) == len(self.in_channels)
# top-down fpn
inner_outs = [feats[-1]]
for idx in range(len(self.in_channels) - 1, 0, -1):
feat_heigh = inner_outs[0]
feat_low = feats[idx - 1]
feat_heigh = self.lateral_convs[len(self.in_channels) - 1 - idx](
feat_heigh)
inner_outs[0] = feat_heigh
upsample_feat = self.upsample(feat_heigh)
inner_out = self.fpn_blocks[len(self.in_channels) - 1 - idx](
paddle.concat(
[upsample_feat, feat_low], axis=1))
inner_outs.insert(0, inner_out)
# bottom-up pan
outs = [inner_outs[0]]
for idx in range(len(self.in_channels) - 1):
feat_low = outs[-1]
feat_height = inner_outs[idx + 1]
downsample_feat = self.downsample_convs[idx](feat_low)
out = self.pan_blocks[idx](paddle.concat(
[downsample_feat, feat_height], axis=1))
outs.append(out)
return outs
@classmethod
def from_config(cls, cfg, input_shape):
return {'in_channels': [i.channels for i in input_shape], }
@property
def out_shape(self):
return [ShapeSpec(channels=c) for c in self._out_channels]
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
......@@ -28,7 +28,7 @@ __all__ = [
'roi_pool', 'roi_align', 'prior_box', 'generate_proposals',
'iou_similarity', 'box_coder', 'yolo_box', 'multiclass_nms',
'distribute_fpn_proposals', 'collect_fpn_proposals', 'matrix_nms',
'batch_norm', 'mish', 'swish', 'identity'
'batch_norm', 'get_activation', 'mish', 'swish', 'identity'
]
......@@ -106,6 +106,18 @@ def batch_norm(ch,
return norm_layer
def get_activation(name="silu"):
if name == "silu":
module = nn.Silu()
elif name == "relu":
module = nn.ReLU()
elif name == "leakyrelu":
module = nn.LeakyReLU(0.1)
else:
raise AttributeError("Unsupported act type: {}".format(name))
return module
@paddle.jit.not_to_static
def roi_pool(input,
rois,
......
......@@ -209,6 +209,33 @@ class BurninWarmup(object):
return boundary, value
@serializable
class ExpWarmup(object):
"""
Warm up learning rate in exponential mode
Args:
steps (int): warm up steps.
epochs (int|None): use epochs as warm up steps, the priority
of `epochs` is higher than `steps`. Default: None.
"""
def __init__(self, steps=5, epochs=None):
super(ExpWarmup, self).__init__()
self.steps = steps
self.epochs = epochs
def __call__(self, base_lr, step_per_epoch):
boundary = []
value = []
warmup_steps = self.epochs * step_per_epoch if self.epochs is not None else self.steps
for i in range(warmup_steps + 1):
factor = (i / float(warmup_steps))**2
value.append(base_lr * factor)
if i > 0:
boundary.append(i)
return boundary, value
@register
class LearningRate(object):
"""
......@@ -331,7 +358,8 @@ class ModelEMA(object):
Ema's parameter are updated with the formula:
`ema_param = decay * ema_param + (1 - decay) * cur_param`.
Defaults is 0.9998.
use_thres_step (bool): Whether set decay by thres_step or not
ema_decay_type (str): type in ['threshold', 'normal', 'exponential'],
'threshold' as default.
cycle_epoch (int): The epoch of interval to reset ema_param and
step. Defaults is -1, which means not reset. Its function is to
add a regular effect to ema, which is set according to experience
......@@ -341,7 +369,7 @@ class ModelEMA(object):
def __init__(self,
model,
decay=0.9998,
use_thres_step=False,
ema_decay_type='threshold',
cycle_epoch=-1):
self.step = 0
self.epoch = 0
......@@ -349,7 +377,7 @@ class ModelEMA(object):
self.state_dict = dict()
for k, v in model.state_dict().items():
self.state_dict[k] = paddle.zeros_like(v)
self.use_thres_step = use_thres_step
self.ema_decay_type = ema_decay_type
self.cycle_epoch = cycle_epoch
self._model_state = {
......@@ -370,8 +398,10 @@ class ModelEMA(object):
self.step = step
def update(self, model=None):
if self.use_thres_step:
if self.ema_decay_type == 'threshold':
decay = min(self.decay, (1 + self.step) / (10 + self.step))
elif self.ema_decay_type == 'exponential':
decay = self.decay * (1 - math.exp(-(self.step + 1) / 2000))
else:
decay = self.decay
self._decay = decay
......@@ -394,7 +424,8 @@ class ModelEMA(object):
return self.state_dict
state_dict = dict()
for k, v in self.state_dict.items():
v = v / (1 - self._decay**self.step)
if self.ema_decay_type != 'exponential':
v = v / (1 - self._decay**self.step)
v.stop_gradient = True
state_dict[k] = v
self.epoch += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册