提交 c779cc14 编写于 作者: D dolcexu 提交者: zengshao0622

adamwdl fix

上级 995c8b26
......@@ -20,6 +20,7 @@ import inspect
from paddle import optimizer as optim
from ppcls.utils import logger
from functools import partial
class SGD(object):
......@@ -349,7 +350,8 @@ class AdamWDL(object):
self.layerwise_decay = layerwise_decay
self.name_dict = name_dict
self.n_layers = n_layers
self.set_param_lr_fun = self._layerwise_lr_decay
self.set_param_lr_func = partial(
self._layerwise_lr_decay, layerwise_decay, name_dict, n_layers)
super().__init__(
learning_rate=learning_rate,
parameters=parameters,
......@@ -361,30 +363,16 @@ class AdamWDL(object):
apply_decay_param_fun=apply_decay_param_fun,
weight_decay=weight_decay,
lazy_mode=lazy_mode,
multi_precision=multi_precision)
def _append_optimize_op(self, block, param_and_grad):
if self.set_param_lr_fun is None:
return super(AdamLW, self)._append_optimize_op(block,
param_and_grad)
self._append_decoupled_weight_decay(block, param_and_grad)
prev_lr = param_and_grad[0].optimize_attr["learning_rate"]
self.set_param_lr_fun(self.layerwise_decay, self.name_dict,
self.n_layers, param_and_grad[0])
# excute Adam op
res = super(optim.AdamW, self)._append_optimize_op(block,
param_and_grad)
param_and_grad[0].optimize_attr["learning_rate"] = prev_lr
return res
multi_precision=multi_precision,
lr_ratio=self.set_param_lr_func)
# Layerwise decay
def _layerwise_lr_decay(self, decay_rate, name_dict, n_layers, param):
"""
Args:
decay_rate (float):
decay_rate (float):
The layer-wise decay ratio.
name_dict (dict):
name_dict (dict):
The keys of name_dict is dynamic name of model while the value
of name_dict is static name.
Use model.named_parameters() to get name_dict.
......@@ -399,7 +387,8 @@ class AdamWDL(object):
ratio = decay_rate**(n_layers - layer)
elif "embed" in static_name:
ratio = decay_rate**(n_layers + 1)
param.optimize_attr["learning_rate"] *= ratio
# param.optimize_attr["learning_rate"] *= ratio
return ratio
def __call__(self, model_list):
model = model_list[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册