diff --git a/docs/zh_CN/quick_start/quick_start_recognition.md b/docs/zh_CN/quick_start/quick_start_recognition.md index 91c1d7208419dd1fe04bd80db34126aa4c01b0dc..1cc1b8d1602e25a97d1406cf6178dd3fd84f9faa 100644 --- a/docs/zh_CN/quick_start/quick_start_recognition.md +++ b/docs/zh_CN/quick_start/quick_start_recognition.md @@ -42,9 +42,10 @@ ### 1.1 安装 PP-ShiTu android demo 可以通过扫描二维码或者[点击链接](https://paddle-imagenet-models-name.bj.bcebos.com/demos/PP-ShiTu.apk)下载并安装APP + **注:** 华为鸿蒙OS 3.0的系统可能会出现无法调用摄像头的情况,建议更换低版本系统或者使用其它安卓机型进行快速体验。 - +
diff --git a/ppcls/arch/backbone/__init__.py b/ppcls/arch/backbone/__init__.py index 545725f71c23cfb0fa7198dd121fd1ff865fc760..bfc96a57dde765a23dbb0cb54909402bc820eede 100644 --- a/ppcls/arch/backbone/__init__.py +++ b/ppcls/arch/backbone/__init__.py @@ -73,6 +73,7 @@ from .model_zoo.convnext import ConvNeXt_tiny from .variant_models.resnet_variant import ResNet50_last_stage_stride1 from .variant_models.vgg_variant import VGG19Sigmoid from .variant_models.pp_lcnet_variant import PPLCNet_x2_5_Tanh +from .variant_models.pp_lcnetv2_variant import PPLCNetV2_base_ShiTu from .model_zoo.adaface_ir_net import AdaFace_IR_18, AdaFace_IR_34, AdaFace_IR_50, AdaFace_IR_101, AdaFace_IR_152, AdaFace_IR_SE_50, AdaFace_IR_SE_101, AdaFace_IR_SE_152, AdaFace_IR_SE_200 diff --git a/ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py b/ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py index 40264092a47deb1e11ed11d2edbda7135f0b5a75..ea24489c16c9d1281b3555546c5786a2168a8a38 100644 --- a/ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py +++ b/ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py @@ -126,6 +126,8 @@ class RepDepthwiseSeparable(TheseusLayer): use_se=False, use_shortcut=False): super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels self.is_repped = False self.dw_size = dw_size @@ -306,8 +308,8 @@ class PPLCNetV2(TheseusLayer): self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer") self.flatten = nn.Flatten(start_axis=1, stop_axis=-1) - in_features = self.class_expand if self.use_last_conv else NET_CONFIG[ - "stage4"][0] * 2 * scale + in_features = self.class_expand if self.use_last_conv else make_divisible( + NET_CONFIG["stage4"][0] * 2 * scale) self.fc = Linear(in_features, class_num) def forward(self, x): diff --git a/ppcls/arch/backbone/variant_models/__init__.py b/ppcls/arch/backbone/variant_models/__init__.py index 75cf29ffa9c59b744972a9e82fba7a506219e83b..d2fcd0bdd9b83c6e87a6a0684382c380e5fff93a 100644 --- a/ppcls/arch/backbone/variant_models/__init__.py +++ b/ppcls/arch/backbone/variant_models/__init__.py @@ -1,3 +1,4 @@ from .resnet_variant import ResNet50_last_stage_stride1 from .vgg_variant import VGG19Sigmoid from .pp_lcnet_variant import PPLCNet_x2_5_Tanh +from .pp_lcnetv2_variant import PPLCNetV2_base_ShiTu diff --git a/ppcls/arch/backbone/variant_models/pp_lcnetv2_variant.py b/ppcls/arch/backbone/variant_models/pp_lcnetv2_variant.py new file mode 100644 index 0000000000000000000000000000000000000000..6acccdc8e5c115cf4e1e6b213ab3ea3ffcc710b3 --- /dev/null +++ b/ppcls/arch/backbone/variant_models/pp_lcnetv2_variant.py @@ -0,0 +1,56 @@ +from paddle.nn import Conv2D, Identity + +from ..legendary_models.pp_lcnet_v2 import MODEL_URLS, PPLCNetV2_base, RepDepthwiseSeparable, _load_pretrained + +__all__ = ["PPLCNetV2_base_ShiTu"] + + +def PPLCNetV2_base_ShiTu(pretrained=False, use_ssld=False, **kwargs): + """ + An variant network of PPLCNetV2_base + 1. remove ReLU layer after last_conv + 2. add bias to last_conv + 3. change stride to 1 in last two RepDepthwiseSeparable Block + """ + model = PPLCNetV2_base(pretrained=False, use_ssld=use_ssld, **kwargs) + + def remove_ReLU_function(conv, pattern): + new_conv = Identity() + return new_conv + + def add_bias_last_conv(conv, pattern): + new_conv = Conv2D( + in_channels=conv._in_channels, + out_channels=conv._out_channels, + kernel_size=conv._kernel_size, + stride=conv._stride, + padding=conv._padding, + groups=conv._groups, + bias_attr=True) + return new_conv + + def last_stride_function(rep_block, pattern): + new_conv = RepDepthwiseSeparable( + in_channels=rep_block.in_channels, + out_channels=rep_block.out_channels, + stride=1, + dw_size=rep_block.dw_size, + split_pw=rep_block.split_pw, + use_rep=rep_block.use_rep, + use_se=rep_block.use_se, + use_shortcut=rep_block.use_shortcut) + return new_conv + + pattern_act = ["act"] + pattern_lastconv = ["last_conv"] + pattern_last_stride = [ + "stages[3][0]", + "stages[3][1]", + ] + model.upgrade_sublayer(pattern_act, remove_ReLU_function) + model.upgrade_sublayer(pattern_lastconv, add_bias_last_conv) + model.upgrade_sublayer(pattern_last_stride, last_stride_function) + + # load params again after upgrade some layers + _load_pretrained(pretrained, model, MODEL_URLS["PPLCNetV2_base"], use_ssld) + return model diff --git a/ppcls/configs/GeneralRecognitionV2/GeneralRecognitionV2_PPLCNetV2_base.yaml b/ppcls/configs/GeneralRecognitionV2/GeneralRecognitionV2_PPLCNetV2_base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e6dfde7cdde9b88772ac414cb0de1646daf9c304 --- /dev/null +++ b/ppcls/configs/GeneralRecognitionV2/GeneralRecognitionV2_PPLCNetV2_base.yaml @@ -0,0 +1,205 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 100 + print_batch_step: 20 + use_visualdl: False + eval_mode: retrieval + retrieval_feature_from: features # 'backbone' or 'features' + re_ranking: False + use_dali: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + +AMP: + scale_loss: 65536 + use_dynamic_loss_scaling: True + # O1: mixed fp16 + level: O1 + +# model architecture +Arch: + name: RecModel + infer_output_key: features + infer_add_softmax: False + + Backbone: + name: PPLCNetV2_base_ShiTu + pretrained: True + use_ssld: True + class_expand: &feat_dim 512 + BackboneStopLayer: + name: flatten + Neck: + name: BNNeck + num_features: *feat_dim + weight_attr: + initializer: + name: Constant + value: 1.0 + bias_attr: + initializer: + name: Constant + value: 0.0 + learning_rate: 1.0e-20 # NOTE: Temporarily set lr small enough to freeze the bias to zero + Head: + name: FC + embedding_size: *feat_dim + class_num: 192612 + weight_attr: + initializer: + name: Normal + std: 0.001 + bias_attr: False + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + - TripletAngularMarginLoss: + weight: 1.0 + feature_from: features + margin: 0.5 + reduction: mean + add_absolute: True + absolute_loss_weight: 0.1 + normalize_feature: True + ap_value: 0.8 + an_value: 0.4 + Eval: + - CELoss: + weight: 1.0 + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: Cosine + learning_rate: 0.06 # for 8gpu x 256bs + warmup_epoch: 5 + regularizer: + name: L2 + coeff: 0.00001 + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ + cls_label_path: ./dataset/train_reg_all_data_v2.txt + relabel: True + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + size: [224, 224] + return_numpy: False + interpolation: bilinear + backend: cv2 + - RandFlipImage: + flip_code: 1 + - Pad: + padding: 10 + backend: cv2 + - RandCropImageV2: + size: [224, 224] + - RandomRotation: + prob: 0.5 + degrees: 90 + interpolation: bilinear + - ResizeImage: + size: [224, 224] + return_numpy: False + interpolation: bilinear + backend: cv2 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: hwc + sampler: + name: PKSampler + batch_size: 256 + sample_per_id: 4 + drop_last: False + shuffle: True + sample_method: "id_avg_prob" + id_list: [50030, 80700, 92019, 96015] # be careful when set relabel=True + ratio: [4, 4] + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + Query: + dataset: + name: VeriWild + image_root: ./dataset/Aliproduct/ + cls_label_path: ./dataset/Aliproduct/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + size: [224, 224] + return_numpy: False + interpolation: bilinear + backend: cv2 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: hwc + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + + Gallery: + dataset: + name: VeriWild + image_root: ./dataset/Aliproduct/ + cls_label_path: ./dataset/Aliproduct/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + size: [224, 224] + return_numpy: False + interpolation: bilinear + backend: cv2 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: hwc + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Metric: + Eval: + - Recallk: + topk: [1, 5] + - mAP: {} diff --git a/ppcls/data/dataloader/imagenet_dataset.py b/ppcls/data/dataloader/imagenet_dataset.py index 0e2b1c8a0b4cbc7e2f251ed2363fca71ec3239e7..51f1fbec3df24164a8c7c41a29a0ad91128120aa 100644 --- a/ppcls/data/dataloader/imagenet_dataset.py +++ b/ppcls/data/dataloader/imagenet_dataset.py @@ -21,27 +21,54 @@ from .common_dataset import CommonDataset class ImageNetDataset(CommonDataset): - def __init__( - self, - image_root, - cls_label_path, - transform_ops=None, - delimiter=None): + """ImageNetDataset + + Args: + image_root (str): image root, path to `ILSVRC2012` + cls_label_path (str): path to annotation file `train_list.txt` or 'val_list.txt` + transform_ops (list, optional): list of transform op(s). Defaults to None. + delimiter (str, optional): delimiter. Defaults to None. + relabel (bool, optional): whether do relabel when original label do not starts from 0 or are discontinuous. Defaults to False. + """ + def __init__(self, + image_root, + cls_label_path, + transform_ops=None, + delimiter=None, + relabel=False): self.delimiter = delimiter if delimiter is not None else " " - super(ImageNetDataset, self).__init__(image_root, cls_label_path, transform_ops) + self.relabel = relabel + super(ImageNetDataset, self).__init__(image_root, cls_label_path, + transform_ops) def _load_anno(self, seed=None): - assert os.path.exists(self._cls_path) - assert os.path.exists(self._img_root) + assert os.path.exists( + self._cls_path), f"path {self._cls_path} does not exist." + assert os.path.exists( + self._img_root), f"path {self._img_root} does not exist." self.images = [] self.labels = [] with open(self._cls_path) as fd: lines = fd.readlines() + if self.relabel: + label_set = set() + for line in lines: + line = line.strip().split(self.delimiter) + label_set.add(np.int64(line[1])) + label_map = { + oldlabel: newlabel + for newlabel, oldlabel in enumerate(label_set) + } + if seed is not None: np.random.RandomState(seed).shuffle(lines) - for l in lines: - l = l.strip().split(self.delimiter) - self.images.append(os.path.join(self._img_root, l[0])) - self.labels.append(np.int64(l[1])) - assert os.path.exists(self.images[-1]) + for line in lines: + line = line.strip().split(self.delimiter) + self.images.append(os.path.join(self._img_root, line[0])) + if self.relabel: + self.labels.append(label_map[np.int64(line[1])]) + else: + self.labels.append(np.int64(line[1])) + assert os.path.exists(self.images[ + -1]), f"path {self.images[-1]} does not exist." diff --git a/ppcls/data/dataloader/pk_sampler.py b/ppcls/data/dataloader/pk_sampler.py index 69d1a7c83001e0ea326b30082093fee2f83d3b8a..a4081b5c31f3fe37ae18bd9793cc030e479a77ab 100644 --- a/ppcls/data/dataloader/pk_sampler.py +++ b/ppcls/data/dataloader/pk_sampler.py @@ -32,17 +32,23 @@ class PKSampler(DistributedBatchSampler): batch_size (int): batch size sample_per_id (int): number of instance(s) within an class shuffle (bool, optional): _description_. Defaults to True. + id_list(list): list of (start_id, end_id, start_id, end_id) for set of ids to duplicated. + ratio(list): list of (ratio1, ratio2..) the duplication number for ids in id_list. drop_last (bool, optional): whether to discard the data at the end. Defaults to True. sample_method (str, optional): sample method when generating prob_list. Defaults to "sample_avg_prob". """ + def __init__(self, dataset, batch_size, sample_per_id, shuffle=True, drop_last=True, + id_list=None, + ratio=None, sample_method="sample_avg_prob"): - super().__init__(dataset, batch_size, shuffle=shuffle, drop_last=drop_last) + super().__init__( + dataset, batch_size, shuffle=shuffle, drop_last=drop_last) assert batch_size % sample_per_id == 0, \ f"PKSampler configs error, sample_per_id({sample_per_id}) must be a divisor of batch_size({batch_size})." assert hasattr(self.dataset, @@ -67,6 +73,16 @@ class PKSampler(DistributedBatchSampler): logger.error( "PKSampler only support id_avg_prob and sample_avg_prob sample method, " "but receive {}.".format(self.sample_method)) + + if id_list and ratio: + assert len(id_list) % 2 == 0 and len(id_list) == len(ratio) * 2 + for i in range(len(self.prob_list)): + for j in range(len(ratio)): + if i >= id_list[j * 2] and i <= id_list[j * 2 + 1]: + self.prob_list[i] = self.prob_list[i] * ratio[j] + break + self.prob_list = self.prob_list / sum(self.prob_list) + diff = np.abs(sum(self.prob_list) - 1) if diff > 0.00000001: self.prob_list[-1] = 1 - sum(self.prob_list[:-1]) @@ -74,8 +90,8 @@ class PKSampler(DistributedBatchSampler): logger.error("PKSampler prob list error") else: logger.info( - "PKSampler: sum of prob list not equal to 1, diff is {}, change the last prob".format(diff) - ) + "PKSampler: sum of prob list not equal to 1, diff is {}, change the last prob". + format(diff)) def __iter__(self): label_per_batch = self.batch_size // self.sample_per_label diff --git a/ppcls/data/dataloader/vehicle_dataset.py b/ppcls/data/dataloader/vehicle_dataset.py index 2981a57a0516aa25145f39479a34635b3be063f8..e4fbcad6a48d31b02a6fac6063ccb10d4dccdb48 100644 --- a/ppcls/data/dataloader/vehicle_dataset.py +++ b/ppcls/data/dataloader/vehicle_dataset.py @@ -89,11 +89,7 @@ class CompCars(Dataset): class VeriWild(Dataset): - def __init__( - self, - image_root, - cls_label_path, - transform_ops=None, ): + def __init__(self, image_root, cls_label_path, transform_ops=None): self._img_root = image_root self._cls_path = cls_label_path if transform_ops: @@ -102,19 +98,23 @@ class VeriWild(Dataset): self._load_anno() def _load_anno(self): - assert os.path.exists(self._cls_path) - assert os.path.exists(self._img_root) + assert os.path.exists( + self._cls_path), f"path {self._cls_path} does not exist." + assert os.path.exists( + self._img_root), f"path {self._img_root} does not exist." self.images = [] self.labels = [] self.cameras = [] with open(self._cls_path) as fd: lines = fd.readlines() - for l in lines: - l = l.strip().split() - self.images.append(os.path.join(self._img_root, l[0])) - self.labels.append(np.int64(l[1])) - self.cameras.append(np.int64(l[2])) + for line in lines: + line = line.strip().split() + self.images.append(os.path.join(self._img_root, line[0])) + self.labels.append(np.int64(line[1])) + if len(line) >= 3: + self.cameras.append(np.int64(line[2])) assert os.path.exists(self.images[-1]) + self.has_camera = len(self.cameras) > 0 def __getitem__(self, idx): try: @@ -123,7 +123,10 @@ class VeriWild(Dataset): if self._transform_ops: img = transform(img, self._transform_ops) img = img.transpose((2, 0, 1)) - return (img, self.labels[idx], self.cameras[idx]) + if self.has_camera: + return (img, self.labels[idx], self.cameras[idx]) + else: + return (img, self.labels[idx]) except Exception as ex: logger.error("Exception occured when parse line: {} with msg: {}". format(self.images[idx], ex)) diff --git a/ppcls/data/preprocess/__init__.py b/ppcls/data/preprocess/__init__.py index d0cfcf2409d2d890adcf03ef0e03b2475625ead8..0a91a85b4275fccb64e665696548f71efe5ccbd0 100644 --- a/ppcls/data/preprocess/__init__.py +++ b/ppcls/data/preprocess/__init__.py @@ -38,6 +38,7 @@ from ppcls.data.preprocess.ops.operators import CropWithPadding from ppcls.data.preprocess.ops.operators import RandomInterpolationAugment from ppcls.data.preprocess.ops.operators import ColorJitter from ppcls.data.preprocess.ops.operators import RandomCropImage +from ppcls.data.preprocess.ops.operators import RandomRotation from ppcls.data.preprocess.ops.operators import Padv2 from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator diff --git a/ppcls/data/preprocess/ops/operators.py b/ppcls/data/preprocess/ops/operators.py index e617b8a71afffeb9e18e4be412f5a3374bd387ec..c70b9cb723dce77755189aad16f3839046673f35 100644 --- a/ppcls/data/preprocess/ops/operators.py +++ b/ppcls/data/preprocess/ops/operators.py @@ -26,6 +26,7 @@ import cv2 import numpy as np from PIL import Image, ImageOps, __version__ as PILLOW_VERSION from paddle.vision.transforms import ColorJitter as RawColorJitter +from paddle.vision.transforms import RandomRotation as RawRandomRotation from paddle.vision.transforms import ToTensor, Normalize, RandomHorizontalFlip, RandomResizedCrop from paddle.vision.transforms import functional as F from .autoaugment import ImageNetPolicy @@ -181,7 +182,8 @@ class DecodeImage(object): img = np.asarray(img)[:, :, ::-1] # BRG if self.to_rgb: - assert img.shape[2] == 3, f"invalid shape of image[{img.shape}]" + assert img.shape[ + 2] == 3, f"invalid shape of image[{img.shape}]" img = img[:, :, ::-1] if self.channel_first: @@ -495,7 +497,13 @@ class RandFlipImage(object): if isinstance(img, np.ndarray): return cv2.flip(img, self.flip_code) else: - return img.transpose(Image.FLIP_LEFT_RIGHT) + if self.flip_code == 1: + return img.transpose(Image.FLIP_LEFT_RIGHT) + elif self.flip_code == 0: + return img.transpose(Image.FLIP_TOP_BOTTOM) + else: + return img.transpose(Image.FLIP_LEFT_RIGHT).transpose( + Image.FLIP_LEFT_RIGHT) else: return img @@ -653,17 +661,38 @@ class ColorJitter(RawColorJitter): return img +class RandomRotation(RawRandomRotation): + """RandomRotation. + """ + + def __init__(self, prob=0.5, *args, **kwargs): + super().__init__(*args, **kwargs) + self.prob = prob + + def __call__(self, img): + if np.random.random() < self.prob: + img = super()._apply_image(img) + return img + + class Pad(object): """ Pads the given PIL.Image on all sides with specified padding mode and fill value. adapted from: https://pytorch.org/vision/stable/_modules/torchvision/transforms/transforms.html#Pad """ - def __init__(self, padding: int, fill: int=0, - padding_mode: str="constant"): + def __init__(self, + padding: int, + fill: int=0, + padding_mode: str="constant", + backend: str="pil"): self.padding = padding self.fill = fill self.padding_mode = padding_mode + self.backend = backend + assert backend in [ + "pil", "cv2" + ], f"backend must in ['pil', 'cv2'], but got {backend}" def _parse_fill(self, fill, img, min_pil_version, name="fillcolor"): # Process fill color for affine transforms @@ -698,11 +727,21 @@ class Pad(object): return {name: fill} def __call__(self, img): - opts = self._parse_fill(self.fill, img, "2.3.0", name="fill") - if img.mode == "P": - palette = img.getpalette() - img = ImageOps.expand(img, border=self.padding, **opts) - img.putpalette(palette) + if self.backend == "pil": + opts = self._parse_fill(self.fill, img, "2.3.0", name="fill") + if img.mode == "P": + palette = img.getpalette() + img = ImageOps.expand(img, border=self.padding, **opts) + img.putpalette(palette) + return img + return ImageOps.expand(img, border=self.padding, **opts) + else: + img = cv2.copyMakeBorder( + img, + self.padding, + self.padding, + self.padding, + self.padding, + cv2.BORDER_CONSTANT, + value=(self.fill, self.fill, self.fill)) return img - - return ImageOps.expand(img, border=self.padding, **opts) diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 1aa0a1e05c306f46c77ff09b3fb6af344d3e01e3..5a3fe20a29bdc926ad9c46dca4122a26c8939747 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -114,6 +114,7 @@ class Engine(object): #TODO(gaotingquan): support rec class_num = config["Arch"].get("class_num", None) self.config["DataLoader"].update({"class_num": class_num}) + # build dataloader if self.mode == 'train': self.train_dataloader = build_dataloader( diff --git a/ppcls/loss/__init__.py b/ppcls/loss/__init__.py index 489aea7fb1ee4b4a8ff9388f3984e64965a1eac7..4f973eb09aef909d3be2edb9cc5ac2bc2e8b58f3 100644 --- a/ppcls/loss/__init__.py +++ b/ppcls/loss/__init__.py @@ -12,6 +12,7 @@ from .msmloss import MSMLoss from .npairsloss import NpairsLoss from .trihardloss import TriHardLoss from .triplet import TripletLoss, TripletLossV2 +from .tripletangularmarginloss import TripletAngularMarginLoss from .supconloss import SupConLoss from .pairwisecosface import PairwiseCosface from .dmlloss import DMLLoss diff --git a/ppcls/loss/tripletangularmarginloss.py b/ppcls/loss/tripletangularmarginloss.py new file mode 100644 index 0000000000000000000000000000000000000000..3a91d2d499fa22aadc7ca15322f4048b978fb19d --- /dev/null +++ b/ppcls/loss/tripletangularmarginloss.py @@ -0,0 +1,115 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn + + +class TripletAngularMarginLoss(nn.Layer): + """A more robust triplet loss with hard positive/negative mining on angular margin instead of relative distance between d(a,p) and d(a,n). + + Args: + margin (float, optional): angular margin. Defaults to 0.5. + normalize_feature (bool, optional): whether to apply L2-norm in feature before computing distance(cos-similarity). Defaults to True. + reduction (str, optional): reducing option within an batch . Defaults to "mean". + add_absolute (bool, optional): whether add absolute loss within d(a,p) or d(a,n). Defaults to False. + absolute_loss_weight (float, optional): weight for absolute loss. Defaults to 1.0. + ap_value (float, optional): weight for d(a, p). Defaults to 0.9. + an_value (float, optional): weight for d(a, n). Defaults to 0.5. + feature_from (str, optional): which key feature from. Defaults to "features". + """ + + def __init__(self, + margin=0.5, + normalize_feature=True, + reduction="mean", + add_absolute=False, + absolute_loss_weight=1.0, + ap_value=0.9, + an_value=0.5, + feature_from="features"): + super(TripletAngularMarginLoss, self).__init__() + self.margin = margin + self.feature_from = feature_from + self.ranking_loss = paddle.nn.loss.MarginRankingLoss( + margin=margin, reduction=reduction) + self.normalize_feature = normalize_feature + self.add_absolute = add_absolute + self.ap_value = ap_value + self.an_value = an_value + self.absolute_loss_weight = absolute_loss_weight + + def forward(self, input, target): + """ + Args: + inputs: feature matrix with shape (batch_size, feat_dim) + target: ground truth labels with shape (num_classes) + """ + inputs = input[self.feature_from] + + if self.normalize_feature: + inputs = paddle.divide( + inputs, paddle.norm( + inputs, p=2, axis=-1, keepdim=True)) + + bs = inputs.shape[0] + + # compute distance(cos-similarity) + dist = paddle.matmul(inputs, inputs.t()) + + # hard negative mining + is_pos = paddle.expand(target, ( + bs, bs)).equal(paddle.expand(target, (bs, bs)).t()) + is_neg = paddle.expand(target, ( + bs, bs)).not_equal(paddle.expand(target, (bs, bs)).t()) + + # `dist_ap` means distance(anchor, positive) + # both `dist_ap` and `relative_p_inds` with shape [N, 1] + dist_ap = paddle.min(paddle.reshape( + paddle.masked_select(dist, is_pos), (bs, -1)), + axis=1, + keepdim=True) + # `dist_an` means distance(anchor, negative) + # both `dist_an` and `relative_n_inds` with shape [N, 1] + dist_an = paddle.max(paddle.reshape( + paddle.masked_select(dist, is_neg), (bs, -1)), + axis=1, + keepdim=True) + # shape [N] + dist_ap = paddle.squeeze(dist_ap, axis=1) + dist_an = paddle.squeeze(dist_an, axis=1) + + # Compute ranking hinge loss + y = paddle.ones_like(dist_an) + loss = self.ranking_loss(dist_ap, dist_an, y) + + if self.add_absolute: + absolut_loss_ap = self.ap_value - dist_ap + absolut_loss_ap = paddle.where(absolut_loss_ap > 0, + absolut_loss_ap, + paddle.zeros_like(absolut_loss_ap)) + + absolut_loss_an = dist_an - self.an_value + absolut_loss_an = paddle.where(absolut_loss_an > 0, + absolut_loss_an, + paddle.ones_like(absolut_loss_an)) + + loss = (absolut_loss_an.mean() + absolut_loss_ap.mean() + ) * self.absolute_loss_weight + loss.mean() + + return {"TripletAngularMarginLoss": loss} diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py index 0c803ccfdbb29216381625ea3df4a4540c7b56c0..b6dc934f31c04b0df2a90e63fed48973dddff1ca 100644 --- a/ppcls/metric/metrics.py +++ b/ppcls/metric/metrics.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from cmath import nan import numpy as np import paddle import paddle.nn as nn @@ -97,6 +98,11 @@ class mAP(nn.Layer): num_rel = paddle.greater_than(num_rel, paddle.to_tensor(0.)) num_rel_index = paddle.nonzero(num_rel.astype("int")) num_rel_index = paddle.reshape(num_rel_index, [num_rel_index.shape[0]]) + + if paddle.numel(num_rel_index).item() == 0: + metric_dict["mAP"] = np.nan + return metric_dict + equal_flag = paddle.index_select(equal_flag, num_rel_index, axis=0) acc_sum = paddle.cumsum(equal_flag, axis=1)