未验证 提交 10bf8de7 编写于 作者: F Feng Ni 提交者: GitHub

add YOLOX codes (#5740)

上级 fecae1ee
...@@ -31,16 +31,32 @@ sys.path.insert(0, parent_path) ...@@ -31,16 +31,32 @@ sys.path.insert(0, parent_path)
from benchmark_utils import PaddleInferBenchmark from benchmark_utils import PaddleInferBenchmark
from picodet_postprocess import PicoDetPostProcess from picodet_postprocess import PicoDetPostProcess
from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, decode_image from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image
from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop
from visualize import visualize_box_mask from visualize import visualize_box_mask
from utils import argsparser, Timer, get_current_memory_mb from utils import argsparser, Timer, get_current_memory_mb
# Global dictionary # Global dictionary
SUPPORT_MODELS = { SUPPORT_MODELS = {
'YOLO', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet', 'S2ANet', 'JDE', 'YOLO',
'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet', 'TOOD', 'RetinaNet', 'RCNN',
'StrongBaseline', 'STGCN' 'SSD',
'Face',
'FCOS',
'SOLOv2',
'TTFNet',
'S2ANet',
'JDE',
'FairMOT',
'DeepSORT',
'GFL',
'PicoDet',
'CenterNet',
'TOOD',
'RetinaNet',
'StrongBaseline',
'STGCN',
'YOLOX',
} }
......
...@@ -246,6 +246,81 @@ class LetterBoxResize(object): ...@@ -246,6 +246,81 @@ class LetterBoxResize(object):
return im, im_info return im, im_info
class Pad(object):
def __init__(self,
size=None,
size_divisor=32,
pad_mode=0,
offsets=None,
fill_value=(127.5, 127.5, 127.5)):
"""
Pad image to a specified size or multiple of size_divisor.
Args:
size (int, Sequence): image target size, if None, pad to multiple of size_divisor, default None
size_divisor (int): size divisor, default 32
pad_mode (int): pad mode, currently only supports four modes [-1, 0, 1, 2]. if -1, use specified offsets
if 0, only pad to right and bottom. if 1, pad according to center. if 2, only pad left and top
offsets (list): [offset_x, offset_y], specify offset while padding, only supported pad_mode=-1
fill_value (bool): rgb value of pad area, default (127.5, 127.5, 127.5)
"""
super(Pad, self).__init__()
if isinstance(size, int):
size = [size, size]
assert pad_mode in [
-1, 0, 1, 2
], 'currently only supports four modes [-1, 0, 1, 2]'
if pad_mode == -1:
assert offsets, 'if pad_mode is -1, offsets should not be None'
self.size = size
self.size_divisor = size_divisor
self.pad_mode = pad_mode
self.fill_value = fill_value
self.offsets = offsets
def apply_image(self, image, offsets, im_size, size):
x, y = offsets
im_h, im_w = im_size
h, w = size
canvas = np.ones((h, w, 3), dtype=np.float32)
canvas *= np.array(self.fill_value, dtype=np.float32)
canvas[y:y + im_h, x:x + im_w, :] = image.astype(np.float32)
return canvas
def __call__(self, im, im_info):
im_h, im_w = im.shape[:2]
if self.size:
h, w = self.size
assert (
im_h <= h and im_w <= w
), '(h, w) of target size should be greater than (im_h, im_w)'
else:
h = int(np.ceil(im_h / self.size_divisor) * self.size_divisor)
w = int(np.ceil(im_w / self.size_divisor) * self.size_divisor)
if h == im_h and w == im_w:
im = im.astype(np.float32)
return im, im_info
if self.pad_mode == -1:
offset_x, offset_y = self.offsets
elif self.pad_mode == 0:
offset_y, offset_x = 0, 0
elif self.pad_mode == 1:
offset_y, offset_x = (h - im_h) // 2, (w - im_w) // 2
else:
offset_y, offset_x = h - im_h, w - im_w
offsets, im_size, size = [offset_x, offset_y], [im_h, im_w], [h, w]
im = self.apply_image(im, offsets, im_size, size)
if self.pad_mode == 0:
return im, im_info
return im, im_info
class WarpAffine(object): class WarpAffine(object):
"""Warp affine the image """Warp affine the image
""" """
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -75,7 +75,7 @@ class DetDataset(Dataset): ...@@ -75,7 +75,7 @@ class DetDataset(Dataset):
n = len(self.roidbs) n = len(self.roidbs)
roidb = [roidb, ] + [ roidb = [roidb, ] + [
copy.deepcopy(self.roidbs[np.random.randint(n)]) copy.deepcopy(self.roidbs[np.random.randint(n)])
for _ in range(3) for _ in range(4)
] ]
if isinstance(roidb, Sequence): if isinstance(roidb, Sequence):
for r in roidb: for r in roidb:
......
...@@ -2034,13 +2034,14 @@ class Pad(BaseOperator): ...@@ -2034,13 +2034,14 @@ class Pad(BaseOperator):
if self.size: if self.size:
h, w = self.size h, w = self.size
assert ( assert (
im_h < h and im_w < w im_h <= h and im_w <= w
), '(h, w) of target size should be greater than (im_h, im_w)' ), '(h, w) of target size should be greater than (im_h, im_w)'
else: else:
h = int(np.ceil(im_h / self.size_divisor) * self.size_divisor) h = int(np.ceil(im_h / self.size_divisor) * self.size_divisor)
w = int(np.ceil(im_w / self.size_divisor) * self.size_divisor) w = int(np.ceil(im_w / self.size_divisor) * self.size_divisor)
if h == im_h and w == im_w: if h == im_h and w == im_w:
sample['image'] = im.astype(np.float32)
return sample return sample
if self.pad_mode == -1: if self.pad_mode == -1:
...@@ -2139,16 +2140,29 @@ class Rbox2Poly(BaseOperator): ...@@ -2139,16 +2140,29 @@ class Rbox2Poly(BaseOperator):
@register_op @register_op
class AugmentHSV(BaseOperator): class AugmentHSV(BaseOperator):
def __init__(self, fraction=0.50, is_bgr=True): """
""" Augment the SV channel of image data.
Augment the SV channel of image data. Args:
Args: fraction (float): the fraction for augment. Default: 0.5.
fraction (float): the fraction for augment. Default: 0.5. is_bgr (bool): whether the image is BGR mode. Default: True.
is_bgr (bool): whether the image is BGR mode. Default: True. hgain (float): H channel gains
""" sgain (float): S channel gains
vgain (float): V channel gains
"""
def __init__(self,
fraction=0.50,
is_bgr=True,
hgain=None,
sgain=None,
vgain=None):
super(AugmentHSV, self).__init__() super(AugmentHSV, self).__init__()
self.fraction = fraction self.fraction = fraction
self.is_bgr = is_bgr self.is_bgr = is_bgr
self.hgain = hgain
self.sgain = sgain
self.vgain = vgain
self.use_hsvgain = False if hgain is None else True
def apply(self, sample, context=None): def apply(self, sample, context=None):
img = sample['image'] img = sample['image']
...@@ -2156,21 +2170,33 @@ class AugmentHSV(BaseOperator): ...@@ -2156,21 +2170,33 @@ class AugmentHSV(BaseOperator):
img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
else: else:
img_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) img_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
S = img_hsv[:, :, 1].astype(np.float32)
V = img_hsv[:, :, 2].astype(np.float32)
a = (random.random() * 2 - 1) * self.fraction + 1 if self.use_hsvgain:
S *= a hsv_augs = np.random.uniform(
if a > 1: -1, 1, 3) * [self.hgain, self.sgain, self.vgain]
np.clip(S, a_min=0, a_max=255, out=S) # random selection of h, s, v
hsv_augs *= np.random.randint(0, 2, 3)
img_hsv[..., 0] = (img_hsv[..., 0] + hsv_augs[0]) % 180
img_hsv[..., 1] = np.clip(img_hsv[..., 1] + hsv_augs[1], 0, 255)
img_hsv[..., 2] = np.clip(img_hsv[..., 2] + hsv_augs[2], 0, 255)
a = (random.random() * 2 - 1) * self.fraction + 1 else:
V *= a S = img_hsv[:, :, 1].astype(np.float32)
if a > 1: V = img_hsv[:, :, 2].astype(np.float32)
np.clip(V, a_min=0, a_max=255, out=V)
a = (random.random() * 2 - 1) * self.fraction + 1
S *= a
if a > 1:
np.clip(S, a_min=0, a_max=255, out=S)
a = (random.random() * 2 - 1) * self.fraction + 1
V *= a
if a > 1:
np.clip(V, a_min=0, a_max=255, out=V)
img_hsv[:, :, 1] = S.astype(np.uint8)
img_hsv[:, :, 2] = V.astype(np.uint8)
img_hsv[:, :, 1] = S.astype(np.uint8)
img_hsv[:, :, 2] = V.astype(np.uint8)
if self.is_bgr: if self.is_bgr:
cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img)
else: else:
...@@ -3018,3 +3044,373 @@ class CenterRandColor(BaseOperator): ...@@ -3018,3 +3044,373 @@ class CenterRandColor(BaseOperator):
img = func(img, img_gray) img = func(img, img_gray)
sample['image'] = img sample['image'] = img
return sample return sample
@register_op
class Mosaic(BaseOperator):
""" Mosaic operator for image and gt_bboxes
The code is based on https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/data/datasets/mosaicdetection.py
1. get mosaic coords
2. clip bbox and get mosaic_labels
3. random_affine augment
4. Mixup augment as copypaste (optinal), not used in tiny/nano
Args:
prob (float): probability of using Mosaic, 1.0 as default
input_dim (list[int]): input shape
degrees (list[2]): the rotate range to apply, transform range is [min, max]
translate (list[2]): the translate range to apply, transform range is [min, max]
scale (list[2]): the scale range to apply, transform range is [min, max]
shear (list[2]): the shear range to apply, transform range is [min, max]
enable_mixup (bool): whether to enable Mixup or not
mixup_prob (float): probability of using Mixup, 1.0 as default
mixup_scale (list[int]): scale range of Mixup
remove_outside_box (bool): whether remove outside boxes, False as
default in COCO dataset, True in MOT dataset
"""
def __init__(self,
prob=1.0,
input_dim=[640, 640],
degrees=[-10, 10],
translate=[-0.1, 0.1],
scale=[0.1, 2],
shear=[-2, 2],
enable_mixup=True,
mixup_prob=1.0,
mixup_scale=[0.5, 1.5],
remove_outside_box=False):
super(Mosaic, self).__init__()
self.prob = prob
self.input_dim = input_dim
self.degrees = degrees
self.translate = translate
self.scale = scale
self.shear = shear
self.enable_mixup = enable_mixup
self.mixup_prob = mixup_prob
self.mixup_scale = mixup_scale
self.remove_outside_box = remove_outside_box
def get_mosaic_coords(self, mosaic_idx, xc, yc, w, h, input_h, input_w):
# (x1, y1, x2, y2) means coords in large image,
# small_coords means coords in small image in mosaic aug.
if mosaic_idx == 0:
# top left
x1, y1, x2, y2 = max(xc - w, 0), max(yc - h, 0), xc, yc
small_coords = w - (x2 - x1), h - (y2 - y1), w, h
elif mosaic_idx == 1:
# top right
x1, y1, x2, y2 = xc, max(yc - h, 0), min(xc + w, input_w * 2), yc
small_coords = 0, h - (y2 - y1), min(w, x2 - x1), h
elif mosaic_idx == 2:
# bottom left
x1, y1, x2, y2 = max(xc - w, 0), yc, xc, min(input_h * 2, yc + h)
small_coords = w - (x2 - x1), 0, w, min(y2 - y1, h)
elif mosaic_idx == 3:
# bottom right
x1, y1, x2, y2 = xc, yc, min(xc + w, input_w * 2), min(input_h * 2,
yc + h)
small_coords = 0, 0, min(w, x2 - x1), min(y2 - y1, h)
return (x1, y1, x2, y2), small_coords
def random_affine_augment(self,
img,
labels=[],
input_dim=[640, 640],
degrees=[-10, 10],
scales=[0.1, 2],
shears=[-2, 2],
translates=[-0.1, 0.1]):
# random rotation and scale
degree = random.uniform(degrees[0], degrees[1])
scale = random.uniform(scales[0], scales[1])
assert scale > 0, "Argument scale should be positive."
R = cv2.getRotationMatrix2D(angle=degree, center=(0, 0), scale=scale)
M = np.ones([2, 3])
# random shear
shear = random.uniform(shears[0], shears[1])
shear_x = math.tan(shear * math.pi / 180)
shear_y = math.tan(shear * math.pi / 180)
M[0] = R[0] + shear_y * R[1]
M[1] = R[1] + shear_x * R[0]
# random translation
translate = random.uniform(translates[0], translates[1])
translation_x = translate * input_dim[0]
translation_y = translate * input_dim[1]
M[0, 2] = translation_x
M[1, 2] = translation_y
# warpAffine
img = cv2.warpAffine(
img, M, dsize=input_dim, borderValue=(114, 114, 114))
num_gts = len(labels)
if num_gts > 0:
# warp corner points
corner_points = np.ones((4 * num_gts, 3))
corner_points[:, :2] = labels[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(
4 * num_gts, 2) # x1y1, x2y2, x1y2, x2y1
# apply affine transform
corner_points = corner_points @M.T
corner_points = corner_points.reshape(num_gts, 8)
# create new boxes
corner_xs = corner_points[:, 0::2]
corner_ys = corner_points[:, 1::2]
new_bboxes = np.concatenate((corner_xs.min(1), corner_ys.min(1),
corner_xs.max(1), corner_ys.max(1)))
new_bboxes = new_bboxes.reshape(4, num_gts).T
# clip boxes
new_bboxes[:, 0::2] = np.clip(new_bboxes[:, 0::2], 0, input_dim[0])
new_bboxes[:, 1::2] = np.clip(new_bboxes[:, 1::2], 0, input_dim[1])
labels[:, :4] = new_bboxes
return img, labels
def __call__(self, sample, context=None):
if not isinstance(sample, Sequence):
return sample
assert len(
sample) == 5, "Mosaic needs 5 samples, 4 for mosaic and 1 for mixup."
if np.random.uniform(0., 1.) > self.prob:
return sample[0]
mosaic_gt_bbox, mosaic_gt_class, mosaic_is_crowd = [], [], []
input_h, input_w = self.input_dim
yc = int(random.uniform(0.5 * input_h, 1.5 * input_h))
xc = int(random.uniform(0.5 * input_w, 1.5 * input_w))
mosaic_img = np.full((input_h * 2, input_w * 2, 3), 114, dtype=np.uint8)
# 1. get mosaic coords
for mosaic_idx, sp in enumerate(sample[:4]):
img = sp['image']
gt_bbox = sp['gt_bbox']
h0, w0 = img.shape[:2]
scale = min(1. * input_h / h0, 1. * input_w / w0)
img = cv2.resize(
img, (int(w0 * scale), int(h0 * scale)),
interpolation=cv2.INTER_LINEAR)
(h, w, c) = img.shape[:3]
# suffix l means large image, while s means small image in mosaic aug.
(l_x1, l_y1, l_x2, l_y2), (
s_x1, s_y1, s_x2, s_y2) = self.get_mosaic_coords(
mosaic_idx, xc, yc, w, h, input_h, input_w)
mosaic_img[l_y1:l_y2, l_x1:l_x2] = img[s_y1:s_y2, s_x1:s_x2]
padw, padh = l_x1 - s_x1, l_y1 - s_y1
# Normalized xywh to pixel xyxy format
_gt_bbox = gt_bbox.copy()
if len(gt_bbox) > 0:
_gt_bbox[:, 0] = scale * gt_bbox[:, 0] + padw
_gt_bbox[:, 1] = scale * gt_bbox[:, 1] + padh
_gt_bbox[:, 2] = scale * gt_bbox[:, 2] + padw
_gt_bbox[:, 3] = scale * gt_bbox[:, 3] + padh
mosaic_gt_bbox.append(_gt_bbox)
mosaic_gt_class.append(sp['gt_class'])
mosaic_is_crowd.append(sp['is_crowd'])
# 2. clip bbox and get mosaic_labels([gt_bbox, gt_class, is_crowd])
if len(mosaic_gt_bbox):
mosaic_gt_bbox = np.concatenate(mosaic_gt_bbox, 0)
mosaic_gt_class = np.concatenate(mosaic_gt_class, 0)
mosaic_is_crowd = np.concatenate(mosaic_is_crowd, 0)
mosaic_labels = np.concatenate([
mosaic_gt_bbox, mosaic_gt_class.astype(mosaic_gt_bbox.dtype),
mosaic_is_crowd.astype(mosaic_gt_bbox.dtype)
], 1)
if self.remove_outside_box:
# for MOT dataset
flag1 = mosaic_gt_bbox[:, 0] < 2 * input_w
flag2 = mosaic_gt_bbox[:, 2] > 0
flag3 = mosaic_gt_bbox[:, 1] < 2 * input_h
flag4 = mosaic_gt_bbox[:, 3] > 0
flag_all = flag1 * flag2 * flag3 * flag4
mosaic_labels = mosaic_labels[flag_all]
else:
mosaic_labels[:, 0] = np.clip(mosaic_labels[:, 0], 0,
2 * input_w)
mosaic_labels[:, 1] = np.clip(mosaic_labels[:, 1], 0,
2 * input_h)
mosaic_labels[:, 2] = np.clip(mosaic_labels[:, 2], 0,
2 * input_w)
mosaic_labels[:, 3] = np.clip(mosaic_labels[:, 3], 0,
2 * input_h)
else:
mosaic_labels = np.zeros((1, 6))
# 3. random_affine augment
mosaic_img, mosaic_labels = self.random_affine_augment(
mosaic_img,
mosaic_labels,
input_dim=self.input_dim,
degrees=self.degrees,
translates=self.translate,
scales=self.scale,
shears=self.shear)
# 4. Mixup augment as copypaste, https://arxiv.org/abs/2012.07177
# optinal, not used(enable_mixup=False) in tiny/nano
if (self.enable_mixup and not len(mosaic_labels) == 0 and
random.random() < self.mixup_prob):
sample_mixup = sample[4]
mixup_img = sample_mixup['image']
cp_labels = np.concatenate([
sample_mixup['gt_bbox'],
sample_mixup['gt_class'].astype(mosaic_labels.dtype),
sample_mixup['is_crowd'].astype(mosaic_labels.dtype)
], 1)
mosaic_img, mosaic_labels = self.mixup_augment(
mosaic_img, mosaic_labels, self.input_dim, cp_labels, mixup_img)
sample0 = sample[0]
sample0['image'] = mosaic_img.astype(np.uint8) # can not be float32
sample0['h'] = float(mosaic_img.shape[0])
sample0['w'] = float(mosaic_img.shape[1])
sample0['im_shape'][0] = sample0['h']
sample0['im_shape'][1] = sample0['w']
sample0['gt_bbox'] = mosaic_labels[:, :4].astype(np.float32)
sample0['gt_class'] = mosaic_labels[:, 4:5].astype(np.float32)
sample0['is_crowd'] = mosaic_labels[:, 5:6].astype(np.float32)
return sample0
def mixup_augment(self, origin_img, origin_labels, input_dim, cp_labels,
img):
jit_factor = random.uniform(*self.mixup_scale)
FLIP = random.uniform(0, 1) > 0.5
if len(img.shape) == 3:
cp_img = np.ones(
(input_dim[0], input_dim[1], 3), dtype=np.uint8) * 114
else:
cp_img = np.ones(input_dim, dtype=np.uint8) * 114
cp_scale_ratio = min(input_dim[0] / img.shape[0],
input_dim[1] / img.shape[1])
resized_img = cv2.resize(
img, (int(img.shape[1] * cp_scale_ratio),
int(img.shape[0] * cp_scale_ratio)),
interpolation=cv2.INTER_LINEAR)
cp_img[:int(img.shape[0] * cp_scale_ratio), :int(img.shape[
1] * cp_scale_ratio)] = resized_img
cp_img = cv2.resize(cp_img, (int(cp_img.shape[1] * jit_factor),
int(cp_img.shape[0] * jit_factor)))
cp_scale_ratio *= jit_factor
if FLIP:
cp_img = cp_img[:, ::-1, :]
origin_h, origin_w = cp_img.shape[:2]
target_h, target_w = origin_img.shape[:2]
padded_img = np.zeros(
(max(origin_h, target_h), max(origin_w, target_w), 3),
dtype=np.uint8)
padded_img[:origin_h, :origin_w] = cp_img
x_offset, y_offset = 0, 0
if padded_img.shape[0] > target_h:
y_offset = random.randint(0, padded_img.shape[0] - target_h - 1)
if padded_img.shape[1] > target_w:
x_offset = random.randint(0, padded_img.shape[1] - target_w - 1)
padded_cropped_img = padded_img[y_offset:y_offset + target_h, x_offset:
x_offset + target_w]
# adjust boxes
cp_bboxes_origin_np = cp_labels[:, :4].copy()
cp_bboxes_origin_np[:, 0::2] = np.clip(cp_bboxes_origin_np[:, 0::2] *
cp_scale_ratio, 0, origin_w)
cp_bboxes_origin_np[:, 1::2] = np.clip(cp_bboxes_origin_np[:, 1::2] *
cp_scale_ratio, 0, origin_h)
if FLIP:
cp_bboxes_origin_np[:, 0::2] = (
origin_w - cp_bboxes_origin_np[:, 0::2][:, ::-1])
cp_bboxes_transformed_np = cp_bboxes_origin_np.copy()
if self.remove_outside_box:
# for MOT dataset
cp_bboxes_transformed_np[:, 0::2] -= x_offset
cp_bboxes_transformed_np[:, 1::2] -= y_offset
else:
cp_bboxes_transformed_np[:, 0::2] = np.clip(
cp_bboxes_transformed_np[:, 0::2] - x_offset, 0, target_w)
cp_bboxes_transformed_np[:, 1::2] = np.clip(
cp_bboxes_transformed_np[:, 1::2] - y_offset, 0, target_h)
cls_labels = cp_labels[:, 4:5].copy()
crd_labels = cp_labels[:, 5:6].copy()
box_labels = cp_bboxes_transformed_np
labels = np.hstack((box_labels, cls_labels, crd_labels))
if self.remove_outside_box:
labels = labels[labels[:, 0] < target_w]
labels = labels[labels[:, 2] > 0]
labels = labels[labels[:, 1] < target_h]
labels = labels[labels[:, 3] > 0]
origin_labels = np.vstack((origin_labels, labels))
origin_img = origin_img.astype(np.float32)
origin_img = 0.5 * origin_img + 0.5 * padded_cropped_img.astype(
np.float32)
return origin_img.astype(np.uint8), origin_labels
@register_op
class PadResize(BaseOperator):
""" PadResize for image and gt_bbbox
Args:
target_size (list[int]): input shape
fill_value (float): pixel value of padded image
"""
def __init__(self, target_size, fill_value=114):
super(PadResize, self).__init__()
if isinstance(target_size, Integral):
target_size = [target_size, target_size]
self.target_size = target_size
self.fill_value = fill_value
def _resize(self, img, bboxes, labels):
ratio = min(self.target_size[0] / img.shape[0],
self.target_size[1] / img.shape[1])
w, h = int(img.shape[1] * ratio), int(img.shape[0] * ratio)
resized_img = cv2.resize(img, (w, h), interpolation=cv2.INTER_LINEAR)
if len(bboxes) > 0:
bboxes *= ratio
mask = np.minimum(bboxes[:, 2] - bboxes[:, 0],
bboxes[:, 3] - bboxes[:, 1]) > 1
bboxes = bboxes[mask]
labels = labels[mask]
return resized_img, bboxes, labels
def _pad(self, img):
h, w, _ = img.shape
if h == self.target_size[0] and w == self.target_size[1]:
return img
padded_img = np.full(
(self.target_size[0], self.target_size[1], 3),
self.fill_value,
dtype=np.uint8)
padded_img[:h, :w] = img
return padded_img
def apply(self, sample, context=None):
image = sample['image']
bboxes = sample['gt_bbox']
labels = sample['gt_class']
image, bboxes, labels = self._resize(image, bboxes, labels)
sample['image'] = self._pad(image).astype(np.float32)
sample['gt_bbox'] = bboxes
sample['gt_class'] = labels
return sample
...@@ -48,6 +48,7 @@ TRT_MIN_SUBGRAPH = { ...@@ -48,6 +48,7 @@ TRT_MIN_SUBGRAPH = {
'PicoDet': 3, 'PicoDet': 3,
'CenterNet': 5, 'CenterNet': 5,
'TOOD': 5, 'TOOD': 5,
'YOLOX': 8,
} }
KEYPOINT_ARCH = ['HigherHRNet', 'TopDownHRNet'] KEYPOINT_ARCH = ['HigherHRNet', 'TopDownHRNet']
...@@ -147,6 +148,12 @@ def _dump_infer_config(config, path, image_shape, model): ...@@ -147,6 +148,12 @@ def _dump_infer_config(config, path, image_shape, model):
infer_cfg['min_subgraph_size'] = min_subgraph_size infer_cfg['min_subgraph_size'] = min_subgraph_size
arch_state = True arch_state = True
break break
if infer_arch == 'YOLOX':
infer_cfg['arch'] = infer_arch
infer_cfg['min_subgraph_size'] = TRT_MIN_SUBGRAPH[infer_arch]
arch_state = True
if not arch_state: if not arch_state:
logger.error( logger.error(
'Architecture: {} is not supported for exporting model now.\n'. 'Architecture: {} is not supported for exporting model now.\n'.
......
...@@ -28,6 +28,7 @@ from PIL import Image, ImageOps, ImageFile ...@@ -28,6 +28,7 @@ from PIL import Image, ImageOps, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True ImageFile.LOAD_TRUNCATED_IMAGES = True
import paddle import paddle
import paddle.nn as nn
import paddle.distributed as dist import paddle.distributed as dist
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle import amp from paddle import amp
...@@ -99,6 +100,12 @@ class Trainer(object): ...@@ -99,6 +100,12 @@ class Trainer(object):
self.model = self.cfg.model self.model = self.cfg.model
self.is_loaded_weights = True self.is_loaded_weights = True
if cfg.architecture == 'YOLOX':
for k, m in self.model.named_sublayers():
if isinstance(m, nn.BatchNorm2D):
m.epsilon = 1e-3 # for amp(fp16)
m.momentum = 0.97 # 0.03 in pytorch
#normalize params for deploy #normalize params for deploy
if 'slim' in cfg and cfg['slim_type'] == 'OFA': if 'slim' in cfg and cfg['slim_type'] == 'OFA':
self.model.model.load_meanstd(cfg['TestReader'][ self.model.model.load_meanstd(cfg['TestReader'][
...@@ -117,10 +124,11 @@ class Trainer(object): ...@@ -117,10 +124,11 @@ class Trainer(object):
if self.use_ema: if self.use_ema:
ema_decay = self.cfg.get('ema_decay', 0.9998) ema_decay = self.cfg.get('ema_decay', 0.9998)
cycle_epoch = self.cfg.get('cycle_epoch', -1) cycle_epoch = self.cfg.get('cycle_epoch', -1)
ema_decay_type = self.cfg.get('ema_decay_type', 'threshold')
self.ema = ModelEMA( self.ema = ModelEMA(
self.model, self.model,
decay=ema_decay, decay=ema_decay,
use_thres_step=True, ema_decay_type=ema_decay_type,
cycle_epoch=cycle_epoch) cycle_epoch=cycle_epoch)
# EvalDataset build with BatchSampler to evaluate in single device # EvalDataset build with BatchSampler to evaluate in single device
......
...@@ -5,6 +5,13 @@ ...@@ -5,6 +5,13 @@
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import meta_arch from . import meta_arch
from . import faster_rcnn from . import faster_rcnn
from . import mask_rcnn from . import mask_rcnn
...@@ -28,6 +35,7 @@ from . import sparse_rcnn ...@@ -28,6 +35,7 @@ from . import sparse_rcnn
from . import tood from . import tood
from . import retinanet from . import retinanet
from . import bytetrack from . import bytetrack
from . import yolox
from .meta_arch import * from .meta_arch import *
from .faster_rcnn import * from .faster_rcnn import *
...@@ -53,3 +61,4 @@ from .sparse_rcnn import * ...@@ -53,3 +61,4 @@ from .sparse_rcnn import *
from .tood import * from .tood import *
from .retinanet import * from .retinanet import *
from .bytetrack import * from .bytetrack import *
from .yolox import *
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ppdet.core.workspace import register, create
from .meta_arch import BaseArch
import random
import paddle
import paddle.nn.functional as F
import paddle.distributed as dist
from ppdet.modeling.ops import paddle_distributed_is_initialized
__all__ = ['YOLOX']
@register
class YOLOX(BaseArch):
"""
YOLOX network, see https://arxiv.org/abs/2107.08430
Args:
backbone (nn.Layer): backbone instance
neck (nn.Layer): neck instance
head (nn.Layer): head instance
for_mot (bool): whether used for MOT or not
input_size (list[int]): initial scale, will be reset by self._preprocess()
size_stride (int): stride of the size range
size_range (list[int]): multi-scale range for training
random_interval (int): interval of iter to change self._input_size
"""
__category__ = 'architecture'
def __init__(self,
backbone='CSPDarkNet',
neck='YOLOCSPPAN',
head='YOLOXHead',
for_mot=False,
input_size=[640, 640],
size_stride=32,
size_range=[15, 25],
random_interval=10):
super(YOLOX, self).__init__()
self.backbone = backbone
self.neck = neck
self.head = head
self.for_mot = for_mot
self.input_size = input_size
self._input_size = paddle.to_tensor(input_size)
self.size_stride = size_stride
self.size_range = size_range
self.random_interval = random_interval
self._step = 0
@classmethod
def from_config(cls, cfg, *args, **kwargs):
# backbone
backbone = create(cfg['backbone'])
# fpn
kwargs = {'input_shape': backbone.out_shape}
neck = create(cfg['neck'], **kwargs)
# head
kwargs = {'input_shape': neck.out_shape}
head = create(cfg['head'], **kwargs)
return {
'backbone': backbone,
'neck': neck,
"head": head,
}
def _forward(self):
if self.training:
self._preprocess()
body_feats = self.backbone(self.inputs)
neck_feats = self.neck(body_feats, self.for_mot)
if self.training:
yolox_losses = self.head(neck_feats, self.inputs)
yolox_losses.update({'size': self._input_size[0]})
return yolox_losses
else:
head_outs = self.head(neck_feats)
bbox, bbox_num = self.head.post_process(
head_outs, self.inputs['im_shape'], self.inputs['scale_factor'])
return {'bbox': bbox, 'bbox_num': bbox_num}
def get_loss(self):
return self._forward()
def get_pred(self):
return self._forward()
def _preprocess(self):
# YOLOX multi-scale training, interpolate resize before inputs of the network.
self._get_size()
scale_y = self._input_size[0] / self.input_size[0]
scale_x = self._input_size[1] / self.input_size[1]
if scale_x != 1 or scale_y != 1:
self.inputs['image'] = F.interpolate(
self.inputs['image'],
size=self._input_size,
mode='bilinear',
align_corners=False)
gt_bboxes = self.inputs['gt_bbox']
for i in range(len(gt_bboxes)):
if len(gt_bboxes[i]) > 0:
gt_bboxes[i][:, 0::2] = gt_bboxes[i][:, 0::2] * scale_x
gt_bboxes[i][:, 1::2] = gt_bboxes[i][:, 1::2] * scale_y
self.inputs['gt_bbox'] = gt_bboxes
def _get_size(self):
# random_interval = 10 as default, every 10 iters to change self._input_size
image_ratio = self.input_size[1] * 1.0 / self.input_size[0]
if self._step % self.random_interval == 0:
size_factor = random.randint(*self.size_range)
size = [
self.size_stride * size_factor,
self.size_stride * int(size_factor * image_ratio)
]
size = paddle.to_tensor(size)
if dist.get_world_size() > 1 and paddle_distributed_is_initialized(
):
dist.barrier()
dist.broadcast(size, 0)
self._input_size = size
self._step += 1
...@@ -115,7 +115,7 @@ def check_points_inside_bboxes(points, ...@@ -115,7 +115,7 @@ def check_points_inside_bboxes(points,
Args: Args:
points (Tensor, float32): shape[L, 2], "xy" format, L: num_anchors points (Tensor, float32): shape[L, 2], "xy" format, L: num_anchors
bboxes (Tensor, float32): shape[B, n, 4], "xmin, ymin, xmax, ymax" format bboxes (Tensor, float32): shape[B, n, 4], "xmin, ymin, xmax, ymax" format
center_radius_tensor (Tensor, float32): shape [L, 1] Default: None. center_radius_tensor (Tensor, float32): shape [L, 1]. Default: None.
eps (float): Default: 1e-9 eps (float): Default: 1e-9
Returns: Returns:
is_in_bboxes (Tensor, float32): shape[B, n, L], value=1. means selected is_in_bboxes (Tensor, float32): shape[B, n, L], value=1. means selected
...@@ -123,25 +123,28 @@ def check_points_inside_bboxes(points, ...@@ -123,25 +123,28 @@ def check_points_inside_bboxes(points,
points = points.unsqueeze([0, 1]) points = points.unsqueeze([0, 1])
x, y = points.chunk(2, axis=-1) x, y = points.chunk(2, axis=-1)
xmin, ymin, xmax, ymax = bboxes.unsqueeze(2).chunk(4, axis=-1) xmin, ymin, xmax, ymax = bboxes.unsqueeze(2).chunk(4, axis=-1)
if center_radius_tensor is not None: # check whether `points` is in `bboxes`
center_radius_tensor = center_radius_tensor.unsqueeze([0, 1])
bboxes_cx = (xmin + xmax) / 2.
bboxes_cy = (ymin + ymax) / 2.
xmin_sampling = bboxes_cx - center_radius_tensor
ymin_sampling = bboxes_cy - center_radius_tensor
xmax_sampling = bboxes_cx + center_radius_tensor
ymax_sampling = bboxes_cy + center_radius_tensor
xmin = paddle.maximum(xmin, xmin_sampling)
ymin = paddle.maximum(ymin, ymin_sampling)
xmax = paddle.minimum(xmax, xmax_sampling)
ymax = paddle.minimum(ymax, ymax_sampling)
l = x - xmin l = x - xmin
t = y - ymin t = y - ymin
r = xmax - x r = xmax - x
b = ymax - y b = ymax - y
bbox_ltrb = paddle.concat([l, t, r, b], axis=-1) delta_ltrb = paddle.concat([l, t, r, b], axis=-1)
return (bbox_ltrb.min(axis=-1) > eps).astype(bboxes.dtype) is_in_bboxes = (delta_ltrb.min(axis=-1) > eps)
if center_radius_tensor is not None:
# check whether `points` is in `center_radius`
center_radius_tensor = center_radius_tensor.unsqueeze([0, 1])
cx = (xmin + xmax) * 0.5
cy = (ymin + ymax) * 0.5
l = x - (cx - center_radius_tensor)
t = y - (cy - center_radius_tensor)
r = (cx + center_radius_tensor) - x
b = (cy + center_radius_tensor) - y
delta_ltrb_c = paddle.concat([l, t, r, b], axis=-1)
is_in_center = (delta_ltrb_c.min(axis=-1) > eps)
return (paddle.logical_and(is_in_bboxes, is_in_center),
paddle.logical_or(is_in_bboxes, is_in_center))
return is_in_bboxes.astype(bboxes.dtype)
def compute_max_iou_anchor(ious): def compute_max_iou_anchor(ious):
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import vgg from . import vgg
...@@ -30,6 +30,7 @@ from . import lcnet ...@@ -30,6 +30,7 @@ from . import lcnet
from . import hardnet from . import hardnet
from . import esnet from . import esnet
from . import cspresnet from . import cspresnet
from . import csp_darknet
from .vgg import * from .vgg import *
from .resnet import * from .resnet import *
...@@ -49,3 +50,4 @@ from .lcnet import * ...@@ -49,3 +50,4 @@ from .lcnet import *
from .hardnet import * from .hardnet import *
from .esnet import * from .esnet import *
from .cspresnet import * from .cspresnet import *
from .csp_darknet import *
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.regularizer import L2Decay
from ppdet.core.workspace import register, serializable
from ppdet.modeling.ops import get_activation
from ppdet.modeling.initializer import conv_init_
from ..shape_spec import ShapeSpec
__all__ = [
'CSPDarkNet', 'BaseConv', 'DWConv', 'BottleNeck', 'SPPLayer', 'SPPFLayer'
]
class BaseConv(nn.Layer):
def __init__(self,
in_channels,
out_channels,
ksize,
stride,
groups=1,
bias=False,
act="silu"):
super(BaseConv, self).__init__()
self.conv = nn.Conv2D(
in_channels,
out_channels,
kernel_size=ksize,
stride=stride,
padding=(ksize - 1) // 2,
groups=groups,
bias_attr=bias)
self.bn = nn.BatchNorm2D(
out_channels,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
self.act = get_activation(act)
self._init_weights()
def _init_weights(self):
conv_init_(self.conv)
def forward(self, x):
return self.act(self.bn(self.conv(x)))
class DWConv(nn.Layer):
"""Depthwise Conv"""
def __init__(self,
in_channels,
out_channels,
ksize,
stride=1,
bias=False,
act="silu"):
super(DWConv, self).__init__()
self.dw_conv = BaseConv(
in_channels,
in_channels,
ksize=ksize,
stride=stride,
groups=in_channels,
bias=bias,
act=act, )
self.pw_conv = BaseConv(
in_channels,
out_channels,
ksize=1,
stride=1,
groups=1,
bias=bias,
act=act)
def forward(self, x):
return self.pw_conv(self.dw_conv(x))
class Focus(nn.Layer):
"""Focus width and height information into channel space, used in YOLOX."""
def __init__(self,
in_channels,
out_channels,
ksize=3,
stride=1,
bias=False,
act="silu"):
super(Focus, self).__init__()
self.conv = BaseConv(
in_channels * 4,
out_channels,
ksize=ksize,
stride=stride,
bias=bias,
act=act)
def forward(self, inputs):
# inputs [bs, C, H, W] -> outputs [bs, 4C, W/2, H/2]
top_left = inputs[:, :, 0::2, 0::2]
top_right = inputs[:, :, 0::2, 1::2]
bottom_left = inputs[:, :, 1::2, 0::2]
bottom_right = inputs[:, :, 1::2, 1::2]
outputs = paddle.concat(
[top_left, bottom_left, top_right, bottom_right], 1)
return self.conv(outputs)
class BottleNeck(nn.Layer):
def __init__(self,
in_channels,
out_channels,
shortcut=True,
expansion=0.5,
depthwise=False,
bias=False,
act="silu"):
super(BottleNeck, self).__init__()
hidden_channels = int(out_channels * expansion)
Conv = DWConv if depthwise else BaseConv
self.conv1 = BaseConv(
in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act)
self.conv2 = Conv(
hidden_channels,
out_channels,
ksize=3,
stride=1,
bias=bias,
act=act)
self.add_shortcut = shortcut and in_channels == out_channels
def forward(self, x):
y = self.conv2(self.conv1(x))
if self.add_shortcut:
y = y + x
return y
class SPPLayer(nn.Layer):
"""Spatial Pyramid Pooling (SPP) layer used in YOLOv3-SPP and YOLOX"""
def __init__(self,
in_channels,
out_channels,
kernel_sizes=(5, 9, 13),
bias=False,
act="silu"):
super(SPPLayer, self).__init__()
hidden_channels = in_channels // 2
self.conv1 = BaseConv(
in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act)
self.maxpoolings = nn.LayerList([
nn.MaxPool2D(
kernel_size=ks, stride=1, padding=ks // 2)
for ks in kernel_sizes
])
conv2_channels = hidden_channels * (len(kernel_sizes) + 1)
self.conv2 = BaseConv(
conv2_channels, out_channels, ksize=1, stride=1, bias=bias, act=act)
def forward(self, x):
x = self.conv1(x)
x = paddle.concat([x] + [mp(x) for mp in self.maxpoolings], axis=1)
x = self.conv2(x)
return x
class SPPFLayer(nn.Layer):
""" Spatial Pyramid Pooling - Fast (SPPF) layer used in YOLOv5 by Glenn Jocher,
equivalent to SPP(k=(5, 9, 13))
"""
def __init__(self,
in_channels,
out_channels,
ksize=5,
bias=False,
act='silu'):
super(SPPFLayer, self).__init__()
hidden_channels = in_channels // 2
self.conv1 = BaseConv(
in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act)
self.maxpooling = nn.MaxPool2D(
kernel_size=ksize, stride=1, padding=ksize // 2)
conv2_channels = hidden_channels * 4
self.conv2 = BaseConv(
conv2_channels, out_channels, ksize=1, stride=1, bias=bias, act=act)
def forward(self, x):
x = self.conv1(x)
y1 = self.maxpooling(x)
y2 = self.maxpooling(y1)
y3 = self.maxpooling(y2)
concats = paddle.concat([x, y1, y2, y3], axis=1)
out = self.conv2(concats)
return out
class CSPLayer(nn.Layer):
"""CSP (Cross Stage Partial) layer with 3 convs, named C3 in YOLOv5"""
def __init__(self,
in_channels,
out_channels,
num_blocks=1,
shortcut=True,
expansion=0.5,
depthwise=False,
bias=False,
act="silu"):
super(CSPLayer, self).__init__()
hidden_channels = int(out_channels * expansion)
self.conv1 = BaseConv(
in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act)
self.conv2 = BaseConv(
in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act)
self.bottlenecks = nn.Sequential(* [
BottleNeck(
hidden_channels,
hidden_channels,
shortcut=shortcut,
expansion=1.0,
depthwise=depthwise,
bias=bias,
act=act) for _ in range(num_blocks)
])
self.conv3 = BaseConv(
hidden_channels * 2,
out_channels,
ksize=1,
stride=1,
bias=bias,
act=act)
def forward(self, x):
x_1 = self.conv1(x)
x_1 = self.bottlenecks(x_1)
x_2 = self.conv2(x)
x = paddle.concat([x_1, x_2], axis=1)
x = self.conv3(x)
return x
@register
@serializable
class CSPDarkNet(nn.Layer):
"""
CSPDarkNet backbone.
Args:
arch (str): Architecture of CSPDarkNet, from {P5, P6, X}, default as X,
and 'X' means used in YOLOX, 'P5/P6' means used in YOLOv5.
depth_mult (float): Depth multiplier, multiply number of channels in
each layer, default as 1.0.
width_mult (float): Width multiplier, multiply number of blocks in
CSPLayer, default as 1.0.
depthwise (bool): Whether to use depth-wise conv layer.
act (str): Activation function type, default as 'silu'.
return_idx (list): Index of stages whose feature maps are returned.
"""
__shared__ = ['depth_mult', 'width_mult', 'act']
# in_channels, out_channels, num_blocks, add_shortcut, use_spp(use_sppf)
# 'X' means setting used in YOLOX, 'P5/P6' means setting used in YOLOv5.
arch_settings = {
'X': [[64, 128, 3, True, False], [128, 256, 9, True, False],
[256, 512, 9, True, False], [512, 1024, 3, False, True]],
'P5': [[64, 128, 3, True, False], [128, 256, 6, True, False],
[256, 512, 9, True, False], [512, 1024, 3, True, True]],
'P6': [[64, 128, 3, True, False], [128, 256, 6, True, False],
[256, 512, 9, True, False], [512, 768, 3, True, False],
[768, 1024, 3, True, True]],
}
def __init__(self,
arch='X',
depth_mult=1.0,
width_mult=1.0,
depthwise=False,
act='silu',
return_idx=[2, 3, 4]):
super(CSPDarkNet, self).__init__()
self.arch = arch
self.return_idx = return_idx
Conv = DWConv if depthwise else BaseConv
arch_setting = self.arch_settings[arch]
base_channels = int(arch_setting[0][0] * width_mult)
# Note: differences between the latest YOLOv5 and the original YOLOX
# 1. self.stem, use SPPF(in YOLOv5) or SPP(in YOLOX)
# 2. use SPPF(in YOLOv5) or SPP(in YOLOX)
# 3. put SPPF before(YOLOv5) or SPP after(YOLOX) the last cspdark block's CSPLayer
# 4. whether SPPF(SPP)'CSPLayer add shortcut, True in YOLOv5, False in YOLOX
if arch in ['P5', 'P6']:
# in the latest YOLOv5, use Conv stem, and SPPF (fast, only single spp kernal size)
self.stem = Conv(
3, base_channels, ksize=6, stride=2, bias=False, act=act)
spp_kernal_sizes = 5
elif arch in ['X']:
# in the original YOLOX, use Focus stem, and SPP (three spp kernal sizes)
self.stem = Focus(
3, base_channels, ksize=3, stride=1, bias=False, act=act)
spp_kernal_sizes = (5, 9, 13)
else:
raise AttributeError("Unsupported arch type: {}".format(arch))
_out_channels = [base_channels]
layers_num = 1
self.csp_dark_blocks = []
for i, (in_channels, out_channels, num_blocks, shortcut,
use_spp) in enumerate(arch_setting):
in_channels = int(in_channels * width_mult)
out_channels = int(out_channels * width_mult)
_out_channels.append(out_channels)
num_blocks = max(round(num_blocks * depth_mult), 1)
stage = []
conv_layer = self.add_sublayer(
'layers{}.stage{}.conv_layer'.format(layers_num, i + 1),
Conv(
in_channels, out_channels, 3, 2, bias=False, act=act))
stage.append(conv_layer)
layers_num += 1
if use_spp and arch in ['X']:
# in YOLOX use SPPLayer
spp_layer = self.add_sublayer(
'layers{}.stage{}.spp_layer'.format(layers_num, i + 1),
SPPLayer(
out_channels,
out_channels,
kernel_sizes=spp_kernal_sizes,
bias=False,
act=act))
stage.append(spp_layer)
layers_num += 1
csp_layer = self.add_sublayer(
'layers{}.stage{}.csp_layer'.format(layers_num, i + 1),
CSPLayer(
out_channels,
out_channels,
num_blocks=num_blocks,
shortcut=shortcut,
depthwise=depthwise,
bias=False,
act=act))
stage.append(csp_layer)
layers_num += 1
if use_spp and arch in ['P5', 'P6']:
# in latest YOLOv5 use SPPFLayer instead of SPPLayer
sppf_layer = self.add_sublayer(
'layers{}.stage{}.sppf_layer'.format(layers_num, i + 1),
SPPFLayer(
out_channels,
out_channels,
ksize=5,
bias=False,
act=act))
stage.append(sppf_layer)
layers_num += 1
self.csp_dark_blocks.append(nn.Sequential(*stage))
self._out_channels = [_out_channels[i] for i in self.return_idx]
self.strides = [[2, 4, 8, 16, 32, 64][i] for i in self.return_idx]
def forward(self, inputs):
x = inputs['image']
outputs = []
x = self.stem(x)
for i, layer in enumerate(self.csp_dark_blocks):
x = layer(x)
if i + 1 in self.return_idx:
outputs.append(x)
return outputs
@property
def out_shape(self):
return [
ShapeSpec(
channels=c, stride=s)
for c, s in zip(self._out_channels, self.strides)
]
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
...@@ -5,6 +19,16 @@ from paddle import ParamAttr ...@@ -5,6 +19,16 @@ from paddle import ParamAttr
from paddle.regularizer import L2Decay from paddle.regularizer import L2Decay
from ppdet.core.workspace import register from ppdet.core.workspace import register
import math
import numpy as np
from ..initializer import bias_init_with_prob, constant_
from ..backbones.csp_darknet import BaseConv, DWConv
from ..losses import IouLoss
from ppdet.modeling.assigners.simota_assigner import SimOTAAssigner
from ppdet.modeling.bbox_utils import bbox_overlaps
__all__ = ['YOLOv3Head', 'YOLOXHead']
def _de_sigmoid(x, eps=1e-7): def _de_sigmoid(x, eps=1e-7):
x = paddle.clip(x, eps, 1. / eps) x = paddle.clip(x, eps, 1. / eps)
...@@ -122,3 +146,259 @@ class YOLOv3Head(nn.Layer): ...@@ -122,3 +146,259 @@ class YOLOv3Head(nn.Layer):
@classmethod @classmethod
def from_config(cls, cfg, input_shape): def from_config(cls, cfg, input_shape):
return {'in_channels': [i.channels for i in input_shape], } return {'in_channels': [i.channels for i in input_shape], }
@register
class YOLOXHead(nn.Layer):
__shared__ = ['num_classes', 'width_mult', 'act']
__inject__ = ['assigner', 'nms']
def __init__(self,
num_classes=80,
width_mult=1.0,
depthwise=False,
in_channels=[256, 512, 1024],
feat_channels=256,
fpn_strides=(8, 16, 32),
l1_epoch=285,
act='silu',
assigner=SimOTAAssigner(use_vfl=False),
nms='MultiClassNMS',
loss_weight={'cls': 1.0,
'obj': 1.0,
'iou': 5.0,
'l1': 1.0}):
super(YOLOXHead, self).__init__()
self._dtype = paddle.framework.get_default_dtype()
self.num_classes = num_classes
assert len(in_channels) > 0, "in_channels length should > 0"
self.in_channels = in_channels
feat_channels = int(feat_channels * width_mult)
self.fpn_strides = fpn_strides
self.l1_epoch = l1_epoch
self.assigner = assigner
self.nms = nms
self.loss_weight = loss_weight
self.iou_loss = IouLoss(loss_weight=1.0) # default loss_weight 2.5
ConvBlock = DWConv if depthwise else BaseConv
self.stem_conv = nn.LayerList()
self.conv_cls = nn.LayerList()
self.conv_reg = nn.LayerList() # reg [x,y,w,h] + obj
for in_c in self.in_channels:
self.stem_conv.append(BaseConv(in_c, feat_channels, 1, 1, act=act))
self.conv_cls.append(
nn.Sequential(* [
ConvBlock(
feat_channels, feat_channels, 3, 1, act=act), ConvBlock(
feat_channels, feat_channels, 3, 1, act=act),
nn.Conv2D(
feat_channels,
self.num_classes,
1,
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
]))
self.conv_reg.append(
nn.Sequential(* [
ConvBlock(
feat_channels, feat_channels, 3, 1, act=act),
ConvBlock(
feat_channels, feat_channels, 3, 1, act=act),
nn.Conv2D(
feat_channels,
4 + 1, # reg [x,y,w,h] + obj
1,
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
]))
self._init_weights()
@classmethod
def from_config(cls, cfg, input_shape):
return {'in_channels': [i.channels for i in input_shape], }
def _init_weights(self):
bias_cls = bias_init_with_prob(0.01)
bias_reg = paddle.full([5], math.log(5.), dtype=self._dtype)
bias_reg[:2] = 0.
bias_reg[-1] = bias_cls
for cls_, reg_ in zip(self.conv_cls, self.conv_reg):
constant_(cls_[-1].weight)
constant_(cls_[-1].bias, bias_cls)
constant_(reg_[-1].weight)
reg_[-1].bias.set_value(bias_reg)
def _generate_anchor_point(self, feat_sizes, strides, offset=0.):
anchor_points, stride_tensor = [], []
num_anchors_list = []
for feat_size, stride in zip(feat_sizes, strides):
h, w = feat_size
x = (paddle.arange(w) + offset) * stride
y = (paddle.arange(h) + offset) * stride
y, x = paddle.meshgrid(y, x)
anchor_points.append(paddle.stack([x, y], axis=-1).reshape([-1, 2]))
stride_tensor.append(
paddle.full(
[len(anchor_points[-1]), 1], stride, dtype=self._dtype))
num_anchors_list.append(len(anchor_points[-1]))
anchor_points = paddle.concat(anchor_points).astype(self._dtype)
anchor_points.stop_gradient = True
stride_tensor = paddle.concat(stride_tensor)
stride_tensor.stop_gradient = True
return anchor_points, stride_tensor, num_anchors_list
def forward(self, feats, targets=None):
assert len(feats) == len(self.fpn_strides), \
"The size of feats is not equal to size of fpn_strides"
feat_sizes = [[f.shape[-2], f.shape[-1]] for f in feats]
cls_score_list, reg_pred_list = [], []
obj_score_list = []
for i, feat in enumerate(feats):
feat = self.stem_conv[i](feat)
cls_logit = self.conv_cls[i](feat)
reg_pred = self.conv_reg[i](feat)
# cls prediction
cls_score = F.sigmoid(cls_logit)
cls_score_list.append(cls_score.flatten(2).transpose([0, 2, 1]))
# reg prediction
reg_xywh, obj_logit = paddle.split(reg_pred, [4, 1], axis=1)
reg_xywh = reg_xywh.flatten(2).transpose([0, 2, 1])
reg_pred_list.append(reg_xywh)
# obj prediction
obj_score = F.sigmoid(obj_logit)
obj_score_list.append(obj_score.flatten(2).transpose([0, 2, 1]))
cls_score_list = paddle.concat(cls_score_list, axis=1)
reg_pred_list = paddle.concat(reg_pred_list, axis=1)
obj_score_list = paddle.concat(obj_score_list, axis=1)
# bbox decode
anchor_points, stride_tensor, _ =\
self._generate_anchor_point(feat_sizes, self.fpn_strides)
reg_xy, reg_wh = paddle.split(reg_pred_list, 2, axis=-1)
reg_xy += (anchor_points / stride_tensor)
reg_wh = paddle.exp(reg_wh) * 0.5
bbox_pred_list = paddle.concat(
[reg_xy - reg_wh, reg_xy + reg_wh], axis=-1)
if self.training:
anchor_points, stride_tensor, num_anchors_list =\
self._generate_anchor_point(feat_sizes, self.fpn_strides, 0.5)
yolox_losses = self.get_loss([
cls_score_list, bbox_pred_list, obj_score_list, anchor_points,
stride_tensor, num_anchors_list
], targets)
return yolox_losses
else:
pred_scores = (cls_score_list * obj_score_list).sqrt()
return pred_scores, bbox_pred_list, stride_tensor
def get_loss(self, head_outs, targets):
pred_cls, pred_bboxes, pred_obj,\
anchor_points, stride_tensor, num_anchors_list = head_outs
gt_labels = targets['gt_class']
gt_bboxes = targets['gt_bbox']
pred_scores = (pred_cls * pred_obj).sqrt()
# label assignment
center_and_strides = paddle.concat(
[anchor_points, stride_tensor, stride_tensor], axis=-1)
pos_num_list, label_list, bbox_target_list = [], [], []
for pred_score, pred_bbox, gt_box, gt_label in zip(
pred_scores.detach(),
pred_bboxes.detach() * stride_tensor, gt_bboxes, gt_labels):
pos_num, label, _, bbox_target = self.assigner(
pred_score, center_and_strides, pred_bbox, gt_box, gt_label)
pos_num_list.append(pos_num)
label_list.append(label)
bbox_target_list.append(bbox_target)
labels = paddle.to_tensor(np.stack(label_list, axis=0))
bbox_targets = paddle.to_tensor(np.stack(bbox_target_list, axis=0))
bbox_targets /= stride_tensor # rescale bbox
# 1. obj score loss
mask_positive = (labels != self.num_classes)
loss_obj = F.binary_cross_entropy(
pred_obj,
mask_positive.astype(pred_obj.dtype).unsqueeze(-1),
reduction='sum')
num_pos = sum(pos_num_list)
if num_pos > 0:
num_pos = paddle.to_tensor(num_pos, dtype=self._dtype).clip(min=1)
loss_obj /= num_pos
# 2. iou loss
bbox_mask = mask_positive.unsqueeze(-1).tile([1, 1, 4])
pred_bboxes_pos = paddle.masked_select(pred_bboxes,
bbox_mask).reshape([-1, 4])
assigned_bboxes_pos = paddle.masked_select(
bbox_targets, bbox_mask).reshape([-1, 4])
bbox_iou = bbox_overlaps(pred_bboxes_pos, assigned_bboxes_pos)
bbox_iou = paddle.diag(bbox_iou)
loss_iou = self.iou_loss(
pred_bboxes_pos.split(
4, axis=-1),
assigned_bboxes_pos.split(
4, axis=-1))
loss_iou = loss_iou.sum() / num_pos
# 3. cls loss
cls_mask = mask_positive.unsqueeze(-1).tile(
[1, 1, self.num_classes])
pred_cls_pos = paddle.masked_select(
pred_cls, cls_mask).reshape([-1, self.num_classes])
assigned_cls_pos = paddle.masked_select(labels, mask_positive)
assigned_cls_pos = F.one_hot(assigned_cls_pos,
self.num_classes + 1)[..., :-1]
assigned_cls_pos *= bbox_iou.unsqueeze(-1)
loss_cls = F.binary_cross_entropy(
pred_cls_pos, assigned_cls_pos, reduction='sum')
loss_cls /= num_pos
# 4. l1 loss
if targets['epoch_id'] >= self.l1_epoch:
loss_l1 = F.l1_loss(
pred_bboxes_pos, assigned_bboxes_pos, reduction='sum')
loss_l1 /= num_pos
else:
loss_l1 = paddle.zeros([1])
loss_l1.stop_gradient = False
else:
loss_cls = paddle.zeros([1])
loss_iou = paddle.zeros([1])
loss_l1 = paddle.zeros([1])
loss_cls.stop_gradient = False
loss_iou.stop_gradient = False
loss_l1.stop_gradient = False
loss = self.loss_weight['obj'] * loss_obj + \
self.loss_weight['cls'] * loss_cls + \
self.loss_weight['iou'] * loss_iou
if targets['epoch_id'] >= self.l1_epoch:
loss += (self.loss_weight['l1'] * loss_l1)
yolox_losses = {
'loss': loss,
'loss_cls': loss_cls,
'loss_obj': loss_obj,
'loss_iou': loss_iou,
'loss_l1': loss_l1,
}
return yolox_losses
def post_process(self, head_outs, img_shape, scale_factor):
pred_scores, pred_bboxes, stride_tensor = head_outs
pred_scores = pred_scores.transpose([0, 2, 1])
pred_bboxes *= stride_tensor
# scale bbox to origin image
scale_factor = scale_factor.flip(-1).tile([1, 2]).unsqueeze(1)
pred_bboxes /= scale_factor
bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
return bbox_pred, bbox_num
...@@ -273,7 +273,8 @@ def linear_init_(module): ...@@ -273,7 +273,8 @@ def linear_init_(module):
def conv_init_(module): def conv_init_(module):
bound = 1 / np.sqrt(np.prod(module.weight.shape[1:])) bound = 1 / np.sqrt(np.prod(module.weight.shape[1:]))
uniform_(module.weight, -bound, bound) uniform_(module.weight, -bound, bound)
uniform_(module.bias, -bound, bound) if module.bias is not None:
uniform_(module.bias, -bound, bound)
def bias_init_with_prob(prior_prob=0.01): def bias_init_with_prob(prior_prob=0.01):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle import paddle
...@@ -19,8 +19,9 @@ from ppdet.core.workspace import register, serializable ...@@ -19,8 +19,9 @@ from ppdet.core.workspace import register, serializable
from ppdet.modeling.layers import DropBlock from ppdet.modeling.layers import DropBlock
from ..backbones.darknet import ConvBNLayer from ..backbones.darknet import ConvBNLayer
from ..shape_spec import ShapeSpec from ..shape_spec import ShapeSpec
from ..backbones.csp_darknet import BaseConv, DWConv, CSPLayer
__all__ = ['YOLOv3FPN', 'PPYOLOFPN', 'PPYOLOTinyFPN', 'PPYOLOPAN'] __all__ = ['YOLOv3FPN', 'PPYOLOFPN', 'PPYOLOTinyFPN', 'PPYOLOPAN', 'YOLOCSPPAN']
def add_coord(x, data_format): def add_coord(x, data_format):
...@@ -986,3 +987,102 @@ class PPYOLOPAN(nn.Layer): ...@@ -986,3 +987,102 @@ class PPYOLOPAN(nn.Layer):
@property @property
def out_shape(self): def out_shape(self):
return [ShapeSpec(channels=c) for c in self._out_channels] return [ShapeSpec(channels=c) for c in self._out_channels]
@register
@serializable
class YOLOCSPPAN(nn.Layer):
"""
YOLO CSP-PAN, used in YOLOv5 and YOLOX.
"""
__shared__ = ['depth_mult', 'act']
def __init__(self,
depth_mult=1.0,
in_channels=[256, 512, 1024],
depthwise=False,
act='silu'):
super(YOLOCSPPAN, self).__init__()
self.in_channels = in_channels
self._out_channels = in_channels
Conv = DWConv if depthwise else BaseConv
self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
# top-down fpn
self.lateral_convs = nn.LayerList()
self.fpn_blocks = nn.LayerList()
for idx in range(len(in_channels) - 1, 0, -1):
self.lateral_convs.append(
BaseConv(
int(in_channels[idx]),
int(in_channels[idx - 1]),
1,
1,
act=act))
self.fpn_blocks.append(
CSPLayer(
int(in_channels[idx - 1] * 2),
int(in_channels[idx - 1]),
round(3 * depth_mult),
shortcut=False,
depthwise=depthwise,
act=act))
# bottom-up pan
self.downsample_convs = nn.LayerList()
self.pan_blocks = nn.LayerList()
for idx in range(len(in_channels) - 1):
self.downsample_convs.append(
Conv(
int(in_channels[idx]),
int(in_channels[idx]),
3,
stride=2,
act=act))
self.pan_blocks.append(
CSPLayer(
int(in_channels[idx] * 2),
int(in_channels[idx + 1]),
round(3 * depth_mult),
shortcut=False,
depthwise=depthwise,
act=act))
def forward(self, feats, for_mot=False):
assert len(feats) == len(self.in_channels)
# top-down fpn
inner_outs = [feats[-1]]
for idx in range(len(self.in_channels) - 1, 0, -1):
feat_heigh = inner_outs[0]
feat_low = feats[idx - 1]
feat_heigh = self.lateral_convs[len(self.in_channels) - 1 - idx](
feat_heigh)
inner_outs[0] = feat_heigh
upsample_feat = self.upsample(feat_heigh)
inner_out = self.fpn_blocks[len(self.in_channels) - 1 - idx](
paddle.concat(
[upsample_feat, feat_low], axis=1))
inner_outs.insert(0, inner_out)
# bottom-up pan
outs = [inner_outs[0]]
for idx in range(len(self.in_channels) - 1):
feat_low = outs[-1]
feat_height = inner_outs[idx + 1]
downsample_feat = self.downsample_convs[idx](feat_low)
out = self.pan_blocks[idx](paddle.concat(
[downsample_feat, feat_height], axis=1))
outs.append(out)
return outs
@classmethod
def from_config(cls, cfg, input_shape):
return {'in_channels': [i.channels for i in input_shape], }
@property
def out_shape(self):
return [ShapeSpec(channels=c) for c in self._out_channels]
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle import paddle
...@@ -28,7 +28,7 @@ __all__ = [ ...@@ -28,7 +28,7 @@ __all__ = [
'roi_pool', 'roi_align', 'prior_box', 'generate_proposals', 'roi_pool', 'roi_align', 'prior_box', 'generate_proposals',
'iou_similarity', 'box_coder', 'yolo_box', 'multiclass_nms', 'iou_similarity', 'box_coder', 'yolo_box', 'multiclass_nms',
'distribute_fpn_proposals', 'collect_fpn_proposals', 'matrix_nms', 'distribute_fpn_proposals', 'collect_fpn_proposals', 'matrix_nms',
'batch_norm', 'mish', 'swish', 'identity' 'batch_norm', 'get_activation', 'mish', 'swish', 'identity'
] ]
...@@ -106,6 +106,18 @@ def batch_norm(ch, ...@@ -106,6 +106,18 @@ def batch_norm(ch,
return norm_layer return norm_layer
def get_activation(name="silu"):
if name == "silu":
module = nn.Silu()
elif name == "relu":
module = nn.ReLU()
elif name == "leakyrelu":
module = nn.LeakyReLU(0.1)
else:
raise AttributeError("Unsupported act type: {}".format(name))
return module
@paddle.jit.not_to_static @paddle.jit.not_to_static
def roi_pool(input, def roi_pool(input,
rois, rois,
......
...@@ -209,6 +209,33 @@ class BurninWarmup(object): ...@@ -209,6 +209,33 @@ class BurninWarmup(object):
return boundary, value return boundary, value
@serializable
class ExpWarmup(object):
"""
Warm up learning rate in exponential mode
Args:
steps (int): warm up steps.
epochs (int|None): use epochs as warm up steps, the priority
of `epochs` is higher than `steps`. Default: None.
"""
def __init__(self, steps=5, epochs=None):
super(ExpWarmup, self).__init__()
self.steps = steps
self.epochs = epochs
def __call__(self, base_lr, step_per_epoch):
boundary = []
value = []
warmup_steps = self.epochs * step_per_epoch if self.epochs is not None else self.steps
for i in range(warmup_steps + 1):
factor = (i / float(warmup_steps))**2
value.append(base_lr * factor)
if i > 0:
boundary.append(i)
return boundary, value
@register @register
class LearningRate(object): class LearningRate(object):
""" """
...@@ -331,7 +358,8 @@ class ModelEMA(object): ...@@ -331,7 +358,8 @@ class ModelEMA(object):
Ema's parameter are updated with the formula: Ema's parameter are updated with the formula:
`ema_param = decay * ema_param + (1 - decay) * cur_param`. `ema_param = decay * ema_param + (1 - decay) * cur_param`.
Defaults is 0.9998. Defaults is 0.9998.
use_thres_step (bool): Whether set decay by thres_step or not ema_decay_type (str): type in ['threshold', 'normal', 'exponential'],
'threshold' as default.
cycle_epoch (int): The epoch of interval to reset ema_param and cycle_epoch (int): The epoch of interval to reset ema_param and
step. Defaults is -1, which means not reset. Its function is to step. Defaults is -1, which means not reset. Its function is to
add a regular effect to ema, which is set according to experience add a regular effect to ema, which is set according to experience
...@@ -341,7 +369,7 @@ class ModelEMA(object): ...@@ -341,7 +369,7 @@ class ModelEMA(object):
def __init__(self, def __init__(self,
model, model,
decay=0.9998, decay=0.9998,
use_thres_step=False, ema_decay_type='threshold',
cycle_epoch=-1): cycle_epoch=-1):
self.step = 0 self.step = 0
self.epoch = 0 self.epoch = 0
...@@ -349,7 +377,7 @@ class ModelEMA(object): ...@@ -349,7 +377,7 @@ class ModelEMA(object):
self.state_dict = dict() self.state_dict = dict()
for k, v in model.state_dict().items(): for k, v in model.state_dict().items():
self.state_dict[k] = paddle.zeros_like(v) self.state_dict[k] = paddle.zeros_like(v)
self.use_thres_step = use_thres_step self.ema_decay_type = ema_decay_type
self.cycle_epoch = cycle_epoch self.cycle_epoch = cycle_epoch
self._model_state = { self._model_state = {
...@@ -370,8 +398,10 @@ class ModelEMA(object): ...@@ -370,8 +398,10 @@ class ModelEMA(object):
self.step = step self.step = step
def update(self, model=None): def update(self, model=None):
if self.use_thres_step: if self.ema_decay_type == 'threshold':
decay = min(self.decay, (1 + self.step) / (10 + self.step)) decay = min(self.decay, (1 + self.step) / (10 + self.step))
elif self.ema_decay_type == 'exponential':
decay = self.decay * (1 - math.exp(-(self.step + 1) / 2000))
else: else:
decay = self.decay decay = self.decay
self._decay = decay self._decay = decay
...@@ -394,7 +424,8 @@ class ModelEMA(object): ...@@ -394,7 +424,8 @@ class ModelEMA(object):
return self.state_dict return self.state_dict
state_dict = dict() state_dict = dict()
for k, v in self.state_dict.items(): for k, v in self.state_dict.items():
v = v / (1 - self._decay**self.step) if self.ema_decay_type != 'exponential':
v = v / (1 - self._decay**self.step)
v.stop_gradient = True v.stop_gradient = True
state_dict[k] = v state_dict[k] = v
self.epoch += 1 self.epoch += 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册