未验证 提交 c22803f2 编写于 作者: W Wenyu 提交者: GitHub

upgrade adamw for new paddle version (#7507)

上级 1a336e5f
......@@ -16,10 +16,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle.optimizer import AdamW
from functools import partial
import re
IS_PADDLE_LATER_2_4 = (
int(paddle.version.major) >= 2 and
int(paddle.version.minor) >= 4) or int(paddle.version.major) == 0
def layerwise_lr_decay(decay_rate, name_dict, n_layers, param):
"""
......@@ -48,6 +53,9 @@ def layerwise_lr_decay(decay_rate, name_dict, n_layers, param):
elif 'cls_token' in static_name or 'patch_embed' in static_name:
ratio = decay_rate**(n_layers + 1)
if IS_PADDLE_LATER_2_4:
return ratio
else:
param.optimize_attr['learning_rate'] *= ratio
......@@ -172,6 +180,22 @@ class AdamWDL(AdamW):
self.set_param_lr_func = partial(
set_param_lr_func, layerwise_decay, name_dict,
n_layers) if set_param_lr_func is not None else set_param_lr_func
if IS_PADDLE_LATER_2_4:
super(AdamWDL, self).__init__(
learning_rate=learning_rate,
parameters=parameters,
beta1=beta1,
beta2=beta2,
epsilon=epsilon,
grad_clip=grad_clip,
name=name,
apply_decay_param_fun=apply_decay_param_fun,
weight_decay=weight_decay,
lazy_mode=lazy_mode,
multi_precision=multi_precision,
lr_ratio=self.set_param_lr_func)
else:
super(AdamWDL, self).__init__(
learning_rate=learning_rate,
parameters=parameters,
......@@ -185,10 +209,10 @@ class AdamWDL(AdamW):
lazy_mode=lazy_mode,
multi_precision=multi_precision)
def _append_optimize_op(self, block, param_and_grad):
def _append_optimize_op(self, block, param_and_grad):
if self.set_param_lr_func is None:
return super(AdamWDL, self)._append_optimize_op(block,
param_and_grad)
return super(AdamWDL, self)._append_optimize_op(block, param_and_grad)
self._append_decoupled_weight_decay(block, param_and_grad)
prev_lr = param_and_grad[0].optimize_attr["learning_rate"]
......@@ -199,6 +223,10 @@ class AdamWDL(AdamW):
return res
if not IS_PADDLE_LATER_2_4:
AdamWDL._append_optimize_op = _append_optimize_op
def build_adamwdl(model,
lr=1e-4,
weight_decay=0.05,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册