未验证 提交 3541a80d 编写于 作者: W Wei Shengyu 提交者: GitHub

Merge pull request #2269 from zengshao0622/merge_CAE

Merge CAE
......@@ -69,6 +69,7 @@ from .model_zoo.repvgg import RepVGG_A0, RepVGG_A1, RepVGG_A2, RepVGG_B0, RepVGG
from .model_zoo.van import VAN_tiny
from .model_zoo.peleenet import PeleeNet
from .model_zoo.convnext import ConvNeXt_tiny
from .model_zoo.cae import cae_base_patch16_224, cae_large_patch16_224
from .variant_models.resnet_variant import ResNet50_last_stage_stride1
from .variant_models.vgg_variant import VGG19Sigmoid
......
此差异已折叠。
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 20
eval_during_train: True
eval_interval: 1
epochs: 100
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# model architecture
Arch:
name: cae_base_patch16_224
class_num: 102
drop_rate: 0.0
drop_path_rate: 0.1
attn_drop_rate: 0.0
use_mean_pooling: True
init_scale: 0.001
use_rel_pos_bias: True
use_abs_pos_emb: False
init_values: 0.1
lin_probe: False
sin_pos_emb: True
abs_pos_emb: False
enable_linear_eval: False
model_key: model|module|state_dict
rel_pos_bias: True
model_ema:
enable_model_ema: False
model_ema_decay: 0.9999
model_ema_force_cpu: False
pretrained: True
# loss function config for traing/eval process
Loss:
Train:
- SoftTargetCrossEntropy:
weight: 1.0
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: AdamWDL
beta1: 0.9
beta2: 0.999
epsilon: 1e-8
weight_decay: 0.05
layerwise_decay: 0.65
lr:
name: Cosine
learning_rate: 0.001
eta_min: 1e-6
warmup_epoch: 10
warmup_start_lr: 1e-6
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/flowers102/
cls_label_path: ./dataset/flowers102/train_list.txt
batch_transform_ops:
- MixupCutmixHybrid:
mixup_alpha: 0.8
cutmix_alpha: 1.0
switch_prob: 0.5
num_classes: 102
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
interpolation: bilinear
- RandFlipImage:
flip_code: 1
- RandAugment:
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- RandomErasing:
EPSILON: 0.5
sl: 0.02
sh: 0.3
r1: 0.3
sampler:
name: DistributedBatchSampler
batch_size: 16
drop_last: True
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/flowers102/
cls_label_path: ./dataset/flowers102/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 16
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
topk: [1, 5]
......@@ -42,6 +42,7 @@ from ppcls.data.preprocess.ops.operators import RandomRotation
from ppcls.data.preprocess.ops.operators import Padv2
from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator
from ppcls.data.preprocess.batch_ops.batch_operators import MixupCutmixHybrid
import numpy as np
from PIL import Image
......
......@@ -23,6 +23,9 @@ import numpy as np
from ppcls.utils import logger
from ppcls.data.preprocess.ops.fmix import sample_mask
import paddle
import paddle.nn.functional as F
class BatchOperator(object):
""" BatchOperator """
......@@ -229,3 +232,270 @@ class OpSampler(object):
list(self.ops.keys()), weights=list(self.ops.values()), k=1)[0]
# return batch directly when None Op
return op(batch) if op else batch
class MixupCutmixHybrid(object):
""" Mixup/Cutmix that applies different params to each element or whole batch
Args:
mixup_alpha (float): mixup alpha value, mixup is active if > 0.
cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
prob (float): probability of applying mixup or cutmix per batch or element
switch_prob (float): probability of switching to cutmix instead of mixup when both are active
mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
label_smoothing (float): apply label smoothing to the mixed target tensor
num_classes (int): number of classes for target
"""
def __init__(self,
mixup_alpha=1.,
cutmix_alpha=0.,
cutmix_minmax=None,
prob=1.0,
switch_prob=0.5,
mode='batch',
correct_lam=True,
label_smoothing=0.1,
num_classes=4):
self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha
self.cutmix_minmax = cutmix_minmax
if self.cutmix_minmax is not None:
assert len(self.cutmix_minmax) == 2
# force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
self.cutmix_alpha = 1.0
self.mix_prob = prob
self.switch_prob = switch_prob
self.label_smoothing = label_smoothing
self.num_classes = num_classes
self.mode = mode
self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
def _one_hot(self, x, num_classes, on_value=1., off_value=0.):
x = paddle.cast(x, dtype='int64')
on_value = paddle.full([x.shape[0], num_classes], on_value)
off_value = paddle.full([x.shape[0], num_classes], off_value)
return paddle.where(
F.one_hot(x, num_classes) == 1, on_value, off_value)
def _mixup_target(self, target, num_classes, lam=1., smoothing=0.0):
off_value = smoothing / num_classes
on_value = 1. - smoothing + off_value
y1 = self._one_hot(
target,
num_classes,
on_value=on_value,
off_value=off_value, )
y2 = self._one_hot(
target.flip(0),
num_classes,
on_value=on_value,
off_value=off_value)
return y1 * lam + y2 * (1. - lam)
def _rand_bbox(self, img_shape, lam, margin=0., count=None):
""" Standard CutMix bounding-box
Generates a random square bbox based on lambda value. This impl includes
support for enforcing a border margin as percent of bbox dimensions.
Args:
img_shape (tuple): Image shape as tuple
lam (float): Cutmix lambda value
margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
count (int): Number of bbox to generate
"""
ratio = np.sqrt(1 - lam)
img_h, img_w = img_shape[-2:]
cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
yl = np.clip(cy - cut_h // 2, 0, img_h)
yh = np.clip(cy + cut_h // 2, 0, img_h)
xl = np.clip(cx - cut_w // 2, 0, img_w)
xh = np.clip(cx + cut_w // 2, 0, img_w)
return yl, yh, xl, xh
def _rand_bbox_minmax(self, img_shape, minmax, count=None):
""" Min-Max CutMix bounding-box
Inspired by Darknet cutmix impl, generates a random rectangular bbox
based on min/max percent values applied to each dimension of the input image.
Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
Args:
img_shape (tuple): Image shape as tuple
minmax (tuple or list): Min and max bbox ratios (as percent of image size)
count (int): Number of bbox to generate
"""
assert len(minmax) == 2
img_h, img_w = img_shape[-2:]
cut_h = np.random.randint(
int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
cut_w = np.random.randint(
int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
yl = np.random.randint(0, img_h - cut_h, size=count)
xl = np.random.randint(0, img_w - cut_w, size=count)
yu = yl + cut_h
xu = xl + cut_w
return yl, yu, xl, xu
def _cutmix_bbox_and_lam(self,
img_shape,
lam,
ratio_minmax=None,
correct_lam=True,
count=None):
""" Generate bbox and apply lambda correction.
"""
if ratio_minmax is not None:
yl, yu, xl, xu = self._rand_bbox_minmax(
img_shape, ratio_minmax, count=count)
else:
yl, yu, xl, xu = self._rand_bbox(img_shape, lam, count=count)
if correct_lam or ratio_minmax is not None:
bbox_area = (yu - yl) * (xu - xl)
lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
return (yl, yu, xl, xu), lam
def _params_per_elem(self, batch_size):
lam = np.ones(batch_size, dtype=np.float32)
use_cutmix = np.zeros(batch_size, dtype=np.bool)
if self.mixup_enabled:
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
use_cutmix = np.random.rand(batch_size) < self.switch_prob
lam_mix = np.where(
use_cutmix,
np.random.beta(
self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
np.random.beta(
self.mixup_alpha, self.mixup_alpha, size=batch_size))
elif self.mixup_alpha > 0.:
lam_mix = np.random.beta(
self.mixup_alpha, self.mixup_alpha, size=batch_size)
elif self.cutmix_alpha > 0.:
use_cutmix = np.ones(batch_size, dtype=np.bool)
lam_mix = np.random.beta(
self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
else:
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
lam = np.where(
np.random.rand(batch_size) < self.mix_prob,
lam_mix.astype(np.float32), lam)
return lam, use_cutmix
def _params_per_batch(self):
lam = 1.
use_cutmix = False
if self.mixup_enabled and np.random.rand() < self.mix_prob:
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
use_cutmix = np.random.rand() < self.switch_prob
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
np.random.beta(self.mixup_alpha, self.mixup_alpha)
elif self.mixup_alpha > 0.:
lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
elif self.cutmix_alpha > 0.:
use_cutmix = True
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
else:
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
lam = float(lam_mix)
return lam, use_cutmix
def _mix_elem(self, x):
batch_size = len(x)
lam_batch, use_cutmix = self._params_per_elem(batch_size)
x_orig = x.clone(
) # need to keep an unmodified original for mixing source
for i in range(batch_size):
j = batch_size - i - 1
lam = lam_batch[i]
if lam != 1.:
if use_cutmix[i]:
(yl, yh, xl, xh), lam = self._cutmix_bbox_and_lam(
x[i].shape,
lam,
ratio_minmax=self.cutmix_minmax,
correct_lam=self.correct_lam)
if yl < yh and xl < xh:
x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
lam_batch[i] = lam
else:
x[i] = x[i] * lam + x_orig[j] * (1 - lam)
return paddle.to_tensor(lam_batch, dtype=x.dtype).unsqueeze(1)
def _mix_pair(self, x):
batch_size = len(x)
lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
x_orig = x.clone(
) # need to keep an unmodified original for mixing source
for i in range(batch_size // 2):
j = batch_size - i - 1
lam = lam_batch[i]
if lam != 1.:
if use_cutmix[i]:
(yl, yh, xl, xh), lam = self._cutmix_bbox_and_lam(
x[i].shape,
lam,
ratio_minmax=self.cutmix_minmax,
correct_lam=self.correct_lam)
if yl < yh and xl < xh:
x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh]
lam_batch[i] = lam
else:
x[i] = x[i] * lam + x_orig[j] * (1 - lam)
x[j] = x[j] * lam + x_orig[i] * (1 - lam)
lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
return paddle.to_tensor(lam_batch, dtype=x.dtype).unsqueeze(1)
def _mix_batch(self, x):
lam, use_cutmix = self._params_per_batch()
if lam == 1.:
return 1.
if use_cutmix:
(yl, yh, xl, xh), lam = self._cutmix_bbox_and_lam(
x.shape,
lam,
ratio_minmax=self.cutmix_minmax,
correct_lam=self.correct_lam)
if yl < yh and xl < xh:
x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh]
else:
x_flipped = x.flip(0) * (1. - lam)
x[:] = x * lam + x_flipped
return lam
def _unpack(self, batch):
""" _unpack """
assert isinstance(batch, list), \
'batch should be a list filled with tuples (img, label)'
bs = len(batch)
assert bs > 0, 'size of the batch data should > 0'
#imgs, labels = list(zip(*batch))
imgs = []
labels = []
for item in batch:
imgs.append(item[0])
labels.append(item[1])
return np.array(imgs), np.array(labels), bs
def __call__(self, batch):
x, target, bs = self._unpack(batch)
x = paddle.to_tensor(x)
target = paddle.to_tensor(target)
assert len(x) % 2 == 0, 'Batch size should be even when using this'
if self.mode == 'elem':
lam = self._mix_elem(x)
elif self.mode == 'pair':
lam = self._mix_pair(x)
else:
lam = self._mix_batch(x)
target = self._mixup_target(target, self.num_classes, lam,
self.label_smoothing)
return list(zip(x.numpy(), target.numpy()))
......@@ -17,6 +17,7 @@ from .supconloss import SupConLoss
from .pairwisecosface import PairwiseCosface
from .dmlloss import DMLLoss
from .distanceloss import DistanceLoss
from .softtargetceloss import SoftTargetCrossEntropy
from .distillationloss import DistillationCELoss
from .distillationloss import DistillationGTCELoss
......
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class SoftTargetCrossEntropy(nn.Layer):
def __init__(self):
super().__init__()
def forward(self, x, target):
loss = paddle.sum(-target * F.log_softmax(x, axis=-1), axis=-1)
loss = loss.mean()
return {"SoftTargetCELoss": loss}
def __str__(self, ):
return type(self).__name__
......@@ -272,3 +272,145 @@ class AdamW(object):
def _apply_decay_param_fun(self, name):
return name not in self.no_weight_decay_param_name_list
class AdamWDL(object):
"""
The AdamWDL optimizer is implemented based on the AdamW Optimization with dynamic lr setting.
Generally it's used for transformer model.
"""
def __init__(self,
learning_rate=0.001,
beta1=0.9,
beta2=0.999,
epsilon=1e-8,
weight_decay=None,
multi_precision=False,
grad_clip=None,
layerwise_decay=None,
filter_bias_and_bn=True,
**args):
self.learning_rate = learning_rate
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.grad_clip = grad_clip
self.weight_decay = weight_decay
self.multi_precision = multi_precision
self.layerwise_decay = layerwise_decay
self.filter_bias_and_bn = filter_bias_and_bn
class AdamWDLImpl(optim.AdamW):
def __init__(self,
learning_rate=0.001,
beta1=0.9,
beta2=0.999,
epsilon=1e-8,
parameters=None,
weight_decay=0.01,
apply_decay_param_fun=None,
grad_clip=None,
lazy_mode=False,
multi_precision=False,
layerwise_decay=1.0,
n_layers=12,
name_dict=None,
name=None):
if not isinstance(layerwise_decay, float) and \
not isinstance(layerwise_decay, fluid.framework.Variable):
raise TypeError("coeff should be float or Tensor.")
self.layerwise_decay = layerwise_decay
self.name_dict = name_dict
self.n_layers = n_layers
self.set_param_lr_fun = self._layerwise_lr_decay
super().__init__(
learning_rate=learning_rate,
parameters=parameters,
beta1=beta1,
beta2=beta2,
epsilon=epsilon,
grad_clip=grad_clip,
name=name,
apply_decay_param_fun=apply_decay_param_fun,
weight_decay=weight_decay,
lazy_mode=lazy_mode,
multi_precision=multi_precision)
def _append_optimize_op(self, block, param_and_grad):
if self.set_param_lr_fun is None:
return super(AdamLW, self)._append_optimize_op(block,
param_and_grad)
self._append_decoupled_weight_decay(block, param_and_grad)
prev_lr = param_and_grad[0].optimize_attr["learning_rate"]
self.set_param_lr_fun(self.layerwise_decay, self.name_dict,
self.n_layers, param_and_grad[0])
# excute Adam op
res = super(optim.AdamW, self)._append_optimize_op(block,
param_and_grad)
param_and_grad[0].optimize_attr["learning_rate"] = prev_lr
return res
# Layerwise decay
def _layerwise_lr_decay(self, decay_rate, name_dict, n_layers, param):
"""
Args:
decay_rate (float):
The layer-wise decay ratio.
name_dict (dict):
The keys of name_dict is dynamic name of model while the value
of name_dict is static name.
Use model.named_parameters() to get name_dict.
n_layers (int):
Total number of layers in the transformer encoder.
"""
ratio = 1.0
static_name = name_dict[param.name]
if "blocks" in static_name:
idx = static_name.find("blocks.")
layer = int(static_name[idx:].split(".")[1])
ratio = decay_rate**(n_layers - layer)
elif "embed" in static_name:
ratio = decay_rate**(n_layers + 1)
param.optimize_attr["learning_rate"] *= ratio
def __call__(self, model_list):
model = model_list[0]
if self.weight_decay and self.filter_bias_and_bn:
skip = {}
if hasattr(model, 'no_weight_decay'):
skip = model.no_weight_decay()
decay_dict = {
param.name: not (len(param.shape) == 1 or
name.endswith(".bias") or name in skip)
for name, param in model.named_parameters()
if not 'teacher' in name
}
parameters = [
param for param in model.parameters()
if 'teacher' not in param.name
]
weight_decay = 0.
else:
parameters = model.parameters()
opt_args = dict(
learning_rate=self.learning_rate, weight_decay=self.weight_decay)
opt_args['parameters'] = parameters
if decay_dict is not None:
opt_args['apply_decay_param_fun'] = lambda n: decay_dict[n]
opt_args['epsilon'] = self.epsilon
opt_args['beta1'] = self.beta1
opt_args['beta2'] = self.beta2
if self.layerwise_decay and self.layerwise_decay < 1.0:
opt_args['layerwise_decay'] = self.layerwise_decay
name_dict = dict()
for n, p in model.named_parameters():
name_dict[p.name] = n
opt_args['name_dict'] = name_dict
opt_args['n_layers'] = model.get_num_layers()
optimizer = self.AdamWDLImpl(**opt_args)
return optimizer
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册