未验证 提交 75c762b9 编写于 作者: W Wenyu 提交者: GitHub

upgrade adamw for new paddle version (#7506)

上级 4d39dc22
...@@ -16,10 +16,15 @@ from __future__ import absolute_import ...@@ -16,10 +16,15 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import paddle
from paddle.optimizer import AdamW from paddle.optimizer import AdamW
from functools import partial from functools import partial
import re import re
IS_PADDLE_LATER_2_4 = (
int(paddle.version.major) >= 2 and
int(paddle.version.minor) >= 4) or int(paddle.version.major) == 0
def layerwise_lr_decay(decay_rate, name_dict, n_layers, param): def layerwise_lr_decay(decay_rate, name_dict, n_layers, param):
""" """
...@@ -48,6 +53,9 @@ def layerwise_lr_decay(decay_rate, name_dict, n_layers, param): ...@@ -48,6 +53,9 @@ def layerwise_lr_decay(decay_rate, name_dict, n_layers, param):
elif 'cls_token' in static_name or 'patch_embed' in static_name: elif 'cls_token' in static_name or 'patch_embed' in static_name:
ratio = decay_rate**(n_layers + 1) ratio = decay_rate**(n_layers + 1)
if IS_PADDLE_LATER_2_4:
return ratio
else:
param.optimize_attr['learning_rate'] *= ratio param.optimize_attr['learning_rate'] *= ratio
...@@ -172,6 +180,22 @@ class AdamWDL(AdamW): ...@@ -172,6 +180,22 @@ class AdamWDL(AdamW):
self.set_param_lr_func = partial( self.set_param_lr_func = partial(
set_param_lr_func, layerwise_decay, name_dict, set_param_lr_func, layerwise_decay, name_dict,
n_layers) if set_param_lr_func is not None else set_param_lr_func n_layers) if set_param_lr_func is not None else set_param_lr_func
if IS_PADDLE_LATER_2_4:
super(AdamWDL, self).__init__(
learning_rate=learning_rate,
parameters=parameters,
beta1=beta1,
beta2=beta2,
epsilon=epsilon,
grad_clip=grad_clip,
name=name,
apply_decay_param_fun=apply_decay_param_fun,
weight_decay=weight_decay,
lazy_mode=lazy_mode,
multi_precision=multi_precision,
lr_ratio=self.set_param_lr_func)
else:
super(AdamWDL, self).__init__( super(AdamWDL, self).__init__(
learning_rate=learning_rate, learning_rate=learning_rate,
parameters=parameters, parameters=parameters,
...@@ -185,10 +209,10 @@ class AdamWDL(AdamW): ...@@ -185,10 +209,10 @@ class AdamWDL(AdamW):
lazy_mode=lazy_mode, lazy_mode=lazy_mode,
multi_precision=multi_precision) multi_precision=multi_precision)
def _append_optimize_op(self, block, param_and_grad):
def _append_optimize_op(self, block, param_and_grad):
if self.set_param_lr_func is None: if self.set_param_lr_func is None:
return super(AdamWDL, self)._append_optimize_op(block, return super(AdamWDL, self)._append_optimize_op(block, param_and_grad)
param_and_grad)
self._append_decoupled_weight_decay(block, param_and_grad) self._append_decoupled_weight_decay(block, param_and_grad)
prev_lr = param_and_grad[0].optimize_attr["learning_rate"] prev_lr = param_and_grad[0].optimize_attr["learning_rate"]
...@@ -199,6 +223,10 @@ class AdamWDL(AdamW): ...@@ -199,6 +223,10 @@ class AdamWDL(AdamW):
return res return res
if not IS_PADDLE_LATER_2_4:
AdamWDL._append_optimize_op = _append_optimize_op
def build_adamwdl(model, def build_adamwdl(model,
lr=1e-4, lr=1e-4,
weight_decay=0.05, weight_decay=0.05,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册