diff --git a/ppcls/arch/backbone/legendary_models/resnet.py b/ppcls/arch/backbone/legendary_models/resnet.py index 639475e0d684260089bc5cc1275332e5065d8a29..643e860faf022000453e00cad637ef1ad572e0dc 100644 --- a/ppcls/arch/backbone/legendary_models/resnet.py +++ b/ppcls/arch/backbone/legendary_models/resnet.py @@ -278,6 +278,7 @@ class ResNet(TheseusLayer): config, stages_pattern, version="vb", + stem_act="relu", class_num=1000, lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0], data_format="NCHW", @@ -311,13 +312,13 @@ class ResNet(TheseusLayer): [[input_image_channel, 32, 3, 2], [32, 32, 3, 1], [32, 64, 3, 1]] } - self.stem = nn.Sequential(* [ + self.stem = nn.Sequential(*[ ConvBNLayer( num_channels=in_c, num_filters=out_c, filter_size=k, stride=s, - act="relu", + act=stem_act, lr_mult=self.lr_mult_list[0], data_format=data_format) for in_c, out_c, k, s in self.stem_cfg[version] diff --git a/ppcls/arch/gears/bnneck.py b/ppcls/arch/gears/bnneck.py index d4d867c6722c8f18e98dfa34384289773a1b17a4..c2f10c79f9c3862102f7b425c18018d2c4cce15e 100644 --- a/ppcls/arch/gears/bnneck.py +++ b/ppcls/arch/gears/bnneck.py @@ -17,21 +17,32 @@ from __future__ import absolute_import, division, print_function import paddle import paddle.nn as nn +from ppcls.arch.utils import get_param_attr_dict + class BNNeck(nn.Layer): - def __init__(self, num_features): + def __init__(self, num_features, **kwargs): super().__init__() weight_attr = paddle.ParamAttr( initializer=paddle.nn.initializer.Constant(value=1.0)) bias_attr = paddle.ParamAttr( initializer=paddle.nn.initializer.Constant(value=0.0), trainable=False) + + if 'weight_attr' in kwargs: + weight_attr = get_param_attr_dict(kwargs['weight_attr']) + + bias_attr = None + if 'bias_attr' in kwargs: + bias_attr = get_param_attr_dict(kwargs['bias_attr']) + self.feat_bn = nn.BatchNorm1D( num_features, momentum=0.9, epsilon=1e-05, weight_attr=weight_attr, bias_attr=bias_attr) + self.flatten = nn.Flatten() def forward(self, x): diff --git a/ppcls/arch/gears/fc.py b/ppcls/arch/gears/fc.py index b32474195e1ada4cd0a17b493f68f65a242d82cd..279c5496e4aeeef86f1ebdafbdbfe7468391fa2d 100644 --- a/ppcls/arch/gears/fc.py +++ b/ppcls/arch/gears/fc.py @@ -19,16 +19,29 @@ from __future__ import print_function import paddle import paddle.nn as nn +from ppcls.arch.utils import get_param_attr_dict + class FC(nn.Layer): - def __init__(self, embedding_size, class_num): + def __init__(self, embedding_size, class_num, **kwargs): super(FC, self).__init__() self.embedding_size = embedding_size self.class_num = class_num + weight_attr = paddle.ParamAttr( initializer=paddle.nn.initializer.XavierNormal()) - self.fc = paddle.nn.Linear( - self.embedding_size, self.class_num, weight_attr=weight_attr) + if 'weight_attr' in kwargs: + weight_attr = get_param_attr_dict(kwargs['weight_attr']) + + bias_attr = None + if 'bias_attr' in kwargs: + bias_attr = get_param_attr_dict(kwargs['bias_attr']) + + self.fc = nn.Linear( + self.embedding_size, + self.class_num, + weight_attr=weight_attr, + bias_attr=bias_attr) def forward(self, input, label=None): out = self.fc(input) diff --git a/ppcls/arch/utils.py b/ppcls/arch/utils.py index 308475d7dbe7e4b9702a9e9e2eb3a0210da26e7a..785b7fbbe7e609e5314b549355165d83715bd48a 100644 --- a/ppcls/arch/utils.py +++ b/ppcls/arch/utils.py @@ -14,9 +14,11 @@ import six import types +import paddle from difflib import SequenceMatcher from . import backbone +from typing import Any, Dict, Union def get_architectures(): @@ -51,3 +53,47 @@ def similar_architectures(name='', names=[], thresh=0.1, topk=10): scores.sort(key=lambda x: x[1], reverse=True) similar_names = [names[s[0]] for s in scores[:min(topk, len(scores))]] return similar_names + + +def get_param_attr_dict(ParamAttr_config: Union[None, bool, Dict[str, Dict]] + ) -> Union[None, bool, paddle.ParamAttr]: + """parse ParamAttr from an dict + + Args: + ParamAttr_config (Union[None, bool, Dict[str, Dict]]): ParamAttr configure + + Returns: + Union[None, bool, paddle.ParamAttr]: Generated ParamAttr + """ + if ParamAttr_config is None: + return None + if isinstance(ParamAttr_config, bool): + return ParamAttr_config + ParamAttr_dict = {} + if 'initializer' in ParamAttr_config: + initializer_cfg = ParamAttr_config.get('initializer') + if 'name' in initializer_cfg: + initializer_name = initializer_cfg.pop('name') + ParamAttr_dict['initializer'] = getattr( + paddle.nn.initializer, initializer_name)(**initializer_cfg) + else: + raise ValueError(f"'name' must specified in initializer_cfg") + if 'learning_rate' in ParamAttr_config: + # NOTE: only support an single value now + learning_rate_value = ParamAttr_config.get('learning_rate') + if isinstance(learning_rate_value, (int, float)): + ParamAttr_dict['learning_rate'] = learning_rate_value + else: + raise ValueError( + f"learning_rate_value must be float or int, but got {type(learning_rate_value)}" + ) + if 'regularizer' in ParamAttr_config: + regularizer_cfg = ParamAttr_config.get('regularizer') + if 'name' in regularizer_cfg: + # L1Decay or L2Decay + regularizer_name = regularizer_cfg.pop('name') + ParamAttr_dict['regularizer'] = getattr( + paddle.regularizer, regularizer_name)(**regularizer_cfg) + else: + raise ValueError(f"'name' must specified in regularizer_cfg") + return paddle.ParamAttr(**ParamAttr_dict) diff --git a/ppcls/configs/Pedestrian/strong_baseline_baseline.yaml b/ppcls/configs/Pedestrian/strong_baseline_baseline.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a0395f3b129bd0f2148e0e9cfd62dadaf8692ff9 --- /dev/null +++ b/ppcls/configs/Pedestrian/strong_baseline_baseline.yaml @@ -0,0 +1,147 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: "./output/" + device: "gpu" + save_interval: 40 + eval_during_train: True + eval_interval: 10 + epochs: 120 + print_batch_step: 20 + use_visualdl: False + eval_mode: "retrieval" + retrieval_feature_from: "backbone" # 'backbone' or 'neck' + # used for static mode and model export + image_shape: [3, 256, 128] + save_inference_dir: "./inference" + +# model architecture +Arch: + name: "RecModel" + infer_output_key: "features" + infer_add_softmax: False + Backbone: + name: "ResNet50" + pretrained: True + stem_act: null + BackboneStopLayer: + name: "flatten" + Head: + name: "FC" + embedding_size: 2048 + class_num: 751 + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + - TripletLossV2: + weight: 1.0 + margin: 0.3 + normalize_feature: False + feature_from: "backbone" + Eval: + - CELoss: + weight: 1.0 + +Optimizer: + name: Adam + lr: + name: Piecewise + decay_epochs: [40, 70] + values: [0.00035, 0.000035, 0.0000035] + warmup_epoch: 10 + by_epoch: True + last_epoch: 0 + regularizer: + name: 'L2' + coeff: 0.0005 + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: "Market1501" + image_root: "./dataset/" + cls_label_path: "bounding_box_train" + backend: "pil" + transform_ops: + - ResizeImage: + size: [128, 256] + return_numpy: False + backend: "pil" + - RandFlipImage: + flip_code: 1 + - Pad: + padding: 10 + - RandCropImageV2: + size: [128, 256] + - ToTensor: + - Normalize: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + sampler: + name: DistributedRandomIdentitySampler + batch_size: 64 + num_instances: 4 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + Eval: + Query: + dataset: + name: "Market1501" + image_root: "./dataset/" + cls_label_path: "query" + backend: "pil" + transform_ops: + - ResizeImage: + size: [128, 256] + return_numpy: False + backend: "pil" + - ToTensor: + - Normalize: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + + Gallery: + dataset: + name: "Market1501" + image_root: "./dataset/" + cls_label_path: "bounding_box_test" + backend: "pil" + transform_ops: + - ResizeImage: + size: [128, 256] + return_numpy: False + backend: "pil" + - ToTensor: + - Normalize: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Metric: + Eval: + - Recallk: + topk: [1, 5] + - mAP: {} diff --git a/ppcls/configs/Pedestrian/strong_baseline_m1.yaml b/ppcls/configs/Pedestrian/strong_baseline_m1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ef4b605aee5de905494b67beda0bd545a8b12fcb --- /dev/null +++ b/ppcls/configs/Pedestrian/strong_baseline_m1.yaml @@ -0,0 +1,172 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: "./output/" + device: "gpu" + save_interval: 40 + eval_during_train: True + eval_interval: 10 + epochs: 120 + print_batch_step: 20 + use_visualdl: False + eval_mode: "retrieval" + retrieval_feature_from: "features" # 'backbone' or 'features' + # used for static mode and model export + image_shape: [3, 256, 128] + save_inference_dir: "./inference" + +# model architecture +Arch: + name: "RecModel" + infer_output_key: "features" + infer_add_softmax: False + Backbone: + name: "ResNet50_last_stage_stride1" + pretrained: True + stem_act: null + BackboneStopLayer: + name: "flatten" + Neck: + name: BNNeck + num_features: &feat_dim 2048 + weight_attr: + initializer: + name: Constant + value: 1.0 + bias_attr: + initializer: + name: Constant + value: 0.0 + learning_rate: 1.0e-20 # NOTE: Temporarily set lr small enough to freeze the bias to zero + Head: + name: "FC" + embedding_size: *feat_dim + class_num: 751 + weight_attr: + initializer: + name: Normal + std: 0.001 + bias_attr: False + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + - TripletLossV2: + weight: 1.0 + margin: 0.3 + normalize_feature: False + feature_from: "backbone" + Eval: + - CELoss: + weight: 1.0 + +Optimizer: + name: Adam + lr: + name: Piecewise + decay_epochs: [30, 60] + values: [0.00035, 0.000035, 0.0000035] + warmup_epoch: 10 + warmup_start_lr: 0.0000035 + by_epoch: True + last_epoch: 0 + regularizer: + name: 'L2' + coeff: 0.0005 + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: "Market1501" + image_root: "./dataset/" + cls_label_path: "bounding_box_train" + backend: "pil" + transform_ops: + - ResizeImage: + size: [128, 256] + return_numpy: False + backend: "pil" + - RandFlipImage: + flip_code: 1 + - Pad: + padding: 10 + - RandCropImageV2: + size: [128, 256] + - ToTensor: + - Normalize: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + - RandomErasing: + EPSILON: 0.5 + sl: 0.02 + sh: 0.4 + r1: 0.3 + mean: [0.485, 0.456, 0.406] + sampler: + name: DistributedRandomIdentitySampler + batch_size: 64 + num_instances: 4 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + Eval: + Query: + dataset: + name: "Market1501" + image_root: "./dataset/" + cls_label_path: "query" + backend: "pil" + transform_ops: + - ResizeImage: + size: [128, 256] + return_numpy: False + backend: "pil" + - ToTensor: + - Normalize: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + + Gallery: + dataset: + name: "Market1501" + image_root: "./dataset/" + cls_label_path: "bounding_box_test" + backend: "pil" + transform_ops: + - ResizeImage: + size: [128, 256] + return_numpy: False + backend: "pil" + - ToTensor: + - Normalize: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Metric: + Eval: + - Recallk: + topk: [1, 5] + - mAP: {} diff --git a/ppcls/configs/Pedestrian/strong_baseline_m1_centerloss.yaml b/ppcls/configs/Pedestrian/strong_baseline_m1_centerloss.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6c14bb209875354d9bc0e485aa4aa8b910d116b9 --- /dev/null +++ b/ppcls/configs/Pedestrian/strong_baseline_m1_centerloss.yaml @@ -0,0 +1,183 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: "./output/" + device: "gpu" + save_interval: 40 + eval_during_train: True + eval_interval: 10 + epochs: 120 + print_batch_step: 20 + use_visualdl: False + eval_mode: "retrieval" + retrieval_feature_from: "features" # 'backbone' or 'features' + # used for static mode and model export + image_shape: [3, 256, 128] + save_inference_dir: "./inference" + +# model architecture +Arch: + name: "RecModel" + infer_output_key: "features" + infer_add_softmax: False + Backbone: + name: "ResNet50_last_stage_stride1" + pretrained: True + stem_act: null + BackboneStopLayer: + name: "flatten" + Neck: + name: BNNeck + num_features: &feat_dim 2048 + weight_attr: + initializer: + name: Constant + value: 1.0 + bias_attr: + initializer: + name: Constant + value: 0.0 + learning_rate: 1.0e-20 # NOTE: Temporarily set lr small enough to freeze the bias to zero + Head: + name: "FC" + embedding_size: *feat_dim + class_num: &class_num 751 + weight_attr: + initializer: + name: Normal + std: 0.001 + bias_attr: False + +# loss function config for traing/eval process +Loss: + Train: + - CELoss: + weight: 1.0 + epsilon: 0.1 + - TripletLossV2: + weight: 1.0 + margin: 0.3 + normalize_feature: False + feature_from: "backbone" + - CenterLoss: + weight: 0.0005 + num_classes: *class_num + feat_dim: *feat_dim + feature_from: "backbone" + Eval: + - CELoss: + weight: 1.0 + +Optimizer: + - Adam: + scope: RecModel + lr: + name: Piecewise + decay_epochs: [30, 60] + values: [0.00035, 0.000035, 0.0000035] + warmup_epoch: 10 + warmup_start_lr: 0.0000035 + by_epoch: True + last_epoch: 0 + regularizer: + name: 'L2' + coeff: 0.0005 + - SGD: + scope: CenterLoss + lr: + name: Constant + learning_rate: 1000.0 # NOTE: set to ori_lr*(1/centerloss_weight) to avoid manually scaling centers' gradidents. + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: "Market1501" + image_root: "./dataset/" + cls_label_path: "bounding_box_train" + backend: "pil" + transform_ops: + - ResizeImage: + size: [128, 256] + return_numpy: False + backend: "pil" + - RandFlipImage: + flip_code: 1 + - Pad: + padding: 10 + - RandCropImageV2: + size: [128, 256] + - ToTensor: + - Normalize: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + - RandomErasing: + EPSILON: 0.5 + sl: 0.02 + sh: 0.4 + r1: 0.3 + mean: [0.485, 0.456, 0.406] + sampler: + name: DistributedRandomIdentitySampler + batch_size: 64 + num_instances: 4 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + Eval: + Query: + dataset: + name: "Market1501" + image_root: "./dataset/" + cls_label_path: "query" + backend: "pil" + transform_ops: + - ResizeImage: + size: [128, 256] + return_numpy: False + backend: "pil" + - ToTensor: + - Normalize: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + + Gallery: + dataset: + name: "Market1501" + image_root: "./dataset/" + cls_label_path: "bounding_box_test" + backend: "pil" + transform_ops: + - ResizeImage: + size: [128, 256] + return_numpy: False + backend: "pil" + - ToTensor: + - Normalize: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + sampler: + name: DistributedBatchSampler + batch_size: 128 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Metric: + Eval: + - Recallk: + topk: [1, 5] + - mAP: {} diff --git a/ppcls/data/dataloader/person_dataset.py b/ppcls/data/dataloader/person_dataset.py index 2812b2d9373104b910389567d61587af489a661d..97af957c4aef31d7a1b691f2a5f5c037d07deea4 100644 --- a/ppcls/data/dataloader/person_dataset.py +++ b/ppcls/data/dataloader/person_dataset.py @@ -43,7 +43,11 @@ class Market1501(Dataset): """ _dataset_dir = 'market1501/Market-1501-v15.09.15' - def __init__(self, image_root, cls_label_path, transform_ops=None): + def __init__(self, + image_root, + cls_label_path, + transform_ops=None, + backend="cv2"): self._img_root = image_root self._cls_path = cls_label_path # the sub folder in the dataset self._dataset_dir = osp.join(image_root, self._dataset_dir, @@ -51,6 +55,7 @@ class Market1501(Dataset): self._check_before_run() if transform_ops: self._transform_ops = create_operators(transform_ops) + self.backend = backend self._dtype = paddle.get_default_dtype() self._load_anno(relabel=True if 'train' in self._cls_path else False) @@ -92,10 +97,12 @@ class Market1501(Dataset): def __getitem__(self, idx): try: img = Image.open(self.images[idx]).convert('RGB') - img = np.array(img, dtype="float32").astype(np.uint8) + if self.backend == "cv2": + img = np.array(img, dtype="float32").astype(np.uint8) if self._transform_ops: img = transform(img, self._transform_ops) - img = img.transpose((2, 0, 1)) + if self.backend == "cv2": + img = img.transpose((2, 0, 1)) return (img, self.labels[idx], self.cameras[idx]) except Exception as ex: logger.error("Exception occured when parse line: {} with msg: {}". diff --git a/ppcls/data/preprocess/__init__.py b/ppcls/data/preprocess/__init__.py index 075ee89278e2e099ce3c9cbc108dfe159e2012f2..62066016a47c8cef7bd31bc7d238f202ea6455f0 100644 --- a/ppcls/data/preprocess/__init__.py +++ b/ppcls/data/preprocess/__init__.py @@ -25,10 +25,14 @@ from ppcls.data.preprocess.ops.operators import DecodeImage from ppcls.data.preprocess.ops.operators import ResizeImage from ppcls.data.preprocess.ops.operators import CropImage from ppcls.data.preprocess.ops.operators import RandCropImage +from ppcls.data.preprocess.ops.operators import RandCropImageV2 from ppcls.data.preprocess.ops.operators import RandFlipImage from ppcls.data.preprocess.ops.operators import NormalizeImage from ppcls.data.preprocess.ops.operators import ToCHWImage from ppcls.data.preprocess.ops.operators import AugMix +from ppcls.data.preprocess.ops.operators import Pad +from ppcls.data.preprocess.ops.operators import ToTensor +from ppcls.data.preprocess.ops.operators import Normalize from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator diff --git a/ppcls/data/preprocess/ops/operators.py b/ppcls/data/preprocess/ops/operators.py index 8075ced904de51551c8946905f874e002178abba..157f44f1ab15ffd1162aeada37dba9296ee0ca00 100644 --- a/ppcls/data/preprocess/ops/operators.py +++ b/ppcls/data/preprocess/ops/operators.py @@ -23,8 +23,9 @@ import math import random import cv2 import numpy as np -from PIL import Image +from PIL import Image, ImageOps, __version__ as PILLOW_VERSION from paddle.vision.transforms import ColorJitter as RawColorJitter +from paddle.vision.transforms import ToTensor, Normalize from .autoaugment import ImageNetPolicy from .functional import augmentations @@ -32,7 +33,7 @@ from ppcls.utils import logger class UnifiedResize(object): - def __init__(self, interpolation=None, backend="cv2"): + def __init__(self, interpolation=None, backend="cv2", return_numpy=True): _cv2_interp_from_str = { 'nearest': cv2.INTER_NEAREST, 'bilinear': cv2.INTER_LINEAR, @@ -56,12 +57,17 @@ class UnifiedResize(object): resample = random.choice(resample) return cv2.resize(src, size, interpolation=resample) - def _pil_resize(src, size, resample): + def _pil_resize(src, size, resample, return_numpy=True): if isinstance(resample, tuple): resample = random.choice(resample) - pil_img = Image.fromarray(src) + if isinstance(src, np.ndarray): + pil_img = Image.fromarray(src) + else: + pil_img = src pil_img = pil_img.resize(size, resample) - return np.asarray(pil_img) + if return_numpy: + return np.asarray(pil_img) + return pil_img if backend.lower() == "cv2": if isinstance(interpolation, str): @@ -73,7 +79,8 @@ class UnifiedResize(object): elif backend.lower() == "pil": if isinstance(interpolation, str): interpolation = _pil_interp_from_str[interpolation.lower()] - self.resize_func = partial(_pil_resize, resample=interpolation) + self.resize_func = partial( + _pil_resize, resample=interpolation, return_numpy=return_numpy) else: logger.warning( f"The backend of Resize only support \"cv2\" or \"PIL\". \"f{backend}\" is unavailable. Use \"cv2\" instead." @@ -81,6 +88,8 @@ class UnifiedResize(object): self.resize_func = cv2.resize def __call__(self, src, size): + if isinstance(size, list): + size = tuple(size) return self.resize_func(src, size) @@ -99,14 +108,15 @@ class DecodeImage(object): self.channel_first = channel_first # only enabled when to_np is True def __call__(self, img): - if six.PY2: - assert type(img) is str and len( - img) > 0, "invalid input 'img' in DecodeImage" - else: - assert type(img) is bytes and len( - img) > 0, "invalid input 'img' in DecodeImage" - data = np.frombuffer(img, dtype='uint8') - img = cv2.imdecode(data, 1) + if not isinstance(img, np.ndarray): + if six.PY2: + assert type(img) is str and len( + img) > 0, "invalid input 'img' in DecodeImage" + else: + assert type(img) is bytes and len( + img) > 0, "invalid input 'img' in DecodeImage" + data = np.frombuffer(img, dtype='uint8') + img = cv2.imdecode(data, 1) if self.to_rgb: assert img.shape[2] == 3, 'invalid shape of image[%s]' % ( img.shape) @@ -125,7 +135,8 @@ class ResizeImage(object): size=None, resize_short=None, interpolation=None, - backend="cv2"): + backend="cv2", + return_numpy=True): if resize_short is not None and resize_short > 0: self.resize_short = resize_short self.w = None @@ -139,10 +150,16 @@ class ResizeImage(object): 'both 'size' and 'resize_short' are None") self._resize_func = UnifiedResize( - interpolation=interpolation, backend=backend) + interpolation=interpolation, + backend=backend, + return_numpy=return_numpy) def __call__(self, img): - img_h, img_w = img.shape[:2] + if isinstance(img, np.ndarray): + img_h, img_w = img.shape[:2] + else: + img_w, img_h = img.size + if self.resize_short is not None: percent = float(self.resize_short) / min(img_w, img_h) w = int(round(img_w * percent)) @@ -222,6 +239,40 @@ class RandCropImage(object): return self._resize_func(img, size) +class RandCropImageV2(object): + """ RandCropImageV2 is different from RandCropImage, + it will Select a cutting position randomly in a uniform distribution way, + and cut according to the given size without resize at last.""" + + def __init__(self, size): + if type(size) is int: + self.size = (size, size) # (h, w) + else: + self.size = size + + def __call__(self, img): + if isinstance(img, np.ndarray): + img_h, img_w = img.shap[0], img.shap[1] + else: + img_w, img_h = img.size + tw, th = self.size + + if img_h + 1 < th or img_w + 1 < tw: + raise ValueError( + "Required crop size {} is larger then input image size {}". + format((th, tw), (img_h, img_w))) + + if img_w == tw and img_h == th: + return img + + top = random.randint(0, img_h - th + 1) + left = random.randint(0, img_w - tw + 1) + if isinstance(img, np.ndarray): + return img[top:top + th, left:left + tw, :] + else: + return img.crop((left, top, left + tw, top + th)) + + class RandFlipImage(object): """ random flip image flip_code: @@ -237,7 +288,10 @@ class RandFlipImage(object): def __call__(self, img): if random.randint(0, 1) == 1: - return cv2.flip(img, self.flip_code) + if isinstance(img, np.ndarray): + return cv2.flip(img, self.flip_code) + else: + return img.transpose(Image.FLIP_LEFT_RIGHT) else: return img @@ -391,3 +445,58 @@ class ColorJitter(RawColorJitter): if isinstance(img, Image.Image): img = np.asarray(img) return img + + +class Pad(object): + """ + Pads the given PIL.Image on all sides with specified padding mode and fill value. + adapted from: https://pytorch.org/vision/stable/_modules/torchvision/transforms/transforms.html#Pad + """ + + def __init__(self, padding: int, fill: int=0, + padding_mode: str="constant"): + self.padding = padding + self.fill = fill + self.padding_mode = padding_mode + + def _parse_fill(self, fill, img, min_pil_version, name="fillcolor"): + # Process fill color for affine transforms + major_found, minor_found = (int(v) + for v in PILLOW_VERSION.split('.')[:2]) + major_required, minor_required = ( + int(v) for v in min_pil_version.split('.')[:2]) + if major_found < major_required or (major_found == major_required and + minor_found < minor_required): + if fill is None: + return {} + else: + msg = ( + "The option to fill background area of the transformed image, " + "requires pillow>={}") + raise RuntimeError(msg.format(min_pil_version)) + + num_bands = len(img.getbands()) + if fill is None: + fill = 0 + if isinstance(fill, (int, float)) and num_bands > 1: + fill = tuple([fill] * num_bands) + if isinstance(fill, (list, tuple)): + if len(fill) != num_bands: + msg = ( + "The number of elements in 'fill' does not match the number of " + "bands of the image ({} != {})") + raise ValueError(msg.format(len(fill), num_bands)) + + fill = tuple(fill) + + return {name: fill} + + def __call__(self, img): + opts = self._parse_fill(self.fill, img, "2.3.0", name="fill") + if img.mode == "P": + palette = img.getpalette() + img = ImageOps.expand(img, border=self.padding, **opts) + img.putpalette(palette) + return img + + return ImageOps.expand(img, border=self.padding, **opts) diff --git a/ppcls/data/preprocess/ops/random_erasing.py b/ppcls/data/preprocess/ops/random_erasing.py index 1b7c03abdd95e1a5ec2bff0aa18a480bed81bb3d..648b41ea532eb8a767015de6abcdf7fc0448e34c 100644 --- a/ppcls/data/preprocess/ops/random_erasing.py +++ b/ppcls/data/preprocess/ops/random_erasing.py @@ -26,15 +26,21 @@ import numpy as np class Pixels(object): def __init__(self, mode="const", mean=[0., 0., 0.]): self._mode = mode - self._mean = mean + self._mean = np.array(mean) - def __call__(self, h=224, w=224, c=3): + def __call__(self, h=224, w=224, c=3, channel_first=False): if self._mode == "rand": - return np.random.normal(size=(1, 1, 3)) + return np.random.normal(size=( + 1, 1, 3)) if not channel_first else np.random.normal(size=( + 3, 1, 1)) elif self._mode == "pixel": - return np.random.normal(size=(h, w, c)) + return np.random.normal(size=( + h, w, c)) if not channel_first else np.random.normal(size=( + c, h, w)) elif self._mode == "const": - return self._mean + return np.reshape(self._mean, ( + 1, 1, c)) if not channel_first else np.reshape(self._mean, + (c, 1, 1)) else: raise Exception( "Invalid mode in RandomErasing, only support \"const\", \"rand\", \"pixel\"" @@ -69,7 +75,13 @@ class RandomErasing(object): return img for _ in range(self.attempt): - area = img.shape[0] * img.shape[1] + if isinstance(img, np.ndarray): + img_h, img_w, img_c = img.shape + channel_first = False + else: + img_c, img_h, img_w = img.shape + channel_first = True + area = img_h * img_w target_area = random.uniform(self.sl, self.sh) * area aspect_ratio = random.uniform(*self.r1) @@ -79,13 +91,19 @@ class RandomErasing(object): h = int(round(math.sqrt(target_area * aspect_ratio))) w = int(round(math.sqrt(target_area / aspect_ratio))) - if w < img.shape[1] and h < img.shape[0]: - pixels = self.get_pixels(h, w, img.shape[2]) - x1 = random.randint(0, img.shape[0] - h) - y1 = random.randint(0, img.shape[1] - w) - if img.shape[2] == 3: - img[x1:x1 + h, y1:y1 + w, :] = pixels + if w < img_w and h < img_h: + pixels = self.get_pixels(h, w, img_c, channel_first) + x1 = random.randint(0, img_h - h) + y1 = random.randint(0, img_w - w) + if img_c == 3: + if channel_first: + img[:, x1:x1 + h, y1:y1 + w] = pixels + else: + img[x1:x1 + h, y1:y1 + w, :] = pixels else: - img[x1:x1 + h, y1:y1 + w, 0] = pixels[0] + if channel_first: + img[0, x1:x1 + h, y1:y1 + w] = pixels[0] + else: + img[x1:x1 + h, y1:y1 + w, 0] = pixels[:, :, 0] return img return img diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index c5164002bf519120880652fbad5dcf10b5e6f33e..460856e3187285377b1b4ddd47d3618dab0e7dc1 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -288,8 +288,9 @@ class Engine(object): world_size = dist.get_world_size() self.config["Global"]["distributed"] = world_size != 1 if self.mode == "train": - std_gpu_num = 8 if self.config["Optimizer"][ - "name"] == "AdamW" else 4 + std_gpu_num = 8 if isinstance( + self.config["Optimizer"], + dict) and self.config["Optimizer"]["name"] == "AdamW" else 4 if world_size != std_gpu_num: msg = f"The training strategy provided by PaddleClas is based on {std_gpu_num} gpus. But the number of gpu is {world_size} in current training. Please modify the stategy (learning rate, batch size and so on) if use this config to train." logger.warning(msg) @@ -337,6 +338,7 @@ class Engine(object): self.max_iter = len(self.train_dataloader) - 1 if platform.system( ) == "Windows" else len(self.train_dataloader) + for epoch_id in range(best_metric["epoch"] + 1, self.config["Global"]["epochs"] + 1): acc = 0.0 diff --git a/ppcls/engine/evaluation/retrieval.py b/ppcls/engine/evaluation/retrieval.py index b481efae11bf2832b1c965bf0fa43ff0f295abd4..05c5d0c35d0f6fdfcd0a8f1dc1a8a121026ede99 100644 --- a/ppcls/engine/evaluation/retrieval.py +++ b/ppcls/engine/evaluation/retrieval.py @@ -126,7 +126,15 @@ def cal_feature(engine, name='gallery'): out = engine.model(batch[0], batch[1]) if "Student" in out: out = out["Student"] - batch_feas = out["features"] + + # get features + if engine.config["Global"].get("retrieval_feature_from", + "features") == "features": + # use neck's output as features + batch_feas = out["features"] + else: + # use backbone's output as features + batch_feas = out["backbone"] # do norm if engine.config["Global"].get("feature_normalize", True): diff --git a/ppcls/engine/train/train.py b/ppcls/engine/train/train.py index 1e944a609d066a6a193c5af55ce56bc931c82eeb..14db79e73e9e51d16d5784b7aa48a6afb12a7e0f 100644 --- a/ppcls/engine/train/train.py +++ b/ppcls/engine/train/train.py @@ -53,7 +53,7 @@ def train_epoch(engine, epoch_id, print_batch_step): out = forward(engine, batch) loss_dict = engine.train_loss_func(out, batch[1]) - # step opt + # backward & step opt if engine.amp: scaled = engine.scaler.scale(loss_dict["loss"]) scaled.backward() @@ -63,12 +63,15 @@ def train_epoch(engine, epoch_id, print_batch_step): loss_dict["loss"].backward() for i in range(len(engine.optimizer)): engine.optimizer[i].step() + # clear grad for i in range(len(engine.optimizer)): engine.optimizer[i].clear_grad() - # step lr + + # step lr(by step) for i in range(len(engine.lr_sch)): - engine.lr_sch[i].step() + if not getattr(engine.lr_sch[i], "by_epoch", False): + engine.lr_sch[i].step() # below code just for logging # update metric_for_logger @@ -80,6 +83,11 @@ def train_epoch(engine, epoch_id, print_batch_step): log_info(engine, batch_size, epoch_id, iter_id) tic = time.time() + # step lr(by epoch) + for i in range(len(engine.lr_sch)): + if getattr(engine.lr_sch[i], "by_epoch", False): + engine.lr_sch[i].step() + def forward(engine, batch): if not engine.is_rec: diff --git a/ppcls/engine/train/utils.py b/ppcls/engine/train/utils.py index 7f64104da9b745e020c70b5804329ca96a6f35df..ca211ff932f19ca63804a5a1ff52def5eb89477f 100644 --- a/ppcls/engine/train/utils.py +++ b/ppcls/engine/train/utils.py @@ -39,7 +39,7 @@ def update_loss(trainer, loss_dict, batch_size): def log_info(trainer, batch_size, epoch_id, iter_id): lr_msg = ", ".join([ - "lr_{}: {:.8f}".format(i + 1, lr.get_lr()) + "lr({}): {:.8f}".format(lr.__class__.__name__, lr.get_lr()) for i, lr in enumerate(trainer.lr_sch) ]) metric_msg = ", ".join([ @@ -64,7 +64,7 @@ def log_info(trainer, batch_size, epoch_id, iter_id): for i, lr in enumerate(trainer.lr_sch): logger.scaler( - name="lr_{}".format(i + 1), + name="lr({})".format(lr.__class__.__name__), value=lr.get_lr(), step=trainer.global_step, writer=trainer.vdl_writer) diff --git a/ppcls/loss/centerloss.py b/ppcls/loss/centerloss.py index d85b3f2a90c781c2fdabf57ca852140c5a1090ba..23a86ee8875c1863beae749ea873f4cb662510d0 100644 --- a/ppcls/loss/centerloss.py +++ b/ppcls/loss/centerloss.py @@ -1,54 +1,80 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import absolute_import from __future__ import division from __future__ import print_function + +from typing import Dict + import paddle import paddle.nn as nn -import paddle.nn.functional as F class CenterLoss(nn.Layer): - def __init__(self, num_classes=5013, feat_dim=2048): + """Center loss + paper : [A Discriminative Feature Learning Approach for Deep Face Recognition](https://link.springer.com/content/pdf/10.1007%2F978-3-319-46478-7_31.pdf) + code reference: https://github.com/michuanhaohao/reid-strong-baseline/blob/master/layers/center_loss.py#L7 + Args: + num_classes (int): number of classes. + feat_dim (int): number of feature dimensions. + feature_from (str): feature from "backbone" or "features" + """ + + def __init__(self, + num_classes: int, + feat_dim: int, + feature_from: str="features"): super(CenterLoss, self).__init__() self.num_classes = num_classes self.feat_dim = feat_dim - self.centers = paddle.randn( - shape=[self.num_classes, self.feat_dim]).astype( - "float64") #random center + self.feature_from = feature_from + random_init_centers = paddle.randn( + shape=[self.num_classes, self.feat_dim]) + self.centers = self.create_parameter( + shape=(self.num_classes, self.feat_dim), + default_initializer=nn.initializer.Assign(random_init_centers)) + self.add_parameter("centers", self.centers) - def __call__(self, input, target): - """ - inputs: network output: {"features: xxx", "logits": xxxx} - target: image label + def __call__(self, input: Dict[str, paddle.Tensor], + target: paddle.Tensor) -> Dict[str, paddle.Tensor]: + """compute center loss. + + Args: + input (Dict[str, paddle.Tensor]): {'features': (batch_size, feature_dim), ...}. + target (paddle.Tensor): ground truth label with shape (batch_size, ). + + Returns: + Dict[str, paddle.Tensor]: {'CenterLoss': loss}. """ - feats = input["features"] + feats = input[self.feature_from] labels = target + + # squeeze labels to shape (batch_size, ) + if labels.ndim >= 2 and labels.shape[-1] == 1: + labels = paddle.squeeze(labels, axis=[-1]) + batch_size = feats.shape[0] + distmat = paddle.pow(feats, 2).sum(axis=1, keepdim=True).expand([batch_size, self.num_classes]) + \ + paddle.pow(self.centers, 2).sum(axis=1, keepdim=True).expand([self.num_classes, batch_size]).t() + distmat = distmat.addmm(x=feats, y=self.centers.t(), beta=1, alpha=-2) - #calc feat * feat - dist1 = paddle.sum(paddle.square(feats), axis=1, keepdim=True) - dist1 = paddle.expand(dist1, [batch_size, self.num_classes]) - - #dist2 of centers - dist2 = paddle.sum(paddle.square(self.centers), axis=1, - keepdim=True) #num_classes - dist2 = paddle.expand(dist2, - [self.num_classes, batch_size]).astype("float64") - dist2 = paddle.transpose(dist2, [1, 0]) - - #first x * x + y * y - distmat = paddle.add(dist1, dist2) - tmp = paddle.matmul(feats, paddle.transpose(self.centers, [1, 0])) - distmat = distmat - 2.0 * tmp - - #generate the mask - classes = paddle.arange(self.num_classes).astype("int64") - labels = paddle.expand( - paddle.unsqueeze(labels, 1), (batch_size, self.num_classes)) - mask = paddle.equal( - paddle.expand(classes, [batch_size, self.num_classes]), - labels).astype("float64") #get mask - - dist = paddle.multiply(distmat, mask) - loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size + classes = paddle.arange(self.num_classes).astype(labels.dtype) + labels = labels.unsqueeze(1).expand([batch_size, self.num_classes]) + mask = labels.equal(classes.expand([batch_size, self.num_classes])) + dist = distmat * mask.astype(feats.dtype) + loss = dist.clip(min=1e-12, max=1e+12).sum() / batch_size + # return loss return {'CenterLoss': loss} diff --git a/ppcls/loss/triplet.py b/ppcls/loss/triplet.py index 458ee2e27d7b550fecfe16e5208047a8919b89d0..0da7cc5dffb8f54807fa3d4da12b002755e54452 100644 --- a/ppcls/loss/triplet.py +++ b/ppcls/loss/triplet.py @@ -28,9 +28,13 @@ class TripletLossV2(nn.Layer): margin (float): margin for triplet. """ - def __init__(self, margin=0.5, normalize_feature=True): + def __init__(self, + margin=0.5, + normalize_feature=True, + feature_from="features"): super(TripletLossV2, self).__init__() self.margin = margin + self.feature_from = feature_from self.ranking_loss = paddle.nn.loss.MarginRankingLoss(margin=margin) self.normalize_feature = normalize_feature @@ -40,7 +44,7 @@ class TripletLossV2(nn.Layer): inputs: feature matrix with shape (batch_size, feat_dim) target: ground truth labels with shape (num_classes) """ - inputs = input["features"] + inputs = input[self.feature_from] if self.normalize_feature: inputs = 1. * inputs / (paddle.expand_as( diff --git a/ppcls/optimizer/__init__.py b/ppcls/optimizer/__init__.py index 44d7b5ac0b33f267f6893d39bd42d27c8bac0573..bdee9f9b6c4b605a85b635f6a12de5eda6165c90 100644 --- a/ppcls/optimizer/__init__.py +++ b/ppcls/optimizer/__init__.py @@ -115,7 +115,9 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None): optim_model.append(m) else: # opmizer for module in model, such as backbone, neck, head... - if hasattr(model_list[i], optim_scope): + if optim_scope == model_list[i].__class__.__name__: + optim_model.append(model_list[i]) + elif hasattr(model_list[i], optim_scope): optim_model.append(getattr(model_list[i], optim_scope)) optim = getattr(optimizer, optim_name)( diff --git a/ppcls/optimizer/learning_rate.py b/ppcls/optimizer/learning_rate.py index b59387dd935c805078ffdb435788373e07743807..1a4561133f948831b9ca0d69821a3394f092fae7 100644 --- a/ppcls/optimizer/learning_rate.py +++ b/ppcls/optimizer/learning_rate.py @@ -75,6 +75,23 @@ class Linear(object): return learning_rate +class Constant(LRScheduler): + """ + Constant learning rate + Args: + lr (float): The initial learning rate. It is a python float number. + last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. + """ + + def __init__(self, learning_rate, last_epoch=-1, **kwargs): + self.learning_rate = learning_rate + self.last_epoch = last_epoch + super().__init__() + + def get_lr(self): + return self.learning_rate + + class Cosine(object): """ Cosine learning rate decay @@ -188,6 +205,7 @@ class Piecewise(object): The type of element in the list is python float. warmup_epoch(int): The epoch numbers for LinearWarmup. Default: 0. warmup_start_lr(float): Initial learning rate of warm up. Default: 0.0. + by_epoch(bool): Whether lr decay by epoch. Default: False. last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate. """ @@ -198,6 +216,7 @@ class Piecewise(object): epochs, warmup_epoch=0, warmup_start_lr=0.0, + by_epoch=False, last_epoch=-1, **kwargs): super().__init__() @@ -205,24 +224,41 @@ class Piecewise(object): msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}." logger.warning(msg) warmup_epoch = epochs - self.boundaries = [step_each_epoch * e for e in decay_epochs] + self.boundaries_steps = [step_each_epoch * e for e in decay_epochs] + self.boundaries_epoch = decay_epochs self.values = values self.last_epoch = last_epoch self.warmup_steps = round(warmup_epoch * step_each_epoch) + self.warmup_epoch = warmup_epoch self.warmup_start_lr = warmup_start_lr + self.by_epoch = by_epoch def __call__(self): - learning_rate = lr.PiecewiseDecay( - boundaries=self.boundaries, - values=self.values, - last_epoch=self.last_epoch) - if self.warmup_steps > 0: - learning_rate = lr.LinearWarmup( - learning_rate=learning_rate, - warmup_steps=self.warmup_steps, - start_lr=self.warmup_start_lr, - end_lr=self.values[0], + if self.by_epoch: + learning_rate = lr.PiecewiseDecay( + boundaries=self.boundaries_epoch, + values=self.values, + last_epoch=self.last_epoch) + if self.warmup_epoch > 0: + learning_rate = lr.LinearWarmup( + learning_rate=learning_rate, + warmup_steps=self.warmup_epoch, + start_lr=self.warmup_start_lr, + end_lr=self.values[0], + last_epoch=self.last_epoch) + else: + learning_rate = lr.PiecewiseDecay( + boundaries=self.boundaries_steps, + values=self.values, last_epoch=self.last_epoch) + if self.warmup_steps > 0: + learning_rate = lr.LinearWarmup( + learning_rate=learning_rate, + warmup_steps=self.warmup_steps, + start_lr=self.warmup_start_lr, + end_lr=self.values[0], + last_epoch=self.last_epoch) + setattr(learning_rate, "by_epoch", self.by_epoch) return learning_rate