未验证 提交 bd524e8a 编写于 作者: W Walter 提交者: GitHub

Merge pull request #1819 from weisy11/reid

复现reid-strong-baseline
......@@ -278,6 +278,7 @@ class ResNet(TheseusLayer):
config,
stages_pattern,
version="vb",
stem_act="relu",
class_num=1000,
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
data_format="NCHW",
......@@ -311,13 +312,13 @@ class ResNet(TheseusLayer):
[[input_image_channel, 32, 3, 2], [32, 32, 3, 1], [32, 64, 3, 1]]
}
self.stem = nn.Sequential(* [
self.stem = nn.Sequential(*[
ConvBNLayer(
num_channels=in_c,
num_filters=out_c,
filter_size=k,
stride=s,
act="relu",
act=stem_act,
lr_mult=self.lr_mult_list[0],
data_format=data_format)
for in_c, out_c, k, s in self.stem_cfg[version]
......
......@@ -17,21 +17,32 @@ from __future__ import absolute_import, division, print_function
import paddle
import paddle.nn as nn
from ppcls.arch.utils import get_param_attr_dict
class BNNeck(nn.Layer):
def __init__(self, num_features):
def __init__(self, num_features, **kwargs):
super().__init__()
weight_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=1.0))
bias_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0.0),
trainable=False)
if 'weight_attr' in kwargs:
weight_attr = get_param_attr_dict(kwargs['weight_attr'])
bias_attr = None
if 'bias_attr' in kwargs:
bias_attr = get_param_attr_dict(kwargs['bias_attr'])
self.feat_bn = nn.BatchNorm1D(
num_features,
momentum=0.9,
epsilon=1e-05,
weight_attr=weight_attr,
bias_attr=bias_attr)
self.flatten = nn.Flatten()
def forward(self, x):
......
......@@ -19,16 +19,29 @@ from __future__ import print_function
import paddle
import paddle.nn as nn
from ppcls.arch.utils import get_param_attr_dict
class FC(nn.Layer):
def __init__(self, embedding_size, class_num):
def __init__(self, embedding_size, class_num, **kwargs):
super(FC, self).__init__()
self.embedding_size = embedding_size
self.class_num = class_num
weight_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.XavierNormal())
self.fc = paddle.nn.Linear(
self.embedding_size, self.class_num, weight_attr=weight_attr)
if 'weight_attr' in kwargs:
weight_attr = get_param_attr_dict(kwargs['weight_attr'])
bias_attr = None
if 'bias_attr' in kwargs:
bias_attr = get_param_attr_dict(kwargs['bias_attr'])
self.fc = nn.Linear(
self.embedding_size,
self.class_num,
weight_attr=weight_attr,
bias_attr=bias_attr)
def forward(self, input, label=None):
out = self.fc(input)
......
......@@ -14,9 +14,11 @@
import six
import types
import paddle
from difflib import SequenceMatcher
from . import backbone
from typing import Any, Dict, Union
def get_architectures():
......@@ -51,3 +53,47 @@ def similar_architectures(name='', names=[], thresh=0.1, topk=10):
scores.sort(key=lambda x: x[1], reverse=True)
similar_names = [names[s[0]] for s in scores[:min(topk, len(scores))]]
return similar_names
def get_param_attr_dict(ParamAttr_config: Union[None, bool, Dict[str, Dict]]
) -> Union[None, bool, paddle.ParamAttr]:
"""parse ParamAttr from an dict
Args:
ParamAttr_config (Union[None, bool, Dict[str, Dict]]): ParamAttr configure
Returns:
Union[None, bool, paddle.ParamAttr]: Generated ParamAttr
"""
if ParamAttr_config is None:
return None
if isinstance(ParamAttr_config, bool):
return ParamAttr_config
ParamAttr_dict = {}
if 'initializer' in ParamAttr_config:
initializer_cfg = ParamAttr_config.get('initializer')
if 'name' in initializer_cfg:
initializer_name = initializer_cfg.pop('name')
ParamAttr_dict['initializer'] = getattr(
paddle.nn.initializer, initializer_name)(**initializer_cfg)
else:
raise ValueError(f"'name' must specified in initializer_cfg")
if 'learning_rate' in ParamAttr_config:
# NOTE: only support an single value now
learning_rate_value = ParamAttr_config.get('learning_rate')
if isinstance(learning_rate_value, (int, float)):
ParamAttr_dict['learning_rate'] = learning_rate_value
else:
raise ValueError(
f"learning_rate_value must be float or int, but got {type(learning_rate_value)}"
)
if 'regularizer' in ParamAttr_config:
regularizer_cfg = ParamAttr_config.get('regularizer')
if 'name' in regularizer_cfg:
# L1Decay or L2Decay
regularizer_name = regularizer_cfg.pop('name')
ParamAttr_dict['regularizer'] = getattr(
paddle.regularizer, regularizer_name)(**regularizer_cfg)
else:
raise ValueError(f"'name' must specified in regularizer_cfg")
return paddle.ParamAttr(**ParamAttr_dict)
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: "./output/"
device: "gpu"
save_interval: 40
eval_during_train: True
eval_interval: 10
epochs: 120
print_batch_step: 20
use_visualdl: False
eval_mode: "retrieval"
retrieval_feature_from: "backbone" # 'backbone' or 'neck'
# used for static mode and model export
image_shape: [3, 256, 128]
save_inference_dir: "./inference"
# model architecture
Arch:
name: "RecModel"
infer_output_key: "features"
infer_add_softmax: False
Backbone:
name: "ResNet50"
pretrained: True
stem_act: null
BackboneStopLayer:
name: "flatten"
Head:
name: "FC"
embedding_size: 2048
class_num: 751
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
- TripletLossV2:
weight: 1.0
margin: 0.3
normalize_feature: False
feature_from: "backbone"
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Adam
lr:
name: Piecewise
decay_epochs: [40, 70]
values: [0.00035, 0.000035, 0.0000035]
warmup_epoch: 10
by_epoch: True
last_epoch: 0
regularizer:
name: 'L2'
coeff: 0.0005
# data loader for train and eval
DataLoader:
Train:
dataset:
name: "Market1501"
image_root: "./dataset/"
cls_label_path: "bounding_box_train"
backend: "pil"
transform_ops:
- ResizeImage:
size: [128, 256]
return_numpy: False
backend: "pil"
- RandFlipImage:
flip_code: 1
- Pad:
padding: 10
- RandCropImageV2:
size: [128, 256]
- ToTensor:
- Normalize:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
sampler:
name: DistributedRandomIdentitySampler
batch_size: 64
num_instances: 4
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
Query:
dataset:
name: "Market1501"
image_root: "./dataset/"
cls_label_path: "query"
backend: "pil"
transform_ops:
- ResizeImage:
size: [128, 256]
return_numpy: False
backend: "pil"
- ToTensor:
- Normalize:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Gallery:
dataset:
name: "Market1501"
image_root: "./dataset/"
cls_label_path: "bounding_box_test"
backend: "pil"
transform_ops:
- ResizeImage:
size: [128, 256]
return_numpy: False
backend: "pil"
- ToTensor:
- Normalize:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Metric:
Eval:
- Recallk:
topk: [1, 5]
- mAP: {}
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: "./output/"
device: "gpu"
save_interval: 40
eval_during_train: True
eval_interval: 10
epochs: 120
print_batch_step: 20
use_visualdl: False
eval_mode: "retrieval"
retrieval_feature_from: "features" # 'backbone' or 'features'
# used for static mode and model export
image_shape: [3, 256, 128]
save_inference_dir: "./inference"
# model architecture
Arch:
name: "RecModel"
infer_output_key: "features"
infer_add_softmax: False
Backbone:
name: "ResNet50_last_stage_stride1"
pretrained: True
stem_act: null
BackboneStopLayer:
name: "flatten"
Neck:
name: BNNeck
num_features: &feat_dim 2048
weight_attr:
initializer:
name: Constant
value: 1.0
bias_attr:
initializer:
name: Constant
value: 0.0
learning_rate: 1.0e-20 # NOTE: Temporarily set lr small enough to freeze the bias to zero
Head:
name: "FC"
embedding_size: *feat_dim
class_num: 751
weight_attr:
initializer:
name: Normal
std: 0.001
bias_attr: False
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
- TripletLossV2:
weight: 1.0
margin: 0.3
normalize_feature: False
feature_from: "backbone"
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Adam
lr:
name: Piecewise
decay_epochs: [30, 60]
values: [0.00035, 0.000035, 0.0000035]
warmup_epoch: 10
warmup_start_lr: 0.0000035
by_epoch: True
last_epoch: 0
regularizer:
name: 'L2'
coeff: 0.0005
# data loader for train and eval
DataLoader:
Train:
dataset:
name: "Market1501"
image_root: "./dataset/"
cls_label_path: "bounding_box_train"
backend: "pil"
transform_ops:
- ResizeImage:
size: [128, 256]
return_numpy: False
backend: "pil"
- RandFlipImage:
flip_code: 1
- Pad:
padding: 10
- RandCropImageV2:
size: [128, 256]
- ToTensor:
- Normalize:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
- RandomErasing:
EPSILON: 0.5
sl: 0.02
sh: 0.4
r1: 0.3
mean: [0.485, 0.456, 0.406]
sampler:
name: DistributedRandomIdentitySampler
batch_size: 64
num_instances: 4
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
Query:
dataset:
name: "Market1501"
image_root: "./dataset/"
cls_label_path: "query"
backend: "pil"
transform_ops:
- ResizeImage:
size: [128, 256]
return_numpy: False
backend: "pil"
- ToTensor:
- Normalize:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Gallery:
dataset:
name: "Market1501"
image_root: "./dataset/"
cls_label_path: "bounding_box_test"
backend: "pil"
transform_ops:
- ResizeImage:
size: [128, 256]
return_numpy: False
backend: "pil"
- ToTensor:
- Normalize:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Metric:
Eval:
- Recallk:
topk: [1, 5]
- mAP: {}
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: "./output/"
device: "gpu"
save_interval: 40
eval_during_train: True
eval_interval: 10
epochs: 120
print_batch_step: 20
use_visualdl: False
eval_mode: "retrieval"
retrieval_feature_from: "features" # 'backbone' or 'features'
# used for static mode and model export
image_shape: [3, 256, 128]
save_inference_dir: "./inference"
# model architecture
Arch:
name: "RecModel"
infer_output_key: "features"
infer_add_softmax: False
Backbone:
name: "ResNet50_last_stage_stride1"
pretrained: True
stem_act: null
BackboneStopLayer:
name: "flatten"
Neck:
name: BNNeck
num_features: &feat_dim 2048
weight_attr:
initializer:
name: Constant
value: 1.0
bias_attr:
initializer:
name: Constant
value: 0.0
learning_rate: 1.0e-20 # NOTE: Temporarily set lr small enough to freeze the bias to zero
Head:
name: "FC"
embedding_size: *feat_dim
class_num: &class_num 751
weight_attr:
initializer:
name: Normal
std: 0.001
bias_attr: False
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
- TripletLossV2:
weight: 1.0
margin: 0.3
normalize_feature: False
feature_from: "backbone"
- CenterLoss:
weight: 0.0005
num_classes: *class_num
feat_dim: *feat_dim
feature_from: "backbone"
Eval:
- CELoss:
weight: 1.0
Optimizer:
- Adam:
scope: RecModel
lr:
name: Piecewise
decay_epochs: [30, 60]
values: [0.00035, 0.000035, 0.0000035]
warmup_epoch: 10
warmup_start_lr: 0.0000035
by_epoch: True
last_epoch: 0
regularizer:
name: 'L2'
coeff: 0.0005
- SGD:
scope: CenterLoss
lr:
name: Constant
learning_rate: 1000.0 # NOTE: set to ori_lr*(1/centerloss_weight) to avoid manually scaling centers' gradidents.
# data loader for train and eval
DataLoader:
Train:
dataset:
name: "Market1501"
image_root: "./dataset/"
cls_label_path: "bounding_box_train"
backend: "pil"
transform_ops:
- ResizeImage:
size: [128, 256]
return_numpy: False
backend: "pil"
- RandFlipImage:
flip_code: 1
- Pad:
padding: 10
- RandCropImageV2:
size: [128, 256]
- ToTensor:
- Normalize:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
- RandomErasing:
EPSILON: 0.5
sl: 0.02
sh: 0.4
r1: 0.3
mean: [0.485, 0.456, 0.406]
sampler:
name: DistributedRandomIdentitySampler
batch_size: 64
num_instances: 4
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
Query:
dataset:
name: "Market1501"
image_root: "./dataset/"
cls_label_path: "query"
backend: "pil"
transform_ops:
- ResizeImage:
size: [128, 256]
return_numpy: False
backend: "pil"
- ToTensor:
- Normalize:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Gallery:
dataset:
name: "Market1501"
image_root: "./dataset/"
cls_label_path: "bounding_box_test"
backend: "pil"
transform_ops:
- ResizeImage:
size: [128, 256]
return_numpy: False
backend: "pil"
- ToTensor:
- Normalize:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Metric:
Eval:
- Recallk:
topk: [1, 5]
- mAP: {}
......@@ -43,7 +43,11 @@ class Market1501(Dataset):
"""
_dataset_dir = 'market1501/Market-1501-v15.09.15'
def __init__(self, image_root, cls_label_path, transform_ops=None):
def __init__(self,
image_root,
cls_label_path,
transform_ops=None,
backend="cv2"):
self._img_root = image_root
self._cls_path = cls_label_path # the sub folder in the dataset
self._dataset_dir = osp.join(image_root, self._dataset_dir,
......@@ -51,6 +55,7 @@ class Market1501(Dataset):
self._check_before_run()
if transform_ops:
self._transform_ops = create_operators(transform_ops)
self.backend = backend
self._dtype = paddle.get_default_dtype()
self._load_anno(relabel=True if 'train' in self._cls_path else False)
......@@ -92,10 +97,12 @@ class Market1501(Dataset):
def __getitem__(self, idx):
try:
img = Image.open(self.images[idx]).convert('RGB')
img = np.array(img, dtype="float32").astype(np.uint8)
if self.backend == "cv2":
img = np.array(img, dtype="float32").astype(np.uint8)
if self._transform_ops:
img = transform(img, self._transform_ops)
img = img.transpose((2, 0, 1))
if self.backend == "cv2":
img = img.transpose((2, 0, 1))
return (img, self.labels[idx], self.cameras[idx])
except Exception as ex:
logger.error("Exception occured when parse line: {} with msg: {}".
......
......@@ -25,10 +25,14 @@ from ppcls.data.preprocess.ops.operators import DecodeImage
from ppcls.data.preprocess.ops.operators import ResizeImage
from ppcls.data.preprocess.ops.operators import CropImage
from ppcls.data.preprocess.ops.operators import RandCropImage
from ppcls.data.preprocess.ops.operators import RandCropImageV2
from ppcls.data.preprocess.ops.operators import RandFlipImage
from ppcls.data.preprocess.ops.operators import NormalizeImage
from ppcls.data.preprocess.ops.operators import ToCHWImage
from ppcls.data.preprocess.ops.operators import AugMix
from ppcls.data.preprocess.ops.operators import Pad
from ppcls.data.preprocess.ops.operators import ToTensor
from ppcls.data.preprocess.ops.operators import Normalize
from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator
......
......@@ -23,8 +23,9 @@ import math
import random
import cv2
import numpy as np
from PIL import Image
from PIL import Image, ImageOps, __version__ as PILLOW_VERSION
from paddle.vision.transforms import ColorJitter as RawColorJitter
from paddle.vision.transforms import ToTensor, Normalize
from .autoaugment import ImageNetPolicy
from .functional import augmentations
......@@ -32,7 +33,7 @@ from ppcls.utils import logger
class UnifiedResize(object):
def __init__(self, interpolation=None, backend="cv2"):
def __init__(self, interpolation=None, backend="cv2", return_numpy=True):
_cv2_interp_from_str = {
'nearest': cv2.INTER_NEAREST,
'bilinear': cv2.INTER_LINEAR,
......@@ -56,12 +57,17 @@ class UnifiedResize(object):
resample = random.choice(resample)
return cv2.resize(src, size, interpolation=resample)
def _pil_resize(src, size, resample):
def _pil_resize(src, size, resample, return_numpy=True):
if isinstance(resample, tuple):
resample = random.choice(resample)
pil_img = Image.fromarray(src)
if isinstance(src, np.ndarray):
pil_img = Image.fromarray(src)
else:
pil_img = src
pil_img = pil_img.resize(size, resample)
return np.asarray(pil_img)
if return_numpy:
return np.asarray(pil_img)
return pil_img
if backend.lower() == "cv2":
if isinstance(interpolation, str):
......@@ -73,7 +79,8 @@ class UnifiedResize(object):
elif backend.lower() == "pil":
if isinstance(interpolation, str):
interpolation = _pil_interp_from_str[interpolation.lower()]
self.resize_func = partial(_pil_resize, resample=interpolation)
self.resize_func = partial(
_pil_resize, resample=interpolation, return_numpy=return_numpy)
else:
logger.warning(
f"The backend of Resize only support \"cv2\" or \"PIL\". \"f{backend}\" is unavailable. Use \"cv2\" instead."
......@@ -81,6 +88,8 @@ class UnifiedResize(object):
self.resize_func = cv2.resize
def __call__(self, src, size):
if isinstance(size, list):
size = tuple(size)
return self.resize_func(src, size)
......@@ -99,14 +108,15 @@ class DecodeImage(object):
self.channel_first = channel_first # only enabled when to_np is True
def __call__(self, img):
if six.PY2:
assert type(img) is str and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert type(img) is bytes and len(
img) > 0, "invalid input 'img' in DecodeImage"
data = np.frombuffer(img, dtype='uint8')
img = cv2.imdecode(data, 1)
if not isinstance(img, np.ndarray):
if six.PY2:
assert type(img) is str and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert type(img) is bytes and len(
img) > 0, "invalid input 'img' in DecodeImage"
data = np.frombuffer(img, dtype='uint8')
img = cv2.imdecode(data, 1)
if self.to_rgb:
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (
img.shape)
......@@ -125,7 +135,8 @@ class ResizeImage(object):
size=None,
resize_short=None,
interpolation=None,
backend="cv2"):
backend="cv2",
return_numpy=True):
if resize_short is not None and resize_short > 0:
self.resize_short = resize_short
self.w = None
......@@ -139,10 +150,16 @@ class ResizeImage(object):
'both 'size' and 'resize_short' are None")
self._resize_func = UnifiedResize(
interpolation=interpolation, backend=backend)
interpolation=interpolation,
backend=backend,
return_numpy=return_numpy)
def __call__(self, img):
img_h, img_w = img.shape[:2]
if isinstance(img, np.ndarray):
img_h, img_w = img.shape[:2]
else:
img_w, img_h = img.size
if self.resize_short is not None:
percent = float(self.resize_short) / min(img_w, img_h)
w = int(round(img_w * percent))
......@@ -222,6 +239,40 @@ class RandCropImage(object):
return self._resize_func(img, size)
class RandCropImageV2(object):
""" RandCropImageV2 is different from RandCropImage,
it will Select a cutting position randomly in a uniform distribution way,
and cut according to the given size without resize at last."""
def __init__(self, size):
if type(size) is int:
self.size = (size, size) # (h, w)
else:
self.size = size
def __call__(self, img):
if isinstance(img, np.ndarray):
img_h, img_w = img.shap[0], img.shap[1]
else:
img_w, img_h = img.size
tw, th = self.size
if img_h + 1 < th or img_w + 1 < tw:
raise ValueError(
"Required crop size {} is larger then input image size {}".
format((th, tw), (img_h, img_w)))
if img_w == tw and img_h == th:
return img
top = random.randint(0, img_h - th + 1)
left = random.randint(0, img_w - tw + 1)
if isinstance(img, np.ndarray):
return img[top:top + th, left:left + tw, :]
else:
return img.crop((left, top, left + tw, top + th))
class RandFlipImage(object):
""" random flip image
flip_code:
......@@ -237,7 +288,10 @@ class RandFlipImage(object):
def __call__(self, img):
if random.randint(0, 1) == 1:
return cv2.flip(img, self.flip_code)
if isinstance(img, np.ndarray):
return cv2.flip(img, self.flip_code)
else:
return img.transpose(Image.FLIP_LEFT_RIGHT)
else:
return img
......@@ -391,3 +445,58 @@ class ColorJitter(RawColorJitter):
if isinstance(img, Image.Image):
img = np.asarray(img)
return img
class Pad(object):
"""
Pads the given PIL.Image on all sides with specified padding mode and fill value.
adapted from: https://pytorch.org/vision/stable/_modules/torchvision/transforms/transforms.html#Pad
"""
def __init__(self, padding: int, fill: int=0,
padding_mode: str="constant"):
self.padding = padding
self.fill = fill
self.padding_mode = padding_mode
def _parse_fill(self, fill, img, min_pil_version, name="fillcolor"):
# Process fill color for affine transforms
major_found, minor_found = (int(v)
for v in PILLOW_VERSION.split('.')[:2])
major_required, minor_required = (
int(v) for v in min_pil_version.split('.')[:2])
if major_found < major_required or (major_found == major_required and
minor_found < minor_required):
if fill is None:
return {}
else:
msg = (
"The option to fill background area of the transformed image, "
"requires pillow>={}")
raise RuntimeError(msg.format(min_pil_version))
num_bands = len(img.getbands())
if fill is None:
fill = 0
if isinstance(fill, (int, float)) and num_bands > 1:
fill = tuple([fill] * num_bands)
if isinstance(fill, (list, tuple)):
if len(fill) != num_bands:
msg = (
"The number of elements in 'fill' does not match the number of "
"bands of the image ({} != {})")
raise ValueError(msg.format(len(fill), num_bands))
fill = tuple(fill)
return {name: fill}
def __call__(self, img):
opts = self._parse_fill(self.fill, img, "2.3.0", name="fill")
if img.mode == "P":
palette = img.getpalette()
img = ImageOps.expand(img, border=self.padding, **opts)
img.putpalette(palette)
return img
return ImageOps.expand(img, border=self.padding, **opts)
......@@ -26,15 +26,21 @@ import numpy as np
class Pixels(object):
def __init__(self, mode="const", mean=[0., 0., 0.]):
self._mode = mode
self._mean = mean
self._mean = np.array(mean)
def __call__(self, h=224, w=224, c=3):
def __call__(self, h=224, w=224, c=3, channel_first=False):
if self._mode == "rand":
return np.random.normal(size=(1, 1, 3))
return np.random.normal(size=(
1, 1, 3)) if not channel_first else np.random.normal(size=(
3, 1, 1))
elif self._mode == "pixel":
return np.random.normal(size=(h, w, c))
return np.random.normal(size=(
h, w, c)) if not channel_first else np.random.normal(size=(
c, h, w))
elif self._mode == "const":
return self._mean
return np.reshape(self._mean, (
1, 1, c)) if not channel_first else np.reshape(self._mean,
(c, 1, 1))
else:
raise Exception(
"Invalid mode in RandomErasing, only support \"const\", \"rand\", \"pixel\""
......@@ -69,7 +75,13 @@ class RandomErasing(object):
return img
for _ in range(self.attempt):
area = img.shape[0] * img.shape[1]
if isinstance(img, np.ndarray):
img_h, img_w, img_c = img.shape
channel_first = False
else:
img_c, img_h, img_w = img.shape
channel_first = True
area = img_h * img_w
target_area = random.uniform(self.sl, self.sh) * area
aspect_ratio = random.uniform(*self.r1)
......@@ -79,13 +91,19 @@ class RandomErasing(object):
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < img.shape[1] and h < img.shape[0]:
pixels = self.get_pixels(h, w, img.shape[2])
x1 = random.randint(0, img.shape[0] - h)
y1 = random.randint(0, img.shape[1] - w)
if img.shape[2] == 3:
img[x1:x1 + h, y1:y1 + w, :] = pixels
if w < img_w and h < img_h:
pixels = self.get_pixels(h, w, img_c, channel_first)
x1 = random.randint(0, img_h - h)
y1 = random.randint(0, img_w - w)
if img_c == 3:
if channel_first:
img[:, x1:x1 + h, y1:y1 + w] = pixels
else:
img[x1:x1 + h, y1:y1 + w, :] = pixels
else:
img[x1:x1 + h, y1:y1 + w, 0] = pixels[0]
if channel_first:
img[0, x1:x1 + h, y1:y1 + w] = pixels[0]
else:
img[x1:x1 + h, y1:y1 + w, 0] = pixels[:, :, 0]
return img
return img
......@@ -288,8 +288,9 @@ class Engine(object):
world_size = dist.get_world_size()
self.config["Global"]["distributed"] = world_size != 1
if self.mode == "train":
std_gpu_num = 8 if self.config["Optimizer"][
"name"] == "AdamW" else 4
std_gpu_num = 8 if isinstance(
self.config["Optimizer"],
dict) and self.config["Optimizer"]["name"] == "AdamW" else 4
if world_size != std_gpu_num:
msg = f"The training strategy provided by PaddleClas is based on {std_gpu_num} gpus. But the number of gpu is {world_size} in current training. Please modify the stategy (learning rate, batch size and so on) if use this config to train."
logger.warning(msg)
......@@ -337,6 +338,7 @@ class Engine(object):
self.max_iter = len(self.train_dataloader) - 1 if platform.system(
) == "Windows" else len(self.train_dataloader)
for epoch_id in range(best_metric["epoch"] + 1,
self.config["Global"]["epochs"] + 1):
acc = 0.0
......
......@@ -126,7 +126,15 @@ def cal_feature(engine, name='gallery'):
out = engine.model(batch[0], batch[1])
if "Student" in out:
out = out["Student"]
batch_feas = out["features"]
# get features
if engine.config["Global"].get("retrieval_feature_from",
"features") == "features":
# use neck's output as features
batch_feas = out["features"]
else:
# use backbone's output as features
batch_feas = out["backbone"]
# do norm
if engine.config["Global"].get("feature_normalize", True):
......
......@@ -53,7 +53,7 @@ def train_epoch(engine, epoch_id, print_batch_step):
out = forward(engine, batch)
loss_dict = engine.train_loss_func(out, batch[1])
# step opt
# backward & step opt
if engine.amp:
scaled = engine.scaler.scale(loss_dict["loss"])
scaled.backward()
......@@ -63,12 +63,15 @@ def train_epoch(engine, epoch_id, print_batch_step):
loss_dict["loss"].backward()
for i in range(len(engine.optimizer)):
engine.optimizer[i].step()
# clear grad
for i in range(len(engine.optimizer)):
engine.optimizer[i].clear_grad()
# step lr
# step lr(by step)
for i in range(len(engine.lr_sch)):
engine.lr_sch[i].step()
if not getattr(engine.lr_sch[i], "by_epoch", False):
engine.lr_sch[i].step()
# below code just for logging
# update metric_for_logger
......@@ -80,6 +83,11 @@ def train_epoch(engine, epoch_id, print_batch_step):
log_info(engine, batch_size, epoch_id, iter_id)
tic = time.time()
# step lr(by epoch)
for i in range(len(engine.lr_sch)):
if getattr(engine.lr_sch[i], "by_epoch", False):
engine.lr_sch[i].step()
def forward(engine, batch):
if not engine.is_rec:
......
......@@ -39,7 +39,7 @@ def update_loss(trainer, loss_dict, batch_size):
def log_info(trainer, batch_size, epoch_id, iter_id):
lr_msg = ", ".join([
"lr_{}: {:.8f}".format(i + 1, lr.get_lr())
"lr({}): {:.8f}".format(lr.__class__.__name__, lr.get_lr())
for i, lr in enumerate(trainer.lr_sch)
])
metric_msg = ", ".join([
......@@ -64,7 +64,7 @@ def log_info(trainer, batch_size, epoch_id, iter_id):
for i, lr in enumerate(trainer.lr_sch):
logger.scaler(
name="lr_{}".format(i + 1),
name="lr({})".format(lr.__class__.__name__),
value=lr.get_lr(),
step=trainer.global_step,
writer=trainer.vdl_writer)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Dict
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class CenterLoss(nn.Layer):
def __init__(self, num_classes=5013, feat_dim=2048):
"""Center loss
paper : [A Discriminative Feature Learning Approach for Deep Face Recognition](https://link.springer.com/content/pdf/10.1007%2F978-3-319-46478-7_31.pdf)
code reference: https://github.com/michuanhaohao/reid-strong-baseline/blob/master/layers/center_loss.py#L7
Args:
num_classes (int): number of classes.
feat_dim (int): number of feature dimensions.
feature_from (str): feature from "backbone" or "features"
"""
def __init__(self,
num_classes: int,
feat_dim: int,
feature_from: str="features"):
super(CenterLoss, self).__init__()
self.num_classes = num_classes
self.feat_dim = feat_dim
self.centers = paddle.randn(
shape=[self.num_classes, self.feat_dim]).astype(
"float64") #random center
self.feature_from = feature_from
random_init_centers = paddle.randn(
shape=[self.num_classes, self.feat_dim])
self.centers = self.create_parameter(
shape=(self.num_classes, self.feat_dim),
default_initializer=nn.initializer.Assign(random_init_centers))
self.add_parameter("centers", self.centers)
def __call__(self, input, target):
"""
inputs: network output: {"features: xxx", "logits": xxxx}
target: image label
def __call__(self, input: Dict[str, paddle.Tensor],
target: paddle.Tensor) -> Dict[str, paddle.Tensor]:
"""compute center loss.
Args:
input (Dict[str, paddle.Tensor]): {'features': (batch_size, feature_dim), ...}.
target (paddle.Tensor): ground truth label with shape (batch_size, ).
Returns:
Dict[str, paddle.Tensor]: {'CenterLoss': loss}.
"""
feats = input["features"]
feats = input[self.feature_from]
labels = target
# squeeze labels to shape (batch_size, )
if labels.ndim >= 2 and labels.shape[-1] == 1:
labels = paddle.squeeze(labels, axis=[-1])
batch_size = feats.shape[0]
distmat = paddle.pow(feats, 2).sum(axis=1, keepdim=True).expand([batch_size, self.num_classes]) + \
paddle.pow(self.centers, 2).sum(axis=1, keepdim=True).expand([self.num_classes, batch_size]).t()
distmat = distmat.addmm(x=feats, y=self.centers.t(), beta=1, alpha=-2)
#calc feat * feat
dist1 = paddle.sum(paddle.square(feats), axis=1, keepdim=True)
dist1 = paddle.expand(dist1, [batch_size, self.num_classes])
#dist2 of centers
dist2 = paddle.sum(paddle.square(self.centers), axis=1,
keepdim=True) #num_classes
dist2 = paddle.expand(dist2,
[self.num_classes, batch_size]).astype("float64")
dist2 = paddle.transpose(dist2, [1, 0])
#first x * x + y * y
distmat = paddle.add(dist1, dist2)
tmp = paddle.matmul(feats, paddle.transpose(self.centers, [1, 0]))
distmat = distmat - 2.0 * tmp
#generate the mask
classes = paddle.arange(self.num_classes).astype("int64")
labels = paddle.expand(
paddle.unsqueeze(labels, 1), (batch_size, self.num_classes))
mask = paddle.equal(
paddle.expand(classes, [batch_size, self.num_classes]),
labels).astype("float64") #get mask
dist = paddle.multiply(distmat, mask)
loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
classes = paddle.arange(self.num_classes).astype(labels.dtype)
labels = labels.unsqueeze(1).expand([batch_size, self.num_classes])
mask = labels.equal(classes.expand([batch_size, self.num_classes]))
dist = distmat * mask.astype(feats.dtype)
loss = dist.clip(min=1e-12, max=1e+12).sum() / batch_size
# return loss
return {'CenterLoss': loss}
......@@ -28,9 +28,13 @@ class TripletLossV2(nn.Layer):
margin (float): margin for triplet.
"""
def __init__(self, margin=0.5, normalize_feature=True):
def __init__(self,
margin=0.5,
normalize_feature=True,
feature_from="features"):
super(TripletLossV2, self).__init__()
self.margin = margin
self.feature_from = feature_from
self.ranking_loss = paddle.nn.loss.MarginRankingLoss(margin=margin)
self.normalize_feature = normalize_feature
......@@ -40,7 +44,7 @@ class TripletLossV2(nn.Layer):
inputs: feature matrix with shape (batch_size, feat_dim)
target: ground truth labels with shape (num_classes)
"""
inputs = input["features"]
inputs = input[self.feature_from]
if self.normalize_feature:
inputs = 1. * inputs / (paddle.expand_as(
......
......@@ -115,7 +115,9 @@ def build_optimizer(config, epochs, step_each_epoch, model_list=None):
optim_model.append(m)
else:
# opmizer for module in model, such as backbone, neck, head...
if hasattr(model_list[i], optim_scope):
if optim_scope == model_list[i].__class__.__name__:
optim_model.append(model_list[i])
elif hasattr(model_list[i], optim_scope):
optim_model.append(getattr(model_list[i], optim_scope))
optim = getattr(optimizer, optim_name)(
......
......@@ -75,6 +75,23 @@ class Linear(object):
return learning_rate
class Constant(LRScheduler):
"""
Constant learning rate
Args:
lr (float): The initial learning rate. It is a python float number.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
def __init__(self, learning_rate, last_epoch=-1, **kwargs):
self.learning_rate = learning_rate
self.last_epoch = last_epoch
super().__init__()
def get_lr(self):
return self.learning_rate
class Cosine(object):
"""
Cosine learning rate decay
......@@ -188,6 +205,7 @@ class Piecewise(object):
The type of element in the list is python float.
warmup_epoch(int): The epoch numbers for LinearWarmup. Default: 0.
warmup_start_lr(float): Initial learning rate of warm up. Default: 0.0.
by_epoch(bool): Whether lr decay by epoch. Default: False.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
......@@ -198,6 +216,7 @@ class Piecewise(object):
epochs,
warmup_epoch=0,
warmup_start_lr=0.0,
by_epoch=False,
last_epoch=-1,
**kwargs):
super().__init__()
......@@ -205,24 +224,41 @@ class Piecewise(object):
msg = f"When using warm up, the value of \"Global.epochs\" must be greater than value of \"Optimizer.lr.warmup_epoch\". The value of \"Optimizer.lr.warmup_epoch\" has been set to {epochs}."
logger.warning(msg)
warmup_epoch = epochs
self.boundaries = [step_each_epoch * e for e in decay_epochs]
self.boundaries_steps = [step_each_epoch * e for e in decay_epochs]
self.boundaries_epoch = decay_epochs
self.values = values
self.last_epoch = last_epoch
self.warmup_steps = round(warmup_epoch * step_each_epoch)
self.warmup_epoch = warmup_epoch
self.warmup_start_lr = warmup_start_lr
self.by_epoch = by_epoch
def __call__(self):
learning_rate = lr.PiecewiseDecay(
boundaries=self.boundaries,
values=self.values,
last_epoch=self.last_epoch)
if self.warmup_steps > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_steps,
start_lr=self.warmup_start_lr,
end_lr=self.values[0],
if self.by_epoch:
learning_rate = lr.PiecewiseDecay(
boundaries=self.boundaries_epoch,
values=self.values,
last_epoch=self.last_epoch)
if self.warmup_epoch > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=self.warmup_start_lr,
end_lr=self.values[0],
last_epoch=self.last_epoch)
else:
learning_rate = lr.PiecewiseDecay(
boundaries=self.boundaries_steps,
values=self.values,
last_epoch=self.last_epoch)
if self.warmup_steps > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_steps,
start_lr=self.warmup_start_lr,
end_lr=self.values[0],
last_epoch=self.last_epoch)
setattr(learning_rate, "by_epoch", self.by_epoch)
return learning_rate
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册