diff --git a/deploy/python/infer.py b/deploy/python/infer.py index c8737ccfdb5a87364ca9a4c250dd0204df5798fc..09eec1153796ae43f41f0881db22029ed6056e45 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -31,16 +31,32 @@ sys.path.insert(0, parent_path) from benchmark_utils import PaddleInferBenchmark from picodet_postprocess import PicoDetPostProcess -from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, decode_image +from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop from visualize import visualize_box_mask from utils import argsparser, Timer, get_current_memory_mb # Global dictionary SUPPORT_MODELS = { - 'YOLO', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet', 'S2ANet', 'JDE', - 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet', 'TOOD', 'RetinaNet', - 'StrongBaseline', 'STGCN' + 'YOLO', + 'RCNN', + 'SSD', + 'Face', + 'FCOS', + 'SOLOv2', + 'TTFNet', + 'S2ANet', + 'JDE', + 'FairMOT', + 'DeepSORT', + 'GFL', + 'PicoDet', + 'CenterNet', + 'TOOD', + 'RetinaNet', + 'StrongBaseline', + 'STGCN', + 'YOLOX', } diff --git a/deploy/python/preprocess.py b/deploy/python/preprocess.py index 998367a349b3e14045991c2e1d2af2e6ec94e03d..b8cf256d508e66798063ada9f62d45dbee8fca07 100644 --- a/deploy/python/preprocess.py +++ b/deploy/python/preprocess.py @@ -246,6 +246,81 @@ class LetterBoxResize(object): return im, im_info +class Pad(object): + def __init__(self, + size=None, + size_divisor=32, + pad_mode=0, + offsets=None, + fill_value=(127.5, 127.5, 127.5)): + """ + Pad image to a specified size or multiple of size_divisor. + Args: + size (int, Sequence): image target size, if None, pad to multiple of size_divisor, default None + size_divisor (int): size divisor, default 32 + pad_mode (int): pad mode, currently only supports four modes [-1, 0, 1, 2]. if -1, use specified offsets + if 0, only pad to right and bottom. if 1, pad according to center. if 2, only pad left and top + offsets (list): [offset_x, offset_y], specify offset while padding, only supported pad_mode=-1 + fill_value (bool): rgb value of pad area, default (127.5, 127.5, 127.5) + """ + super(Pad, self).__init__() + if isinstance(size, int): + size = [size, size] + + assert pad_mode in [ + -1, 0, 1, 2 + ], 'currently only supports four modes [-1, 0, 1, 2]' + if pad_mode == -1: + assert offsets, 'if pad_mode is -1, offsets should not be None' + + self.size = size + self.size_divisor = size_divisor + self.pad_mode = pad_mode + self.fill_value = fill_value + self.offsets = offsets + + def apply_image(self, image, offsets, im_size, size): + x, y = offsets + im_h, im_w = im_size + h, w = size + canvas = np.ones((h, w, 3), dtype=np.float32) + canvas *= np.array(self.fill_value, dtype=np.float32) + canvas[y:y + im_h, x:x + im_w, :] = image.astype(np.float32) + return canvas + + def __call__(self, im, im_info): + im_h, im_w = im.shape[:2] + if self.size: + h, w = self.size + assert ( + im_h <= h and im_w <= w + ), '(h, w) of target size should be greater than (im_h, im_w)' + else: + h = int(np.ceil(im_h / self.size_divisor) * self.size_divisor) + w = int(np.ceil(im_w / self.size_divisor) * self.size_divisor) + + if h == im_h and w == im_w: + im = im.astype(np.float32) + return im, im_info + + if self.pad_mode == -1: + offset_x, offset_y = self.offsets + elif self.pad_mode == 0: + offset_y, offset_x = 0, 0 + elif self.pad_mode == 1: + offset_y, offset_x = (h - im_h) // 2, (w - im_w) // 2 + else: + offset_y, offset_x = h - im_h, w - im_w + + offsets, im_size, size = [offset_x, offset_y], [im_h, im_w], [h, w] + im = self.apply_image(im, offsets, im_size, size) + + if self.pad_mode == 0: + return im, im_info + + return im, im_info + + class WarpAffine(object): """Warp affine the image """ diff --git a/ppdet/data/source/dataset.py b/ppdet/data/source/dataset.py index e086dc97d3029356fc1865ae8b46f94ea21a75a0..b8193c1c92eb47f26db27e9fd601c1c657e8dd63 100644 --- a/ppdet/data/source/dataset.py +++ b/ppdet/data/source/dataset.py @@ -5,7 +5,7 @@ # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -75,7 +75,7 @@ class DetDataset(Dataset): n = len(self.roidbs) roidb = [roidb, ] + [ copy.deepcopy(self.roidbs[np.random.randint(n)]) - for _ in range(3) + for _ in range(4) ] if isinstance(roidb, Sequence): for r in roidb: diff --git a/ppdet/data/transform/operators.py b/ppdet/data/transform/operators.py index 7eae8db0ba67a87a1337bd4a289d505aee365aaa..c3fb37e262d327234f505a6fe8bb681ad1894fcb 100644 --- a/ppdet/data/transform/operators.py +++ b/ppdet/data/transform/operators.py @@ -2034,13 +2034,14 @@ class Pad(BaseOperator): if self.size: h, w = self.size assert ( - im_h < h and im_w < w + im_h <= h and im_w <= w ), '(h, w) of target size should be greater than (im_h, im_w)' else: h = int(np.ceil(im_h / self.size_divisor) * self.size_divisor) w = int(np.ceil(im_w / self.size_divisor) * self.size_divisor) if h == im_h and w == im_w: + sample['image'] = im.astype(np.float32) return sample if self.pad_mode == -1: @@ -2139,16 +2140,29 @@ class Rbox2Poly(BaseOperator): @register_op class AugmentHSV(BaseOperator): - def __init__(self, fraction=0.50, is_bgr=True): - """ - Augment the SV channel of image data. - Args: - fraction (float): the fraction for augment. Default: 0.5. - is_bgr (bool): whether the image is BGR mode. Default: True. - """ + """ + Augment the SV channel of image data. + Args: + fraction (float): the fraction for augment. Default: 0.5. + is_bgr (bool): whether the image is BGR mode. Default: True. + hgain (float): H channel gains + sgain (float): S channel gains + vgain (float): V channel gains + """ + + def __init__(self, + fraction=0.50, + is_bgr=True, + hgain=None, + sgain=None, + vgain=None): super(AugmentHSV, self).__init__() self.fraction = fraction self.is_bgr = is_bgr + self.hgain = hgain + self.sgain = sgain + self.vgain = vgain + self.use_hsvgain = False if hgain is None else True def apply(self, sample, context=None): img = sample['image'] @@ -2156,21 +2170,33 @@ class AugmentHSV(BaseOperator): img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) else: img_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) - S = img_hsv[:, :, 1].astype(np.float32) - V = img_hsv[:, :, 2].astype(np.float32) - a = (random.random() * 2 - 1) * self.fraction + 1 - S *= a - if a > 1: - np.clip(S, a_min=0, a_max=255, out=S) + if self.use_hsvgain: + hsv_augs = np.random.uniform( + -1, 1, 3) * [self.hgain, self.sgain, self.vgain] + # random selection of h, s, v + hsv_augs *= np.random.randint(0, 2, 3) + img_hsv[..., 0] = (img_hsv[..., 0] + hsv_augs[0]) % 180 + img_hsv[..., 1] = np.clip(img_hsv[..., 1] + hsv_augs[1], 0, 255) + img_hsv[..., 2] = np.clip(img_hsv[..., 2] + hsv_augs[2], 0, 255) - a = (random.random() * 2 - 1) * self.fraction + 1 - V *= a - if a > 1: - np.clip(V, a_min=0, a_max=255, out=V) + else: + S = img_hsv[:, :, 1].astype(np.float32) + V = img_hsv[:, :, 2].astype(np.float32) + + a = (random.random() * 2 - 1) * self.fraction + 1 + S *= a + if a > 1: + np.clip(S, a_min=0, a_max=255, out=S) + + a = (random.random() * 2 - 1) * self.fraction + 1 + V *= a + if a > 1: + np.clip(V, a_min=0, a_max=255, out=V) + + img_hsv[:, :, 1] = S.astype(np.uint8) + img_hsv[:, :, 2] = V.astype(np.uint8) - img_hsv[:, :, 1] = S.astype(np.uint8) - img_hsv[:, :, 2] = V.astype(np.uint8) if self.is_bgr: cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) else: @@ -3018,3 +3044,373 @@ class CenterRandColor(BaseOperator): img = func(img, img_gray) sample['image'] = img return sample + + +@register_op +class Mosaic(BaseOperator): + """ Mosaic operator for image and gt_bboxes + The code is based on https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/data/datasets/mosaicdetection.py + + 1. get mosaic coords + 2. clip bbox and get mosaic_labels + 3. random_affine augment + 4. Mixup augment as copypaste (optinal), not used in tiny/nano + + Args: + prob (float): probability of using Mosaic, 1.0 as default + input_dim (list[int]): input shape + degrees (list[2]): the rotate range to apply, transform range is [min, max] + translate (list[2]): the translate range to apply, transform range is [min, max] + scale (list[2]): the scale range to apply, transform range is [min, max] + shear (list[2]): the shear range to apply, transform range is [min, max] + enable_mixup (bool): whether to enable Mixup or not + mixup_prob (float): probability of using Mixup, 1.0 as default + mixup_scale (list[int]): scale range of Mixup + remove_outside_box (bool): whether remove outside boxes, False as + default in COCO dataset, True in MOT dataset + """ + + def __init__(self, + prob=1.0, + input_dim=[640, 640], + degrees=[-10, 10], + translate=[-0.1, 0.1], + scale=[0.1, 2], + shear=[-2, 2], + enable_mixup=True, + mixup_prob=1.0, + mixup_scale=[0.5, 1.5], + remove_outside_box=False): + super(Mosaic, self).__init__() + self.prob = prob + self.input_dim = input_dim + self.degrees = degrees + self.translate = translate + self.scale = scale + self.shear = shear + self.enable_mixup = enable_mixup + self.mixup_prob = mixup_prob + self.mixup_scale = mixup_scale + self.remove_outside_box = remove_outside_box + + def get_mosaic_coords(self, mosaic_idx, xc, yc, w, h, input_h, input_w): + # (x1, y1, x2, y2) means coords in large image, + # small_coords means coords in small image in mosaic aug. + if mosaic_idx == 0: + # top left + x1, y1, x2, y2 = max(xc - w, 0), max(yc - h, 0), xc, yc + small_coords = w - (x2 - x1), h - (y2 - y1), w, h + elif mosaic_idx == 1: + # top right + x1, y1, x2, y2 = xc, max(yc - h, 0), min(xc + w, input_w * 2), yc + small_coords = 0, h - (y2 - y1), min(w, x2 - x1), h + elif mosaic_idx == 2: + # bottom left + x1, y1, x2, y2 = max(xc - w, 0), yc, xc, min(input_h * 2, yc + h) + small_coords = w - (x2 - x1), 0, w, min(y2 - y1, h) + elif mosaic_idx == 3: + # bottom right + x1, y1, x2, y2 = xc, yc, min(xc + w, input_w * 2), min(input_h * 2, + yc + h) + small_coords = 0, 0, min(w, x2 - x1), min(y2 - y1, h) + + return (x1, y1, x2, y2), small_coords + + def random_affine_augment(self, + img, + labels=[], + input_dim=[640, 640], + degrees=[-10, 10], + scales=[0.1, 2], + shears=[-2, 2], + translates=[-0.1, 0.1]): + # random rotation and scale + degree = random.uniform(degrees[0], degrees[1]) + scale = random.uniform(scales[0], scales[1]) + assert scale > 0, "Argument scale should be positive." + R = cv2.getRotationMatrix2D(angle=degree, center=(0, 0), scale=scale) + M = np.ones([2, 3]) + + # random shear + shear = random.uniform(shears[0], shears[1]) + shear_x = math.tan(shear * math.pi / 180) + shear_y = math.tan(shear * math.pi / 180) + M[0] = R[0] + shear_y * R[1] + M[1] = R[1] + shear_x * R[0] + + # random translation + translate = random.uniform(translates[0], translates[1]) + translation_x = translate * input_dim[0] + translation_y = translate * input_dim[1] + M[0, 2] = translation_x + M[1, 2] = translation_y + + # warpAffine + img = cv2.warpAffine( + img, M, dsize=input_dim, borderValue=(114, 114, 114)) + + num_gts = len(labels) + if num_gts > 0: + # warp corner points + corner_points = np.ones((4 * num_gts, 3)) + corner_points[:, :2] = labels[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape( + 4 * num_gts, 2) # x1y1, x2y2, x1y2, x2y1 + # apply affine transform + corner_points = corner_points @M.T + corner_points = corner_points.reshape(num_gts, 8) + + # create new boxes + corner_xs = corner_points[:, 0::2] + corner_ys = corner_points[:, 1::2] + new_bboxes = np.concatenate((corner_xs.min(1), corner_ys.min(1), + corner_xs.max(1), corner_ys.max(1))) + new_bboxes = new_bboxes.reshape(4, num_gts).T + + # clip boxes + new_bboxes[:, 0::2] = np.clip(new_bboxes[:, 0::2], 0, input_dim[0]) + new_bboxes[:, 1::2] = np.clip(new_bboxes[:, 1::2], 0, input_dim[1]) + labels[:, :4] = new_bboxes + + return img, labels + + def __call__(self, sample, context=None): + if not isinstance(sample, Sequence): + return sample + + assert len( + sample) == 5, "Mosaic needs 5 samples, 4 for mosaic and 1 for mixup." + if np.random.uniform(0., 1.) > self.prob: + return sample[0] + + mosaic_gt_bbox, mosaic_gt_class, mosaic_is_crowd = [], [], [] + input_h, input_w = self.input_dim + yc = int(random.uniform(0.5 * input_h, 1.5 * input_h)) + xc = int(random.uniform(0.5 * input_w, 1.5 * input_w)) + mosaic_img = np.full((input_h * 2, input_w * 2, 3), 114, dtype=np.uint8) + + # 1. get mosaic coords + for mosaic_idx, sp in enumerate(sample[:4]): + img = sp['image'] + gt_bbox = sp['gt_bbox'] + h0, w0 = img.shape[:2] + scale = min(1. * input_h / h0, 1. * input_w / w0) + img = cv2.resize( + img, (int(w0 * scale), int(h0 * scale)), + interpolation=cv2.INTER_LINEAR) + (h, w, c) = img.shape[:3] + + # suffix l means large image, while s means small image in mosaic aug. + (l_x1, l_y1, l_x2, l_y2), ( + s_x1, s_y1, s_x2, s_y2) = self.get_mosaic_coords( + mosaic_idx, xc, yc, w, h, input_h, input_w) + + mosaic_img[l_y1:l_y2, l_x1:l_x2] = img[s_y1:s_y2, s_x1:s_x2] + padw, padh = l_x1 - s_x1, l_y1 - s_y1 + + # Normalized xywh to pixel xyxy format + _gt_bbox = gt_bbox.copy() + if len(gt_bbox) > 0: + _gt_bbox[:, 0] = scale * gt_bbox[:, 0] + padw + _gt_bbox[:, 1] = scale * gt_bbox[:, 1] + padh + _gt_bbox[:, 2] = scale * gt_bbox[:, 2] + padw + _gt_bbox[:, 3] = scale * gt_bbox[:, 3] + padh + + mosaic_gt_bbox.append(_gt_bbox) + mosaic_gt_class.append(sp['gt_class']) + mosaic_is_crowd.append(sp['is_crowd']) + + # 2. clip bbox and get mosaic_labels([gt_bbox, gt_class, is_crowd]) + if len(mosaic_gt_bbox): + mosaic_gt_bbox = np.concatenate(mosaic_gt_bbox, 0) + mosaic_gt_class = np.concatenate(mosaic_gt_class, 0) + mosaic_is_crowd = np.concatenate(mosaic_is_crowd, 0) + mosaic_labels = np.concatenate([ + mosaic_gt_bbox, mosaic_gt_class.astype(mosaic_gt_bbox.dtype), + mosaic_is_crowd.astype(mosaic_gt_bbox.dtype) + ], 1) + if self.remove_outside_box: + # for MOT dataset + flag1 = mosaic_gt_bbox[:, 0] < 2 * input_w + flag2 = mosaic_gt_bbox[:, 2] > 0 + flag3 = mosaic_gt_bbox[:, 1] < 2 * input_h + flag4 = mosaic_gt_bbox[:, 3] > 0 + flag_all = flag1 * flag2 * flag3 * flag4 + mosaic_labels = mosaic_labels[flag_all] + else: + mosaic_labels[:, 0] = np.clip(mosaic_labels[:, 0], 0, + 2 * input_w) + mosaic_labels[:, 1] = np.clip(mosaic_labels[:, 1], 0, + 2 * input_h) + mosaic_labels[:, 2] = np.clip(mosaic_labels[:, 2], 0, + 2 * input_w) + mosaic_labels[:, 3] = np.clip(mosaic_labels[:, 3], 0, + 2 * input_h) + else: + mosaic_labels = np.zeros((1, 6)) + + # 3. random_affine augment + mosaic_img, mosaic_labels = self.random_affine_augment( + mosaic_img, + mosaic_labels, + input_dim=self.input_dim, + degrees=self.degrees, + translates=self.translate, + scales=self.scale, + shears=self.shear) + + # 4. Mixup augment as copypaste, https://arxiv.org/abs/2012.07177 + # optinal, not used(enable_mixup=False) in tiny/nano + if (self.enable_mixup and not len(mosaic_labels) == 0 and + random.random() < self.mixup_prob): + sample_mixup = sample[4] + mixup_img = sample_mixup['image'] + cp_labels = np.concatenate([ + sample_mixup['gt_bbox'], + sample_mixup['gt_class'].astype(mosaic_labels.dtype), + sample_mixup['is_crowd'].astype(mosaic_labels.dtype) + ], 1) + mosaic_img, mosaic_labels = self.mixup_augment( + mosaic_img, mosaic_labels, self.input_dim, cp_labels, mixup_img) + + sample0 = sample[0] + sample0['image'] = mosaic_img.astype(np.uint8) # can not be float32 + sample0['h'] = float(mosaic_img.shape[0]) + sample0['w'] = float(mosaic_img.shape[1]) + sample0['im_shape'][0] = sample0['h'] + sample0['im_shape'][1] = sample0['w'] + sample0['gt_bbox'] = mosaic_labels[:, :4].astype(np.float32) + sample0['gt_class'] = mosaic_labels[:, 4:5].astype(np.float32) + sample0['is_crowd'] = mosaic_labels[:, 5:6].astype(np.float32) + return sample0 + + def mixup_augment(self, origin_img, origin_labels, input_dim, cp_labels, + img): + jit_factor = random.uniform(*self.mixup_scale) + FLIP = random.uniform(0, 1) > 0.5 + if len(img.shape) == 3: + cp_img = np.ones( + (input_dim[0], input_dim[1], 3), dtype=np.uint8) * 114 + else: + cp_img = np.ones(input_dim, dtype=np.uint8) * 114 + + cp_scale_ratio = min(input_dim[0] / img.shape[0], + input_dim[1] / img.shape[1]) + resized_img = cv2.resize( + img, (int(img.shape[1] * cp_scale_ratio), + int(img.shape[0] * cp_scale_ratio)), + interpolation=cv2.INTER_LINEAR) + + cp_img[:int(img.shape[0] * cp_scale_ratio), :int(img.shape[ + 1] * cp_scale_ratio)] = resized_img + + cp_img = cv2.resize(cp_img, (int(cp_img.shape[1] * jit_factor), + int(cp_img.shape[0] * jit_factor))) + cp_scale_ratio *= jit_factor + + if FLIP: + cp_img = cp_img[:, ::-1, :] + + origin_h, origin_w = cp_img.shape[:2] + target_h, target_w = origin_img.shape[:2] + padded_img = np.zeros( + (max(origin_h, target_h), max(origin_w, target_w), 3), + dtype=np.uint8) + padded_img[:origin_h, :origin_w] = cp_img + + x_offset, y_offset = 0, 0 + if padded_img.shape[0] > target_h: + y_offset = random.randint(0, padded_img.shape[0] - target_h - 1) + if padded_img.shape[1] > target_w: + x_offset = random.randint(0, padded_img.shape[1] - target_w - 1) + padded_cropped_img = padded_img[y_offset:y_offset + target_h, x_offset: + x_offset + target_w] + + # adjust boxes + cp_bboxes_origin_np = cp_labels[:, :4].copy() + cp_bboxes_origin_np[:, 0::2] = np.clip(cp_bboxes_origin_np[:, 0::2] * + cp_scale_ratio, 0, origin_w) + cp_bboxes_origin_np[:, 1::2] = np.clip(cp_bboxes_origin_np[:, 1::2] * + cp_scale_ratio, 0, origin_h) + + if FLIP: + cp_bboxes_origin_np[:, 0::2] = ( + origin_w - cp_bboxes_origin_np[:, 0::2][:, ::-1]) + cp_bboxes_transformed_np = cp_bboxes_origin_np.copy() + if self.remove_outside_box: + # for MOT dataset + cp_bboxes_transformed_np[:, 0::2] -= x_offset + cp_bboxes_transformed_np[:, 1::2] -= y_offset + else: + cp_bboxes_transformed_np[:, 0::2] = np.clip( + cp_bboxes_transformed_np[:, 0::2] - x_offset, 0, target_w) + cp_bboxes_transformed_np[:, 1::2] = np.clip( + cp_bboxes_transformed_np[:, 1::2] - y_offset, 0, target_h) + + cls_labels = cp_labels[:, 4:5].copy() + crd_labels = cp_labels[:, 5:6].copy() + box_labels = cp_bboxes_transformed_np + labels = np.hstack((box_labels, cls_labels, crd_labels)) + if self.remove_outside_box: + labels = labels[labels[:, 0] < target_w] + labels = labels[labels[:, 2] > 0] + labels = labels[labels[:, 1] < target_h] + labels = labels[labels[:, 3] > 0] + + origin_labels = np.vstack((origin_labels, labels)) + origin_img = origin_img.astype(np.float32) + origin_img = 0.5 * origin_img + 0.5 * padded_cropped_img.astype( + np.float32) + + return origin_img.astype(np.uint8), origin_labels + + +@register_op +class PadResize(BaseOperator): + """ PadResize for image and gt_bbbox + + Args: + target_size (list[int]): input shape + fill_value (float): pixel value of padded image + """ + + def __init__(self, target_size, fill_value=114): + super(PadResize, self).__init__() + if isinstance(target_size, Integral): + target_size = [target_size, target_size] + self.target_size = target_size + self.fill_value = fill_value + + def _resize(self, img, bboxes, labels): + ratio = min(self.target_size[0] / img.shape[0], + self.target_size[1] / img.shape[1]) + w, h = int(img.shape[1] * ratio), int(img.shape[0] * ratio) + resized_img = cv2.resize(img, (w, h), interpolation=cv2.INTER_LINEAR) + + if len(bboxes) > 0: + bboxes *= ratio + mask = np.minimum(bboxes[:, 2] - bboxes[:, 0], + bboxes[:, 3] - bboxes[:, 1]) > 1 + bboxes = bboxes[mask] + labels = labels[mask] + return resized_img, bboxes, labels + + def _pad(self, img): + h, w, _ = img.shape + if h == self.target_size[0] and w == self.target_size[1]: + return img + padded_img = np.full( + (self.target_size[0], self.target_size[1], 3), + self.fill_value, + dtype=np.uint8) + padded_img[:h, :w] = img + return padded_img + + def apply(self, sample, context=None): + image = sample['image'] + bboxes = sample['gt_bbox'] + labels = sample['gt_class'] + image, bboxes, labels = self._resize(image, bboxes, labels) + sample['image'] = self._pad(image).astype(np.float32) + sample['gt_bbox'] = bboxes + sample['gt_class'] = labels + return sample diff --git a/ppdet/engine/export_utils.py b/ppdet/engine/export_utils.py index bddb4af07ca91f5b82389fdeba39ba77c1a1344e..12fb1ea6ffe369fef21274eb9a5221cf4e221812 100644 --- a/ppdet/engine/export_utils.py +++ b/ppdet/engine/export_utils.py @@ -48,6 +48,7 @@ TRT_MIN_SUBGRAPH = { 'PicoDet': 3, 'CenterNet': 5, 'TOOD': 5, + 'YOLOX': 8, } KEYPOINT_ARCH = ['HigherHRNet', 'TopDownHRNet'] @@ -147,6 +148,12 @@ def _dump_infer_config(config, path, image_shape, model): infer_cfg['min_subgraph_size'] = min_subgraph_size arch_state = True break + + if infer_arch == 'YOLOX': + infer_cfg['arch'] = infer_arch + infer_cfg['min_subgraph_size'] = TRT_MIN_SUBGRAPH[infer_arch] + arch_state = True + if not arch_state: logger.error( 'Architecture: {} is not supported for exporting model now.\n'. diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 622b9c3b290cb53fc40c4d07341ac115e8b73922..f15b992112beef3a93158bd15e1622cf63398a33 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -28,6 +28,7 @@ from PIL import Image, ImageOps, ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True import paddle +import paddle.nn as nn import paddle.distributed as dist from paddle.distributed import fleet from paddle import amp @@ -99,6 +100,12 @@ class Trainer(object): self.model = self.cfg.model self.is_loaded_weights = True + if cfg.architecture == 'YOLOX': + for k, m in self.model.named_sublayers(): + if isinstance(m, nn.BatchNorm2D): + m.epsilon = 1e-3 # for amp(fp16) + m.momentum = 0.97 # 0.03 in pytorch + #normalize params for deploy if 'slim' in cfg and cfg['slim_type'] == 'OFA': self.model.model.load_meanstd(cfg['TestReader'][ @@ -117,10 +124,11 @@ class Trainer(object): if self.use_ema: ema_decay = self.cfg.get('ema_decay', 0.9998) cycle_epoch = self.cfg.get('cycle_epoch', -1) + ema_decay_type = self.cfg.get('ema_decay_type', 'threshold') self.ema = ModelEMA( self.model, decay=ema_decay, - use_thres_step=True, + ema_decay_type=ema_decay_type, cycle_epoch=cycle_epoch) # EvalDataset build with BatchSampler to evaluate in single device diff --git a/ppdet/modeling/architectures/__init__.py b/ppdet/modeling/architectures/__init__.py index 71c53067f72f96c1d24da19aa4313449e91f4b95..9360a5b7b15596302ae54fc7b375e83718820da0 100644 --- a/ppdet/modeling/architectures/__init__.py +++ b/ppdet/modeling/architectures/__init__.py @@ -5,6 +5,13 @@ # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from . import meta_arch from . import faster_rcnn from . import mask_rcnn @@ -28,6 +35,7 @@ from . import sparse_rcnn from . import tood from . import retinanet from . import bytetrack +from . import yolox from .meta_arch import * from .faster_rcnn import * @@ -53,3 +61,4 @@ from .sparse_rcnn import * from .tood import * from .retinanet import * from .bytetrack import * +from .yolox import * diff --git a/ppdet/modeling/architectures/yolox.py b/ppdet/modeling/architectures/yolox.py new file mode 100644 index 0000000000000000000000000000000000000000..5965473e793fcbfa580be8fb18a14b7a1c950305 --- /dev/null +++ b/ppdet/modeling/architectures/yolox.py @@ -0,0 +1,144 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ppdet.core.workspace import register, create +from .meta_arch import BaseArch + +import random +import paddle +import paddle.nn.functional as F +import paddle.distributed as dist +from ppdet.modeling.ops import paddle_distributed_is_initialized + +__all__ = ['YOLOX'] + + +@register +class YOLOX(BaseArch): + """ + YOLOX network, see https://arxiv.org/abs/2107.08430 + + Args: + backbone (nn.Layer): backbone instance + neck (nn.Layer): neck instance + head (nn.Layer): head instance + for_mot (bool): whether used for MOT or not + input_size (list[int]): initial scale, will be reset by self._preprocess() + size_stride (int): stride of the size range + size_range (list[int]): multi-scale range for training + random_interval (int): interval of iter to change self._input_size + """ + __category__ = 'architecture' + + def __init__(self, + backbone='CSPDarkNet', + neck='YOLOCSPPAN', + head='YOLOXHead', + for_mot=False, + input_size=[640, 640], + size_stride=32, + size_range=[15, 25], + random_interval=10): + super(YOLOX, self).__init__() + self.backbone = backbone + self.neck = neck + self.head = head + self.for_mot = for_mot + + self.input_size = input_size + self._input_size = paddle.to_tensor(input_size) + self.size_stride = size_stride + self.size_range = size_range + self.random_interval = random_interval + self._step = 0 + + @classmethod + def from_config(cls, cfg, *args, **kwargs): + # backbone + backbone = create(cfg['backbone']) + + # fpn + kwargs = {'input_shape': backbone.out_shape} + neck = create(cfg['neck'], **kwargs) + + # head + kwargs = {'input_shape': neck.out_shape} + head = create(cfg['head'], **kwargs) + + return { + 'backbone': backbone, + 'neck': neck, + "head": head, + } + + def _forward(self): + if self.training: + self._preprocess() + body_feats = self.backbone(self.inputs) + neck_feats = self.neck(body_feats, self.for_mot) + + if self.training: + yolox_losses = self.head(neck_feats, self.inputs) + yolox_losses.update({'size': self._input_size[0]}) + return yolox_losses + else: + head_outs = self.head(neck_feats) + bbox, bbox_num = self.head.post_process( + head_outs, self.inputs['im_shape'], self.inputs['scale_factor']) + return {'bbox': bbox, 'bbox_num': bbox_num} + + def get_loss(self): + return self._forward() + + def get_pred(self): + return self._forward() + + def _preprocess(self): + # YOLOX multi-scale training, interpolate resize before inputs of the network. + self._get_size() + scale_y = self._input_size[0] / self.input_size[0] + scale_x = self._input_size[1] / self.input_size[1] + if scale_x != 1 or scale_y != 1: + self.inputs['image'] = F.interpolate( + self.inputs['image'], + size=self._input_size, + mode='bilinear', + align_corners=False) + gt_bboxes = self.inputs['gt_bbox'] + for i in range(len(gt_bboxes)): + if len(gt_bboxes[i]) > 0: + gt_bboxes[i][:, 0::2] = gt_bboxes[i][:, 0::2] * scale_x + gt_bboxes[i][:, 1::2] = gt_bboxes[i][:, 1::2] * scale_y + self.inputs['gt_bbox'] = gt_bboxes + + def _get_size(self): + # random_interval = 10 as default, every 10 iters to change self._input_size + image_ratio = self.input_size[1] * 1.0 / self.input_size[0] + if self._step % self.random_interval == 0: + size_factor = random.randint(*self.size_range) + size = [ + self.size_stride * size_factor, + self.size_stride * int(size_factor * image_ratio) + ] + size = paddle.to_tensor(size) + if dist.get_world_size() > 1 and paddle_distributed_is_initialized( + ): + dist.barrier() + dist.broadcast(size, 0) + self._input_size = size + self._step += 1 diff --git a/ppdet/modeling/assigners/utils.py b/ppdet/modeling/assigners/utils.py index 0c5e1d94fc06dd2abdf135638f9b17c8ba2ff8cf..f2d9be99509b62d585a9f4db999818b6d25a2d94 100644 --- a/ppdet/modeling/assigners/utils.py +++ b/ppdet/modeling/assigners/utils.py @@ -115,7 +115,7 @@ def check_points_inside_bboxes(points, Args: points (Tensor, float32): shape[L, 2], "xy" format, L: num_anchors bboxes (Tensor, float32): shape[B, n, 4], "xmin, ymin, xmax, ymax" format - center_radius_tensor (Tensor, float32): shape [L, 1] Default: None. + center_radius_tensor (Tensor, float32): shape [L, 1]. Default: None. eps (float): Default: 1e-9 Returns: is_in_bboxes (Tensor, float32): shape[B, n, L], value=1. means selected @@ -123,25 +123,28 @@ def check_points_inside_bboxes(points, points = points.unsqueeze([0, 1]) x, y = points.chunk(2, axis=-1) xmin, ymin, xmax, ymax = bboxes.unsqueeze(2).chunk(4, axis=-1) - if center_radius_tensor is not None: - center_radius_tensor = center_radius_tensor.unsqueeze([0, 1]) - bboxes_cx = (xmin + xmax) / 2. - bboxes_cy = (ymin + ymax) / 2. - xmin_sampling = bboxes_cx - center_radius_tensor - ymin_sampling = bboxes_cy - center_radius_tensor - xmax_sampling = bboxes_cx + center_radius_tensor - ymax_sampling = bboxes_cy + center_radius_tensor - - xmin = paddle.maximum(xmin, xmin_sampling) - ymin = paddle.maximum(ymin, ymin_sampling) - xmax = paddle.minimum(xmax, xmax_sampling) - ymax = paddle.minimum(ymax, ymax_sampling) + # check whether `points` is in `bboxes` l = x - xmin t = y - ymin r = xmax - x b = ymax - y - bbox_ltrb = paddle.concat([l, t, r, b], axis=-1) - return (bbox_ltrb.min(axis=-1) > eps).astype(bboxes.dtype) + delta_ltrb = paddle.concat([l, t, r, b], axis=-1) + is_in_bboxes = (delta_ltrb.min(axis=-1) > eps) + if center_radius_tensor is not None: + # check whether `points` is in `center_radius` + center_radius_tensor = center_radius_tensor.unsqueeze([0, 1]) + cx = (xmin + xmax) * 0.5 + cy = (ymin + ymax) * 0.5 + l = x - (cx - center_radius_tensor) + t = y - (cy - center_radius_tensor) + r = (cx + center_radius_tensor) - x + b = (cy + center_radius_tensor) - y + delta_ltrb_c = paddle.concat([l, t, r, b], axis=-1) + is_in_center = (delta_ltrb_c.min(axis=-1) > eps) + return (paddle.logical_and(is_in_bboxes, is_in_center), + paddle.logical_or(is_in_bboxes, is_in_center)) + + return is_in_bboxes.astype(bboxes.dtype) def compute_max_iou_anchor(ious): diff --git a/ppdet/modeling/backbones/__init__.py b/ppdet/modeling/backbones/__init__.py index bbc696dbacd048ce0ca752580ecfc54e11a1433f..e619e7ec663d8df745e690de8dc83a2c8391d7bc 100644 --- a/ppdet/modeling/backbones/__init__.py +++ b/ppdet/modeling/backbones/__init__.py @@ -1,15 +1,15 @@ -# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and # limitations under the License. from . import vgg @@ -30,6 +30,7 @@ from . import lcnet from . import hardnet from . import esnet from . import cspresnet +from . import csp_darknet from .vgg import * from .resnet import * @@ -49,3 +50,4 @@ from .lcnet import * from .hardnet import * from .esnet import * from .cspresnet import * +from .csp_darknet import * diff --git a/ppdet/modeling/backbones/csp_darknet.py b/ppdet/modeling/backbones/csp_darknet.py new file mode 100644 index 0000000000000000000000000000000000000000..45b06076044b53f9f75516e38d1c52d18b1d885b --- /dev/null +++ b/ppdet/modeling/backbones/csp_darknet.py @@ -0,0 +1,403 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.regularizer import L2Decay +from ppdet.core.workspace import register, serializable +from ppdet.modeling.ops import get_activation +from ppdet.modeling.initializer import conv_init_ +from ..shape_spec import ShapeSpec + +__all__ = [ + 'CSPDarkNet', 'BaseConv', 'DWConv', 'BottleNeck', 'SPPLayer', 'SPPFLayer' +] + + +class BaseConv(nn.Layer): + def __init__(self, + in_channels, + out_channels, + ksize, + stride, + groups=1, + bias=False, + act="silu"): + super(BaseConv, self).__init__() + self.conv = nn.Conv2D( + in_channels, + out_channels, + kernel_size=ksize, + stride=stride, + padding=(ksize - 1) // 2, + groups=groups, + bias_attr=bias) + self.bn = nn.BatchNorm2D( + out_channels, + weight_attr=ParamAttr(regularizer=L2Decay(0.0)), + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + self.act = get_activation(act) + + self._init_weights() + + def _init_weights(self): + conv_init_(self.conv) + + def forward(self, x): + return self.act(self.bn(self.conv(x))) + + +class DWConv(nn.Layer): + """Depthwise Conv""" + + def __init__(self, + in_channels, + out_channels, + ksize, + stride=1, + bias=False, + act="silu"): + super(DWConv, self).__init__() + self.dw_conv = BaseConv( + in_channels, + in_channels, + ksize=ksize, + stride=stride, + groups=in_channels, + bias=bias, + act=act, ) + self.pw_conv = BaseConv( + in_channels, + out_channels, + ksize=1, + stride=1, + groups=1, + bias=bias, + act=act) + + def forward(self, x): + return self.pw_conv(self.dw_conv(x)) + + +class Focus(nn.Layer): + """Focus width and height information into channel space, used in YOLOX.""" + + def __init__(self, + in_channels, + out_channels, + ksize=3, + stride=1, + bias=False, + act="silu"): + super(Focus, self).__init__() + self.conv = BaseConv( + in_channels * 4, + out_channels, + ksize=ksize, + stride=stride, + bias=bias, + act=act) + + def forward(self, inputs): + # inputs [bs, C, H, W] -> outputs [bs, 4C, W/2, H/2] + top_left = inputs[:, :, 0::2, 0::2] + top_right = inputs[:, :, 0::2, 1::2] + bottom_left = inputs[:, :, 1::2, 0::2] + bottom_right = inputs[:, :, 1::2, 1::2] + outputs = paddle.concat( + [top_left, bottom_left, top_right, bottom_right], 1) + return self.conv(outputs) + + +class BottleNeck(nn.Layer): + def __init__(self, + in_channels, + out_channels, + shortcut=True, + expansion=0.5, + depthwise=False, + bias=False, + act="silu"): + super(BottleNeck, self).__init__() + hidden_channels = int(out_channels * expansion) + Conv = DWConv if depthwise else BaseConv + self.conv1 = BaseConv( + in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act) + self.conv2 = Conv( + hidden_channels, + out_channels, + ksize=3, + stride=1, + bias=bias, + act=act) + self.add_shortcut = shortcut and in_channels == out_channels + + def forward(self, x): + y = self.conv2(self.conv1(x)) + if self.add_shortcut: + y = y + x + return y + + +class SPPLayer(nn.Layer): + """Spatial Pyramid Pooling (SPP) layer used in YOLOv3-SPP and YOLOX""" + + def __init__(self, + in_channels, + out_channels, + kernel_sizes=(5, 9, 13), + bias=False, + act="silu"): + super(SPPLayer, self).__init__() + hidden_channels = in_channels // 2 + self.conv1 = BaseConv( + in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act) + self.maxpoolings = nn.LayerList([ + nn.MaxPool2D( + kernel_size=ks, stride=1, padding=ks // 2) + for ks in kernel_sizes + ]) + conv2_channels = hidden_channels * (len(kernel_sizes) + 1) + self.conv2 = BaseConv( + conv2_channels, out_channels, ksize=1, stride=1, bias=bias, act=act) + + def forward(self, x): + x = self.conv1(x) + x = paddle.concat([x] + [mp(x) for mp in self.maxpoolings], axis=1) + x = self.conv2(x) + return x + + +class SPPFLayer(nn.Layer): + """ Spatial Pyramid Pooling - Fast (SPPF) layer used in YOLOv5 by Glenn Jocher, + equivalent to SPP(k=(5, 9, 13)) + """ + + def __init__(self, + in_channels, + out_channels, + ksize=5, + bias=False, + act='silu'): + super(SPPFLayer, self).__init__() + hidden_channels = in_channels // 2 + self.conv1 = BaseConv( + in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act) + self.maxpooling = nn.MaxPool2D( + kernel_size=ksize, stride=1, padding=ksize // 2) + conv2_channels = hidden_channels * 4 + self.conv2 = BaseConv( + conv2_channels, out_channels, ksize=1, stride=1, bias=bias, act=act) + + def forward(self, x): + x = self.conv1(x) + y1 = self.maxpooling(x) + y2 = self.maxpooling(y1) + y3 = self.maxpooling(y2) + concats = paddle.concat([x, y1, y2, y3], axis=1) + out = self.conv2(concats) + return out + + +class CSPLayer(nn.Layer): + """CSP (Cross Stage Partial) layer with 3 convs, named C3 in YOLOv5""" + + def __init__(self, + in_channels, + out_channels, + num_blocks=1, + shortcut=True, + expansion=0.5, + depthwise=False, + bias=False, + act="silu"): + super(CSPLayer, self).__init__() + hidden_channels = int(out_channels * expansion) + self.conv1 = BaseConv( + in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act) + self.conv2 = BaseConv( + in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act) + self.bottlenecks = nn.Sequential(* [ + BottleNeck( + hidden_channels, + hidden_channels, + shortcut=shortcut, + expansion=1.0, + depthwise=depthwise, + bias=bias, + act=act) for _ in range(num_blocks) + ]) + self.conv3 = BaseConv( + hidden_channels * 2, + out_channels, + ksize=1, + stride=1, + bias=bias, + act=act) + + def forward(self, x): + x_1 = self.conv1(x) + x_1 = self.bottlenecks(x_1) + x_2 = self.conv2(x) + x = paddle.concat([x_1, x_2], axis=1) + x = self.conv3(x) + return x + + +@register +@serializable +class CSPDarkNet(nn.Layer): + """ + CSPDarkNet backbone. + Args: + arch (str): Architecture of CSPDarkNet, from {P5, P6, X}, default as X, + and 'X' means used in YOLOX, 'P5/P6' means used in YOLOv5. + depth_mult (float): Depth multiplier, multiply number of channels in + each layer, default as 1.0. + width_mult (float): Width multiplier, multiply number of blocks in + CSPLayer, default as 1.0. + depthwise (bool): Whether to use depth-wise conv layer. + act (str): Activation function type, default as 'silu'. + return_idx (list): Index of stages whose feature maps are returned. + """ + + __shared__ = ['depth_mult', 'width_mult', 'act'] + + # in_channels, out_channels, num_blocks, add_shortcut, use_spp(use_sppf) + # 'X' means setting used in YOLOX, 'P5/P6' means setting used in YOLOv5. + arch_settings = { + 'X': [[64, 128, 3, True, False], [128, 256, 9, True, False], + [256, 512, 9, True, False], [512, 1024, 3, False, True]], + 'P5': [[64, 128, 3, True, False], [128, 256, 6, True, False], + [256, 512, 9, True, False], [512, 1024, 3, True, True]], + 'P6': [[64, 128, 3, True, False], [128, 256, 6, True, False], + [256, 512, 9, True, False], [512, 768, 3, True, False], + [768, 1024, 3, True, True]], + } + + def __init__(self, + arch='X', + depth_mult=1.0, + width_mult=1.0, + depthwise=False, + act='silu', + return_idx=[2, 3, 4]): + super(CSPDarkNet, self).__init__() + self.arch = arch + self.return_idx = return_idx + Conv = DWConv if depthwise else BaseConv + + arch_setting = self.arch_settings[arch] + base_channels = int(arch_setting[0][0] * width_mult) + + # Note: differences between the latest YOLOv5 and the original YOLOX + # 1. self.stem, use SPPF(in YOLOv5) or SPP(in YOLOX) + # 2. use SPPF(in YOLOv5) or SPP(in YOLOX) + # 3. put SPPF before(YOLOv5) or SPP after(YOLOX) the last cspdark block's CSPLayer + # 4. whether SPPF(SPP)'CSPLayer add shortcut, True in YOLOv5, False in YOLOX + if arch in ['P5', 'P6']: + # in the latest YOLOv5, use Conv stem, and SPPF (fast, only single spp kernal size) + self.stem = Conv( + 3, base_channels, ksize=6, stride=2, bias=False, act=act) + spp_kernal_sizes = 5 + elif arch in ['X']: + # in the original YOLOX, use Focus stem, and SPP (three spp kernal sizes) + self.stem = Focus( + 3, base_channels, ksize=3, stride=1, bias=False, act=act) + spp_kernal_sizes = (5, 9, 13) + else: + raise AttributeError("Unsupported arch type: {}".format(arch)) + + _out_channels = [base_channels] + layers_num = 1 + self.csp_dark_blocks = [] + + for i, (in_channels, out_channels, num_blocks, shortcut, + use_spp) in enumerate(arch_setting): + in_channels = int(in_channels * width_mult) + out_channels = int(out_channels * width_mult) + _out_channels.append(out_channels) + num_blocks = max(round(num_blocks * depth_mult), 1) + stage = [] + + conv_layer = self.add_sublayer( + 'layers{}.stage{}.conv_layer'.format(layers_num, i + 1), + Conv( + in_channels, out_channels, 3, 2, bias=False, act=act)) + stage.append(conv_layer) + layers_num += 1 + + if use_spp and arch in ['X']: + # in YOLOX use SPPLayer + spp_layer = self.add_sublayer( + 'layers{}.stage{}.spp_layer'.format(layers_num, i + 1), + SPPLayer( + out_channels, + out_channels, + kernel_sizes=spp_kernal_sizes, + bias=False, + act=act)) + stage.append(spp_layer) + layers_num += 1 + + csp_layer = self.add_sublayer( + 'layers{}.stage{}.csp_layer'.format(layers_num, i + 1), + CSPLayer( + out_channels, + out_channels, + num_blocks=num_blocks, + shortcut=shortcut, + depthwise=depthwise, + bias=False, + act=act)) + stage.append(csp_layer) + layers_num += 1 + + if use_spp and arch in ['P5', 'P6']: + # in latest YOLOv5 use SPPFLayer instead of SPPLayer + sppf_layer = self.add_sublayer( + 'layers{}.stage{}.sppf_layer'.format(layers_num, i + 1), + SPPFLayer( + out_channels, + out_channels, + ksize=5, + bias=False, + act=act)) + stage.append(sppf_layer) + layers_num += 1 + + self.csp_dark_blocks.append(nn.Sequential(*stage)) + + self._out_channels = [_out_channels[i] for i in self.return_idx] + self.strides = [[2, 4, 8, 16, 32, 64][i] for i in self.return_idx] + + def forward(self, inputs): + x = inputs['image'] + outputs = [] + x = self.stem(x) + for i, layer in enumerate(self.csp_dark_blocks): + x = layer(x) + if i + 1 in self.return_idx: + outputs.append(x) + return outputs + + @property + def out_shape(self): + return [ + ShapeSpec( + channels=c, stride=s) + for c, s in zip(self._out_channels, self.strides) + ] diff --git a/ppdet/modeling/heads/yolo_head.py b/ppdet/modeling/heads/yolo_head.py index 7b4e9bc3353edfe42c0c35db6fb38b67de03c730..911ae8e19810f703efaca5f17430f43b090722fe 100644 --- a/ppdet/modeling/heads/yolo_head.py +++ b/ppdet/modeling/heads/yolo_head.py @@ -1,3 +1,17 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import paddle import paddle.nn as nn import paddle.nn.functional as F @@ -5,6 +19,16 @@ from paddle import ParamAttr from paddle.regularizer import L2Decay from ppdet.core.workspace import register +import math +import numpy as np +from ..initializer import bias_init_with_prob, constant_ +from ..backbones.csp_darknet import BaseConv, DWConv +from ..losses import IouLoss +from ppdet.modeling.assigners.simota_assigner import SimOTAAssigner +from ppdet.modeling.bbox_utils import bbox_overlaps + +__all__ = ['YOLOv3Head', 'YOLOXHead'] + def _de_sigmoid(x, eps=1e-7): x = paddle.clip(x, eps, 1. / eps) @@ -122,3 +146,259 @@ class YOLOv3Head(nn.Layer): @classmethod def from_config(cls, cfg, input_shape): return {'in_channels': [i.channels for i in input_shape], } + + +@register +class YOLOXHead(nn.Layer): + __shared__ = ['num_classes', 'width_mult', 'act'] + __inject__ = ['assigner', 'nms'] + + def __init__(self, + num_classes=80, + width_mult=1.0, + depthwise=False, + in_channels=[256, 512, 1024], + feat_channels=256, + fpn_strides=(8, 16, 32), + l1_epoch=285, + act='silu', + assigner=SimOTAAssigner(use_vfl=False), + nms='MultiClassNMS', + loss_weight={'cls': 1.0, + 'obj': 1.0, + 'iou': 5.0, + 'l1': 1.0}): + super(YOLOXHead, self).__init__() + self._dtype = paddle.framework.get_default_dtype() + self.num_classes = num_classes + assert len(in_channels) > 0, "in_channels length should > 0" + self.in_channels = in_channels + feat_channels = int(feat_channels * width_mult) + self.fpn_strides = fpn_strides + self.l1_epoch = l1_epoch + self.assigner = assigner + self.nms = nms + self.loss_weight = loss_weight + self.iou_loss = IouLoss(loss_weight=1.0) # default loss_weight 2.5 + + ConvBlock = DWConv if depthwise else BaseConv + + self.stem_conv = nn.LayerList() + self.conv_cls = nn.LayerList() + self.conv_reg = nn.LayerList() # reg [x,y,w,h] + obj + for in_c in self.in_channels: + self.stem_conv.append(BaseConv(in_c, feat_channels, 1, 1, act=act)) + + self.conv_cls.append( + nn.Sequential(* [ + ConvBlock( + feat_channels, feat_channels, 3, 1, act=act), ConvBlock( + feat_channels, feat_channels, 3, 1, act=act), + nn.Conv2D( + feat_channels, + self.num_classes, + 1, + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + ])) + + self.conv_reg.append( + nn.Sequential(* [ + ConvBlock( + feat_channels, feat_channels, 3, 1, act=act), + ConvBlock( + feat_channels, feat_channels, 3, 1, act=act), + nn.Conv2D( + feat_channels, + 4 + 1, # reg [x,y,w,h] + obj + 1, + bias_attr=ParamAttr(regularizer=L2Decay(0.0))) + ])) + + self._init_weights() + + @classmethod + def from_config(cls, cfg, input_shape): + return {'in_channels': [i.channels for i in input_shape], } + + def _init_weights(self): + bias_cls = bias_init_with_prob(0.01) + bias_reg = paddle.full([5], math.log(5.), dtype=self._dtype) + bias_reg[:2] = 0. + bias_reg[-1] = bias_cls + for cls_, reg_ in zip(self.conv_cls, self.conv_reg): + constant_(cls_[-1].weight) + constant_(cls_[-1].bias, bias_cls) + constant_(reg_[-1].weight) + reg_[-1].bias.set_value(bias_reg) + + def _generate_anchor_point(self, feat_sizes, strides, offset=0.): + anchor_points, stride_tensor = [], [] + num_anchors_list = [] + for feat_size, stride in zip(feat_sizes, strides): + h, w = feat_size + x = (paddle.arange(w) + offset) * stride + y = (paddle.arange(h) + offset) * stride + y, x = paddle.meshgrid(y, x) + anchor_points.append(paddle.stack([x, y], axis=-1).reshape([-1, 2])) + stride_tensor.append( + paddle.full( + [len(anchor_points[-1]), 1], stride, dtype=self._dtype)) + num_anchors_list.append(len(anchor_points[-1])) + anchor_points = paddle.concat(anchor_points).astype(self._dtype) + anchor_points.stop_gradient = True + stride_tensor = paddle.concat(stride_tensor) + stride_tensor.stop_gradient = True + return anchor_points, stride_tensor, num_anchors_list + + def forward(self, feats, targets=None): + assert len(feats) == len(self.fpn_strides), \ + "The size of feats is not equal to size of fpn_strides" + + feat_sizes = [[f.shape[-2], f.shape[-1]] for f in feats] + cls_score_list, reg_pred_list = [], [] + obj_score_list = [] + for i, feat in enumerate(feats): + feat = self.stem_conv[i](feat) + cls_logit = self.conv_cls[i](feat) + reg_pred = self.conv_reg[i](feat) + # cls prediction + cls_score = F.sigmoid(cls_logit) + cls_score_list.append(cls_score.flatten(2).transpose([0, 2, 1])) + # reg prediction + reg_xywh, obj_logit = paddle.split(reg_pred, [4, 1], axis=1) + reg_xywh = reg_xywh.flatten(2).transpose([0, 2, 1]) + reg_pred_list.append(reg_xywh) + # obj prediction + obj_score = F.sigmoid(obj_logit) + obj_score_list.append(obj_score.flatten(2).transpose([0, 2, 1])) + + cls_score_list = paddle.concat(cls_score_list, axis=1) + reg_pred_list = paddle.concat(reg_pred_list, axis=1) + obj_score_list = paddle.concat(obj_score_list, axis=1) + + # bbox decode + anchor_points, stride_tensor, _ =\ + self._generate_anchor_point(feat_sizes, self.fpn_strides) + reg_xy, reg_wh = paddle.split(reg_pred_list, 2, axis=-1) + reg_xy += (anchor_points / stride_tensor) + reg_wh = paddle.exp(reg_wh) * 0.5 + bbox_pred_list = paddle.concat( + [reg_xy - reg_wh, reg_xy + reg_wh], axis=-1) + + if self.training: + anchor_points, stride_tensor, num_anchors_list =\ + self._generate_anchor_point(feat_sizes, self.fpn_strides, 0.5) + yolox_losses = self.get_loss([ + cls_score_list, bbox_pred_list, obj_score_list, anchor_points, + stride_tensor, num_anchors_list + ], targets) + return yolox_losses + else: + pred_scores = (cls_score_list * obj_score_list).sqrt() + return pred_scores, bbox_pred_list, stride_tensor + + def get_loss(self, head_outs, targets): + pred_cls, pred_bboxes, pred_obj,\ + anchor_points, stride_tensor, num_anchors_list = head_outs + gt_labels = targets['gt_class'] + gt_bboxes = targets['gt_bbox'] + pred_scores = (pred_cls * pred_obj).sqrt() + # label assignment + center_and_strides = paddle.concat( + [anchor_points, stride_tensor, stride_tensor], axis=-1) + pos_num_list, label_list, bbox_target_list = [], [], [] + for pred_score, pred_bbox, gt_box, gt_label in zip( + pred_scores.detach(), + pred_bboxes.detach() * stride_tensor, gt_bboxes, gt_labels): + pos_num, label, _, bbox_target = self.assigner( + pred_score, center_and_strides, pred_bbox, gt_box, gt_label) + pos_num_list.append(pos_num) + label_list.append(label) + bbox_target_list.append(bbox_target) + labels = paddle.to_tensor(np.stack(label_list, axis=0)) + bbox_targets = paddle.to_tensor(np.stack(bbox_target_list, axis=0)) + bbox_targets /= stride_tensor # rescale bbox + + # 1. obj score loss + mask_positive = (labels != self.num_classes) + loss_obj = F.binary_cross_entropy( + pred_obj, + mask_positive.astype(pred_obj.dtype).unsqueeze(-1), + reduction='sum') + + num_pos = sum(pos_num_list) + + if num_pos > 0: + num_pos = paddle.to_tensor(num_pos, dtype=self._dtype).clip(min=1) + loss_obj /= num_pos + + # 2. iou loss + bbox_mask = mask_positive.unsqueeze(-1).tile([1, 1, 4]) + pred_bboxes_pos = paddle.masked_select(pred_bboxes, + bbox_mask).reshape([-1, 4]) + assigned_bboxes_pos = paddle.masked_select( + bbox_targets, bbox_mask).reshape([-1, 4]) + bbox_iou = bbox_overlaps(pred_bboxes_pos, assigned_bboxes_pos) + bbox_iou = paddle.diag(bbox_iou) + + loss_iou = self.iou_loss( + pred_bboxes_pos.split( + 4, axis=-1), + assigned_bboxes_pos.split( + 4, axis=-1)) + loss_iou = loss_iou.sum() / num_pos + + # 3. cls loss + cls_mask = mask_positive.unsqueeze(-1).tile( + [1, 1, self.num_classes]) + pred_cls_pos = paddle.masked_select( + pred_cls, cls_mask).reshape([-1, self.num_classes]) + assigned_cls_pos = paddle.masked_select(labels, mask_positive) + assigned_cls_pos = F.one_hot(assigned_cls_pos, + self.num_classes + 1)[..., :-1] + assigned_cls_pos *= bbox_iou.unsqueeze(-1) + loss_cls = F.binary_cross_entropy( + pred_cls_pos, assigned_cls_pos, reduction='sum') + loss_cls /= num_pos + + # 4. l1 loss + if targets['epoch_id'] >= self.l1_epoch: + loss_l1 = F.l1_loss( + pred_bboxes_pos, assigned_bboxes_pos, reduction='sum') + loss_l1 /= num_pos + else: + loss_l1 = paddle.zeros([1]) + loss_l1.stop_gradient = False + else: + loss_cls = paddle.zeros([1]) + loss_iou = paddle.zeros([1]) + loss_l1 = paddle.zeros([1]) + loss_cls.stop_gradient = False + loss_iou.stop_gradient = False + loss_l1.stop_gradient = False + + loss = self.loss_weight['obj'] * loss_obj + \ + self.loss_weight['cls'] * loss_cls + \ + self.loss_weight['iou'] * loss_iou + + if targets['epoch_id'] >= self.l1_epoch: + loss += (self.loss_weight['l1'] * loss_l1) + + yolox_losses = { + 'loss': loss, + 'loss_cls': loss_cls, + 'loss_obj': loss_obj, + 'loss_iou': loss_iou, + 'loss_l1': loss_l1, + } + return yolox_losses + + def post_process(self, head_outs, img_shape, scale_factor): + pred_scores, pred_bboxes, stride_tensor = head_outs + pred_scores = pred_scores.transpose([0, 2, 1]) + pred_bboxes *= stride_tensor + # scale bbox to origin image + scale_factor = scale_factor.flip(-1).tile([1, 2]).unsqueeze(1) + pred_bboxes /= scale_factor + bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores) + return bbox_pred, bbox_num diff --git a/ppdet/modeling/initializer.py b/ppdet/modeling/initializer.py index b7a135dcca7309be09fa4819e21d77a02d332d57..b482f133dd9ac1e2568f5c971f004117c56a5368 100644 --- a/ppdet/modeling/initializer.py +++ b/ppdet/modeling/initializer.py @@ -273,7 +273,8 @@ def linear_init_(module): def conv_init_(module): bound = 1 / np.sqrt(np.prod(module.weight.shape[1:])) uniform_(module.weight, -bound, bound) - uniform_(module.bias, -bound, bound) + if module.bias is not None: + uniform_(module.bias, -bound, bound) def bias_init_with_prob(prior_prob=0.01): diff --git a/ppdet/modeling/necks/yolo_fpn.py b/ppdet/modeling/necks/yolo_fpn.py index bc06b14dba4fb5a2e5570162a6151927cb2df8da..1aaf593f0611d2af9e23c4bfd4c4c2cc326b5ba2 100644 --- a/ppdet/modeling/necks/yolo_fpn.py +++ b/ppdet/modeling/necks/yolo_fpn.py @@ -1,15 +1,15 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and # limitations under the License. import paddle @@ -19,8 +19,9 @@ from ppdet.core.workspace import register, serializable from ppdet.modeling.layers import DropBlock from ..backbones.darknet import ConvBNLayer from ..shape_spec import ShapeSpec +from ..backbones.csp_darknet import BaseConv, DWConv, CSPLayer -__all__ = ['YOLOv3FPN', 'PPYOLOFPN', 'PPYOLOTinyFPN', 'PPYOLOPAN'] +__all__ = ['YOLOv3FPN', 'PPYOLOFPN', 'PPYOLOTinyFPN', 'PPYOLOPAN', 'YOLOCSPPAN'] def add_coord(x, data_format): @@ -986,3 +987,102 @@ class PPYOLOPAN(nn.Layer): @property def out_shape(self): return [ShapeSpec(channels=c) for c in self._out_channels] + + +@register +@serializable +class YOLOCSPPAN(nn.Layer): + """ + YOLO CSP-PAN, used in YOLOv5 and YOLOX. + """ + __shared__ = ['depth_mult', 'act'] + + def __init__(self, + depth_mult=1.0, + in_channels=[256, 512, 1024], + depthwise=False, + act='silu'): + super(YOLOCSPPAN, self).__init__() + self.in_channels = in_channels + self._out_channels = in_channels + Conv = DWConv if depthwise else BaseConv + + self.upsample = nn.Upsample(scale_factor=2, mode="nearest") + + # top-down fpn + self.lateral_convs = nn.LayerList() + self.fpn_blocks = nn.LayerList() + for idx in range(len(in_channels) - 1, 0, -1): + self.lateral_convs.append( + BaseConv( + int(in_channels[idx]), + int(in_channels[idx - 1]), + 1, + 1, + act=act)) + self.fpn_blocks.append( + CSPLayer( + int(in_channels[idx - 1] * 2), + int(in_channels[idx - 1]), + round(3 * depth_mult), + shortcut=False, + depthwise=depthwise, + act=act)) + + # bottom-up pan + self.downsample_convs = nn.LayerList() + self.pan_blocks = nn.LayerList() + for idx in range(len(in_channels) - 1): + self.downsample_convs.append( + Conv( + int(in_channels[idx]), + int(in_channels[idx]), + 3, + stride=2, + act=act)) + self.pan_blocks.append( + CSPLayer( + int(in_channels[idx] * 2), + int(in_channels[idx + 1]), + round(3 * depth_mult), + shortcut=False, + depthwise=depthwise, + act=act)) + + def forward(self, feats, for_mot=False): + assert len(feats) == len(self.in_channels) + + # top-down fpn + inner_outs = [feats[-1]] + for idx in range(len(self.in_channels) - 1, 0, -1): + feat_heigh = inner_outs[0] + feat_low = feats[idx - 1] + feat_heigh = self.lateral_convs[len(self.in_channels) - 1 - idx]( + feat_heigh) + inner_outs[0] = feat_heigh + + upsample_feat = self.upsample(feat_heigh) + inner_out = self.fpn_blocks[len(self.in_channels) - 1 - idx]( + paddle.concat( + [upsample_feat, feat_low], axis=1)) + inner_outs.insert(0, inner_out) + + # bottom-up pan + outs = [inner_outs[0]] + for idx in range(len(self.in_channels) - 1): + feat_low = outs[-1] + feat_height = inner_outs[idx + 1] + downsample_feat = self.downsample_convs[idx](feat_low) + out = self.pan_blocks[idx](paddle.concat( + [downsample_feat, feat_height], axis=1)) + outs.append(out) + + return outs + + @classmethod + def from_config(cls, cfg, input_shape): + return {'in_channels': [i.channels for i in input_shape], } + + @property + def out_shape(self): + return [ShapeSpec(channels=c) for c in self._out_channels] diff --git a/ppdet/modeling/ops.py b/ppdet/modeling/ops.py index 67e8960ca115c1c6b5338fa7c84a26c1d4873fbf..ab7724301f4f7fe6126bd9a95a3728f232363769 100644 --- a/ppdet/modeling/ops.py +++ b/ppdet/modeling/ops.py @@ -1,15 +1,15 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and # limitations under the License. import paddle @@ -29,7 +29,7 @@ __all__ = [ 'roi_pool', 'roi_align', 'prior_box', 'generate_proposals', 'iou_similarity', 'box_coder', 'yolo_box', 'multiclass_nms', 'distribute_fpn_proposals', 'collect_fpn_proposals', 'matrix_nms', - 'batch_norm', 'mish', 'swish', 'identity' + 'batch_norm', 'get_activation', 'mish', 'swish', 'identity' ] @@ -107,6 +107,18 @@ def batch_norm(ch, return norm_layer +def get_activation(name="silu"): + if name == "silu": + module = nn.Silu() + elif name == "relu": + module = nn.ReLU() + elif name == "leakyrelu": + module = nn.LeakyReLU(0.1) + else: + raise AttributeError("Unsupported act type: {}".format(name)) + return module + + @paddle.jit.not_to_static def roi_pool(input, rois, @@ -1012,7 +1024,7 @@ def multiclass_nms(bboxes, nms_threshold, 'keep_top_k', keep_top_k, 'nms_eta', nms_eta, 'normalized', normalized) output, index, nms_rois_num = _C_ops.multiclass_nms3(bboxes, scores, - rois_num, *attrs) + rois_num, *attrs) if not return_index: index = None return output, nms_rois_num, index diff --git a/ppdet/optimizer.py b/ppdet/optimizer.py index 1305b76fed48c6acd3967928ffc82b74ef5a881d..98a5e5f9caa4bbf436009b9f3860222648bf716b 100644 --- a/ppdet/optimizer.py +++ b/ppdet/optimizer.py @@ -209,6 +209,33 @@ class BurninWarmup(object): return boundary, value +@serializable +class ExpWarmup(object): + """ + Warm up learning rate in exponential mode + Args: + steps (int): warm up steps. + epochs (int|None): use epochs as warm up steps, the priority + of `epochs` is higher than `steps`. Default: None. + """ + + def __init__(self, steps=5, epochs=None): + super(ExpWarmup, self).__init__() + self.steps = steps + self.epochs = epochs + + def __call__(self, base_lr, step_per_epoch): + boundary = [] + value = [] + warmup_steps = self.epochs * step_per_epoch if self.epochs is not None else self.steps + for i in range(warmup_steps + 1): + factor = (i / float(warmup_steps))**2 + value.append(base_lr * factor) + if i > 0: + boundary.append(i) + return boundary, value + + @register class LearningRate(object): """ @@ -331,7 +358,8 @@ class ModelEMA(object): Ema's parameter are updated with the formula: `ema_param = decay * ema_param + (1 - decay) * cur_param`. Defaults is 0.9998. - use_thres_step (bool): Whether set decay by thres_step or not + ema_decay_type (str): type in ['threshold', 'normal', 'exponential'], + 'threshold' as default. cycle_epoch (int): The epoch of interval to reset ema_param and step. Defaults is -1, which means not reset. Its function is to add a regular effect to ema, which is set according to experience @@ -341,7 +369,7 @@ class ModelEMA(object): def __init__(self, model, decay=0.9998, - use_thres_step=False, + ema_decay_type='threshold', cycle_epoch=-1): self.step = 0 self.epoch = 0 @@ -349,7 +377,7 @@ class ModelEMA(object): self.state_dict = dict() for k, v in model.state_dict().items(): self.state_dict[k] = paddle.zeros_like(v) - self.use_thres_step = use_thres_step + self.ema_decay_type = ema_decay_type self.cycle_epoch = cycle_epoch self._model_state = { @@ -370,8 +398,10 @@ class ModelEMA(object): self.step = step def update(self, model=None): - if self.use_thres_step: + if self.ema_decay_type == 'threshold': decay = min(self.decay, (1 + self.step) / (10 + self.step)) + elif self.ema_decay_type == 'exponential': + decay = self.decay * (1 - math.exp(-(self.step + 1) / 2000)) else: decay = self.decay self._decay = decay @@ -394,7 +424,8 @@ class ModelEMA(object): return self.state_dict state_dict = dict() for k, v in self.state_dict.items(): - v = v / (1 - self._decay**self.step) + if self.ema_decay_type != 'exponential': + v = v / (1 - self._decay**self.step) v.stop_gradient = True state_dict[k] = v self.epoch += 1