From 1b5e00e82a58bc4e77305e8740859b19ba1faf76 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 23 Aug 2022 07:54:58 +0000 Subject: [PATCH] add PP-ShiTuV2 code --- ppcls/arch/backbone/__init__.py | 1 + ppcls/arch/backbone/base/theseus_layer.py | 2 + .../backbone/legendary_models/pp_lcnet_v2.py | 4 +- .../arch/backbone/variant_models/__init__.py | 1 + .../variant_models/pp_lcnetv2_variant.py | 44 ++++ .../GeneralRecognitionV2_PPLCNetV2_base.yaml | 198 ++++++++++++++++++ ppcls/data/dataloader/imagenet_dataset.py | 25 +-- ppcls/data/dataloader/vehicle_dataset.py | 23 +- ppcls/data/preprocess/__init__.py | 1 + ppcls/data/preprocess/ops/operators.py | 26 ++- ppcls/engine/engine.py | 4 + ppcls/loss/__init__.py | 1 + ppcls/loss/tripletangularmarginloss.py | 115 ++++++++++ 13 files changed, 418 insertions(+), 27 deletions(-) create mode 100644 ppcls/arch/backbone/variant_models/pp_lcnetv2_variant.py create mode 100644 ppcls/configs/GeneralRecognitionV2/GeneralRecognitionV2_PPLCNetV2_base.yaml create mode 100644 ppcls/loss/tripletangularmarginloss.py diff --git a/ppcls/arch/backbone/__init__.py b/ppcls/arch/backbone/__init__.py index 545725f7..bfc96a57 100644 --- a/ppcls/arch/backbone/__init__.py +++ b/ppcls/arch/backbone/__init__.py @@ -73,6 +73,7 @@ from .model_zoo.convnext import ConvNeXt_tiny from .variant_models.resnet_variant import ResNet50_last_stage_stride1 from .variant_models.vgg_variant import VGG19Sigmoid from .variant_models.pp_lcnet_variant import PPLCNet_x2_5_Tanh +from .variant_models.pp_lcnetv2_variant import PPLCNetV2_base_ShiTu from .model_zoo.adaface_ir_net import AdaFace_IR_18, AdaFace_IR_34, AdaFace_IR_50, AdaFace_IR_101, AdaFace_IR_152, AdaFace_IR_SE_50, AdaFace_IR_SE_101, AdaFace_IR_SE_152, AdaFace_IR_SE_200 diff --git a/ppcls/arch/backbone/base/theseus_layer.py b/ppcls/arch/backbone/base/theseus_layer.py index a533cdc7..6a4d6c0a 100644 --- a/ppcls/arch/backbone/base/theseus_layer.py +++ b/ppcls/arch/backbone/base/theseus_layer.py @@ -158,6 +158,8 @@ class TheseusLayer(nn.Layer): return False parent_layer = layer_dict["layer"] + msg = f"Successfully set the layers that after stop_layer_name('{stop_layer_name}') to IdentityLayer." + logger.info(msg) return True def update_res( diff --git a/ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py b/ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py index 40264092..b48d33e0 100644 --- a/ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py +++ b/ppcls/arch/backbone/legendary_models/pp_lcnet_v2.py @@ -306,8 +306,8 @@ class PPLCNetV2(TheseusLayer): self.dropout = Dropout(p=dropout_prob, mode="downscale_in_infer") self.flatten = nn.Flatten(start_axis=1, stop_axis=-1) - in_features = self.class_expand if self.use_last_conv else NET_CONFIG[ - "stage4"][0] * 2 * scale + in_features = self.class_expand if self.use_last_conv else make_divisible( + NET_CONFIG["stage4"][0] * 2 * scale) self.fc = Linear(in_features, class_num) def forward(self, x): diff --git a/ppcls/arch/backbone/variant_models/__init__.py b/ppcls/arch/backbone/variant_models/__init__.py index 75cf29ff..d2fcd0bd 100644 --- a/ppcls/arch/backbone/variant_models/__init__.py +++ b/ppcls/arch/backbone/variant_models/__init__.py @@ -1,3 +1,4 @@ from .resnet_variant import ResNet50_last_stage_stride1 from .vgg_variant import VGG19Sigmoid from .pp_lcnet_variant import PPLCNet_x2_5_Tanh +from .pp_lcnetv2_variant import PPLCNetV2_base_ShiTu diff --git a/ppcls/arch/backbone/variant_models/pp_lcnetv2_variant.py b/ppcls/arch/backbone/variant_models/pp_lcnetv2_variant.py new file mode 100644 index 00000000..ebde9af8 --- /dev/null +++ b/ppcls/arch/backbone/variant_models/pp_lcnetv2_variant.py @@ -0,0 +1,44 @@ +from paddle.nn import Conv2D, Identity +from ..legendary_models.pp_lcnet_v2 import PPLCNetV2_base, RepDepthwiseSeparable, MODEL_URLS, _load_pretrained + +__all__ = ["PPLCNetV2_base_ShiTu"] + + +def PPLCNetV2_base_ShiTu(pretrained=False, use_ssld=False, **kwargs): + + model = PPLCNetV2_base(pretrained=False, use_ssld=use_ssld, **kwargs) + + def remove_ReLU_function(conv, pattern): + new_conv = Identity() + return new_conv + + def last_stride_1_function(conv, pattern): + new_conv = Conv2D( + weight_attr=conv._weight_attr, + in_channels=conv._in_channels, + out_channels=conv._out_channels, + kernel_size=conv._kernel_size, + stride=1, + padding=conv._padding, + groups=conv._groups, + bias_attr=conv._bias_attr) + return new_conv + + pattern_act = ["act"] + pattern_last_stride = [ + "stages[3][0].dw_conv_list[0].conv", + "stages[3][0].dw_conv_list[1].conv", + "stages[3][0].dw_conv", + "stages[3][0].pw_conv.conv", + "stages[3][1].dw_conv_list[0].conv", + "stages[3][1].dw_conv_list[1].conv", + "stages[3][1].dw_conv_list[2].conv", + "stages[3][1].dw_conv", + "stages[3][1].pw_conv.conv", + ] + model.upgrade_sublayer(pattern_last_stride, last_stride_1_function) + model.upgrade_sublayer(pattern_act, remove_ReLU_function) + + # load params again after upgrade some layers + _load_pretrained(pretrained, model, MODEL_URLS["PPLCNetV2_base"], use_ssld) + return model diff --git a/ppcls/configs/GeneralRecognitionV2/GeneralRecognitionV2_PPLCNetV2_base.yaml b/ppcls/configs/GeneralRecognitionV2/GeneralRecognitionV2_PPLCNetV2_base.yaml new file mode 100644 index 00000000..b3babb25 --- /dev/null +++ b/ppcls/configs/GeneralRecognitionV2/GeneralRecognitionV2_PPLCNetV2_base.yaml @@ -0,0 +1,198 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 100 + print_batch_step: 20 + use_visualdl: False + eval_mode: retrieval + retrieval_feature_from: features # 'backbone' or 'features' + re_ranking: False + use_dali: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + +# AMP: +# scale_loss: 65536 +# use_dynamic_loss_scaling: True +# # O1: mixed fp16 +# level: O1 + +# model architecture +Arch: + name: RecModel + infer_output_key: features + infer_add_softmax: False + + Backbone: + name: PPLCNetV2_base_ShiTu + pretrained: True + use_ssld: True + class_expand: &feat_dim 512 + BackboneStopLayer: + name: flatten + Neck: + name: BNNeck + num_features: *feat_dim + weight_attr: + initializer: + name: Constant + value: 1.0 + bias_attr: + initializer: + name: Constant + value: 0.0 + learning_rate: 1.0e-20 # NOTE: Temporarily set lr small enough to freeze the bias to zero + Head: + name: FC + embedding_size: *feat_dim + class_num: 192612 + weight_attr: + initializer: + name: Normal + std: 0.001 + bias_attr: False + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + - TripletAngleMarinLoss: + weight: 1.0 + margin: 0.5 + reduction: mean + add_absolute: True + absolute_loss_weight: 0.1 + normalize_feature: True + feature_from: features + ap_value: 0.8 + an_value: 0.4 + Eval: + - CELoss: + weight: 1.0 + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: Cosine + learning_rate: 0.04 + warmup_epoch: 5 + regularizer: + name: L2 + coeff: 0.00001 + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetDataset + image_root: ./dataset/ + cls_label_path: ./dataset/train_reg_all_data.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + size: [224, 224] + return_numpy: False + interpolation: bilinear + backend: cv2 + - RandFlipImage: + flip_code: 1 + - Pad_cv2: + padding: 10 + - RandCropImageV2: + size: [224, 224] + - RandomRotation: + prob: 0.5 + degrees: 90 + interpolation: bilinear + - ResizeImage: + size: [224, 224] + return_numpy: False + interpolation: bilinear + backend: cv2 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: hwc + sampler: + name: DistributedBatchSampler + batch_size: 256 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + Query: + dataset: + name: VeriWild + image_root: ./dataset/Aliproduct/ + cls_label_path: ./dataset/Aliproduct/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + size: [224, 224] + return_numpy: False + interpolation: bilinear + backend: cv2 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: hwc + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + + Gallery: + dataset: + name: VeriWild + image_root: ./dataset/Aliproduct/ + cls_label_path: ./dataset/Aliproduct/val_list.txt + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + size: [224, 224] + return_numpy: False + interpolation: bilinear + backend: cv2 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: hwc + sampler: + name: DistributedBatchSampler + batch_size: 64 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Metric: + Eval: + - Recallk: + topk: [1, 5] diff --git a/ppcls/data/dataloader/imagenet_dataset.py b/ppcls/data/dataloader/imagenet_dataset.py index 0e2b1c8a..87188c16 100644 --- a/ppcls/data/dataloader/imagenet_dataset.py +++ b/ppcls/data/dataloader/imagenet_dataset.py @@ -21,14 +21,14 @@ from .common_dataset import CommonDataset class ImageNetDataset(CommonDataset): - def __init__( - self, - image_root, - cls_label_path, - transform_ops=None, - delimiter=None): + def __init__(self, + image_root, + cls_label_path, + transform_ops=None, + delimiter=None): self.delimiter = delimiter if delimiter is not None else " " - super(ImageNetDataset, self).__init__(image_root, cls_label_path, transform_ops) + super(ImageNetDataset, self).__init__(image_root, cls_label_path, + transform_ops) def _load_anno(self, seed=None): assert os.path.exists(self._cls_path) @@ -40,8 +40,9 @@ class ImageNetDataset(CommonDataset): lines = fd.readlines() if seed is not None: np.random.RandomState(seed).shuffle(lines) - for l in lines: - l = l.strip().split(self.delimiter) - self.images.append(os.path.join(self._img_root, l[0])) - self.labels.append(np.int64(l[1])) - assert os.path.exists(self.images[-1]) + for line in lines: + line = line.strip().split(self.delimiter) + self.images.append(os.path.join(self._img_root, line[0])) + self.labels.append(np.int64(line[1])) + assert os.path.exists(self.images[ + -1]), f"path {self.images[-1]} does not exist." diff --git a/ppcls/data/dataloader/vehicle_dataset.py b/ppcls/data/dataloader/vehicle_dataset.py index 2981a57a..8c89e382 100644 --- a/ppcls/data/dataloader/vehicle_dataset.py +++ b/ppcls/data/dataloader/vehicle_dataset.py @@ -89,11 +89,7 @@ class CompCars(Dataset): class VeriWild(Dataset): - def __init__( - self, - image_root, - cls_label_path, - transform_ops=None, ): + def __init__(self, image_root, cls_label_path, transform_ops=None): self._img_root = image_root self._cls_path = cls_label_path if transform_ops: @@ -109,12 +105,14 @@ class VeriWild(Dataset): self.cameras = [] with open(self._cls_path) as fd: lines = fd.readlines() - for l in lines: - l = l.strip().split() - self.images.append(os.path.join(self._img_root, l[0])) - self.labels.append(np.int64(l[1])) - self.cameras.append(np.int64(l[2])) + for line in lines: + line = line.strip().split() + self.images.append(os.path.join(self._img_root, line[0])) + self.labels.append(np.int64(line[1])) + if len(line) >= 3: + self.cameras.append(np.int64(line[2])) assert os.path.exists(self.images[-1]) + self.has_camera = len(self.cameras) > 0 def __getitem__(self, idx): try: @@ -123,7 +121,10 @@ class VeriWild(Dataset): if self._transform_ops: img = transform(img, self._transform_ops) img = img.transpose((2, 0, 1)) - return (img, self.labels[idx], self.cameras[idx]) + if self.has_camera: + return (img, self.labels[idx], self.cameras[idx]) + else: + return (img, self.labels[idx]) except Exception as ex: logger.error("Exception occured when parse line: {} with msg: {}". format(self.images[idx], ex)) diff --git a/ppcls/data/preprocess/__init__.py b/ppcls/data/preprocess/__init__.py index d0cfcf24..0a91a85b 100644 --- a/ppcls/data/preprocess/__init__.py +++ b/ppcls/data/preprocess/__init__.py @@ -38,6 +38,7 @@ from ppcls.data.preprocess.ops.operators import CropWithPadding from ppcls.data.preprocess.ops.operators import RandomInterpolationAugment from ppcls.data.preprocess.ops.operators import ColorJitter from ppcls.data.preprocess.ops.operators import RandomCropImage +from ppcls.data.preprocess.ops.operators import RandomRotation from ppcls.data.preprocess.ops.operators import Padv2 from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator diff --git a/ppcls/data/preprocess/ops/operators.py b/ppcls/data/preprocess/ops/operators.py index e617b8a7..b36319bc 100644 --- a/ppcls/data/preprocess/ops/operators.py +++ b/ppcls/data/preprocess/ops/operators.py @@ -26,6 +26,7 @@ import cv2 import numpy as np from PIL import Image, ImageOps, __version__ as PILLOW_VERSION from paddle.vision.transforms import ColorJitter as RawColorJitter +from paddle.vision.transforms import RandomRotation as RawRandomRotation from paddle.vision.transforms import ToTensor, Normalize, RandomHorizontalFlip, RandomResizedCrop from paddle.vision.transforms import functional as F from .autoaugment import ImageNetPolicy @@ -181,7 +182,8 @@ class DecodeImage(object): img = np.asarray(img)[:, :, ::-1] # BRG if self.to_rgb: - assert img.shape[2] == 3, f"invalid shape of image[{img.shape}]" + assert img.shape[ + 2] == 3, f"invalid shape of image[{img.shape}]" img = img[:, :, ::-1] if self.channel_first: @@ -495,7 +497,13 @@ class RandFlipImage(object): if isinstance(img, np.ndarray): return cv2.flip(img, self.flip_code) else: - return img.transpose(Image.FLIP_LEFT_RIGHT) + if self.flip_code == 1: + return img.transpose(Image.FLIP_LEFT_RIGHT) + elif self.flip_code == 0: + return img.transpose(Image.FLIP_TOP_BOTTOM) + else: + return img.transpose(Image.FLIP_LEFT_RIGHT).transpose( + Image.FLIP_LEFT_RIGHT) else: return img @@ -653,6 +661,20 @@ class ColorJitter(RawColorJitter): return img +class RandomRotation(RawRandomRotation): + """RandomRotation. + """ + + def __init__(self, prob=0.5, *args, **kwargs): + super().__init__(*args, **kwargs) + self.prob = prob + + def __call__(self, img): + if np.random.random() < self.prob: + img = super()._apply_image(img) + return img + + class Pad(object): """ Pads the given PIL.Image on all sides with specified padding mode and fill value. diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 1aa0a1e0..c740eb20 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -114,6 +114,10 @@ class Engine(object): #TODO(gaotingquan): support rec class_num = config["Arch"].get("class_num", None) self.config["DataLoader"].update({"class_num": class_num}) + self.model = build_model(self.config, self.mode) + # print(*self.model.state_dict().keys(), sep='\n') + print(self.model.backbone.stages[3][0].dw_conv_list[0].conv) + exit(0) # build dataloader if self.mode == 'train': self.train_dataloader = build_dataloader( diff --git a/ppcls/loss/__init__.py b/ppcls/loss/__init__.py index 489aea7f..d4d54881 100644 --- a/ppcls/loss/__init__.py +++ b/ppcls/loss/__init__.py @@ -12,6 +12,7 @@ from .msmloss import MSMLoss from .npairsloss import NpairsLoss from .trihardloss import TriHardLoss from .triplet import TripletLoss, TripletLossV2 +from .tripletangularmarginloss import TTripletAngularMarginLoss from .supconloss import SupConLoss from .pairwisecosface import PairwiseCosface from .dmlloss import DMLLoss diff --git a/ppcls/loss/tripletangularmarginloss.py b/ppcls/loss/tripletangularmarginloss.py new file mode 100644 index 00000000..fa32a197 --- /dev/null +++ b/ppcls/loss/tripletangularmarginloss.py @@ -0,0 +1,115 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn + + +class TripletAngularMarginLoss(nn.Layer): + """A more robust triplet loss with hard positive/negative mining on angular margin instead of relative distance between d(a,p) and d(a,n). + + Args: + margin (float, optional): angular margin. Defaults to 0.5. + normalize_feature (bool, optional): whether to apply L2-norm in feature before computing distance(cos-similarity). Defaults to True. + reduction (str, optional): reducing option within an batch . Defaults to "mean". + add_absolute (bool, optional): whether add absolute loss within d(a,p) or d(a,n). Defaults to False. + absolute_loss_weight (float, optional): weight for absolute loss. Defaults to 1.0. + ap_value (float, optional): weight for d(a, p). Defaults to 0.9. + an_value (float, optional): weight for d(a, n). Defaults to 0.5. + feature_from (str, optional): which key feature from. Defaults to "features". + """ + + def __init__(self, + margin=0.5, + normalize_feature=True, + reduction="mean", + add_absolute=False, + absolute_loss_weight=1.0, + ap_value=0.9, + an_value=0.5, + feature_from="features"): + super(TripletAngleMarginLoss, self).__init__() + self.margin = margin + self.feature_from = feature_from + self.ranking_loss = paddle.nn.loss.MarginRankingLoss( + margin=margin, reduction=reduction) + self.normalize_feature = normalize_feature + self.add_absolute = add_absolute + self.ap_value = ap_value + self.an_value = an_value + self.absolute_loss_weight = absolute_loss_weight + + def forward(self, input, target): + """ + Args: + inputs: feature matrix with shape (batch_size, feat_dim) + target: ground truth labels with shape (num_classes) + """ + inputs = input[self.feature_from] + + if self.normalize_feature: + inputs = paddle.divide( + inputs, paddle.norm( + inputs, p=2, axis=-1, keepdim=True)) + + bs = inputs.shape[0] + + # compute distance(cos-similarity) + dist = paddle.matmul(inputs, inputs.t()) + + # hard negative mining + is_pos = paddle.expand(target, ( + bs, bs)).equal(paddle.expand(target, (bs, bs)).t()) + is_neg = paddle.expand(target, ( + bs, bs)).not_equal(paddle.expand(target, (bs, bs)).t()) + + # `dist_ap` means distance(anchor, positive) + # both `dist_ap` and `relative_p_inds` with shape [N, 1] + dist_ap = paddle.min(paddle.reshape( + paddle.masked_select(dist, is_pos), (bs, -1)), + axis=1, + keepdim=True) + # `dist_an` means distance(anchor, negative) + # both `dist_an` and `relative_n_inds` with shape [N, 1] + dist_an = paddle.max(paddle.reshape( + paddle.masked_select(dist, is_neg), (bs, -1)), + axis=1, + keepdim=True) + # shape [N] + dist_ap = paddle.squeeze(dist_ap, axis=1) + dist_an = paddle.squeeze(dist_an, axis=1) + + # Compute ranking hinge loss + y = paddle.ones_like(dist_an) + loss = self.ranking_loss(dist_ap, dist_an, y) + + if self.add_absolute: + absolut_loss_ap = self.ap_value - dist_ap + absolut_loss_ap = paddle.where(absolut_loss_ap > 0, + absolut_loss_ap, + paddle.zeros_like(absolut_loss_ap)) + + absolut_loss_an = dist_an - self.an_value + absolut_loss_an = paddle.where(absolut_loss_an > 0, + absolut_loss_an, + paddle.ones_like(absolut_loss_an)) + + loss = (absolut_loss_an.mean() + absolut_loss_ap.mean() + ) * self.absolute_loss_weight + loss.mean() + + return {"TripletAngularMarginLoss": loss} -- GitLab