未验证 提交 91cf918f 编写于 作者: Z zhaoyingli 提交者: GitHub

add layerwise learning rate for adamw (#35569)

* add layerwise learning rate for adamw

* fix format

* add unitest

* add NotImplementedError

* add gpu unitest

* update gpuplace
上级 0bbff93c
...@@ -236,6 +236,10 @@ class AdamWOpMaker : public AdamOpMaker { ...@@ -236,6 +236,10 @@ class AdamWOpMaker : public AdamOpMaker {
public: public:
void Make() { void Make() {
AdamOpMaker::Make(); AdamOpMaker::Make();
AddAttr<float>("lr_ratio",
"(float, default 1.0) "
"layerwise learning rate decay")
.SetDefault(1.0f);
AddAttr<float>("coeff", AddAttr<float>("coeff",
"(float, default 0.01) " "(float, default 0.01) "
"coeff of the weight decay") "coeff of the weight decay")
......
...@@ -20,17 +20,17 @@ namespace operators { ...@@ -20,17 +20,17 @@ namespace operators {
template <typename T, typename MT> template <typename T, typename MT>
__global__ void AdamWKernelREG(MT beta1, MT beta2, MT epsilon, MT coeff, __global__ void AdamWKernelREG(MT beta1, MT beta2, MT epsilon, MT coeff,
MT beta1_pow_, MT beta2_pow_, const MT* moment1, MT lr_ratio, MT beta1_pow_, MT beta2_pow_,
MT* moment1_out, const MT* moment2, const MT* moment1, MT* moment1_out,
MT* moment2_out, const MT* lr_, const T* grad, const MT* moment2, MT* moment2_out,
const T* param, T* param_out, const MT* lr_, const T* grad, const T* param,
const MT* master_param, MT* master_param_out, T* param_out, const MT* master_param,
int ndim) { MT* master_param_out, int ndim) {
MT lr = *lr_; MT lr = *lr_ * lr_ratio;
MT lr_orig = lr;
MT beta1_pow = beta1_pow_; MT beta1_pow = beta1_pow_;
MT beta2_pow = beta2_pow_; MT beta2_pow = beta2_pow_;
MT wd = static_cast<MT>(1.0) - coeff * lr;
lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) / lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow); (static_cast<MT>(1.0) - beta1_pow);
...@@ -43,8 +43,8 @@ __global__ void AdamWKernelREG(MT beta1, MT beta2, MT epsilon, MT coeff, ...@@ -43,8 +43,8 @@ __global__ void AdamWKernelREG(MT beta1, MT beta2, MT epsilon, MT coeff,
MT mom2 = moment2[id]; MT mom2 = moment2[id];
mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g; mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g; mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
p = wd * p - p -= lr_orig * coeff * p;
lr * (mom1 / p -= lr * (mom1 /
(sqrt(mom2) + epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow))); (sqrt(mom2) + epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));
moment1_out[id] = mom1; moment1_out[id] = mom1;
...@@ -57,18 +57,16 @@ __global__ void AdamWKernelREG(MT beta1, MT beta2, MT epsilon, MT coeff, ...@@ -57,18 +57,16 @@ __global__ void AdamWKernelREG(MT beta1, MT beta2, MT epsilon, MT coeff,
} }
template <typename T, typename MT> template <typename T, typename MT>
__global__ void AdamWKernelMEM(MT beta1, MT beta2, MT epsilon, MT coeff, __global__ void AdamWKernelMEM(
const MT* beta1_pow_, const MT* beta2_pow_, MT beta1, MT beta2, MT epsilon, MT coeff, MT lr_ratio, const MT* beta1_pow_,
const MT* moment1, MT* moment1_out, const MT* beta2_pow_, const MT* moment1, MT* moment1_out, const MT* moment2,
const MT* moment2, MT* moment2_out, MT* moment2_out, const MT* lr_, const T* grad, const T* param, T* param_out,
const MT* lr_, const T* grad, const T* param, const MT* master_param, MT* master_param_out, int ndim) {
T* param_out, const MT* master_param, MT lr = *lr_ * lr_ratio;
MT* master_param_out, int ndim) { MT lr_orig = lr;
MT lr = *lr_;
MT beta1_pow = *beta1_pow_; MT beta1_pow = *beta1_pow_;
MT beta2_pow = *beta2_pow_; MT beta2_pow = *beta2_pow_;
MT wd = static_cast<MT>(1.0) - coeff * lr;
lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) / lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow); (static_cast<MT>(1.0) - beta1_pow);
...@@ -81,8 +79,8 @@ __global__ void AdamWKernelMEM(MT beta1, MT beta2, MT epsilon, MT coeff, ...@@ -81,8 +79,8 @@ __global__ void AdamWKernelMEM(MT beta1, MT beta2, MT epsilon, MT coeff,
MT mom2 = static_cast<MT>(moment2[id]); MT mom2 = static_cast<MT>(moment2[id]);
mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g; mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g; mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
p = wd * p - p -= lr_orig * coeff * p;
lr * (mom1 / p -= lr * (mom1 /
(sqrt(mom2) + epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow))); (sqrt(mom2) + epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));
moment1_out[id] = mom1; moment1_out[id] = mom1;
...@@ -103,16 +101,16 @@ __global__ void UpdateAdamWBetaPow(T beta1, T beta2, const T* beta1_pow_, ...@@ -103,16 +101,16 @@ __global__ void UpdateAdamWBetaPow(T beta1, T beta2, const T* beta1_pow_,
template <typename T, typename MT> template <typename T, typename MT>
__global__ void SparseAdamWCUDAKernelREG( __global__ void SparseAdamWCUDAKernelREG(
MT beta1, MT beta2, MT epsilon, MT coeff, const MT beta1_pow, MT beta1, MT beta2, MT epsilon, MT coeff, MT lr_ratio, const MT beta1_pow,
const MT beta2_pow, const MT* mom1_, MT* mom1_out_, const MT* mom2_, const MT beta2_pow, const MT* mom1_, MT* mom1_out_, const MT* mom2_,
MT* mom2_out_, const MT* lr_, const T* grad_, const T* param_, MT* mom2_out_, const MT* lr_, const T* grad_, const T* param_,
T* param_out_, const MT* master_param, MT* master_param_out, T* param_out_, const MT* master_param, MT* master_param_out,
const int64_t* rows_, int64_t row_numel, int64_t row_count, bool lazy_mode, const int64_t* rows_, int64_t row_numel, int64_t row_count, bool lazy_mode,
int ndim) { int ndim) {
int id = blockIdx.x * blockDim.x + threadIdx.x; int id = blockIdx.x * blockDim.x + threadIdx.x;
MT lr = *lr_; MT lr = *lr_ * lr_ratio;
MT lr_orig = lr;
MT wd = static_cast<MT>(1.0) - coeff * lr;
lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) / lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow); (static_cast<MT>(1.0) - beta1_pow);
...@@ -130,8 +128,8 @@ __global__ void SparseAdamWCUDAKernelREG( ...@@ -130,8 +128,8 @@ __global__ void SparseAdamWCUDAKernelREG(
: static_cast<MT>(0); : static_cast<MT>(0);
mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g; mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g; mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
p = wd * p - p -= lr_orig * coeff * p;
lr * (mom1 / (sqrt(mom2) + p -= lr * (mom1 / (sqrt(mom2) +
epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow))); epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));
// Write back to global memory // Write back to global memory
...@@ -165,7 +163,9 @@ class AdamWOpCUDAKernel : public framework::OpKernel<T> { ...@@ -165,7 +163,9 @@ class AdamWOpCUDAKernel : public framework::OpKernel<T> {
bool lazy_mode = ctx.Attr<bool>("lazy_mode"); bool lazy_mode = ctx.Attr<bool>("lazy_mode");
bool use_global_beta_pow = ctx.Attr<bool>("use_global_beta_pow"); bool use_global_beta_pow = ctx.Attr<bool>("use_global_beta_pow");
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
float coeff = ctx.Attr<float>("coeff");
MPDType coeff = static_cast<MPDType>(ctx.Attr<float>("coeff"));
MPDType lr_ratio = static_cast<MPDType>(ctx.Attr<float>("lr_ratio"));
auto* param = ctx.Input<LoDTensor>("Param"); auto* param = ctx.Input<LoDTensor>("Param");
auto* grad_var = ctx.InputVar("Grad"); auto* grad_var = ctx.InputVar("Grad");
...@@ -301,7 +301,7 @@ class AdamWOpCUDAKernel : public framework::OpKernel<T> { ...@@ -301,7 +301,7 @@ class AdamWOpCUDAKernel : public framework::OpKernel<T> {
beta2_pow->place() == platform::CPUPlace()) { beta2_pow->place() == platform::CPUPlace()) {
// Compute with betapow in REG // Compute with betapow in REG
AdamWKernelREG<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>( AdamWKernelREG<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1, beta2, epsilon, coeff, *beta1_pow->data<MPDType>(), beta1, beta2, epsilon, coeff, lr_ratio, *beta1_pow->data<MPDType>(),
*beta2_pow->data<MPDType>(), mom1->data<MPDType>(), *beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
mom1_out->mutable_data<MPDType>(ctx.GetPlace()), mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
mom2->data<MPDType>(), mom2->data<MPDType>(),
...@@ -318,7 +318,7 @@ class AdamWOpCUDAKernel : public framework::OpKernel<T> { ...@@ -318,7 +318,7 @@ class AdamWOpCUDAKernel : public framework::OpKernel<T> {
} }
} else { } else {
AdamWKernelMEM<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>( AdamWKernelMEM<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1, beta2, epsilon, coeff, beta1_pow->data<MPDType>(), beta1, beta2, epsilon, coeff, lr_ratio, beta1_pow->data<MPDType>(),
beta2_pow->data<MPDType>(), mom1->data<MPDType>(), beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
mom1_out->mutable_data<MPDType>(ctx.GetPlace()), mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
mom2->data<MPDType>(), mom2->data<MPDType>(),
...@@ -377,7 +377,7 @@ class AdamWOpCUDAKernel : public framework::OpKernel<T> { ...@@ -377,7 +377,7 @@ class AdamWOpCUDAKernel : public framework::OpKernel<T> {
SparseAdamWCUDAKernelREG< SparseAdamWCUDAKernelREG<
T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>( T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1, beta2, epsilon, coeff, *beta1_pow->data<MPDType>(), beta1, beta2, epsilon, coeff, lr_ratio, *beta1_pow->data<MPDType>(),
*beta2_pow->data<MPDType>(), mom1->data<MPDType>(), *beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
mom1_out->mutable_data<MPDType>(ctx.GetPlace()), mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
mom2->data<MPDType>(), mom2->data<MPDType>(),
...@@ -395,7 +395,7 @@ class AdamWOpCUDAKernel : public framework::OpKernel<T> { ...@@ -395,7 +395,7 @@ class AdamWOpCUDAKernel : public framework::OpKernel<T> {
} }
} else { } else {
SparseAdamWFunctor<T, GPUAdamW, MPDType> functor( SparseAdamWFunctor<T, GPUAdamW, MPDType> functor(
beta1, beta2, epsilon, coeff, beta1_pow->data<MPDType>(), beta1, beta2, epsilon, coeff, lr_ratio, beta1_pow->data<MPDType>(),
beta2_pow->data<MPDType>(), mom1->data<MPDType>(), beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
mom1_out->mutable_data<MPDType>(ctx.GetPlace()), mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
mom2->data<MPDType>(), mom2->data<MPDType>(),
......
...@@ -32,12 +32,13 @@ template <typename T> ...@@ -32,12 +32,13 @@ template <typename T>
class AdamWFunctor<T, CPUAdamW> { class AdamWFunctor<T, CPUAdamW> {
private: private:
const T coeff_; const T coeff_;
const T lr_ratio_;
const T* lr_; const T* lr_;
T* param_; T* param_;
public: public:
AdamWFunctor(const T coeff, const T* lr, T* param) AdamWFunctor(const T coeff, const T lr_ratio, const T* lr, T* param)
: coeff_(coeff), lr_(lr), param_(param) {} : coeff_(coeff), lr_ratio_(lr_ratio), lr_(lr), param_(param) {}
inline HOSTDEVICE void operator()(size_t numel) const { inline HOSTDEVICE void operator()(size_t numel) const {
Eigen::Map<Eigen::Array<T, 1, Eigen::Dynamic>> param{ Eigen::Map<Eigen::Array<T, 1, Eigen::Dynamic>> param{
...@@ -46,7 +47,7 @@ class AdamWFunctor<T, CPUAdamW> { ...@@ -46,7 +47,7 @@ class AdamWFunctor<T, CPUAdamW> {
T lr = *lr_; T lr = *lr_;
// Calculation // Calculation
param = param * (1 - lr * coeff_); param -= lr * lr_ratio_ * coeff_ * param;
} }
}; };
...@@ -60,6 +61,7 @@ class SparseAdamWFunctor<T, GPUAdamW, MT> { ...@@ -60,6 +61,7 @@ class SparseAdamWFunctor<T, GPUAdamW, MT> {
MT beta2_; MT beta2_;
MT epsilon_; MT epsilon_;
MT coeff_; MT coeff_;
MT lr_ratio_;
const MT* beta1_pow_; const MT* beta1_pow_;
const MT* beta2_pow_; const MT* beta2_pow_;
...@@ -80,7 +82,7 @@ class SparseAdamWFunctor<T, GPUAdamW, MT> { ...@@ -80,7 +82,7 @@ class SparseAdamWFunctor<T, GPUAdamW, MT> {
bool lazy_mode_; bool lazy_mode_;
public: public:
SparseAdamWFunctor(MT beta1, MT beta2, MT epsilon, MT coeff, SparseAdamWFunctor(MT beta1, MT beta2, MT epsilon, MT coeff, MT lr_ratio,
const MT* beta1_pow, const MT* beta2_pow, const MT* mom1, const MT* beta1_pow, const MT* beta2_pow, const MT* mom1,
MT* mom1_out, const MT* mom2, MT* mom2_out, const MT* lr, MT* mom1_out, const MT* mom2, MT* mom2_out, const MT* lr,
const T* grad, const T* param, T* param_out, const T* grad, const T* param, T* param_out,
...@@ -91,6 +93,7 @@ class SparseAdamWFunctor<T, GPUAdamW, MT> { ...@@ -91,6 +93,7 @@ class SparseAdamWFunctor<T, GPUAdamW, MT> {
beta2_(beta2), beta2_(beta2),
epsilon_(epsilon), epsilon_(epsilon),
coeff_(coeff), coeff_(coeff),
lr_ratio_(lr_ratio),
beta1_pow_(beta1_pow), beta1_pow_(beta1_pow),
beta2_pow_(beta2_pow), beta2_pow_(beta2_pow),
moment1_(mom1), moment1_(mom1),
...@@ -112,21 +115,21 @@ class SparseAdamWFunctor<T, GPUAdamW, MT> { ...@@ -112,21 +115,21 @@ class SparseAdamWFunctor<T, GPUAdamW, MT> {
// The following code is the same as dense // The following code is the same as dense
MT mom1 = moment1_[i]; MT mom1 = moment1_[i];
MT mom2 = moment2_[i]; MT mom2 = moment2_[i];
MT lr = *lr_; MT lr = *lr_ * lr_ratio_;
MT lr_orig = lr;
MT beta1_pow = *beta1_pow_; MT beta1_pow = *beta1_pow_;
MT beta2_pow = *beta2_pow_; MT beta2_pow = *beta2_pow_;
MT p = master_param_ ? master_param_[i] : static_cast<MT>(param_[i]); MT p = master_param_ ? master_param_[i] : static_cast<MT>(param_[i]);
// Calculation // Calculation
MT wd = static_cast<MT>(1.0) - coeff_ * lr;
lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) / lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow); (static_cast<MT>(1.0) - beta1_pow);
mom1 = beta1_ * mom1 + (static_cast<MT>(1.0) - beta1_) * g; mom1 = beta1_ * mom1 + (static_cast<MT>(1.0) - beta1_) * g;
mom2 = beta2_ * mom2 + (static_cast<MT>(1.0) - beta2_) * g * g; mom2 = beta2_ * mom2 + (static_cast<MT>(1.0) - beta2_) * g * g;
p = wd * p - p -= lr_orig * coeff_ * p;
lr * (mom1 / p -= lr * (mom1 / (sqrt(mom2) +
(sqrt(mom2) + epsilon_ * sqrt(static_cast<MT>(1.0) - beta2_pow))); epsilon_ * sqrt(static_cast<MT>(1.0) - beta2_pow)));
// Write back to global memory // Write back to global memory
moment1_out_[i] = mom1; moment1_out_[i] = mom1;
...@@ -187,6 +190,7 @@ class AdamWOpKernel : public AdamOpKernel<DeviceContext, T> { ...@@ -187,6 +190,7 @@ class AdamWOpKernel : public AdamOpKernel<DeviceContext, T> {
} }
T coeff = static_cast<T>(ctx.Attr<float>("coeff")); T coeff = static_cast<T>(ctx.Attr<float>("coeff"));
T lr_ratio = static_cast<T>(ctx.Attr<float>("lr_ratio"));
auto* lr = ctx.Input<LoDTensor>("LearningRate"); auto* lr = ctx.Input<LoDTensor>("LearningRate");
LoDTensor* param; LoDTensor* param;
...@@ -198,7 +202,8 @@ class AdamWOpKernel : public AdamOpKernel<DeviceContext, T> { ...@@ -198,7 +202,8 @@ class AdamWOpKernel : public AdamOpKernel<DeviceContext, T> {
param = const_cast<LoDTensor*>(ctx.Input<LoDTensor>("Param")); param = const_cast<LoDTensor*>(ctx.Input<LoDTensor>("Param"));
} }
AdamWFunctor<T, CPUAdamW> functor(coeff, lr->data<T>(), param->data<T>()); AdamWFunctor<T, CPUAdamW> functor(coeff, lr_ratio, lr->data<T>(),
param->data<T>());
functor(param->numel()); functor(param->numel());
AdamOpKernel<DeviceContext, T>::Compute(ctx); AdamOpKernel<DeviceContext, T>::Compute(ctx);
......
...@@ -16,6 +16,7 @@ import unittest ...@@ -16,6 +16,7 @@ import unittest
import paddle import paddle
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from functools import partial
class TestAdamWOp(unittest.TestCase): class TestAdamWOp(unittest.TestCase):
...@@ -148,5 +149,91 @@ class TestAdamWOpGroupWithLR(TestAdamWOp): ...@@ -148,5 +149,91 @@ class TestAdamWOpGroupWithLR(TestAdamWOp):
adam.clear_gradients() adam.clear_gradients()
def simple_lr_setting(param, decay_rate, n_layers):
if "fc_0" in param.name or "linear_1" in param.name:
depth = int(param.name.split("_")[2]) + 1
elif "fc_1" in param.name or "linear_2" in param.name:
depth = int(param.name.split("_")[2]) + 2
else:
depth = 0
return decay_rate**(n_layers + 2 - depth)
class TestAdamWOpLayerwiseLR(TestAdamWOp):
def test_adamw_op_dygraph(self):
paddle.disable_static()
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.to_tensor(value)
linear1 = paddle.nn.Linear(13, 8)
linear2 = paddle.nn.Linear(8, 5)
simple_lr_fun = partial(simple_lr_setting, decay_rate=0.8, n_layers=2)
adam = paddle.optimizer.AdamW(
learning_rate=0.01,
parameters=[{
'params': linear1.parameters()
}, {
'params': linear2.parameters(),
}],
apply_decay_param_fun=lambda name: True,
weight_decay=0.01,
lr_ratio=simple_lr_fun)
for _ in range(2):
a1 = linear1(a)
out = linear2(a1)
out.backward()
adam.step()
adam.clear_gradients()
def test_adamw_op(self):
paddle.enable_static()
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() \
else fluid.CPUPlace()
train_prog = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(train_prog, startup):
with fluid.unique_name.guard():
x = fluid.data(name='x', shape=[None, 10], dtype='float32')
y = fluid.data(name='y', shape=[None, 1], dtype='float32')
fc1 = fluid.layers.fc(input=x, size=32, act=None)
prediction = fluid.layers.fc(input=fc1, size=1, act=None)
cost = fluid.layers.square_error_cost(input=prediction, label=y)
avg_cost = fluid.layers.mean(cost)
simple_lr_fun = partial(
simple_lr_setting, decay_rate=0.8, n_layers=2)
beta1 = fluid.layers.create_global_var(
shape=[1], value=0.85, dtype='float32', persistable=True)
beta2 = fluid.layers.create_global_var(
shape=[1], value=0.95, dtype='float32', persistable=True)
betas = [beta1, beta2]
opt = paddle.optimizer.AdamW(
learning_rate=1e-5,
beta1=beta1,
beta2=beta2,
weight_decay=0.01,
epsilon=1e-8,
lr_ratio=simple_lr_fun)
opt.minimize(avg_cost)
exe = fluid.Executor(place)
exe.run(startup)
for _ in range(2):
inputs = np.random.random(size=[8, 10]).astype('float32')
outputs = np.random.random(size=[8, 1]).astype('float32')
rets = exe.run(train_prog,
feed={"x": inputs,
"y": outputs},
fetch_list=[avg_cost])
assert rets[0] is not None
paddle.disable_static()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -18,6 +18,7 @@ from ..fluid import core ...@@ -18,6 +18,7 @@ from ..fluid import core
from ..fluid import framework from ..fluid import framework
from ..fluid.framework import Variable from ..fluid.framework import Variable
from ..fluid.dygraph import base as imperative_base from ..fluid.dygraph import base as imperative_base
from collections import Callable
import paddle import paddle
_C_ops = core.ops _C_ops = core.ops
...@@ -63,6 +64,10 @@ class AdamW(Adam): ...@@ -63,6 +64,10 @@ class AdamW(Adam):
epsilon (float, optional): A small float value for numerical stability. epsilon (float, optional): A small float value for numerical stability.
The default value is 1e-08. The default value is 1e-08.
weight_decay (float|Tensor, optional): The weight decay coefficient, it can be float or Tensor. The default value is 0.01. weight_decay (float|Tensor, optional): The weight decay coefficient, it can be float or Tensor. The default value is 0.01.
lr_ratio (function|None, optional): If it is not None,
the learning rate will be updated with layerwise learning rate ratio.
Otherwise, the learning rate is the original.
Default: None.
apply_decay_param_fun (function|None, optional): If it is not None, apply_decay_param_fun (function|None, optional): If it is not None,
only tensors that makes apply_decay_param_fun(Tensor.name)==True only tensors that makes apply_decay_param_fun(Tensor.name)==True
will be updated with weight decay. It only works when we want to specify tensors. will be updated with weight decay. It only works when we want to specify tensors.
...@@ -140,6 +145,7 @@ class AdamW(Adam): ...@@ -140,6 +145,7 @@ class AdamW(Adam):
epsilon=1e-8, epsilon=1e-8,
parameters=None, parameters=None,
weight_decay=0.01, weight_decay=0.01,
lr_ratio=None,
apply_decay_param_fun=None, apply_decay_param_fun=None,
grad_clip=None, grad_clip=None,
lazy_mode=False, lazy_mode=False,
...@@ -163,6 +169,12 @@ class AdamW(Adam): ...@@ -163,6 +169,12 @@ class AdamW(Adam):
self._apply_decay_param_fun = apply_decay_param_fun self._apply_decay_param_fun = apply_decay_param_fun
self._coeff = coeff self._coeff = coeff
self._lr_to_coeff = dict() self._lr_to_coeff = dict()
if lr_ratio is not None:
assert isinstance(lr_ratio, Callable)
if core.is_compiled_with_xpu() or core.is_compiled_with_npu():
raise NotImplementedError(
"'lr_ratio' is unimplemented in XPU and NPU")
self._lr_ratio = lr_ratio
super(AdamW, self).__init__( super(AdamW, self).__init__(
learning_rate=learning_rate, learning_rate=learning_rate,
...@@ -278,6 +290,8 @@ class AdamW(Adam): ...@@ -278,6 +290,8 @@ class AdamW(Adam):
# create the adamw optimize op # create the adamw optimize op
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
lr_ratio_ = 1. if self._lr_ratio is None else self._lr_ratio(
param_and_grad[0])
_beta1 = self._beta1 if not isinstance( _beta1 = self._beta1 if not isinstance(
self._beta1, Variable) else self._beta1.numpy().item(0) self._beta1, Variable) else self._beta1.numpy().item(0)
...@@ -288,7 +302,8 @@ class AdamW(Adam): ...@@ -288,7 +302,8 @@ class AdamW(Adam):
beta1_pow_acc, beta2_pow_acc, param_and_grad[0], moment1, beta1_pow_acc, beta2_pow_acc, param_and_grad[0], moment1,
moment2, beta1_pow_acc, beta2_pow_acc, 'epsilon', self._epsilon, moment2, beta1_pow_acc, beta2_pow_acc, 'epsilon', self._epsilon,
'lazy_mode', self._lazy_mode, 'min_row_size_to_use_multithread', 'lazy_mode', self._lazy_mode, 'min_row_size_to_use_multithread',
1000, 'beta1', _beta1, 'beta2', _beta2, 'coeff', self._coeff) 1000, 'beta1', _beta1, 'beta2', _beta2, 'coeff', self._coeff,
"lr_ratio", lr_ratio_)
return None return None
...@@ -321,6 +336,8 @@ class AdamW(Adam): ...@@ -321,6 +336,8 @@ class AdamW(Adam):
"multi_precision": find_master, "multi_precision": find_master,
"with_decay": with_decay, "with_decay": with_decay,
"coeff": self._coeff, "coeff": self._coeff,
"lr_ratio": 1.
if self._lr_ratio is None else self._lr_ratio(param_and_grad[0])
} }
if isinstance(self._beta1, Variable): if isinstance(self._beta1, Variable):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册