未验证 提交 2a56f4b3 编写于 作者: Z zhiboniu 提交者: GitHub

fix lamb optimizer always_adapt (#54654)

* fix lamb always_adapt

* fix optest

* fix all optests
上级 82eeda69
......@@ -1231,7 +1231,7 @@
backward : label_smooth_grad
- op : lamb_
args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, float weight_decay, float beta1=0.9, float beta2=0.999, float epsilon=1.0e-6f, bool multi_precision=false)
args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, float weight_decay, float beta1=0.9, float beta2=0.999, float epsilon=1.0e-6f, bool always_adapt=false, bool multi_precision=false)
output : Tensor(param_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(master_param_outs)
infer_meta :
func : LambInferMeta
......
......@@ -2061,6 +2061,7 @@ void LambInferMeta(const MetaTensor& param,
float beta1,
float beta2,
float epsilon,
bool always_adapt,
bool multi_precision,
MetaTensor* param_out,
MetaTensor* moment1_out,
......
......@@ -358,6 +358,7 @@ void LambInferMeta(const MetaTensor& param,
float beta1,
float beta2,
float epsilon,
bool always_adapt,
bool multi_precision,
MetaTensor* param_out,
MetaTensor* moment1_out,
......
......@@ -16,6 +16,7 @@
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/lamb_functors.h"
namespace phi {
......@@ -35,6 +36,7 @@ void ComputeImpl(const Context& dev_ctx,
float beta1_f,
float beta2_f,
float epsilon_f,
bool always_adapt,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* mom1_out,
......@@ -58,6 +60,7 @@ void LambKernel(const Context& dev_ctx,
float beta1,
float beta2,
float epsilon,
bool always_adapt,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* moment1_out,
......@@ -81,6 +84,7 @@ void LambKernel(const Context& dev_ctx,
beta1,
beta2,
epsilon,
always_adapt,
multi_precision,
param_out,
moment1_out,
......@@ -103,6 +107,7 @@ void LambKernel(const Context& dev_ctx,
beta1,
beta2,
epsilon,
always_adapt,
multi_precision,
param_out,
moment1_out,
......@@ -128,6 +133,7 @@ void ComputeImpl(const Context& dev_ctx,
float beta1_f,
float beta2_f,
float epsilon_f,
bool always_adapt,
bool multi_precision UNUSED,
DenseTensor* param_out,
DenseTensor* mom1_out,
......@@ -232,26 +238,39 @@ void ComputeImpl(const Context& dev_ctx,
// paddle/phi/kernels/selected_rows/impl/lamb_kernel_impl.h Please modify it
// together
DenseTensor p_norm_t;
p_norm_t.Resize(phi::make_ddim({1}));
auto* p_norm_ptr = dev_ctx.template Alloc<MT>(&p_norm_t);
DataType dtype = phi::CppTypeToDataType<MT>::Type();
FullKernel<MT, Context>(
dev_ctx, std::vector<int64_t>({1}), 0, dtype, &p_norm_t);
auto* p_norm_ptr = p_norm_t.data<MT>();
DenseTensor trust_ratio_div_norm_t;
trust_ratio_div_norm_t.Resize(phi::make_ddim({1}));
auto* trust_ratio_div_norm_ptr =
dev_ctx.template Alloc<MT>(&trust_ratio_div_norm_t);
FullKernel<MT, Context>(
dev_ctx, std::vector<int64_t>({1}), 0, dtype, &trust_ratio_div_norm_t);
auto* trust_ratio_div_norm_ptr = trust_ratio_div_norm_t.data<MT>();
// DenseTensor p_norm_t;
// p_norm_t.Resize(phi::make_ddim({1}));
// auto* p_norm_ptr = dev_ctx.template Alloc<MT>(&p_norm_t);
// DenseTensor trust_ratio_div_norm_t;
// trust_ratio_div_norm_t.Resize(phi::make_ddim({1}));
// auto* trust_ratio_div_norm_ptr =
// dev_ctx.template Alloc<MT>(&trust_ratio_div_norm_t);
// TODO(zengjinle): remove the following Eigen operations when
// *skip_update == true.
memory_utils::Buffer buffer(dev_ctx.GetPlace());
phi::funcs::SquaredL2Norm(
dev_ctx,
reinterpret_cast<const MT*>(IsMultiPrecision ? master_param_ptr
: param_ptr),
p_norm_ptr,
numel,
&buffer);
phi::funcs::SquaredL2Norm(
dev_ctx, trust_ratio_div_ptr, trust_ratio_div_norm_ptr, numel, &buffer);
if (weight_decay > static_cast<MT>(0) || always_adapt) {
memory_utils::Buffer buffer(dev_ctx.GetPlace());
phi::funcs::SquaredL2Norm(
dev_ctx,
reinterpret_cast<const MT*>(IsMultiPrecision ? master_param_ptr
: param_ptr),
p_norm_ptr,
numel,
&buffer);
phi::funcs::SquaredL2Norm(
dev_ctx, trust_ratio_div_ptr, trust_ratio_div_norm_ptr, numel, &buffer);
}
if (VLOG_IS_ON(1)) {
const auto& name = "Param";
......
......@@ -33,6 +33,7 @@ void LambKernel(const Context& dev_ctx,
float beta1,
float beta2,
float epsilon,
bool always_adapt,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* moment1_out,
......
......@@ -17,6 +17,7 @@
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/lamb_functors.h"
#include "paddle/phi/kernels/funcs/selected_rows_functor.h"
......@@ -38,6 +39,7 @@ void ComputeRowImpl(const Context& dev_ctx,
float beta1_f,
float beta2_f,
float epsilon_f,
bool always_adapt,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* mom1_out,
......@@ -61,6 +63,7 @@ void LambKernel(const Context& dev_ctx,
float beta1,
float beta2,
float epsilon,
bool always_adapt,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* moment1_out,
......@@ -84,6 +87,7 @@ void LambKernel(const Context& dev_ctx,
beta1,
beta2,
epsilon,
always_adapt,
multi_precision,
param_out,
moment1_out,
......@@ -106,6 +110,7 @@ void LambKernel(const Context& dev_ctx,
beta1,
beta2,
epsilon,
always_adapt,
multi_precision,
param_out,
moment1_out,
......@@ -131,6 +136,7 @@ void ComputeRowImpl(const Context& dev_ctx,
float beta1_f,
float beta2_f,
float epsilon_f,
bool always_adapt,
bool multi_precision UNUSED,
DenseTensor* param_out,
DenseTensor* mom1_out,
......@@ -285,27 +291,41 @@ void ComputeRowImpl(const Context& dev_ctx,
// Update parameter
// The code in the following part is exactly the same as that in
// paddle/phi/kernels/impl/lamb_kernel_impl.h Please modify it together
// DenseTensor p_norm_t;
// p_norm_t.Resize(phi::make_ddim({1}));
// auto* p_norm_ptr = dev_ctx.template Alloc<MT>(&p_norm_t);
// DenseTensor trust_ratio_div_norm_t;
// trust_ratio_div_norm_t.Resize(phi::make_ddim({1}));
// auto* trust_ratio_div_norm_ptr =
// dev_ctx.template Alloc<MT>(&trust_ratio_div_norm_t);
DenseTensor p_norm_t;
p_norm_t.Resize(phi::make_ddim({1}));
auto* p_norm_ptr = dev_ctx.template Alloc<MT>(&p_norm_t);
DataType dtype = phi::CppTypeToDataType<MT>::Type();
FullKernel<MT, Context>(
dev_ctx, std::vector<int64_t>({1}), 0, dtype, &p_norm_t);
auto* p_norm_ptr = p_norm_t.data<MT>();
DenseTensor trust_ratio_div_norm_t;
trust_ratio_div_norm_t.Resize(phi::make_ddim({1}));
auto* trust_ratio_div_norm_ptr =
dev_ctx.template Alloc<MT>(&trust_ratio_div_norm_t);
FullKernel<MT, Context>(
dev_ctx, std::vector<int64_t>({1}), 0, dtype, &trust_ratio_div_norm_t);
auto* trust_ratio_div_norm_ptr = trust_ratio_div_norm_t.data<MT>();
// TODO(zengjinle): remove the following Eigen operations when
// *skip_update == true.
memory_utils::Buffer buffer(dev_ctx.GetPlace());
phi::funcs::SquaredL2Norm(
dev_ctx,
reinterpret_cast<const MT*>(IsMultiPrecision ? master_param_ptr
: param_ptr),
p_norm_ptr,
numel,
&buffer);
phi::funcs::SquaredL2Norm(
dev_ctx, trust_ratio_div_ptr, trust_ratio_div_norm_ptr, numel, &buffer);
if (weight_decay > static_cast<MT>(0) || always_adapt) {
memory_utils::Buffer buffer(dev_ctx.GetPlace());
phi::funcs::SquaredL2Norm(
dev_ctx,
reinterpret_cast<const MT*>(IsMultiPrecision ? master_param_ptr
: param_ptr),
p_norm_ptr,
numel,
&buffer);
phi::funcs::SquaredL2Norm(
dev_ctx, trust_ratio_div_ptr, trust_ratio_div_norm_ptr, numel, &buffer);
}
if (VLOG_IS_ON(1)) {
const auto& name = "Param";
......
......@@ -34,6 +34,7 @@ void LambKernel(const Context& dev_ctx,
float beta1,
float beta2,
float epsilon,
bool always_adapt,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* moment1_out,
......
......@@ -38,6 +38,7 @@ void LambKernel(const Context& dev_ctx,
float beta1,
float beta2,
float epsilon,
bool always_adapt,
bool multi_precision,
DenseTensor* param_outs,
DenseTensor* moment1_out,
......
......@@ -71,6 +71,9 @@ class Lamb(Optimizer):
( :ref:`api_paddle_fluid_clip_ClipGradByGlobalNorm` , :ref:`api_paddle_fluid_clip_ClipGradByNorm` ,
:ref:`api_paddle_fluid_clip_ClipGradByValue` ). If you want better convergence, it is recommended
to use :ref:`api_paddle_fluid_clip_ClipGradByGlobalNorm` . Default None, meaning there is no gradient clipping.
exclude_from_weight_decay_fn (function, optional): whether to skip weight decay for a parameter when this function returns True while take the parameter as input.
always_adapt (bool, optional): whether to use Layer-wise LR adaptation. By default, skip adaptation on parameters that are
excluded from weight decay, unless always_adapt == True, then always enable LR adaptation.
name(str|None): For detailed information, please refer to
:ref:`api_guide_Name` . Usually name is no need to set and None by default.
Examples:
......@@ -106,6 +109,7 @@ class Lamb(Optimizer):
grad_clip=None,
exclude_from_weight_decay_fn=None,
multi_precision=False,
always_adapt=False,
name=None,
):
assert learning_rate is not None
......@@ -136,6 +140,7 @@ class Lamb(Optimizer):
self._used_master_weights = {}
# TODO(zengjinle): expose API as soon as possible
self._multi_precision = multi_precision
self.always_adapt = always_adapt
def _get_parameter(self, name, scope=None):
if scope is None:
......@@ -253,6 +258,7 @@ class Lamb(Optimizer):
self._beta1,
self._beta2,
self._epsilon,
self.always_adapt,
find_master,
)
return None
......@@ -279,6 +285,7 @@ class Lamb(Optimizer):
"beta2": self._beta2,
"epsilon": self._epsilon,
"weight_decay": weight_decay,
"always_adapt": self.always_adapt,
"multi_precision": find_master,
}
......
......@@ -158,6 +158,7 @@ def run_model(use_distributed_lamb, use_fp16, use_master_param_norm, **kwargs):
else:
optimizer_class = paddle.optimizer.Lamb
kwargs = dict(kwargs)
kwargs['always_adapt'] = True
kwargs.pop('clip_after_allreduce', None)
kwargs.pop('alignment', None)
kwargs.pop('use_master_acc_grad', None)
......
......@@ -38,6 +38,7 @@ def lamb_wrapper(
beta1=0.9,
beta2=0.999,
weight_decay=0.01,
always_adapt=False,
):
return paddle._C_ops.lamb_(
param,
......@@ -54,6 +55,7 @@ def lamb_wrapper(
beta2,
epsilon,
False,
False,
)
......@@ -64,6 +66,7 @@ class TestLambOp1(OpTest):
'beta1': 0.78,
'beta2': 0.836,
'weight_decay': 0.01,
'always_adapt': False,
}
def setUp(self):
......@@ -120,6 +123,18 @@ class TestLambOp2(TestLambOp1):
'beta1': 0.9,
'beta2': 0.999,
'weight_decay': 0.01,
'always_adapt': False,
}
class TestLambOp3(TestLambOp1):
def set_attrs(self):
self.attrs = {
'epsilon': 1e-8,
'beta1': 0.9,
'beta2': 0.999,
'weight_decay': 0.0,
'always_adapt': False,
}
......@@ -130,6 +145,7 @@ class TestLambOpMultipleSteps(TestLambOp1):
'beta1': 0.9,
'beta2': 0.999,
'weight_decay': 0.01,
'always_adapt': False,
}
self.num_steps = 10
......@@ -189,6 +205,7 @@ def lamb_step(inputs, attributes):
beta2 = attributes['beta2']
epsilon = attributes['epsilon']
weight_decay = attributes['weight_decay']
always_adapt = attributes['always_adapt']
moment1_out = beta1 * moment1 + (1 - beta1) * grad
moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad)
......@@ -196,12 +213,15 @@ def lamb_step(inputs, attributes):
moment1_unbiased = moment1_out / (1 - beta1_pow)
moment2_unbiased = moment2_out / (1 - beta2_pow)
r_1 = np.linalg.norm(param)
r_2 = np.linalg.norm(
moment1_unbiased / (np.sqrt(moment2_unbiased) + epsilon)
+ weight_decay * param
)
lr_t = lr * r_1 / r_2
if weight_decay > 0 or always_adapt:
r_1 = np.linalg.norm(param)
r_2 = np.linalg.norm(
moment1_unbiased / (np.sqrt(moment2_unbiased) + epsilon)
+ weight_decay * param
)
lr_t = lr * r_1 / r_2
else:
lr_t = lr
param_out = param - lr_t * (
moment1_unbiased / (np.sqrt(moment2_unbiased) + epsilon)
......@@ -234,6 +254,7 @@ def lamb_step_sparse(inputs, attributes, height, rows, row_numel, np_grad):
beta2 = attributes['beta2']
epsilon = attributes['epsilon']
weight_decay = attributes['weight_decay']
always_adapt = attributes['always_adapt']
moment1_out = np.zeros(shape=[height, row_numel])
moment2_out = np.zeros(shape=[height, row_numel])
......@@ -256,13 +277,16 @@ def lamb_step_sparse(inputs, attributes, height, rows, row_numel, np_grad):
update_value
)
def update_param():
r_1 = np.linalg.norm(param)
r_2 = np.linalg.norm(
moment1_out / (np.sqrt(moment2_out) + epsilon)
+ weight_decay * param
)
lr_t = lr * r_1 / r_2
def update_param(weight_decay, always_adapt):
if weight_decay > 0 or always_adapt:
r_1 = np.linalg.norm(param)
r_2 = np.linalg.norm(
moment1_out / (np.sqrt(moment2_out) + epsilon)
+ weight_decay * param
)
lr_t = lr * r_1 / r_2
else:
lr_t = lr
param_out = param - lr_t * (
moment1_out / (np.sqrt(moment2_out) + epsilon)
......@@ -275,7 +299,7 @@ def lamb_step_sparse(inputs, attributes, height, rows, row_numel, np_grad):
update_value = np_grad[rows.index(row_id)]
update_mom(row_id, update_value)
update_param()
update_param(weight_decay, always_adapt)
beta1_pow_out = beta1_pow * beta1
beta2_pow_out = beta2_pow * beta2
......@@ -307,6 +331,7 @@ class TestSparseLambOp(unittest.TestCase):
'beta1': beta1,
'beta2': beta2,
'weight_decay': 0.05,
'always_adapt': False,
}
grad_selected_rows = scope.var('Grad').get_selected_rows()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册