diff --git a/ppcls/arch/backbone/__init__.py b/ppcls/arch/backbone/__init__.py index 65a5d1d5e9de3d6d6b5effb5e98df41abbf22785..c51b85a0ee683de9981c838a50c58447f206bf73 100644 --- a/ppcls/arch/backbone/__init__.py +++ b/ppcls/arch/backbone/__init__.py @@ -68,7 +68,7 @@ from ppcls.arch.backbone.model_zoo.van import VAN_tiny from ppcls.arch.backbone.variant_models.resnet_variant import ResNet50_last_stage_stride1 from ppcls.arch.backbone.variant_models.vgg_variant import VGG19Sigmoid from ppcls.arch.backbone.variant_models.pp_lcnet_variant import PPLCNet_x2_5_Tanh -from ppcls.arch.backbone.model_zoo.ir_net import IR_18, IR_34, IR_50, IR_101, IR_152, IR_200, IR_SE_50, IR_SE_101, IR_SE_152, IR_SE_200 +from ppcls.arch.backbone.model_zoo.adaface_ir_net import AdaFace_IR_18, AdaFace_IR_34, AdaFace_IR_50, AdaFace_IR_101, AdaFace_IR_152, AdaFace_IR_SE_50, AdaFace_IR_SE_101, AdaFace_IR_SE_152, AdaFace_IR_SE_200 # help whl get all the models' api (class type) and components' api (func type) diff --git a/ppcls/arch/backbone/model_zoo/ir_net.py b/ppcls/arch/backbone/model_zoo/adaface_ir_net.py similarity index 97% rename from ppcls/arch/backbone/model_zoo/ir_net.py rename to ppcls/arch/backbone/model_zoo/adaface_ir_net.py index ecec581195944627ae02da1556513d85157e5471..47de152b646e6f824e5a888692b770d9e146223b 100644 --- a/ppcls/arch/backbone/model_zoo/ir_net.py +++ b/ppcls/arch/backbone/model_zoo/adaface_ir_net.py @@ -450,14 +450,14 @@ class Backbone(Layer): return x -def IR_18(input_size=(112, 112)): +def AdaFace_IR_18(input_size=(112, 112)): """ Constructs a ir-18 model. """ model = Backbone(input_size, 18, 'ir') return model -def IR_34(input_size=(112, 112)): +def AdaFace_IR_34(input_size=(112, 112)): """ Constructs a ir-34 model. """ model = Backbone(input_size, 34, 'ir') @@ -465,7 +465,7 @@ def IR_34(input_size=(112, 112)): return model -def IR_50(input_size=(112, 112)): +def AdaFace_IR_50(input_size=(112, 112)): """ Constructs a ir-50 model. """ model = Backbone(input_size, 50, 'ir') @@ -473,7 +473,7 @@ def IR_50(input_size=(112, 112)): return model -def IR_101(input_size=(112, 112)): +def AdaFace_IR_101(input_size=(112, 112)): """ Constructs a ir-101 model. """ model = Backbone(input_size, 100, 'ir') @@ -481,7 +481,7 @@ def IR_101(input_size=(112, 112)): return model -def IR_152(input_size=(112, 112)): +def AdaFace_IR_152(input_size=(112, 112)): """ Constructs a ir-152 model. """ model = Backbone(input_size, 152, 'ir') @@ -489,7 +489,7 @@ def IR_152(input_size=(112, 112)): return model -def IR_200(input_size=(112, 112)): +def AdaFace_IR_200(input_size=(112, 112)): """ Constructs a ir-200 model. """ model = Backbone(input_size, 200, 'ir') @@ -497,7 +497,7 @@ def IR_200(input_size=(112, 112)): return model -def IR_SE_50(input_size=(112, 112)): +def AdaFace_IR_SE_50(input_size=(112, 112)): """ Constructs a ir_se-50 model. """ model = Backbone(input_size, 50, 'ir_se') @@ -505,7 +505,7 @@ def IR_SE_50(input_size=(112, 112)): return model -def IR_SE_101(input_size=(112, 112)): +def AdaFace_IR_SE_101(input_size=(112, 112)): """ Constructs a ir_se-101 model. """ model = Backbone(input_size, 100, 'ir_se') @@ -513,7 +513,7 @@ def IR_SE_101(input_size=(112, 112)): return model -def IR_SE_152(input_size=(112, 112)): +def AdaFace_IR_SE_152(input_size=(112, 112)): """ Constructs a ir_se-152 model. """ model = Backbone(input_size, 152, 'ir_se') @@ -521,7 +521,7 @@ def IR_SE_152(input_size=(112, 112)): return model -def IR_SE_200(input_size=(112, 112)): +def AdaFace_IR_SE_200(input_size=(112, 112)): """ Constructs a ir_se-200 model. """ model = Backbone(input_size, 200, 'ir_se') diff --git a/ppcls/configs/metric_learning/ir18_adaface.yaml b/ppcls/configs/metric_learning/adaface_ir18.yaml similarity index 80% rename from ppcls/configs/metric_learning/ir18_adaface.yaml rename to ppcls/configs/metric_learning/adaface_ir18.yaml index 008aed4220f824d7864edacaa31d78393466e7ba..2cbfe5da43763701b244b2422bf9ad82b19ef4d6 100644 --- a/ppcls/configs/metric_learning/ir18_adaface.yaml +++ b/ppcls/configs/metric_learning/adaface_ir18.yaml @@ -21,7 +21,7 @@ Arch: infer_output_key: "features" infer_add_softmax: False Backbone: - name: "IR_18" + name: "AdaFace_IR_18" input_size: [112, 112] Head: name: "AdaMargin" @@ -57,10 +57,21 @@ DataLoader: name: "AdaFaceDataset" root_dir: "dataset/face/" label_path: "dataset/face/train_filter_label.txt" - low_res_augmentation_prob: 0.2 - crop_augmentation_prob: 0.2 - photometric_augmentation_prob: 0.2 transform: + - CropWithPadding: + prob: 0.2 + padding_num: 0 + size: [112, 112] + scale: [0.2, 1.0] + ratio: [0.75, 1.3333333333333333] + - RandomInterpolationAugment: + prob: 0.2 + - ColorJitter: + prob: 0.2 + brightness: 0.5 + contrast: 0.5 + saturation: 0.5 + hue: 0 - RandomHorizontalFlip: - ToTensor: - Normalize: diff --git a/ppcls/data/dataloader/face_dataset.py b/ppcls/data/dataloader/face_dataset.py index 8879939a6ea728b915cb1756a8c3292d3c8e76f9..a32cc2c5f89aa8c8e4904e7decc6ec5fb996aab3 100644 --- a/ppcls/data/dataloader/face_dataset.py +++ b/ppcls/data/dataloader/face_dataset.py @@ -14,40 +14,11 @@ from ppcls.data.preprocess import transform as transform_func # code is based on AdaFace: https://github.com/mk-minchul/AdaFace -def _get_image_size(img): - if F._is_pil_image(img): - return img.size - elif F._is_numpy_image(img): - return img.shape[:2][::-1] - elif F._is_tensor_image(img): - return img.shape[1:][::-1] # chw - else: - raise TypeError("Unexpected type {}".format(type(img))) - - class AdaFaceDataset(Dataset): - def __init__( - self, - root_dir, - label_path, - transform=None, - low_res_augmentation_prob=0.0, - crop_augmentation_prob=0.0, - photometric_augmentation_prob=0.0, ): + def __init__(self, root_dir, label_path, transform=None): self.root_dir = root_dir - self.low_res_augmentation_prob = low_res_augmentation_prob - self.crop_augmentation_prob = crop_augmentation_prob - self.photometric_augmentation_prob = photometric_augmentation_prob - self.random_resized_crop = transforms.RandomResizedCrop( - size=(112, 112), - scale=(0.2, 1.0), - ratio=(0.75, 1.3333333333333333)) - self.photometric = transforms.ColorJitter( - brightness=0.5, contrast=0.5, saturation=0.5, hue=0) self.transform = create_operators(transform) - self.tot_rot_try = 0 - self.rot_success = 0 with open(label_path) as fd: lines = fd.readlines() self.samples = [] @@ -73,65 +44,11 @@ class AdaFaceDataset(Dataset): # if 'WebFace' in self.root: # # swap rgb to bgr since image is in rgb for webface - # sample = Image.fromarray(np.asarray(sample)[:, :, ::-1]) - - sample, _ = self.augment(sample) + # sample = Image.fromarray(np.asarray(sample)[:, :, ::-1] if self.transform is not None: sample = transform_func(sample, self.transform) - return sample, target - def augment(self, sample): - - # crop with zero padding augmentation - if np.random.random() < self.crop_augmentation_prob: - # RandomResizedCrop augmentation - new = np.zeros_like(np.array(sample)) - # orig_W, orig_H = F._get_image_size(sample) - orig_W, orig_H = _get_image_size(sample) - i, j, h, w = self.random_resized_crop._get_param(sample) - cropped = F.crop(sample, i, j, h, w) - new[i:i + h, j:j + w, :] = np.array(cropped) - sample = Image.fromarray(new.astype(np.uint8)) - crop_ratio = min(h, w) / max(orig_H, orig_W) - else: - crop_ratio = 1.0 - - # low resolution augmentation - if np.random.random() < self.low_res_augmentation_prob: - # low res augmentation - img_np, resize_ratio = low_res_augmentation(np.array(sample)) - sample = Image.fromarray(img_np.astype(np.uint8)) - else: - resize_ratio = 1 - - # photometric augmentation - if np.random.random() < self.photometric_augmentation_prob: - sample = self.photometric(sample) - information_score = resize_ratio * crop_ratio - return sample, information_score - - -def low_res_augmentation(img): - # resize the image to a small size and enlarge it back - img_shape = img.shape - side_ratio = np.random.uniform(0.2, 1.0) - small_side = int(side_ratio * img_shape[0]) - interpolation = np.random.choice([ - cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, - cv2.INTER_LANCZOS4 - ]) - small_img = cv2.resize( - img, (small_side, small_side), interpolation=interpolation) - interpolation = np.random.choice([ - cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, - cv2.INTER_LANCZOS4 - ]) - aug_img = cv2.resize( - small_img, (img_shape[1], img_shape[0]), interpolation=interpolation) - - return aug_img, side_ratio - class FiveValidationDataset(Dataset): def __init__(self, val_data_path, concat_mem_file_name): @@ -243,4 +160,4 @@ def get_val_data(data_path): lfw, lfw_issame = get_val_pair(data_path, 'lfw') cplfw, cplfw_issame = get_val_pair(data_path, 'cplfw') calfw, calfw_issame = get_val_pair(data_path, 'calfw') - return agedb_30, cfp_fp, lfw, agedb_30_issame, cfp_fp_issame, lfw_issame, cplfw, cplfw_issame, calfw, calfw_issame \ No newline at end of file + return agedb_30, cfp_fp, lfw, agedb_30_issame, cfp_fp_issame, lfw_issame, cplfw, cplfw_issame, calfw, calfw_issame diff --git a/ppcls/data/preprocess/__init__.py b/ppcls/data/preprocess/__init__.py index 353db19c16201655d18d0adea42e219aa5b128d2..aede295a89af6289252e9120172e174259a612ad 100644 --- a/ppcls/data/preprocess/__init__.py +++ b/ppcls/data/preprocess/__init__.py @@ -34,6 +34,9 @@ from ppcls.data.preprocess.ops.operators import Pad from ppcls.data.preprocess.ops.operators import ToTensor from ppcls.data.preprocess.ops.operators import Normalize from ppcls.data.preprocess.ops.operators import RandomHorizontalFlip +from ppcls.data.preprocess.ops.operators import CropWithPadding +from ppcls.data.preprocess.ops.operators import RandomInterpolationAugment +from ppcls.data.preprocess.ops.operators import ColorJitter from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator diff --git a/ppcls/data/preprocess/ops/operators.py b/ppcls/data/preprocess/ops/operators.py index e4996abc0c0e02ac3a0f9900cfbcbbaaf711fec5..cbd9e1990951884d17e4a2b6eb962a653a2e8d77 100644 --- a/ppcls/data/preprocess/ops/operators.py +++ b/ppcls/data/preprocess/ops/operators.py @@ -25,8 +25,8 @@ import cv2 import numpy as np from PIL import Image, ImageOps, __version__ as PILLOW_VERSION from paddle.vision.transforms import ColorJitter as RawColorJitter -from paddle.vision.transforms import ToTensor, Normalize, RandomHorizontalFlip - +from paddle.vision.transforms import ToTensor, Normalize, RandomHorizontalFlip, RandomResizedCrop +from paddle.vision.transforms import functional as F from .autoaugment import ImageNetPolicy from .functional import augmentations from ppcls.utils import logger @@ -93,6 +93,42 @@ class UnifiedResize(object): return self.resize_func(src, size) +class RandomInterpolationAugment(object): + def __init__(self, prob): + self.prob = prob + + def _aug(self, img): + img_shape = img.shape + side_ratio = np.random.uniform(0.2, 1.0) + small_side = int(side_ratio * img_shape[0]) + interpolation = np.random.choice([ + cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, + cv2.INTER_CUBIC, cv2.INTER_LANCZOS4 + ]) + small_img = cv2.resize( + img, (small_side, small_side), interpolation=interpolation) + interpolation = np.random.choice([ + cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, + cv2.INTER_CUBIC, cv2.INTER_LANCZOS4 + ]) + aug_img = cv2.resize( + small_img, (img_shape[1], img_shape[0]), + interpolation=interpolation) + return aug_img + + def __call__(self, img): + if np.random.random() < self.prob: + if isinstance(img, np.ndarray): + return self._aug(img) + else: + pil_img = np.array(img) + aug_img = self._aug(pil_img) + img = Image.fromarray(aug_img.astype(np.uint8)) + return img + else: + return img + + class OperatorParamError(ValueError): """ OperatorParamError """ @@ -170,6 +206,52 @@ class ResizeImage(object): return self._resize_func(img, (w, h)) +class CropWithPadding(RandomResizedCrop): + """ + crop image and padding to original size + """ + + def __init__(self, + prob=1, + padding_num=0, + size=224, + scale=(0.08, 1.0), + ratio=(3. / 4, 4. / 3), + interpolation='bilinear', + key=None): + super().__init__(size, scale, ratio, interpolation, key) + self.prob = prob + self.padding_num = padding_num + + def __call__(self, img): + is_cv2_img = False + if isinstance(img, np.ndarray): + flag = True + if np.random.random() < self.prob: + # RandomResizedCrop augmentation + new = np.zeros_like(np.array(img)) + self.padding_num + # orig_W, orig_H = F._get_image_size(sample) + orig_W, orig_H = self._get_image_size(img) + i, j, h, w = self._get_param(img) + cropped = F.crop(img, i, j, h, w) + new[i:i + h, j:j + w, :] = np.array(cropped) + if not isinstance: + new = Image.fromarray(new.astype(np.uint8)) + return new + else: + return img + + def _get_image_size(self, img): + if F._is_pil_image(img): + return img.size + elif F._is_numpy_image(img): + return img.shape[:2][::-1] + elif F._is_tensor_image(img): + return img.shape[1:][::-1] # chw + else: + raise TypeError("Unexpected type {}".format(type(img))) + + class CropImage(object): """ crop image """ @@ -434,16 +516,18 @@ class ColorJitter(RawColorJitter): """ColorJitter. """ - def __init__(self, *args, **kwargs): + def __init__(self, prob=2, *args, **kwargs): super().__init__(*args, **kwargs) + self.prob = prob def __call__(self, img): - if not isinstance(img, Image.Image): - img = np.ascontiguousarray(img) - img = Image.fromarray(img) - img = super()._apply_image(img) - if isinstance(img, Image.Image): - img = np.asarray(img) + if np.random.random() < self.prob: + if not isinstance(img, Image.Image): + img = np.ascontiguousarray(img) + img = Image.fromarray(img) + img = super()._apply_image(img) + if isinstance(img, Image.Image): + img = np.asarray(img) return img