diff --git a/ppcls/configs/Pedestrian/strong_baseline_baseline.yaml b/ppcls/configs/Pedestrian/strong_baseline_baseline.yaml index cf13d83bd00b490cba28636fc556dcdbc2fbb4d1..a0cc1fb186371b746ac7469d81966b8264afbcea 100644 --- a/ppcls/configs/Pedestrian/strong_baseline_baseline.yaml +++ b/ppcls/configs/Pedestrian/strong_baseline_baseline.yaml @@ -11,8 +11,7 @@ Global: print_batch_step: 20 use_visualdl: False eval_mode: "retrieval" - re_ranking: False - feat_from: "backbone" # 'backbone' or 'neck' + retrieval_feature_from: "backbone" # 'backbone' or 'neck' # used for static mode and model export image_shape: [3, 256, 128] save_inference_dir: "./inference" @@ -23,7 +22,7 @@ Arch: infer_output_key: "features" infer_add_softmax: False Backbone: - name: "ResNet50_last_stage_stride1" + name: "ResNet50" pretrained: True stem_act: null BackboneStopLayer: @@ -32,36 +31,30 @@ Arch: name: "FC" embedding_size: 2048 class_num: 751 - weight_attr: - initializer: - name: Normal - std: 0.001 - bias_attr: False # loss function config for traing/eval process Loss: Train: - CELoss: weight: 1.0 - epsilon: 0.1 - TripletLossV2: weight: 1.0 margin: 0.3 - normalize_feature: false - feat_from: "backbone" + normalize_feature: False + feature_from: "backbone" Eval: - CELoss: weight: 1.0 Optimizer: - name: Momentum - momentum: 0.9 + name: Adam lr: name: Piecewise - decay_epochs: [30, 60] + decay_epochs: [40, 70] values: [0.00035, 0.000035, 0.0000035] warmup_epoch: 10 - warmup_start_lr: 0.0000035 + by_epoch: True + last_epoch: 0 regularizer: name: 'L2' coeff: 0.0005 @@ -73,26 +66,26 @@ DataLoader: name: "Market1501" image_root: "./dataset/" cls_label_path: "bounding_box_train" + backend: "pil" transform_ops: - ResizeImage: size: [128, 256] + return_numpy: False - RandFlipImage: flip_code: 1 - Pad: padding: 10 - - RandCropImage: + - RandCropImageV2: size: [128, 256] - scale: [0.8022, 0.8022] - ratio: [0.5, 0.5] - - NormalizeImage: + - ToTensor: + - Normalize: mean: [0.485, 0.456, 0.406] std: [0.229, 0.224, 0.225] - order: '' sampler: name: DistributedRandomIdentitySampler batch_size: 64 num_instances: 4 - drop_last: True + drop_last: False shuffle: True loader: num_workers: 4 @@ -103,13 +96,15 @@ DataLoader: name: "Market1501" image_root: "./dataset/" cls_label_path: "query" + backend: "pil" transform_ops: - ResizeImage: size: [128, 256] - - NormalizeImage: + return_numpy: False + - ToTensor: + - Normalize: mean: [0.485, 0.456, 0.406] std: [0.229, 0.224, 0.225] - order: '' sampler: name: DistributedBatchSampler batch_size: 128 @@ -124,13 +119,15 @@ DataLoader: name: "Market1501" image_root: "./dataset/" cls_label_path: "bounding_box_test" + backend: "pil" transform_ops: - ResizeImage: size: [128, 256] - - NormalizeImage: + return_numpy: False + - ToTensor: + - Normalize: mean: [0.485, 0.456, 0.406] std: [0.229, 0.224, 0.225] - order: '' sampler: name: DistributedBatchSampler batch_size: 128 diff --git a/ppcls/configs/Pedestrian/strong_baseline_m1.yaml b/ppcls/configs/Pedestrian/strong_baseline_m1.yaml index e1be618913d1270574aa4c460446136e06095b11..1f8c1289581026df5d76a5e5bc279e6678fd991e 100644 --- a/ppcls/configs/Pedestrian/strong_baseline_m1.yaml +++ b/ppcls/configs/Pedestrian/strong_baseline_m1.yaml @@ -10,10 +10,8 @@ Global: epochs: 120 print_batch_step: 20 use_visualdl: False - warmup_by_epoch: True eval_mode: "retrieval" - re_ranking: False - feat_from: "neck" # 'backbone' or 'neck' + retrieval_feature_from: "features" # 'backbone' or 'features' # used for static mode and model export image_shape: [3, 256, 128] save_inference_dir: "./inference" @@ -40,7 +38,7 @@ Arch: initializer: name: Constant value: 0.0 - learning_rate: 1.0e-20 # NOTE: Temporarily set lr small enough to freeze the bias + learning_rate: 1.0e-20 # NOTE: Temporarily set lr small enough to freeze the bias to zero Head: name: "FC" embedding_size: *feat_dim @@ -60,8 +58,8 @@ Loss: - TripletLossV2: weight: 1.0 margin: 0.3 - normalize_feature: false - feat_from: "backbone" + normalize_feature: False + feature_from: "backbone" Eval: - CELoss: weight: 1.0 @@ -74,6 +72,8 @@ Optimizer: values: [0.00035, 0.000035, 0.0000035] warmup_epoch: 10 warmup_start_lr: 0.0000035 + by_epoch: True + last_epoch: 0 regularizer: name: 'L2' coeff: 0.0005 @@ -85,36 +85,32 @@ DataLoader: name: "Market1501" image_root: "./dataset/" cls_label_path: "bounding_box_train" + backend: "pil" transform_ops: - - DecodeImage: - to_rgb: True - channel_first: False - ResizeImage: size: [128, 256] + return_numpy: False - RandFlipImage: flip_code: 1 - Pad: padding: 10 - - RandCropImage: + - RandCropImageV2: size: [128, 256] - scale: [ 0.8022, 0.8022 ] - ratio: [ 0.5, 0.5 ] - - NormalizeImage: - scale: 0.00392157 + - ToTensor: + - Normalize: mean: [0.485, 0.456, 0.406] std: [0.229, 0.224, 0.225] - order: '' - RandomErasing: EPSILON: 0.5 sl: 0.02 sh: 0.4 r1: 0.3 - mean: [0.4914, 0.4822, 0.4465] + mean: [0.485, 0.456, 0.406] sampler: name: DistributedRandomIdentitySampler batch_size: 64 num_instances: 4 - drop_last: True + drop_last: False shuffle: True loader: num_workers: 4 @@ -125,17 +121,15 @@ DataLoader: name: "Market1501" image_root: "./dataset/" cls_label_path: "query" + backend: "pil" transform_ops: - - DecodeImage: - to_rgb: True - channel_first: False - ResizeImage: size: [128, 256] - - NormalizeImage: - scale: 0.00392157 + return_numpy: False + - ToTensor: + - Normalize: mean: [0.485, 0.456, 0.406] std: [0.229, 0.224, 0.225] - order: '' sampler: name: DistributedBatchSampler batch_size: 128 @@ -150,17 +144,15 @@ DataLoader: name: "Market1501" image_root: "./dataset/" cls_label_path: "bounding_box_test" + backend: "pil" transform_ops: - - DecodeImage: - to_rgb: True - channel_first: False - ResizeImage: size: [128, 256] - - NormalizeImage: - scale: 0.00392157 + return_numpy: False + - ToTensor: + - Normalize: mean: [0.485, 0.456, 0.406] std: [0.229, 0.224, 0.225] - order: '' sampler: name: DistributedBatchSampler batch_size: 128 diff --git a/ppcls/configs/Pedestrian/strong_baseline_m1_centerloss.yaml b/ppcls/configs/Pedestrian/strong_baseline_m1_centerloss.yaml index 042cdffdec49852200ecdfe4dd9d2e52434eb873..4c6fce2a4825cb78b129cd934d5c4df9f74e8874 100644 --- a/ppcls/configs/Pedestrian/strong_baseline_m1_centerloss.yaml +++ b/ppcls/configs/Pedestrian/strong_baseline_m1_centerloss.yaml @@ -10,10 +10,8 @@ Global: epochs: 120 print_batch_step: 20 use_visualdl: False - warmup_by_epoch: True eval_mode: "retrieval" - re_ranking: False - feat_from: "neck" # 'backbone' or 'neck' + retrieval_feature_from: "features" # 'backbone' or 'features' # used for static mode and model export image_shape: [3, 256, 128] save_inference_dir: "./inference" @@ -40,7 +38,7 @@ Arch: initializer: name: Constant value: 0.0 - learning_rate: 1.0e-20 # NOTE: Temporarily set lr small enough to freeze the bias + learning_rate: 1.0e-20 # NOTE: Temporarily set lr small enough to freeze the bias to zero Head: name: "FC" embedding_size: *feat_dim @@ -60,8 +58,8 @@ Loss: - TripletLossV2: weight: 1.0 margin: 0.3 - normalize_feature: false - feat_from: "backbone" + normalize_feature: False + feature_from: "backbone" - CenterLoss: weight: 0.0005 num_classes: *class_num @@ -80,7 +78,8 @@ Optimizer: values: [0.00035, 0.000035, 0.0000035] warmup_epoch: 10 warmup_start_lr: 0.0000035 - warmup_by_epoch: True + by_epoch: True + last_epoch: 0 regularizer: name: 'L2' coeff: 0.0005 @@ -97,36 +96,32 @@ DataLoader: name: "Market1501" image_root: "./dataset/" cls_label_path: "bounding_box_train" + backend: "pil" transform_ops: - - DecodeImage: - to_rgb: True - channel_first: False - ResizeImage: size: [128, 256] + return_numpy: False - RandFlipImage: flip_code: 1 - Pad: padding: 10 - - RandCropImage: + - RandCropImageV2: size: [128, 256] - scale: [ 0.8022, 0.8022 ] - ratio: [ 0.5, 0.5 ] - - NormalizeImage: - scale: 0.00392157 + - ToTensor: + - Normalize: mean: [0.485, 0.456, 0.406] std: [0.229, 0.224, 0.225] - order: '' - RandomErasing: EPSILON: 0.5 sl: 0.02 sh: 0.4 r1: 0.3 - mean: [0.4914, 0.4822, 0.4465] + mean: [0.485, 0.456, 0.406] sampler: name: DistributedRandomIdentitySampler batch_size: 64 num_instances: 4 - drop_last: True + drop_last: False shuffle: True loader: num_workers: 4 @@ -137,17 +132,15 @@ DataLoader: name: "Market1501" image_root: "./dataset/" cls_label_path: "query" + backend: "pil" transform_ops: - - DecodeImage: - to_rgb: True - channel_first: False - ResizeImage: size: [128, 256] - - NormalizeImage: - scale: 0.00392157 + return_numpy: False + - ToTensor: + - Normalize: mean: [0.485, 0.456, 0.406] std: [0.229, 0.224, 0.225] - order: '' sampler: name: DistributedBatchSampler batch_size: 128 @@ -162,17 +155,15 @@ DataLoader: name: "Market1501" image_root: "./dataset/" cls_label_path: "bounding_box_test" + backend: "pil" transform_ops: - - DecodeImage: - to_rgb: True - channel_first: False - ResizeImage: size: [128, 256] - - NormalizeImage: - scale: 0.00392157 + return_numpy: False + - ToTensor: + - Normalize: mean: [0.485, 0.456, 0.406] std: [0.229, 0.224, 0.225] - order: '' sampler: name: DistributedBatchSampler batch_size: 128 diff --git a/ppcls/data/dataloader/person_dataset.py b/ppcls/data/dataloader/person_dataset.py index 2812b2d9373104b910389567d61587af489a661d..97af957c4aef31d7a1b691f2a5f5c037d07deea4 100644 --- a/ppcls/data/dataloader/person_dataset.py +++ b/ppcls/data/dataloader/person_dataset.py @@ -43,7 +43,11 @@ class Market1501(Dataset): """ _dataset_dir = 'market1501/Market-1501-v15.09.15' - def __init__(self, image_root, cls_label_path, transform_ops=None): + def __init__(self, + image_root, + cls_label_path, + transform_ops=None, + backend="cv2"): self._img_root = image_root self._cls_path = cls_label_path # the sub folder in the dataset self._dataset_dir = osp.join(image_root, self._dataset_dir, @@ -51,6 +55,7 @@ class Market1501(Dataset): self._check_before_run() if transform_ops: self._transform_ops = create_operators(transform_ops) + self.backend = backend self._dtype = paddle.get_default_dtype() self._load_anno(relabel=True if 'train' in self._cls_path else False) @@ -92,10 +97,12 @@ class Market1501(Dataset): def __getitem__(self, idx): try: img = Image.open(self.images[idx]).convert('RGB') - img = np.array(img, dtype="float32").astype(np.uint8) + if self.backend == "cv2": + img = np.array(img, dtype="float32").astype(np.uint8) if self._transform_ops: img = transform(img, self._transform_ops) - img = img.transpose((2, 0, 1)) + if self.backend == "cv2": + img = img.transpose((2, 0, 1)) return (img, self.labels[idx], self.cameras[idx]) except Exception as ex: logger.error("Exception occured when parse line: {} with msg: {}". diff --git a/ppcls/data/preprocess/__init__.py b/ppcls/data/preprocess/__init__.py index 912fe84dc084cda7ff60e37827d2f378cf40e8c7..fc52c1f67ab79a783fc058dee7716892937d2aee 100644 --- a/ppcls/data/preprocess/__init__.py +++ b/ppcls/data/preprocess/__init__.py @@ -30,6 +30,8 @@ from ppcls.data.preprocess.ops.operators import NormalizeImage from ppcls.data.preprocess.ops.operators import ToCHWImage from ppcls.data.preprocess.ops.operators import AugMix from ppcls.data.preprocess.ops.operators import Pad +from ppcls.data.preprocess.ops.operators import ToTensor +from ppcls.data.preprocess.ops.operators import Normalize from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator diff --git a/ppcls/data/preprocess/ops/operators.py b/ppcls/data/preprocess/ops/operators.py index 0515ce73f4d52414d207b7b4e022bd22fd38d222..e6be4b2bb5b62c621ee3e3a689ab0d2e864b022f 100644 --- a/ppcls/data/preprocess/ops/operators.py +++ b/ppcls/data/preprocess/ops/operators.py @@ -22,10 +22,11 @@ import six import math import random import cv2 +from typing import Sequence import numpy as np -from PIL import Image +from PIL import Image, ImageOps, __version__ as PILLOW_VERSION from paddle.vision.transforms import ColorJitter as RawColorJitter -from paddle.vision.transforms import Pad +from paddle.vision.transforms import ToTensor, Normalize from .autoaugment import ImageNetPolicy from .functional import augmentations @@ -33,7 +34,7 @@ from ppcls.utils import logger class UnifiedResize(object): - def __init__(self, interpolation=None, backend="cv2"): + def __init__(self, interpolation=None, backend="cv2", return_numpy=True): _cv2_interp_from_str = { 'nearest': cv2.INTER_NEAREST, 'bilinear': cv2.INTER_LINEAR, @@ -57,12 +58,15 @@ class UnifiedResize(object): resample = random.choice(resample) return cv2.resize(src, size, interpolation=resample) - def _pil_resize(src, size, resample): + def _pil_resize(src, size, resample, return_numpy=True): if isinstance(resample, tuple): resample = random.choice(resample) - pil_img = Image.fromarray(src) + if isinstance(src, np.ndarray): + pil_img = Image.fromarray(src) pil_img = pil_img.resize(size, resample) - return np.asarray(pil_img) + if return_numpy: + return np.asarray(pil_img) + return pil_img if backend.lower() == "cv2": if isinstance(interpolation, str): @@ -74,7 +78,8 @@ class UnifiedResize(object): elif backend.lower() == "pil": if isinstance(interpolation, str): interpolation = _pil_interp_from_str[interpolation.lower()] - self.resize_func = partial(_pil_resize, resample=interpolation) + self.resize_func = partial( + _pil_resize, resample=interpolation, return_numpy=return_numpy) else: logger.warning( f"The backend of Resize only support \"cv2\" or \"PIL\". \"f{backend}\" is unavailable. Use \"cv2\" instead." @@ -129,7 +134,8 @@ class ResizeImage(object): size=None, resize_short=None, interpolation=None, - backend="cv2"): + backend="cv2", + return_numpy=True): if resize_short is not None and resize_short > 0: self.resize_short = resize_short self.w = None @@ -143,10 +149,16 @@ class ResizeImage(object): 'both 'size' and 'resize_short' are None") self._resize_func = UnifiedResize( - interpolation=interpolation, backend=backend) + interpolation=interpolation, + backend=backend, + return_numpy=return_numpy) def __call__(self, img): - img_h, img_w = img.shape[:2] + if isinstance(img, np.ndarray): + img_h, img_w = img.shape[:2] + else: + img_w, img_h = img.size + if self.resize_short is not None: percent = float(self.resize_short) / min(img_w, img_h) w = int(round(img_w * percent)) @@ -226,6 +238,40 @@ class RandCropImage(object): return self._resize_func(img, size) +class RandCropImageV2(object): + """ RandCropImageV2 is different from RandCropImage, + it will Select a cutting position randomly in a uniform distribution way, + and cut according to the given size without resize at last.""" + + def __init__(self, size): + if type(size) is int: + self.size = (size, size) # (h, w) + else: + self.size = size + + def __call__(self, img): + if isinstance(img, np.ndarray): + img_h, img_w = img.shap[0], img.shap[1] + else: + img_w, img_h = img.size + tw, th = self.size + + if img_h + 1 < th or img_w + 1 < tw: + raise ValueError( + "Required crop size {} is larger then input image size {}". + format((th, tw), (img_h, img_w))) + + if img_w == tw and img_h == th: + return img + + top = random.randint(0, img_h - th + 1) + left = random.randint(0, img_w - tw + 1) + if isinstance(img, np.ndarray): + return img[top:top + th, left:left + tw, :] + else: + return img.crop((left, top, left + tw, top + th)) + + class RandFlipImage(object): """ random flip image flip_code: @@ -241,7 +287,10 @@ class RandFlipImage(object): def __call__(self, img): if random.randint(0, 1) == 1: - return cv2.flip(img, self.flip_code) + if isinstance(img, np.ndarray): + return cv2.flip(img, self.flip_code) + else: + return img.transpose(Image.FLIP_LEFT_RIGHT) else: return img @@ -395,3 +444,58 @@ class ColorJitter(RawColorJitter): if isinstance(img, Image.Image): img = np.asarray(img) return img + + +class Pad(object): + """ + Pads the given PIL.Image on all sides with specified padding mode and fill value. + adapted from: https://pytorch.org/vision/stable/_modules/torchvision/transforms/transforms.html#Pad + """ + + def __init__(self, padding: int, fill: int=0, + padding_mode: str="constant"): + self.padding = padding + self.fill = fill + self.padding_mode = padding_mode + + def _parse_fill(self, fill, img, min_pil_version, name="fillcolor"): + # Process fill color for affine transforms + major_found, minor_found = (int(v) + for v in PILLOW_VERSION.split('.')[:2]) + major_required, minor_required = ( + int(v) for v in min_pil_version.split('.')[:2]) + if major_found < major_required or (major_found == major_required and + minor_found < minor_required): + if fill is None: + return {} + else: + msg = ( + "The option to fill background area of the transformed image, " + "requires pillow>={}") + raise RuntimeError(msg.format(min_pil_version)) + + num_bands = len(img.getbands()) + if fill is None: + fill = 0 + if isinstance(fill, (int, float)) and num_bands > 1: + fill = tuple([fill] * num_bands) + if isinstance(fill, (list, tuple)): + if len(fill) != num_bands: + msg = ( + "The number of elements in 'fill' does not match the number of " + "bands of the image ({} != {})") + raise ValueError(msg.format(len(fill), num_bands)) + + fill = tuple(fill) + + return {name: fill} + + def __call__(self, img): + opts = self._parse_fill(self.fill, img, "2.3.0", name="fill") + if img.mode == "P": + palette = img.getpalette() + img = ImageOps.expand(img, border=self.padding, **opts) + img.putpalette(palette) + return img + + return ImageOps.expand(img, border=self.padding, **opts) diff --git a/ppcls/data/preprocess/ops/random_erasing.py b/ppcls/data/preprocess/ops/random_erasing.py index f234abbbac87cf8230e4d619fe7832e8309abcdb..7a77a6affec930af062d169c5bb258cebe67fea4 100644 --- a/ppcls/data/preprocess/ops/random_erasing.py +++ b/ppcls/data/preprocess/ops/random_erasing.py @@ -25,15 +25,21 @@ import numpy as np class Pixels(object): def __init__(self, mode="const", mean=[0., 0., 0.]): self._mode = mode - self._mean = mean + self._mean = np.array(mean) - def __call__(self, h=224, w=224, c=3): + def __call__(self, h=224, w=224, c=3, channel_first=False): if self._mode == "rand": - return np.random.normal(size=(1, 1, 3)) + return np.random.normal(size=( + 1, 1, 3)) if not channel_first else np.random.normal(size=( + 3, 1, 1)) elif self._mode == "pixel": - return np.random.normal(size=(h, w, c)) + return np.random.normal(size=( + h, w, c)) if not channel_first else np.random.normal(size=( + c, h, w)) elif self._mode == "const": - return self._mean + return np.reshape(self._mean, ( + 1, 1, c)) if not channel_first else np.reshape(self._mean, + (c, 1, 1)) else: raise Exception( "Invalid mode in RandomErasing, only support \"const\", \"rand\", \"pixel\"" @@ -68,7 +74,13 @@ class RandomErasing(object): return img for _ in range(self.attempt): - area = img.shape[0] * img.shape[1] + if isinstance(img, np.ndarray): + img_h, img_w, img_c = img.shape + channel_first = False + else: + img_c, img_h, img_w = img.shape + channel_first = True + area = img_h * img_w target_area = random.uniform(self.sl, self.sh) * area aspect_ratio = random.uniform(*self.r1) @@ -78,13 +90,19 @@ class RandomErasing(object): h = int(round(math.sqrt(target_area * aspect_ratio))) w = int(round(math.sqrt(target_area / aspect_ratio))) - if w < img.shape[1] and h < img.shape[0]: - pixels = self.get_pixels(h, w, img.shape[2]) - x1 = random.randint(0, img.shape[0] - h) - y1 = random.randint(0, img.shape[1] - w) - if img.shape[2] == 3: - img[x1:x1 + h, y1:y1 + w, :] = pixels + if w < img_w and h < img_h: + pixels = self.get_pixels(h, w, img_c, channel_first) + x1 = random.randint(0, img_h - h) + y1 = random.randint(0, img_w - w) + if img_c == 3: + if channel_first: + img[:, x1:x1 + h, y1:y1 + w] = pixels + else: + img[x1:x1 + h, y1:y1 + w, :] = pixels else: - img[x1:x1 + h, y1:y1 + w, 0] = pixels[0] + if channel_first: + img[0, x1:x1 + h, y1:y1 + w] = pixels[0] + else: + img[x1:x1 + h, y1:y1 + w, 0] = pixels[:, :, 0] return img return img diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index c696f8d91eee2cfe61a7acf129d708780ac97d0d..a1dfe6de56fd4b2eced1e15c4ec406715281198a 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -304,25 +304,12 @@ class Engine(object): self.max_iter = len(self.train_dataloader) - 1 if platform.system( ) == "Windows" else len(self.train_dataloader) - # step lr once before first epoch when when Global.warmup_by_epoch=True - if self.config["Global"].get("warmup_by_epoch", False): - for i in range(len(self.lr_sch)): - self.lr_sch[i].step() - logger.info( - "lr_sch step once before the first epoch, when Global.warmup_by_epoch=True" - ) - for epoch_id in range(best_metric["epoch"] + 1, self.config["Global"]["epochs"] + 1): acc = 0.0 # for one epoch train self.train_epoch_func(self, epoch_id, print_batch_step) - # lr step when Global.warmup_by_epoch=True - if self.config["Global"].get("warmup_by_epoch", False): - for i in range(len(self.lr_sch)): - self.lr_sch[i].step() - if self.use_dali: self.train_dataloader.reset() metric_msg = ", ".join([ diff --git a/ppcls/engine/evaluation/retrieval.py b/ppcls/engine/evaluation/retrieval.py index 6945443d17167716b22f07e46f6d9ed303d6b1dd..05c5d0c35d0f6fdfcd0a8f1dc1a8a121026ede99 100644 --- a/ppcls/engine/evaluation/retrieval.py +++ b/ppcls/engine/evaluation/retrieval.py @@ -16,8 +16,6 @@ from __future__ import division from __future__ import print_function import platform - -import numpy as np import paddle from ppcls.utils import logger @@ -51,48 +49,33 @@ def retrieval_eval(engine, epoch_id=0): metric_dict = {metric_key: 0.} else: metric_dict = dict() - reranking_flag = engine.config['Global'].get('re_ranking', False) - logger.info(f"re_ranking={reranking_flag}") - if not reranking_flag: - for block_idx, block_fea in enumerate(fea_blocks): - similarity_matrix = paddle.matmul( - block_fea, gallery_feas, transpose_y=True) - if query_query_id is not None: - query_id_block = query_id_blocks[block_idx] - query_id_mask = (query_id_block != gallery_unique_id.t()) - - image_id_block = image_id_blocks[block_idx] - image_id_mask = (image_id_block != gallery_img_id.t()) - - keep_mask = paddle.logical_or(query_id_mask, image_id_mask) - similarity_matrix = similarity_matrix * keep_mask.astype( - "float32") + for block_idx, block_fea in enumerate(fea_blocks): + similarity_matrix = paddle.matmul( + block_fea, gallery_feas, transpose_y=True) + if query_query_id is not None: + query_id_block = query_id_blocks[block_idx] + query_id_mask = (query_id_block != gallery_unique_id.t()) + + image_id_block = image_id_blocks[block_idx] + image_id_mask = (image_id_block != gallery_img_id.t()) + + keep_mask = paddle.logical_or(query_id_mask, image_id_mask) + similarity_matrix = similarity_matrix * keep_mask.astype( + "float32") + else: + keep_mask = None + + metric_tmp = engine.eval_metric_func(similarity_matrix, + image_id_blocks[block_idx], + gallery_img_id, keep_mask) + + for key in metric_tmp: + if key not in metric_dict: + metric_dict[key] = metric_tmp[key] * block_fea.shape[ + 0] / len(query_feas) else: - keep_mask = None - - metric_tmp = engine.eval_metric_func( - similarity_matrix, image_id_blocks[block_idx], - gallery_img_id, keep_mask) - - for key in metric_tmp: - if key not in metric_dict: - metric_dict[key] = metric_tmp[key] * block_fea.shape[ - 0] / len(query_feas) - else: - metric_dict[key] += metric_tmp[key] * block_fea.shape[ - 0] / len(query_feas) - else: - metric_dict = dict() - distmat = re_ranking( - query_feas, gallery_feas, k1=20, k2=6, lambda_value=0.3) - cmc, mAP = eval_func(distmat, - np.squeeze(query_img_id.numpy()), - np.squeeze(gallery_img_id.numpy()), - np.squeeze(query_query_id.numpy()), - np.squeeze(gallery_unique_id.numpy())) - metric_dict["recall1(RK)"] = cmc[0] - metric_dict["recall5(RK)"] = cmc[4] - metric_dict["mAP(RK)"] = mAP + metric_dict[key] += metric_tmp[key] * block_fea.shape[ + 0] / len(query_feas) metric_info_list = [] for key in metric_dict: @@ -105,159 +88,6 @@ def retrieval_eval(engine, epoch_id=0): return metric_dict[metric_key] -def re_ranking(queFea, - galFea, - k1=20, - k2=6, - lambda_value=0.5, - local_distmat=None, - only_local=False): - # if feature vector is numpy, you should use 'paddle.tensor' transform it to tensor - query_num = queFea.shape[0] - all_num = query_num + galFea.shape[0] - if only_local: - original_dist = local_distmat - else: - feat = paddle.concat([queFea, galFea]) - logger.info('using GPU to compute original distance') - - # L2 distance - distmat = paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([all_num, all_num]) + \ - paddle.pow(feat, 2).sum(axis=1, keepdim=True).expand([all_num, all_num]).t() - distmat = distmat.addmm(x=feat, y=feat.t(), alpha=-2.0, beta=1.0) - # Cosine distance - # distmat = paddle.matmul(queFea, galFea, transpose_y=True) - # if query_query_id is not None: - # query_id_mask = (queCid != galCid.t()) - # image_id_mask = (queId != galId.t()) - # keep_mask = paddle.logical_or(query_id_mask, image_id_mask) - # distmat = distmat * keep_mask.astype("float32") - - original_dist = distmat.cpu().numpy() - del feat - if local_distmat is not None: - original_dist = original_dist + local_distmat - - gallery_num = original_dist.shape[0] - original_dist = np.transpose(original_dist / np.max(original_dist, axis=0)) - V = np.zeros_like(original_dist).astype(np.float16) - initial_rank = np.argsort(original_dist).astype(np.int32) - logger.info('starting re_ranking') - for i in range(all_num): - # k-reciprocal neighbors - forward_k_neigh_index = initial_rank[i, :k1 + 1] - backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] - fi = np.where(backward_k_neigh_index == i)[0] - k_reciprocal_index = forward_k_neigh_index[fi] - k_reciprocal_expansion_index = k_reciprocal_index - for j in range(len(k_reciprocal_index)): - candidate = k_reciprocal_index[j] - candidate_forward_k_neigh_index = initial_rank[candidate, :int( - np.around(k1 / 2)) + 1] - candidate_backward_k_neigh_index = initial_rank[ - candidate_forward_k_neigh_index, :int(np.around(k1 / 2)) + 1] - fi_candidate = np.where( - candidate_backward_k_neigh_index == candidate)[0] - candidate_k_reciprocal_index = candidate_forward_k_neigh_index[ - fi_candidate] - if len( - np.intersect1d(candidate_k_reciprocal_index, - k_reciprocal_index)) > 2 / 3 * len( - candidate_k_reciprocal_index): - k_reciprocal_expansion_index = np.append( - k_reciprocal_expansion_index, candidate_k_reciprocal_index) - - k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) - weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) - V[i, k_reciprocal_expansion_index] = weight / np.sum(weight) - original_dist = original_dist[:query_num, ] - if k2 != 1: - V_qe = np.zeros_like(V, dtype=np.float16) - for i in range(all_num): - V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) - V = V_qe - del V_qe - del initial_rank - invIndex = [] - for i in range(gallery_num): - invIndex.append(np.where(V[:, i] != 0)[0]) - - jaccard_dist = np.zeros_like(original_dist, dtype=np.float16) - for i in range(query_num): - temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16) - indNonZero = np.where(V[i, :] != 0)[0] - indImages = [invIndex[ind] for ind in indNonZero] - for j in range(len(indNonZero)): - temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum( - V[i, indNonZero[j]], V[indImages[j], indNonZero[j]]) - jaccard_dist[i] = 1 - temp_min / (2 - temp_min) - - final_dist = jaccard_dist * (1 - lambda_value - ) + original_dist * lambda_value - del original_dist - del V - del jaccard_dist - final_dist = final_dist[:query_num, query_num:] - return final_dist - - -def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50): - """Evaluation with market1501 metric - Key: for each query identity, its gallery images from the same camera view are discarded. - """ - num_q, num_g = distmat.shape - if num_g < max_rank: - max_rank = num_g - print("Note: number of gallery samples is quite small, got {}".format( - num_g)) - indices = np.argsort(distmat, axis=1) - matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) - - # compute cmc curve for each query - all_cmc = [] - all_AP = [] - num_valid_q = 0. # number of valid query - for q_idx in range(num_q): - # get query pid and camid - q_pid = q_pids[q_idx] - q_camid = q_camids[q_idx] - - # remove gallery samples that have the same pid and camid with query - order = indices[q_idx] - remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) - keep = np.invert(remove) - - # compute cmc curve - # binary vector, positions with value 1 are correct matches - orig_cmc = matches[q_idx][keep] - if not np.any(orig_cmc): - # this condition is true when query identity does not appear in gallery - continue - - cmc = orig_cmc.cumsum() - cmc[cmc > 1] = 1 - - all_cmc.append(cmc[:max_rank]) - num_valid_q += 1. - - # compute average precision - # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision - num_rel = orig_cmc.sum() - tmp_cmc = orig_cmc.cumsum() - tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] - tmp_cmc = np.asarray(tmp_cmc) * orig_cmc - AP = tmp_cmc.sum() / num_rel - all_AP.append(AP) - - assert num_valid_q > 0, "Error: all query identities do not appear in gallery" - - all_cmc = np.asarray(all_cmc).astype(np.float32) - all_cmc = all_cmc.sum(0) / num_valid_q - mAP = np.mean(all_AP) - - return all_cmc, mAP - - def cal_feature(engine, name='gallery'): has_unique_id = False all_unique_id = None @@ -298,12 +128,13 @@ def cal_feature(engine, name='gallery'): out = out["Student"] # get features - if engine.config["Global"].get("feat_from", 'backbone') == 'backbone': + if engine.config["Global"].get("retrieval_feature_from", + "features") == "features": + # use neck's output as features + batch_feas = out["features"] + else: # use backbone's output as features batch_feas = out["backbone"] - else: - # use neck's output as features - batch_feas = out["neck"] # do norm if engine.config["Global"].get("feature_normalize", True): diff --git a/ppcls/engine/train/train.py b/ppcls/engine/train/train.py index 90ed776f4d300e2b359b52dd631d63114a987b31..14db79e73e9e51d16d5784b7aa48a6afb12a7e0f 100644 --- a/ppcls/engine/train/train.py +++ b/ppcls/engine/train/train.py @@ -68,9 +68,9 @@ def train_epoch(engine, epoch_id, print_batch_step): for i in range(len(engine.optimizer)): engine.optimizer[i].clear_grad() - # step lr - if engine.config["Global"].get("warmup_by_epoch", False) is False: - for i in range(len(engine.lr_sch)): + # step lr(by step) + for i in range(len(engine.lr_sch)): + if not getattr(engine.lr_sch[i], "by_epoch", False): engine.lr_sch[i].step() # below code just for logging @@ -83,6 +83,11 @@ def train_epoch(engine, epoch_id, print_batch_step): log_info(engine, batch_size, epoch_id, iter_id) tic = time.time() + # step lr(by epoch) + for i in range(len(engine.lr_sch)): + if getattr(engine.lr_sch[i], "by_epoch", False): + engine.lr_sch[i].step() + def forward(engine, batch): if not engine.is_rec: diff --git a/ppcls/loss/centerloss.py b/ppcls/loss/centerloss.py index dbc214b75963d08e5fbac44e67069c561d7f281e..22ec55592d45219997c28ae07b380cc5cbd7d36b 100644 --- a/ppcls/loss/centerloss.py +++ b/ppcls/loss/centerloss.py @@ -28,17 +28,17 @@ class CenterLoss(nn.Layer): Args: num_classes (int): number of classes. feat_dim (int): number of feature dimensions. - feat_from (str): features from backbone or neck + feature_from (str): feature from "backbone" or "features" """ def __init__(self, num_classes: int, feat_dim: int, - feat_from: str='backbone'): + feature_from: str="features"): super(CenterLoss, self).__init__() self.num_classes = num_classes self.feat_dim = feat_dim - self.feat_from = feat_from + self.feature_from = feature_from random_init_centers = paddle.randn( shape=[self.num_classes, self.feat_dim]) self.centers = self.create_parameter( @@ -57,7 +57,7 @@ class CenterLoss(nn.Layer): Returns: Dict[str, paddle.Tensor]: {'CenterLoss': loss}. """ - feats = input[self.feat_from] + feats = input[self.feature_from] labels = target # squeeze labels to shape (batch_size, ) diff --git a/ppcls/loss/triplet.py b/ppcls/loss/triplet.py index f07dc22ef0763ba67efd4150f6db44e4e3edffe3..0da7cc5dffb8f54807fa3d4da12b002755e54452 100644 --- a/ppcls/loss/triplet.py +++ b/ppcls/loss/triplet.py @@ -31,10 +31,10 @@ class TripletLossV2(nn.Layer): def __init__(self, margin=0.5, normalize_feature=True, - feat_from='backbone'): + feature_from="features"): super(TripletLossV2, self).__init__() self.margin = margin - self.feat_from = feat_from + self.feature_from = feature_from self.ranking_loss = paddle.nn.loss.MarginRankingLoss(margin=margin) self.normalize_feature = normalize_feature @@ -44,7 +44,7 @@ class TripletLossV2(nn.Layer): inputs: feature matrix with shape (batch_size, feat_dim) target: ground truth labels with shape (num_classes) """ - inputs = input[self.feat_from] + inputs = input[self.feature_from] if self.normalize_feature: inputs = 1. * inputs / (paddle.expand_as( diff --git a/ppcls/optimizer/learning_rate.py b/ppcls/optimizer/learning_rate.py index 97d23e914dc11da1f7752a2c1ab810126ffdd43c..4d69bed722b9d9a4f7981a107fbec09c0c6e35a2 100644 --- a/ppcls/optimizer/learning_rate.py +++ b/ppcls/optimizer/learning_rate.py @@ -205,6 +205,7 @@ class Piecewise(object): The type of element in the list is python float. warmup_epoch(int): The epoch numbers for LinearWarmup. Default: 0. warmup_start_lr(float): Initial learning rate of warm up. Default: 0.0. + by_epoch(bool): Whether lr decay by epoch. Default: False. last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. """ @@ -215,7 +216,7 @@ class Piecewise(object): epochs, warmup_epoch=0, warmup_start_lr=0.0, - warmup_by_epoch=False, + by_epoch=False, last_epoch=-1, **kwargs): super().__init__() @@ -230,33 +231,34 @@ class Piecewise(object): self.warmup_steps = round(warmup_epoch * step_each_epoch) self.warmup_epoch = warmup_epoch self.warmup_start_lr = warmup_start_lr - self.warmup_by_epoch = warmup_by_epoch + self.by_epoch = by_epoch def __call__(self): - if self.warmup_by_epoch is False: + if self.by_epoch: learning_rate = lr.PiecewiseDecay( - boundaries=self.boundaries_steps, + boundaries=self.boundaries_epoch, values=self.values, last_epoch=self.last_epoch) - if self.warmup_steps > 0: + if self.warmup_epoch > 0: learning_rate = lr.LinearWarmup( learning_rate=learning_rate, - warmup_steps=self.warmup_steps, + warmup_steps=self.warmup_epoch, start_lr=self.warmup_start_lr, end_lr=self.values[0], last_epoch=self.last_epoch) else: learning_rate = lr.PiecewiseDecay( - boundaries=self.boundaries_epoch, + boundaries=self.boundaries_steps, values=self.values, last_epoch=self.last_epoch) - if self.warmup_epoch > 0: + if self.warmup_steps > 0: learning_rate = lr.LinearWarmup( learning_rate=learning_rate, - warmup_steps=self.warmup_epoch, + warmup_steps=self.warmup_steps, start_lr=self.warmup_start_lr, end_lr=self.values[0], last_epoch=self.last_epoch) + setattr(learning_rate, "by_epoch", self.by_epoch) return learning_rate