未验证 提交 8693e1d0 编写于 作者: W Wenyu 提交者: GitHub

AdamWDL builder and VIT checkpoint functionality (#6232)

* upgrade adawm to adawmdl for transoformer

* rename arg
上级 bf895541
...@@ -340,12 +340,15 @@ class VisionTransformer(nn.Layer): ...@@ -340,12 +340,15 @@ class VisionTransformer(nn.Layer):
use_abs_pos_emb=False, use_abs_pos_emb=False,
use_sincos_pos_emb=True, use_sincos_pos_emb=True,
with_fpn=True, with_fpn=True,
use_checkpoint=False,
**args): **args):
super().__init__() super().__init__()
self.img_size = img_size self.img_size = img_size
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.with_fpn = with_fpn self.with_fpn = with_fpn
self.use_checkpoint = use_checkpoint
if use_checkpoint:
print('please set: FLAGS_allocator_strategy=naive_best_fit')
self.patch_embed = PatchEmbed( self.patch_embed = PatchEmbed(
img_size=img_size, img_size=img_size,
patch_size=patch_size, patch_size=patch_size,
...@@ -575,7 +578,7 @@ class VisionTransformer(nn.Layer): ...@@ -575,7 +578,7 @@ class VisionTransformer(nn.Layer):
def forward(self, x): def forward(self, x):
x = x['image'] if isinstance(x, dict) else x x = x['image'] if isinstance(x, dict) else x
_, _, w, h = x.shape _, _, h, w = x.shape
x = self.patch_embed(x) x = self.patch_embed(x)
...@@ -586,7 +589,8 @@ class VisionTransformer(nn.Layer): ...@@ -586,7 +589,8 @@ class VisionTransformer(nn.Layer):
x = paddle.concat([cls_tokens, x], axis=1) x = paddle.concat([cls_tokens, x], axis=1)
if self.pos_embed is not None: if self.pos_embed is not None:
x = x + self.interpolate_pos_encoding(x, w, h) # x = x + self.interpolate_pos_encoding(x, w, h)
x = x + self.interpolate_pos_encoding(x, h, w)
x = self.pos_drop(x) x = self.pos_drop(x)
...@@ -597,7 +601,12 @@ class VisionTransformer(nn.Layer): ...@@ -597,7 +601,12 @@ class VisionTransformer(nn.Layer):
feats = [] feats = []
for idx, blk in enumerate(self.blocks): for idx, blk in enumerate(self.blocks):
if self.use_checkpoint:
x = paddle.distributed.fleet.utils.recompute(
blk, x, rel_pos_bias, **{"preserve_rng_state": True})
else:
x = blk(x, rel_pos_bias) x = blk(x, rel_pos_bias)
if idx in self.out_indices: if idx in self.out_indices:
xp = paddle.reshape( xp = paddle.reshape(
paddle.transpose( paddle.transpose(
......
...@@ -18,6 +18,7 @@ from __future__ import print_function ...@@ -18,6 +18,7 @@ from __future__ import print_function
from paddle.optimizer import AdamW from paddle.optimizer import AdamW
from functools import partial from functools import partial
import re
def layerwise_lr_decay(decay_rate, name_dict, n_layers, param): def layerwise_lr_decay(decay_rate, name_dict, n_layers, param):
...@@ -34,15 +35,20 @@ def layerwise_lr_decay(decay_rate, name_dict, n_layers, param): ...@@ -34,15 +35,20 @@ def layerwise_lr_decay(decay_rate, name_dict, n_layers, param):
""" """
ratio = 1.0 ratio = 1.0
static_name = name_dict[param.name] static_name = name_dict[param.name]
if "blocks" in static_name: if 'blocks.' in static_name or 'layers.' in static_name:
idx = static_name.find("blocks.") idx_1 = static_name.find('blocks.')
layer = int(static_name[idx:].split(".")[1]) idx_2 = static_name.find('layers.')
assert any([x >= 0 for x in [idx_1, idx_2]]), ''
idx = idx_1 if idx_1 >= 0 else idx_2
# idx = re.findall('[blocks|layers]\.(\d+)\.', static_name)[0]
layer = int(static_name[idx:].split('.')[1])
ratio = decay_rate**(n_layers - layer) ratio = decay_rate**(n_layers - layer)
elif "cls_token" in static_name or 'patch_embed' in static_name: elif 'cls_token' in static_name or 'patch_embed' in static_name:
ratio = decay_rate**(n_layers + 1) ratio = decay_rate**(n_layers + 1)
param.optimize_attr["learning_rate"] *= ratio param.optimize_attr['learning_rate'] *= ratio
class AdamWDL(AdamW): class AdamWDL(AdamW):
...@@ -156,16 +162,16 @@ class AdamWDL(AdamW): ...@@ -156,16 +162,16 @@ class AdamWDL(AdamW):
multi_precision=False, multi_precision=False,
layerwise_decay=1.0, layerwise_decay=1.0,
n_layers=12, n_layers=12,
set_param_lr_fun=None, set_param_lr_func=None,
name_dict=None, name_dict=None,
name=None): name=None):
if not isinstance(layerwise_decay, float): if not isinstance(layerwise_decay, float):
raise TypeError("coeff should be float or Tensor.") raise TypeError("coeff should be float or Tensor.")
self.layerwise_decay = layerwise_decay self.layerwise_decay = layerwise_decay
self.n_layers = n_layers self.n_layers = n_layers
self.set_param_lr_fun = partial( self.set_param_lr_func = partial(
set_param_lr_fun, layerwise_decay, name_dict, set_param_lr_func, layerwise_decay, name_dict,
n_layers) if set_param_lr_fun is not None else set_param_lr_fun n_layers) if set_param_lr_func is not None else set_param_lr_func
super(AdamWDL, self).__init__( super(AdamWDL, self).__init__(
learning_rate=learning_rate, learning_rate=learning_rate,
parameters=parameters, parameters=parameters,
...@@ -180,20 +186,20 @@ class AdamWDL(AdamW): ...@@ -180,20 +186,20 @@ class AdamWDL(AdamW):
multi_precision=multi_precision) multi_precision=multi_precision)
def _append_optimize_op(self, block, param_and_grad): def _append_optimize_op(self, block, param_and_grad):
if self.set_param_lr_fun is None: if self.set_param_lr_func is None:
return super(AdamWDL, self)._append_optimize_op(block, return super(AdamWDL, self)._append_optimize_op(block,
param_and_grad) param_and_grad)
self._append_decoupled_weight_decay(block, param_and_grad) self._append_decoupled_weight_decay(block, param_and_grad)
prev_lr = param_and_grad[0].optimize_attr["learning_rate"] prev_lr = param_and_grad[0].optimize_attr["learning_rate"]
self.set_param_lr_fun(param_and_grad[0]) self.set_param_lr_func(param_and_grad[0])
# excute Adam op # excute Adam op
res = super(AdamW, self)._append_optimize_op(block, param_and_grad) res = super(AdamW, self)._append_optimize_op(block, param_and_grad)
param_and_grad[0].optimize_attr["learning_rate"] = prev_lr param_and_grad[0].optimize_attr["learning_rate"] = prev_lr
return res return res
def build_adamw(model, def build_adamwdl(model,
lr=1e-4, lr=1e-4,
weight_decay=0.05, weight_decay=0.05,
betas=(0.9, 0.999), betas=(0.9, 0.999),
...@@ -201,15 +207,14 @@ def build_adamw(model, ...@@ -201,15 +207,14 @@ def build_adamw(model,
num_layers=None, num_layers=None,
filter_bias_and_bn=True, filter_bias_and_bn=True,
skip_decay_names=None, skip_decay_names=None,
set_param_lr_fun=None): set_param_lr_func='layerwise_lr_decay'):
if skip_decay_names and filter_bias_and_bn: if skip_decay_names and filter_bias_and_bn:
decay_dict = { decay_dict = {
param.name: not (len(param.shape) == 1 or name.endswith(".bias") or param.name: not (len(param.shape) == 1 or name.endswith('.bias') or
any([_n in name for _n in skip_decay_names])) any([_n in name for _n in skip_decay_names]))
for name, param in model.named_parameters() for name, param in model.named_parameters()
} }
parameters = [p for p in model.parameters()] parameters = [p for p in model.parameters()]
else: else:
...@@ -221,17 +226,15 @@ def build_adamw(model, ...@@ -221,17 +226,15 @@ def build_adamw(model,
if decay_dict is not None: if decay_dict is not None:
opt_args['apply_decay_param_fun'] = lambda n: decay_dict[n] opt_args['apply_decay_param_fun'] = lambda n: decay_dict[n]
if isinstance(set_param_lr_fun, str): if isinstance(set_param_lr_func, str):
set_param_lr_fun = eval(set_param_lr_fun) set_param_lr_func = eval(set_param_lr_func)
opt_args['set_param_lr_fun'] = set_param_lr_fun opt_args['set_param_lr_func'] = set_param_lr_func
opt_args['beta1'] = betas[0] opt_args['beta1'] = betas[0]
opt_args['beta2'] = betas[1] opt_args['beta2'] = betas[1]
opt_args['layerwise_decay'] = layer_decay opt_args['layerwise_decay'] = layer_decay
name_dict = dict() name_dict = {p.name: n for n, p in model.named_parameters()}
for n, p in model.named_parameters():
name_dict[p.name] = n
opt_args['name_dict'] = name_dict opt_args['name_dict'] = name_dict
opt_args['n_layers'] = num_layers opt_args['n_layers'] = num_layers
......
...@@ -27,6 +27,8 @@ import paddle.regularizer as regularizer ...@@ -27,6 +27,8 @@ import paddle.regularizer as regularizer
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
import copy import copy
from .adamw import AdamWDL, build_adamwdl
__all__ = ['LearningRate', 'OptimizerBuilder'] __all__ = ['LearningRate', 'OptimizerBuilder']
from ppdet.utils.logger import setup_logger from ppdet.utils.logger import setup_logger
...@@ -317,8 +319,13 @@ class OptimizerBuilder(): ...@@ -317,8 +319,13 @@ class OptimizerBuilder():
optim_args = self.optimizer.copy() optim_args = self.optimizer.copy()
optim_type = optim_args['type'] optim_type = optim_args['type']
del optim_args['type'] del optim_args['type']
if optim_type == 'AdamWDL':
return build_adamwdl(model, lr=learning_rate, **optim_args)
if optim_type != 'AdamW': if optim_type != 'AdamW':
optim_args['weight_decay'] = regularization optim_args['weight_decay'] = regularization
op = getattr(optimizer, optim_type) op = getattr(optimizer, optim_type)
if 'param_groups' in optim_args: if 'param_groups' in optim_args:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册