未验证 提交 3541a80d 编写于 作者: W Wei Shengyu 提交者: GitHub

Merge pull request #2269 from zengshao0622/merge_CAE

Merge CAE
......@@ -69,6 +69,7 @@ from .model_zoo.repvgg import RepVGG_A0, RepVGG_A1, RepVGG_A2, RepVGG_B0, RepVGG
from .model_zoo.van import VAN_tiny
from .model_zoo.peleenet import PeleeNet
from .model_zoo.convnext import ConvNeXt_tiny
from .model_zoo.cae import cae_base_patch16_224, cae_large_patch16_224
from .variant_models.resnet_variant import ResNet50_last_stage_stride1
from .variant_models.vgg_variant import VGG19Sigmoid
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Code was heavily based on https://github.com/PaddlePaddle/VIMER/blob/main/CAE/models/modeling_finetune.py
# reference: https://arxiv.org/abs/2202.03026
import collections
from itertools import repeat
import math
import numpy as np
from functools import partial
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ....utils.download import get_weights_path_from_url
MODEL_URLS = {
"cae_base_patch16_224":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/cae_base_patch16_224_pretrained.pdparams",
"cae_large_patch16_224":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/cae_large_patch16_224_pretrained.pdparams"
}
__all__ = list(MODEL_URLS.keys())
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
def trunc_normal_(tensor, mean=0., std=1.):
nn.initializer.TruncatedNormal(mean=mean, std=std)(tensor)
def drop_path(x, drop_prob: float=0., training: bool=False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0], ) + (1, ) * (
x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
random_tensor.floor_() # binarize
output = x / keep_prob * random_tensor
return output
class DropPath(nn.Layer):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)
class Mlp(nn.Layer):
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias_attr=True)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features, bias_attr=True)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the orignal BERT implement
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Layer):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
window_size=None,
attn_head_dim=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim**-0.5
self.zeros_ = nn.initializer.Constant(value=0.)
self.qkv = nn.Linear(dim, all_head_dim * 3, bias_attr=False)
if qkv_bias:
self.q_bias = self.create_parameter(
[all_head_dim], default_initializer=self.zeros_)
self.v_bias = self.create_parameter(
[all_head_dim], default_initializer=self.zeros_)
else:
self.q_bias = None
self.v_bias = None
if window_size:
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (
2 * window_size[1] - 1) + 3
self.relative_position_bias_table = self.create_parameter(
[self.num_relative_distance, num_heads],
default_initializer=self.zeros_) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = paddle.arange(window_size[0])
coords_w = paddle.arange(window_size[1])
coords = paddle.stack(paddle.meshgrid(
[coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = paddle.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :,
None] - coords_flatten[:,
None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.transpose(
[1, 2, 0]) # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[
0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
paddle.zeros((window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(
-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index",
relative_position_index)
else:
self.window_size = None
self.relative_position_bias_table = None
self.relative_position_index = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim, bias_attr=True)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, rel_pos_bias=None):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
k_bias = paddle.zeros_like(self.v_bias)
k_bias.stop_gradient = True
qkv_bias = paddle.concat((self.q_bias, k_bias, self.v_bias))
# qkv = self.qkv(x).reshape([B, N, 3, self.num_heads, C // self.num_heads]).transpose([2, 0, 3, 1, 4])
qkv = F.linear(x=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape([B, N, 3, self.num_heads, -1]).transpose(
[2, 0, 3, 1, 4])
q, k, v = qkv[0], qkv[1], qkv[
2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @k.transpose([0, 1, 3, 2]))
if self.relative_position_bias_table is not None:
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.reshape([-1])].reshape([
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1]) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.transpose(
[2, 0, 1]) # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if rel_pos_bias is not None:
attn = attn + rel_pos_bias
attn = F.softmax(attn, axis=-1)
attn = self.attn_drop(attn)
x = (attn @v).transpose([0, 2, 1, 3]).reshape([B, N, -1])
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Layer):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
init_values=None,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
window_size=None,
attn_head_dim=None):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
window_size=window_size,
attn_head_dim=attn_head_dim)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
if init_values > 0:
self.gamma_1 = self.create_parameter(
[dim],
default_initializer=nn.initializer.Constant(value=init_values))
self.gamma_2 = self.create_parameter(
[dim],
default_initializer=nn.initializer.Constant(value=init_values))
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x, rel_pos_bias=None):
if self.gamma_1 is None:
x = x + self.drop_path(
self.attn(
self.norm1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(
self.norm1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Layer):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
to_2tuple = _ntuple(2)
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] //
patch_size[0])
self.patch_shape = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.in_chans = in_chans
self.out_chans = embed_dim
self.proj = nn.Conv2D(
in_chans,
embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias_attr=True)
def forward(self, x, **kwargs):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose([0, 2, 1])
return x
def _init_weights(self):
fan_out = self.out_chans
fan_in = self.patch_size[0] * self.patch_size[1] * self.in_chans
weight_attr = paddle.ParamAttr(
initializer=nn.initializer.XavierUniform(fan_in, fan_out)) # MAE
bias_attr = paddle.ParamAttr(initializer=nn.initializer.Constant(0.0))
return weight_attr, bias_attr
class RelativePositionBias(nn.Layer):
def __init__(self, window_size, num_heads):
super().__init__()
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (
2 * window_size[1] - 1) + 3
self.zeros_ = nn.initializer.Constant(value=0.)
self.relative_position_bias_table = self.create_parameter(
[self.num_relative_distance, num_heads],
default_initializer=self.zeros_) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = paddle.arange(window_size[0])
coords_w = paddle.arange(window_size[1])
coords = paddle.stack(paddle.meshgrid(
[coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = paddle.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :,
None] - coords_flatten[:,
None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.transpose(
[1, 2, 0]) # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
paddle.zeros((window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(
-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index",
relative_position_index)
def forward(self):
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.reshape([-1])].reshape([
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1]) # Wh*Ww,Wh*Ww,nH
return relative_position_bias.transpose([2, 0, 1]) # nH, Wh*Ww, Wh*Ww
def get_sinusoid_encoding_table(n_position, d_hid, token=False):
''' Sinusoid position encoding table '''
def get_position_angle_vec(position):
return [
position / np.power(10000, 2 * (hid_j // 2) / d_hid)
for hid_j in range(d_hid)
]
sinusoid_table = np.array(
[get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
if token:
sinusoid_table = np.concatenate(
[sinusoid_table, np.zeros([1, d_hid])], dim=0)
return paddle.to_tensor(sinusoid_table).unsqueeze(0)
class VisionTransformer(nn.Layer):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self,
img_size=224,
patch_size=16,
in_chans=3,
class_num=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
init_values=None,
use_abs_pos_emb=True,
use_rel_pos_bias=False,
use_shared_rel_pos_bias=False,
use_mean_pooling=True,
init_scale=0.001,
lin_probe=False,
sin_pos_emb=True,
args=None):
super().__init__()
self.class_num = class_num
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.use_mean_pooling = use_mean_pooling
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.zeros_ = nn.initializer.Constant(value=0.)
self.ones_ = nn.initializer.Constant(value=1.)
self.cls_token = self.create_parameter(
[1, 1, embed_dim], default_initializer=self.zeros_)
self.use_abs_pos_emb = use_abs_pos_emb
if use_abs_pos_emb:
self.pos_embed = self.create_parameter(
[1, num_patches + 1, embed_dim],
default_initializer=self.zeros_)
elif sin_pos_emb:
# sine-cosine positional embeddings is on the way
self.pos_embed = self.create_parameter(
[1, num_patches + 1, embed_dim],
default_initializer=self.zeros_)
self.pos_embed.set_value(
self.build_2d_sincos_position_embedding(embed_dim))
self.pos_embed.stop_gradient = True # fixed sin-cos embedding
else:
self.pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate)
if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(
window_size=self.patch_embed.patch_shape, num_heads=num_heads)
else:
self.rel_pos_bias = None
dpr = [x.item() for x in paddle.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
self.use_rel_pos_bias = use_rel_pos_bias
self.blocks = nn.LayerList([
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
init_values=init_values,
window_size=self.patch_embed.patch_shape
if use_rel_pos_bias else None) for i in range(depth)
])
self.norm = nn.Identity() if use_mean_pooling else norm_layer(
embed_dim)
self.lin_probe = lin_probe
# NOTE: batch norm
if lin_probe:
# TODO
from models.lincls_bn import LP_BatchNorm
self.fc_norm = LP_BatchNorm(embed_dim, affine=False)
else:
if use_mean_pooling:
self.fc_norm = norm_layer(embed_dim)
else:
self.fc_norm = None
self.head = nn.Linear(embed_dim,
class_num) if class_num > 0 else nn.Identity()
if self.pos_embed is not None and use_abs_pos_emb:
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
# trunc_normal_(self.mask_token, std=.02)
trunc_normal_(self.head.weight, std=.02)
self.apply(self._init_weights)
self.fix_init_weight()
self.head.weight.set_value(self.head.weight * init_scale)
self.head.bias.set_value(self.head.bias * init_scale)
def build_2d_sincos_position_embedding(self,
embed_dim=768,
temperature=10000.):
h, w = self.patch_embed.patch_shape
grid_w = paddle.arange(w, dtype=paddle.float32)
grid_h = paddle.arange(h, dtype=paddle.float32)
grid_w, grid_h = paddle.meshgrid(grid_w, grid_h)
assert embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
pos_dim = embed_dim // 4
omega = paddle.arange(pos_dim, dtype=paddle.float32) / pos_dim
omega = 1. / (temperature**omega)
out_w = paddle.einsum('m,d->md', grid_w.flatten(), omega)
out_h = paddle.einsum('m,d->md', grid_h.flatten(), omega)
pos_emb = paddle.concat(
[
paddle.sin(out_w), paddle.cos(out_w), paddle.sin(out_h),
paddle.cos(out_h)
],
axis=1)[None, :, :]
# if not self.use_mean_pooling:
pe_token = paddle.zeros([1, 1, embed_dim], dtype=paddle.float32)
pos_emb = paddle.concat([pe_token, pos_emb], axis=1)
return pos_emb
def fix_init_weight(self):
def rescale(param, layer_id):
param.set_value(param / math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight, layer_id + 1)
rescale(layer.mlp.fc2.weight, layer_id + 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
self.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
self.zeros_(m.bias)
self.ones_(m.weight)
def get_num_layers(self):
return len(self.blocks)
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def get_classifier(self):
return self.head
def reset_classifier(self, class_num, global_pool=''):
self.class_num = class_num
self.head = nn.Linear(self.embed_dim,
class_num) if class_num > 0 else nn.Identity()
def forward_features(self, x, is_train=True):
x = self.patch_embed(x)
batch_size, seq_len, _ = x.shape
cls_tokens = self.cls_token.expand(
[batch_size, -1,
-1]) # stole cls_tokens impl from Phil Wang, thanks
x = paddle.concat((cls_tokens, x), axis=1)
if self.pos_embed is not None:
if self.use_abs_pos_emb:
x = x + self.pos_embed.expand(
[batch_size, -1, -1]).astype(x.dtype).clone().detach()
else:
x = x + self.pos_embed.expand(
[batch_size, -1, -1]).astype(x.dtype).clone().detach()
x = self.pos_drop(x)
rel_pos_bias = self.rel_pos_bias(
) if self.rel_pos_bias is not None else None
for blk in self.blocks:
x = blk(x, rel_pos_bias=rel_pos_bias)
x = self.norm(x)
if self.fc_norm is not None:
t = x[:, 1:, :]
if self.lin_probe:
if self.use_mean_pooling:
return self.fc_norm(t.mean(1), is_train=is_train)
else:
return self.fc_norm(x[:, 0], is_train=is_train)
else:
return self.fc_norm(t.mean(1))
else:
return x[:, 0]
def forward(self, x, is_train=True):
x = self.forward_features(x, is_train)
x = self.head(x)
return x
def _enable_linear_eval(model):
zeros_ = nn.initializer.Constant(value=0.)
normal_ = nn.initializer.Normal(mean=0.0, std=0.01)
linear_keyword = 'head'
head_norm = 'fc_norm'
requires_grad = []
for name, param in model.named_parameters():
if name not in [
'%s.weight' % linear_keyword, '%s.bias' % linear_keyword
] and head_norm not in name:
param.stop_gradient = True
else:
requires_grad.append(name)
# init the fc layer
normal_(getattr(model, linear_keyword).weight)
zeros_(getattr(model, linear_keyword).bias)
return
def _load_pretrained(pretrained,
pretrained_url,
model,
model_keys,
model_ema_configs,
abs_pos_emb,
rel_pos_bias,
use_ssld=False):
if pretrained is False:
pass
elif pretrained is True:
local_weight_path = get_weights_path_from_url(pretrained_url).replace(
".pdparams", "")
checkpoint = paddle.load(local_weight_path + ".pdparams")
elif isinstance(pretrained, str):
checkpoint = paddle.load(local_weight_path + ".pdparams")
checkpoint_model = None
for model_key in model_keys.split('|'):
if model_key in checkpoint:
checkpoint_model = checkpoint[model_key]
break
if checkpoint_model is None:
checkpoint_model = checkpoint
state_dict = model.state_dict()
all_keys = list(checkpoint_model.keys())
# NOTE: remove all decoder keys
all_keys = [key for key in all_keys if key.startswith('encoder.')]
for key in all_keys:
new_key = key.replace('encoder.', '')
checkpoint_model[new_key] = checkpoint_model[key]
checkpoint_model.pop(key)
for key in list(checkpoint_model.keys()):
if key.startswith('regressor_and_decoder.'):
checkpoint_model.pop(key)
if key.startswith('teacher_network.'):
checkpoint_model.pop(key)
# NOTE: replace norm with fc_norm
for key in list(checkpoint_model.keys()):
if key.startswith('norm.'):
new_key = key.replace('norm.', 'fc_norm.')
checkpoint_model[new_key] = checkpoint_model[key]
checkpoint_model.pop(key)
for k in ['head.weight', 'head.bias']:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[
k].shape:
del checkpoint_model[k]
if model.use_rel_pos_bias and "rel_pos_bias.relative_position_bias_table" in checkpoint_model:
num_layers = model.get_num_layers()
rel_pos_bias = checkpoint_model[
"rel_pos_bias.relative_position_bias_table"]
for i in range(num_layers):
checkpoint_model["blocks.%d.attn.relative_position_bias_table" %
i] = rel_pos_bias.clone()
checkpoint_model.pop("rel_pos_bias.relative_position_bias_table")
all_keys = list(checkpoint_model.keys())
for key in all_keys:
if "relative_position_index" in key:
checkpoint_model.pop(key)
if "relative_position_bias_table" in key and rel_pos_bias:
rel_pos_bias = checkpoint_model[key]
src_num_pos, num_attn_heads = rel_pos_bias.size()
dst_num_pos, _ = model.state_dict()[key].size()
dst_patch_shape = model.patch_embed.patch_shape
if dst_patch_shape[0] != dst_patch_shape[1]:
raise NotImplementedError()
num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (
dst_patch_shape[1] * 2 - 1)
src_size = int((src_num_pos - num_extra_tokens)**0.5)
dst_size = int((dst_num_pos - num_extra_tokens)**0.5)
if src_size != dst_size:
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
def geometric_progression(a, r, n):
return a * (1.0 - r**n) / (1.0 - r)
left, right = 1.01, 1.5
while right - left > 1e-6:
q = (left + right) / 2.0
gp = geometric_progression(1, q, src_size // 2)
if gp > dst_size // 2:
right = q
else:
left = q
dis = []
cur = 1
for i in range(src_size // 2):
dis.append(cur)
cur += q**(i + 1)
r_ids = [-_ for _ in reversed(dis)]
x = r_ids + [0] + dis
y = r_ids + [0] + dis
t = dst_size // 2.0
dx = np.arange(-t, t + 0.1, 1.0)
dy = np.arange(-t, t + 0.1, 1.0)
all_rel_pos_bias = []
for i in range(num_attn_heads):
z = rel_pos_bias[:, i].view(src_size,
src_size).float().numpy()
f = interpolate.interp2d(x, y, z, kind='cubic')
all_rel_pos_bias.append(
paddle.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(
rel_pos_bias.device))
rel_pos_bias = paddle.concat(all_rel_pos_bias, axis=-1)
new_rel_pos_bias = paddle.concat(
(rel_pos_bias, extra_tokens), axis=0)
checkpoint_model[key] = new_rel_pos_bias
# interpolate position embedding
if 'pos_embed' in checkpoint_model and abs_pos_emb:
pos_embed_checkpoint = checkpoint_model['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens)**
0.5)
# height (== width) for the new position embedding
new_size = int(num_patches**0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size,
embedding_size).permute(0, 3, 1, 2)
pos_tokens = paddle.nn.functional.interpolate(
pos_tokens,
size=(new_size, new_size),
mode='bicubic',
align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = paddle.concat((extra_tokens, pos_tokens), axis=1)
checkpoint_model['pos_embed'] = new_pos_embed
msg = model.set_state_dict(checkpoint_model)
model_without_ddp = model
n_parameters = sum(p.numel() for p in model.parameters()
if not p.stop_gradient).item()
return
def cae_base_patch16_224(pretrained=True, use_ssld=False, **kwargs):
config = kwargs.copy()
enable_linear_eval = config.pop('enable_linear_eval')
model_keys = config.pop('model_key')
model_ema_configs = config.pop('model_ema')
abs_pos_emb = config.pop('abs_pos_emb')
rel_pos_bias = config.pop('rel_pos_bias')
if pretrained in config:
pretrained = config.pop('pretrained')
model = VisionTransformer(
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(
nn.LayerNorm, epsilon=1e-6),
**config)
if enable_linear_eval:
_enable_linear_eval(model)
_load_pretrained(
pretrained,
MODEL_URLS["cae_base_patch16_224"],
model,
model_keys,
model_ema_configs,
abs_pos_emb,
rel_pos_bias,
use_ssld=False)
return model
def cae_large_patch16_224(pretrained=True, use_ssld=False, **kwargs):
config = kwargs.copy()
enable_linear_eval = config.pop('enable_linear_eval')
model_keys = config.pop('model_key')
model_ema_configs = config.pop('model_ema')
abs_pos_emb = config.pop('abs_pos_emb')
rel_pos_bias = config.pop('rel_pos_bias')
if pretrained in config:
pretrained = config.pop('pretrained')
model = VisionTransformer(
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(
nn.LayerNorm, epsilon=1e-6),
**kwargs)
if enable_linear_eval:
_enable_linear_eval(model)
_load_pretrained(
pretrained,
MODEL_URLS["cae_large_patch16_224"],
model,
model_keys,
model_ema_configs,
abs_pos_emb,
rel_pos_bias,
use_ssld=False)
return model
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 20
eval_during_train: True
eval_interval: 1
epochs: 100
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# model architecture
Arch:
name: cae_base_patch16_224
class_num: 102
drop_rate: 0.0
drop_path_rate: 0.1
attn_drop_rate: 0.0
use_mean_pooling: True
init_scale: 0.001
use_rel_pos_bias: True
use_abs_pos_emb: False
init_values: 0.1
lin_probe: False
sin_pos_emb: True
abs_pos_emb: False
enable_linear_eval: False
model_key: model|module|state_dict
rel_pos_bias: True
model_ema:
enable_model_ema: False
model_ema_decay: 0.9999
model_ema_force_cpu: False
pretrained: True
# loss function config for traing/eval process
Loss:
Train:
- SoftTargetCrossEntropy:
weight: 1.0
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: AdamWDL
beta1: 0.9
beta2: 0.999
epsilon: 1e-8
weight_decay: 0.05
layerwise_decay: 0.65
lr:
name: Cosine
learning_rate: 0.001
eta_min: 1e-6
warmup_epoch: 10
warmup_start_lr: 1e-6
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/flowers102/
cls_label_path: ./dataset/flowers102/train_list.txt
batch_transform_ops:
- MixupCutmixHybrid:
mixup_alpha: 0.8
cutmix_alpha: 1.0
switch_prob: 0.5
num_classes: 102
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
interpolation: bilinear
- RandFlipImage:
flip_code: 1
- RandAugment:
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- RandomErasing:
EPSILON: 0.5
sl: 0.02
sh: 0.3
r1: 0.3
sampler:
name: DistributedBatchSampler
batch_size: 16
drop_last: True
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/flowers102/
cls_label_path: ./dataset/flowers102/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 16
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
topk: [1, 5]
......@@ -42,6 +42,7 @@ from ppcls.data.preprocess.ops.operators import RandomRotation
from ppcls.data.preprocess.ops.operators import Padv2
from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator
from ppcls.data.preprocess.batch_ops.batch_operators import MixupCutmixHybrid
import numpy as np
from PIL import Image
......
......@@ -23,6 +23,9 @@ import numpy as np
from ppcls.utils import logger
from ppcls.data.preprocess.ops.fmix import sample_mask
import paddle
import paddle.nn.functional as F
class BatchOperator(object):
""" BatchOperator """
......@@ -229,3 +232,270 @@ class OpSampler(object):
list(self.ops.keys()), weights=list(self.ops.values()), k=1)[0]
# return batch directly when None Op
return op(batch) if op else batch
class MixupCutmixHybrid(object):
""" Mixup/Cutmix that applies different params to each element or whole batch
Args:
mixup_alpha (float): mixup alpha value, mixup is active if > 0.
cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
prob (float): probability of applying mixup or cutmix per batch or element
switch_prob (float): probability of switching to cutmix instead of mixup when both are active
mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
label_smoothing (float): apply label smoothing to the mixed target tensor
num_classes (int): number of classes for target
"""
def __init__(self,
mixup_alpha=1.,
cutmix_alpha=0.,
cutmix_minmax=None,
prob=1.0,
switch_prob=0.5,
mode='batch',
correct_lam=True,
label_smoothing=0.1,
num_classes=4):
self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha
self.cutmix_minmax = cutmix_minmax
if self.cutmix_minmax is not None:
assert len(self.cutmix_minmax) == 2
# force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
self.cutmix_alpha = 1.0
self.mix_prob = prob
self.switch_prob = switch_prob
self.label_smoothing = label_smoothing
self.num_classes = num_classes
self.mode = mode
self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
def _one_hot(self, x, num_classes, on_value=1., off_value=0.):
x = paddle.cast(x, dtype='int64')
on_value = paddle.full([x.shape[0], num_classes], on_value)
off_value = paddle.full([x.shape[0], num_classes], off_value)
return paddle.where(
F.one_hot(x, num_classes) == 1, on_value, off_value)
def _mixup_target(self, target, num_classes, lam=1., smoothing=0.0):
off_value = smoothing / num_classes
on_value = 1. - smoothing + off_value
y1 = self._one_hot(
target,
num_classes,
on_value=on_value,
off_value=off_value, )
y2 = self._one_hot(
target.flip(0),
num_classes,
on_value=on_value,
off_value=off_value)
return y1 * lam + y2 * (1. - lam)
def _rand_bbox(self, img_shape, lam, margin=0., count=None):
""" Standard CutMix bounding-box
Generates a random square bbox based on lambda value. This impl includes
support for enforcing a border margin as percent of bbox dimensions.
Args:
img_shape (tuple): Image shape as tuple
lam (float): Cutmix lambda value
margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
count (int): Number of bbox to generate
"""
ratio = np.sqrt(1 - lam)
img_h, img_w = img_shape[-2:]
cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
yl = np.clip(cy - cut_h // 2, 0, img_h)
yh = np.clip(cy + cut_h // 2, 0, img_h)
xl = np.clip(cx - cut_w // 2, 0, img_w)
xh = np.clip(cx + cut_w // 2, 0, img_w)
return yl, yh, xl, xh
def _rand_bbox_minmax(self, img_shape, minmax, count=None):
""" Min-Max CutMix bounding-box
Inspired by Darknet cutmix impl, generates a random rectangular bbox
based on min/max percent values applied to each dimension of the input image.
Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
Args:
img_shape (tuple): Image shape as tuple
minmax (tuple or list): Min and max bbox ratios (as percent of image size)
count (int): Number of bbox to generate
"""
assert len(minmax) == 2
img_h, img_w = img_shape[-2:]
cut_h = np.random.randint(
int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
cut_w = np.random.randint(
int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
yl = np.random.randint(0, img_h - cut_h, size=count)
xl = np.random.randint(0, img_w - cut_w, size=count)
yu = yl + cut_h
xu = xl + cut_w
return yl, yu, xl, xu
def _cutmix_bbox_and_lam(self,
img_shape,
lam,
ratio_minmax=None,
correct_lam=True,
count=None):
""" Generate bbox and apply lambda correction.
"""
if ratio_minmax is not None:
yl, yu, xl, xu = self._rand_bbox_minmax(
img_shape, ratio_minmax, count=count)
else:
yl, yu, xl, xu = self._rand_bbox(img_shape, lam, count=count)
if correct_lam or ratio_minmax is not None:
bbox_area = (yu - yl) * (xu - xl)
lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
return (yl, yu, xl, xu), lam
def _params_per_elem(self, batch_size):
lam = np.ones(batch_size, dtype=np.float32)
use_cutmix = np.zeros(batch_size, dtype=np.bool)
if self.mixup_enabled:
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
use_cutmix = np.random.rand(batch_size) < self.switch_prob
lam_mix = np.where(
use_cutmix,
np.random.beta(
self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
np.random.beta(
self.mixup_alpha, self.mixup_alpha, size=batch_size))
elif self.mixup_alpha > 0.:
lam_mix = np.random.beta(
self.mixup_alpha, self.mixup_alpha, size=batch_size)
elif self.cutmix_alpha > 0.:
use_cutmix = np.ones(batch_size, dtype=np.bool)
lam_mix = np.random.beta(
self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
else:
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
lam = np.where(
np.random.rand(batch_size) < self.mix_prob,
lam_mix.astype(np.float32), lam)
return lam, use_cutmix
def _params_per_batch(self):
lam = 1.
use_cutmix = False
if self.mixup_enabled and np.random.rand() < self.mix_prob:
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
use_cutmix = np.random.rand() < self.switch_prob
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
np.random.beta(self.mixup_alpha, self.mixup_alpha)
elif self.mixup_alpha > 0.:
lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
elif self.cutmix_alpha > 0.:
use_cutmix = True
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
else:
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
lam = float(lam_mix)
return lam, use_cutmix
def _mix_elem(self, x):
batch_size = len(x)
lam_batch, use_cutmix = self._params_per_elem(batch_size)
x_orig = x.clone(
) # need to keep an unmodified original for mixing source
for i in range(batch_size):
j = batch_size - i - 1
lam = lam_batch[i]
if lam != 1.:
if use_cutmix[i]:
(yl, yh, xl, xh), lam = self._cutmix_bbox_and_lam(
x[i].shape,
lam,
ratio_minmax=self.cutmix_minmax,
correct_lam=self.correct_lam)
if yl < yh and xl < xh:
x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
lam_batch[i] = lam
else:
x[i] = x[i] * lam + x_orig[j] * (1 - lam)
return paddle.to_tensor(lam_batch, dtype=x.dtype).unsqueeze(1)
def _mix_pair(self, x):
batch_size = len(x)
lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
x_orig = x.clone(
) # need to keep an unmodified original for mixing source
for i in range(batch_size // 2):
j = batch_size - i - 1
lam = lam_batch[i]
if lam != 1.:
if use_cutmix[i]:
(yl, yh, xl, xh), lam = self._cutmix_bbox_and_lam(
x[i].shape,
lam,
ratio_minmax=self.cutmix_minmax,
correct_lam=self.correct_lam)
if yl < yh and xl < xh:
x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh]
lam_batch[i] = lam
else:
x[i] = x[i] * lam + x_orig[j] * (1 - lam)
x[j] = x[j] * lam + x_orig[i] * (1 - lam)
lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
return paddle.to_tensor(lam_batch, dtype=x.dtype).unsqueeze(1)
def _mix_batch(self, x):
lam, use_cutmix = self._params_per_batch()
if lam == 1.:
return 1.
if use_cutmix:
(yl, yh, xl, xh), lam = self._cutmix_bbox_and_lam(
x.shape,
lam,
ratio_minmax=self.cutmix_minmax,
correct_lam=self.correct_lam)
if yl < yh and xl < xh:
x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh]
else:
x_flipped = x.flip(0) * (1. - lam)
x[:] = x * lam + x_flipped
return lam
def _unpack(self, batch):
""" _unpack """
assert isinstance(batch, list), \
'batch should be a list filled with tuples (img, label)'
bs = len(batch)
assert bs > 0, 'size of the batch data should > 0'
#imgs, labels = list(zip(*batch))
imgs = []
labels = []
for item in batch:
imgs.append(item[0])
labels.append(item[1])
return np.array(imgs), np.array(labels), bs
def __call__(self, batch):
x, target, bs = self._unpack(batch)
x = paddle.to_tensor(x)
target = paddle.to_tensor(target)
assert len(x) % 2 == 0, 'Batch size should be even when using this'
if self.mode == 'elem':
lam = self._mix_elem(x)
elif self.mode == 'pair':
lam = self._mix_pair(x)
else:
lam = self._mix_batch(x)
target = self._mixup_target(target, self.num_classes, lam,
self.label_smoothing)
return list(zip(x.numpy(), target.numpy()))
......@@ -17,6 +17,7 @@ from .supconloss import SupConLoss
from .pairwisecosface import PairwiseCosface
from .dmlloss import DMLLoss
from .distanceloss import DistanceLoss
from .softtargetceloss import SoftTargetCrossEntropy
from .distillationloss import DistillationCELoss
from .distillationloss import DistillationGTCELoss
......
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class SoftTargetCrossEntropy(nn.Layer):
def __init__(self):
super().__init__()
def forward(self, x, target):
loss = paddle.sum(-target * F.log_softmax(x, axis=-1), axis=-1)
loss = loss.mean()
return {"SoftTargetCELoss": loss}
def __str__(self, ):
return type(self).__name__
......@@ -272,3 +272,145 @@ class AdamW(object):
def _apply_decay_param_fun(self, name):
return name not in self.no_weight_decay_param_name_list
class AdamWDL(object):
"""
The AdamWDL optimizer is implemented based on the AdamW Optimization with dynamic lr setting.
Generally it's used for transformer model.
"""
def __init__(self,
learning_rate=0.001,
beta1=0.9,
beta2=0.999,
epsilon=1e-8,
weight_decay=None,
multi_precision=False,
grad_clip=None,
layerwise_decay=None,
filter_bias_and_bn=True,
**args):
self.learning_rate = learning_rate
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.grad_clip = grad_clip
self.weight_decay = weight_decay
self.multi_precision = multi_precision
self.layerwise_decay = layerwise_decay
self.filter_bias_and_bn = filter_bias_and_bn
class AdamWDLImpl(optim.AdamW):
def __init__(self,
learning_rate=0.001,
beta1=0.9,
beta2=0.999,
epsilon=1e-8,
parameters=None,
weight_decay=0.01,
apply_decay_param_fun=None,
grad_clip=None,
lazy_mode=False,
multi_precision=False,
layerwise_decay=1.0,
n_layers=12,
name_dict=None,
name=None):
if not isinstance(layerwise_decay, float) and \
not isinstance(layerwise_decay, fluid.framework.Variable):
raise TypeError("coeff should be float or Tensor.")
self.layerwise_decay = layerwise_decay
self.name_dict = name_dict
self.n_layers = n_layers
self.set_param_lr_fun = self._layerwise_lr_decay
super().__init__(
learning_rate=learning_rate,
parameters=parameters,
beta1=beta1,
beta2=beta2,
epsilon=epsilon,
grad_clip=grad_clip,
name=name,
apply_decay_param_fun=apply_decay_param_fun,
weight_decay=weight_decay,
lazy_mode=lazy_mode,
multi_precision=multi_precision)
def _append_optimize_op(self, block, param_and_grad):
if self.set_param_lr_fun is None:
return super(AdamLW, self)._append_optimize_op(block,
param_and_grad)
self._append_decoupled_weight_decay(block, param_and_grad)
prev_lr = param_and_grad[0].optimize_attr["learning_rate"]
self.set_param_lr_fun(self.layerwise_decay, self.name_dict,
self.n_layers, param_and_grad[0])
# excute Adam op
res = super(optim.AdamW, self)._append_optimize_op(block,
param_and_grad)
param_and_grad[0].optimize_attr["learning_rate"] = prev_lr
return res
# Layerwise decay
def _layerwise_lr_decay(self, decay_rate, name_dict, n_layers, param):
"""
Args:
decay_rate (float):
The layer-wise decay ratio.
name_dict (dict):
The keys of name_dict is dynamic name of model while the value
of name_dict is static name.
Use model.named_parameters() to get name_dict.
n_layers (int):
Total number of layers in the transformer encoder.
"""
ratio = 1.0
static_name = name_dict[param.name]
if "blocks" in static_name:
idx = static_name.find("blocks.")
layer = int(static_name[idx:].split(".")[1])
ratio = decay_rate**(n_layers - layer)
elif "embed" in static_name:
ratio = decay_rate**(n_layers + 1)
param.optimize_attr["learning_rate"] *= ratio
def __call__(self, model_list):
model = model_list[0]
if self.weight_decay and self.filter_bias_and_bn:
skip = {}
if hasattr(model, 'no_weight_decay'):
skip = model.no_weight_decay()
decay_dict = {
param.name: not (len(param.shape) == 1 or
name.endswith(".bias") or name in skip)
for name, param in model.named_parameters()
if not 'teacher' in name
}
parameters = [
param for param in model.parameters()
if 'teacher' not in param.name
]
weight_decay = 0.
else:
parameters = model.parameters()
opt_args = dict(
learning_rate=self.learning_rate, weight_decay=self.weight_decay)
opt_args['parameters'] = parameters
if decay_dict is not None:
opt_args['apply_decay_param_fun'] = lambda n: decay_dict[n]
opt_args['epsilon'] = self.epsilon
opt_args['beta1'] = self.beta1
opt_args['beta2'] = self.beta2
if self.layerwise_decay and self.layerwise_decay < 1.0:
opt_args['layerwise_decay'] = self.layerwise_decay
name_dict = dict()
for n, p in model.named_parameters():
name_dict[p.name] = n
opt_args['name_dict'] = name_dict
opt_args['n_layers'] = model.get_num_layers()
optimizer = self.AdamWDLImpl(**opt_args)
return optimizer
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册