未验证 提交 2a56f4b3 编写于 作者: Z zhiboniu 提交者: GitHub

fix lamb optimizer always_adapt (#54654)

* fix lamb always_adapt

* fix optest

* fix all optests
上级 82eeda69
...@@ -1231,7 +1231,7 @@ ...@@ -1231,7 +1231,7 @@
backward : label_smooth_grad backward : label_smooth_grad
- op : lamb_ - op : lamb_
args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, float weight_decay, float beta1=0.9, float beta2=0.999, float epsilon=1.0e-6f, bool multi_precision=false) args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, float weight_decay, float beta1=0.9, float beta2=0.999, float epsilon=1.0e-6f, bool always_adapt=false, bool multi_precision=false)
output : Tensor(param_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(master_param_outs) output : Tensor(param_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(master_param_outs)
infer_meta : infer_meta :
func : LambInferMeta func : LambInferMeta
......
...@@ -2061,6 +2061,7 @@ void LambInferMeta(const MetaTensor& param, ...@@ -2061,6 +2061,7 @@ void LambInferMeta(const MetaTensor& param,
float beta1, float beta1,
float beta2, float beta2,
float epsilon, float epsilon,
bool always_adapt,
bool multi_precision, bool multi_precision,
MetaTensor* param_out, MetaTensor* param_out,
MetaTensor* moment1_out, MetaTensor* moment1_out,
......
...@@ -358,6 +358,7 @@ void LambInferMeta(const MetaTensor& param, ...@@ -358,6 +358,7 @@ void LambInferMeta(const MetaTensor& param,
float beta1, float beta1,
float beta2, float beta2,
float epsilon, float epsilon,
bool always_adapt,
bool multi_precision, bool multi_precision,
MetaTensor* param_out, MetaTensor* param_out,
MetaTensor* moment1_out, MetaTensor* moment1_out,
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/lamb_functors.h" #include "paddle/phi/kernels/funcs/lamb_functors.h"
namespace phi { namespace phi {
...@@ -35,6 +36,7 @@ void ComputeImpl(const Context& dev_ctx, ...@@ -35,6 +36,7 @@ void ComputeImpl(const Context& dev_ctx,
float beta1_f, float beta1_f,
float beta2_f, float beta2_f,
float epsilon_f, float epsilon_f,
bool always_adapt,
bool multi_precision, bool multi_precision,
DenseTensor* param_out, DenseTensor* param_out,
DenseTensor* mom1_out, DenseTensor* mom1_out,
...@@ -58,6 +60,7 @@ void LambKernel(const Context& dev_ctx, ...@@ -58,6 +60,7 @@ void LambKernel(const Context& dev_ctx,
float beta1, float beta1,
float beta2, float beta2,
float epsilon, float epsilon,
bool always_adapt,
bool multi_precision, bool multi_precision,
DenseTensor* param_out, DenseTensor* param_out,
DenseTensor* moment1_out, DenseTensor* moment1_out,
...@@ -81,6 +84,7 @@ void LambKernel(const Context& dev_ctx, ...@@ -81,6 +84,7 @@ void LambKernel(const Context& dev_ctx,
beta1, beta1,
beta2, beta2,
epsilon, epsilon,
always_adapt,
multi_precision, multi_precision,
param_out, param_out,
moment1_out, moment1_out,
...@@ -103,6 +107,7 @@ void LambKernel(const Context& dev_ctx, ...@@ -103,6 +107,7 @@ void LambKernel(const Context& dev_ctx,
beta1, beta1,
beta2, beta2,
epsilon, epsilon,
always_adapt,
multi_precision, multi_precision,
param_out, param_out,
moment1_out, moment1_out,
...@@ -128,6 +133,7 @@ void ComputeImpl(const Context& dev_ctx, ...@@ -128,6 +133,7 @@ void ComputeImpl(const Context& dev_ctx,
float beta1_f, float beta1_f,
float beta2_f, float beta2_f,
float epsilon_f, float epsilon_f,
bool always_adapt,
bool multi_precision UNUSED, bool multi_precision UNUSED,
DenseTensor* param_out, DenseTensor* param_out,
DenseTensor* mom1_out, DenseTensor* mom1_out,
...@@ -232,16 +238,28 @@ void ComputeImpl(const Context& dev_ctx, ...@@ -232,16 +238,28 @@ void ComputeImpl(const Context& dev_ctx,
// paddle/phi/kernels/selected_rows/impl/lamb_kernel_impl.h Please modify it // paddle/phi/kernels/selected_rows/impl/lamb_kernel_impl.h Please modify it
// together // together
DenseTensor p_norm_t; DenseTensor p_norm_t;
p_norm_t.Resize(phi::make_ddim({1})); DataType dtype = phi::CppTypeToDataType<MT>::Type();
auto* p_norm_ptr = dev_ctx.template Alloc<MT>(&p_norm_t); FullKernel<MT, Context>(
dev_ctx, std::vector<int64_t>({1}), 0, dtype, &p_norm_t);
auto* p_norm_ptr = p_norm_t.data<MT>();
DenseTensor trust_ratio_div_norm_t; DenseTensor trust_ratio_div_norm_t;
trust_ratio_div_norm_t.Resize(phi::make_ddim({1})); FullKernel<MT, Context>(
auto* trust_ratio_div_norm_ptr = dev_ctx, std::vector<int64_t>({1}), 0, dtype, &trust_ratio_div_norm_t);
dev_ctx.template Alloc<MT>(&trust_ratio_div_norm_t); auto* trust_ratio_div_norm_ptr = trust_ratio_div_norm_t.data<MT>();
// DenseTensor p_norm_t;
// p_norm_t.Resize(phi::make_ddim({1}));
// auto* p_norm_ptr = dev_ctx.template Alloc<MT>(&p_norm_t);
// DenseTensor trust_ratio_div_norm_t;
// trust_ratio_div_norm_t.Resize(phi::make_ddim({1}));
// auto* trust_ratio_div_norm_ptr =
// dev_ctx.template Alloc<MT>(&trust_ratio_div_norm_t);
// TODO(zengjinle): remove the following Eigen operations when // TODO(zengjinle): remove the following Eigen operations when
// *skip_update == true. // *skip_update == true.
if (weight_decay > static_cast<MT>(0) || always_adapt) {
memory_utils::Buffer buffer(dev_ctx.GetPlace()); memory_utils::Buffer buffer(dev_ctx.GetPlace());
phi::funcs::SquaredL2Norm( phi::funcs::SquaredL2Norm(
dev_ctx, dev_ctx,
...@@ -252,6 +270,7 @@ void ComputeImpl(const Context& dev_ctx, ...@@ -252,6 +270,7 @@ void ComputeImpl(const Context& dev_ctx,
&buffer); &buffer);
phi::funcs::SquaredL2Norm( phi::funcs::SquaredL2Norm(
dev_ctx, trust_ratio_div_ptr, trust_ratio_div_norm_ptr, numel, &buffer); dev_ctx, trust_ratio_div_ptr, trust_ratio_div_norm_ptr, numel, &buffer);
}
if (VLOG_IS_ON(1)) { if (VLOG_IS_ON(1)) {
const auto& name = "Param"; const auto& name = "Param";
......
...@@ -33,6 +33,7 @@ void LambKernel(const Context& dev_ctx, ...@@ -33,6 +33,7 @@ void LambKernel(const Context& dev_ctx,
float beta1, float beta1,
float beta2, float beta2,
float epsilon, float epsilon,
bool always_adapt,
bool multi_precision, bool multi_precision,
DenseTensor* param_out, DenseTensor* param_out,
DenseTensor* moment1_out, DenseTensor* moment1_out,
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/lamb_functors.h" #include "paddle/phi/kernels/funcs/lamb_functors.h"
#include "paddle/phi/kernels/funcs/selected_rows_functor.h" #include "paddle/phi/kernels/funcs/selected_rows_functor.h"
...@@ -38,6 +39,7 @@ void ComputeRowImpl(const Context& dev_ctx, ...@@ -38,6 +39,7 @@ void ComputeRowImpl(const Context& dev_ctx,
float beta1_f, float beta1_f,
float beta2_f, float beta2_f,
float epsilon_f, float epsilon_f,
bool always_adapt,
bool multi_precision, bool multi_precision,
DenseTensor* param_out, DenseTensor* param_out,
DenseTensor* mom1_out, DenseTensor* mom1_out,
...@@ -61,6 +63,7 @@ void LambKernel(const Context& dev_ctx, ...@@ -61,6 +63,7 @@ void LambKernel(const Context& dev_ctx,
float beta1, float beta1,
float beta2, float beta2,
float epsilon, float epsilon,
bool always_adapt,
bool multi_precision, bool multi_precision,
DenseTensor* param_out, DenseTensor* param_out,
DenseTensor* moment1_out, DenseTensor* moment1_out,
...@@ -84,6 +87,7 @@ void LambKernel(const Context& dev_ctx, ...@@ -84,6 +87,7 @@ void LambKernel(const Context& dev_ctx,
beta1, beta1,
beta2, beta2,
epsilon, epsilon,
always_adapt,
multi_precision, multi_precision,
param_out, param_out,
moment1_out, moment1_out,
...@@ -106,6 +110,7 @@ void LambKernel(const Context& dev_ctx, ...@@ -106,6 +110,7 @@ void LambKernel(const Context& dev_ctx,
beta1, beta1,
beta2, beta2,
epsilon, epsilon,
always_adapt,
multi_precision, multi_precision,
param_out, param_out,
moment1_out, moment1_out,
...@@ -131,6 +136,7 @@ void ComputeRowImpl(const Context& dev_ctx, ...@@ -131,6 +136,7 @@ void ComputeRowImpl(const Context& dev_ctx,
float beta1_f, float beta1_f,
float beta2_f, float beta2_f,
float epsilon_f, float epsilon_f,
bool always_adapt,
bool multi_precision UNUSED, bool multi_precision UNUSED,
DenseTensor* param_out, DenseTensor* param_out,
DenseTensor* mom1_out, DenseTensor* mom1_out,
...@@ -285,17 +291,30 @@ void ComputeRowImpl(const Context& dev_ctx, ...@@ -285,17 +291,30 @@ void ComputeRowImpl(const Context& dev_ctx,
// Update parameter // Update parameter
// The code in the following part is exactly the same as that in // The code in the following part is exactly the same as that in
// paddle/phi/kernels/impl/lamb_kernel_impl.h Please modify it together // paddle/phi/kernels/impl/lamb_kernel_impl.h Please modify it together
// DenseTensor p_norm_t;
// p_norm_t.Resize(phi::make_ddim({1}));
// auto* p_norm_ptr = dev_ctx.template Alloc<MT>(&p_norm_t);
// DenseTensor trust_ratio_div_norm_t;
// trust_ratio_div_norm_t.Resize(phi::make_ddim({1}));
// auto* trust_ratio_div_norm_ptr =
// dev_ctx.template Alloc<MT>(&trust_ratio_div_norm_t);
DenseTensor p_norm_t; DenseTensor p_norm_t;
p_norm_t.Resize(phi::make_ddim({1})); DataType dtype = phi::CppTypeToDataType<MT>::Type();
auto* p_norm_ptr = dev_ctx.template Alloc<MT>(&p_norm_t); FullKernel<MT, Context>(
dev_ctx, std::vector<int64_t>({1}), 0, dtype, &p_norm_t);
auto* p_norm_ptr = p_norm_t.data<MT>();
DenseTensor trust_ratio_div_norm_t; DenseTensor trust_ratio_div_norm_t;
trust_ratio_div_norm_t.Resize(phi::make_ddim({1})); FullKernel<MT, Context>(
auto* trust_ratio_div_norm_ptr = dev_ctx, std::vector<int64_t>({1}), 0, dtype, &trust_ratio_div_norm_t);
dev_ctx.template Alloc<MT>(&trust_ratio_div_norm_t); auto* trust_ratio_div_norm_ptr = trust_ratio_div_norm_t.data<MT>();
// TODO(zengjinle): remove the following Eigen operations when // TODO(zengjinle): remove the following Eigen operations when
// *skip_update == true. // *skip_update == true.
if (weight_decay > static_cast<MT>(0) || always_adapt) {
memory_utils::Buffer buffer(dev_ctx.GetPlace()); memory_utils::Buffer buffer(dev_ctx.GetPlace());
phi::funcs::SquaredL2Norm( phi::funcs::SquaredL2Norm(
dev_ctx, dev_ctx,
...@@ -306,6 +325,7 @@ void ComputeRowImpl(const Context& dev_ctx, ...@@ -306,6 +325,7 @@ void ComputeRowImpl(const Context& dev_ctx,
&buffer); &buffer);
phi::funcs::SquaredL2Norm( phi::funcs::SquaredL2Norm(
dev_ctx, trust_ratio_div_ptr, trust_ratio_div_norm_ptr, numel, &buffer); dev_ctx, trust_ratio_div_ptr, trust_ratio_div_norm_ptr, numel, &buffer);
}
if (VLOG_IS_ON(1)) { if (VLOG_IS_ON(1)) {
const auto& name = "Param"; const auto& name = "Param";
......
...@@ -34,6 +34,7 @@ void LambKernel(const Context& dev_ctx, ...@@ -34,6 +34,7 @@ void LambKernel(const Context& dev_ctx,
float beta1, float beta1,
float beta2, float beta2,
float epsilon, float epsilon,
bool always_adapt,
bool multi_precision, bool multi_precision,
DenseTensor* param_out, DenseTensor* param_out,
DenseTensor* moment1_out, DenseTensor* moment1_out,
......
...@@ -38,6 +38,7 @@ void LambKernel(const Context& dev_ctx, ...@@ -38,6 +38,7 @@ void LambKernel(const Context& dev_ctx,
float beta1, float beta1,
float beta2, float beta2,
float epsilon, float epsilon,
bool always_adapt,
bool multi_precision, bool multi_precision,
DenseTensor* param_outs, DenseTensor* param_outs,
DenseTensor* moment1_out, DenseTensor* moment1_out,
......
...@@ -71,6 +71,9 @@ class Lamb(Optimizer): ...@@ -71,6 +71,9 @@ class Lamb(Optimizer):
( :ref:`api_paddle_fluid_clip_ClipGradByGlobalNorm` , :ref:`api_paddle_fluid_clip_ClipGradByNorm` , ( :ref:`api_paddle_fluid_clip_ClipGradByGlobalNorm` , :ref:`api_paddle_fluid_clip_ClipGradByNorm` ,
:ref:`api_paddle_fluid_clip_ClipGradByValue` ). If you want better convergence, it is recommended :ref:`api_paddle_fluid_clip_ClipGradByValue` ). If you want better convergence, it is recommended
to use :ref:`api_paddle_fluid_clip_ClipGradByGlobalNorm` . Default None, meaning there is no gradient clipping. to use :ref:`api_paddle_fluid_clip_ClipGradByGlobalNorm` . Default None, meaning there is no gradient clipping.
exclude_from_weight_decay_fn (function, optional): whether to skip weight decay for a parameter when this function returns True while take the parameter as input.
always_adapt (bool, optional): whether to use Layer-wise LR adaptation. By default, skip adaptation on parameters that are
excluded from weight decay, unless always_adapt == True, then always enable LR adaptation.
name(str|None): For detailed information, please refer to name(str|None): For detailed information, please refer to
:ref:`api_guide_Name` . Usually name is no need to set and None by default. :ref:`api_guide_Name` . Usually name is no need to set and None by default.
Examples: Examples:
...@@ -106,6 +109,7 @@ class Lamb(Optimizer): ...@@ -106,6 +109,7 @@ class Lamb(Optimizer):
grad_clip=None, grad_clip=None,
exclude_from_weight_decay_fn=None, exclude_from_weight_decay_fn=None,
multi_precision=False, multi_precision=False,
always_adapt=False,
name=None, name=None,
): ):
assert learning_rate is not None assert learning_rate is not None
...@@ -136,6 +140,7 @@ class Lamb(Optimizer): ...@@ -136,6 +140,7 @@ class Lamb(Optimizer):
self._used_master_weights = {} self._used_master_weights = {}
# TODO(zengjinle): expose API as soon as possible # TODO(zengjinle): expose API as soon as possible
self._multi_precision = multi_precision self._multi_precision = multi_precision
self.always_adapt = always_adapt
def _get_parameter(self, name, scope=None): def _get_parameter(self, name, scope=None):
if scope is None: if scope is None:
...@@ -253,6 +258,7 @@ class Lamb(Optimizer): ...@@ -253,6 +258,7 @@ class Lamb(Optimizer):
self._beta1, self._beta1,
self._beta2, self._beta2,
self._epsilon, self._epsilon,
self.always_adapt,
find_master, find_master,
) )
return None return None
...@@ -279,6 +285,7 @@ class Lamb(Optimizer): ...@@ -279,6 +285,7 @@ class Lamb(Optimizer):
"beta2": self._beta2, "beta2": self._beta2,
"epsilon": self._epsilon, "epsilon": self._epsilon,
"weight_decay": weight_decay, "weight_decay": weight_decay,
"always_adapt": self.always_adapt,
"multi_precision": find_master, "multi_precision": find_master,
} }
......
...@@ -158,6 +158,7 @@ def run_model(use_distributed_lamb, use_fp16, use_master_param_norm, **kwargs): ...@@ -158,6 +158,7 @@ def run_model(use_distributed_lamb, use_fp16, use_master_param_norm, **kwargs):
else: else:
optimizer_class = paddle.optimizer.Lamb optimizer_class = paddle.optimizer.Lamb
kwargs = dict(kwargs) kwargs = dict(kwargs)
kwargs['always_adapt'] = True
kwargs.pop('clip_after_allreduce', None) kwargs.pop('clip_after_allreduce', None)
kwargs.pop('alignment', None) kwargs.pop('alignment', None)
kwargs.pop('use_master_acc_grad', None) kwargs.pop('use_master_acc_grad', None)
......
...@@ -38,6 +38,7 @@ def lamb_wrapper( ...@@ -38,6 +38,7 @@ def lamb_wrapper(
beta1=0.9, beta1=0.9,
beta2=0.999, beta2=0.999,
weight_decay=0.01, weight_decay=0.01,
always_adapt=False,
): ):
return paddle._C_ops.lamb_( return paddle._C_ops.lamb_(
param, param,
...@@ -54,6 +55,7 @@ def lamb_wrapper( ...@@ -54,6 +55,7 @@ def lamb_wrapper(
beta2, beta2,
epsilon, epsilon,
False, False,
False,
) )
...@@ -64,6 +66,7 @@ class TestLambOp1(OpTest): ...@@ -64,6 +66,7 @@ class TestLambOp1(OpTest):
'beta1': 0.78, 'beta1': 0.78,
'beta2': 0.836, 'beta2': 0.836,
'weight_decay': 0.01, 'weight_decay': 0.01,
'always_adapt': False,
} }
def setUp(self): def setUp(self):
...@@ -120,6 +123,18 @@ class TestLambOp2(TestLambOp1): ...@@ -120,6 +123,18 @@ class TestLambOp2(TestLambOp1):
'beta1': 0.9, 'beta1': 0.9,
'beta2': 0.999, 'beta2': 0.999,
'weight_decay': 0.01, 'weight_decay': 0.01,
'always_adapt': False,
}
class TestLambOp3(TestLambOp1):
def set_attrs(self):
self.attrs = {
'epsilon': 1e-8,
'beta1': 0.9,
'beta2': 0.999,
'weight_decay': 0.0,
'always_adapt': False,
} }
...@@ -130,6 +145,7 @@ class TestLambOpMultipleSteps(TestLambOp1): ...@@ -130,6 +145,7 @@ class TestLambOpMultipleSteps(TestLambOp1):
'beta1': 0.9, 'beta1': 0.9,
'beta2': 0.999, 'beta2': 0.999,
'weight_decay': 0.01, 'weight_decay': 0.01,
'always_adapt': False,
} }
self.num_steps = 10 self.num_steps = 10
...@@ -189,6 +205,7 @@ def lamb_step(inputs, attributes): ...@@ -189,6 +205,7 @@ def lamb_step(inputs, attributes):
beta2 = attributes['beta2'] beta2 = attributes['beta2']
epsilon = attributes['epsilon'] epsilon = attributes['epsilon']
weight_decay = attributes['weight_decay'] weight_decay = attributes['weight_decay']
always_adapt = attributes['always_adapt']
moment1_out = beta1 * moment1 + (1 - beta1) * grad moment1_out = beta1 * moment1 + (1 - beta1) * grad
moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad) moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad)
...@@ -196,12 +213,15 @@ def lamb_step(inputs, attributes): ...@@ -196,12 +213,15 @@ def lamb_step(inputs, attributes):
moment1_unbiased = moment1_out / (1 - beta1_pow) moment1_unbiased = moment1_out / (1 - beta1_pow)
moment2_unbiased = moment2_out / (1 - beta2_pow) moment2_unbiased = moment2_out / (1 - beta2_pow)
if weight_decay > 0 or always_adapt:
r_1 = np.linalg.norm(param) r_1 = np.linalg.norm(param)
r_2 = np.linalg.norm( r_2 = np.linalg.norm(
moment1_unbiased / (np.sqrt(moment2_unbiased) + epsilon) moment1_unbiased / (np.sqrt(moment2_unbiased) + epsilon)
+ weight_decay * param + weight_decay * param
) )
lr_t = lr * r_1 / r_2 lr_t = lr * r_1 / r_2
else:
lr_t = lr
param_out = param - lr_t * ( param_out = param - lr_t * (
moment1_unbiased / (np.sqrt(moment2_unbiased) + epsilon) moment1_unbiased / (np.sqrt(moment2_unbiased) + epsilon)
...@@ -234,6 +254,7 @@ def lamb_step_sparse(inputs, attributes, height, rows, row_numel, np_grad): ...@@ -234,6 +254,7 @@ def lamb_step_sparse(inputs, attributes, height, rows, row_numel, np_grad):
beta2 = attributes['beta2'] beta2 = attributes['beta2']
epsilon = attributes['epsilon'] epsilon = attributes['epsilon']
weight_decay = attributes['weight_decay'] weight_decay = attributes['weight_decay']
always_adapt = attributes['always_adapt']
moment1_out = np.zeros(shape=[height, row_numel]) moment1_out = np.zeros(shape=[height, row_numel])
moment2_out = np.zeros(shape=[height, row_numel]) moment2_out = np.zeros(shape=[height, row_numel])
...@@ -256,13 +277,16 @@ def lamb_step_sparse(inputs, attributes, height, rows, row_numel, np_grad): ...@@ -256,13 +277,16 @@ def lamb_step_sparse(inputs, attributes, height, rows, row_numel, np_grad):
update_value update_value
) )
def update_param(): def update_param(weight_decay, always_adapt):
if weight_decay > 0 or always_adapt:
r_1 = np.linalg.norm(param) r_1 = np.linalg.norm(param)
r_2 = np.linalg.norm( r_2 = np.linalg.norm(
moment1_out / (np.sqrt(moment2_out) + epsilon) moment1_out / (np.sqrt(moment2_out) + epsilon)
+ weight_decay * param + weight_decay * param
) )
lr_t = lr * r_1 / r_2 lr_t = lr * r_1 / r_2
else:
lr_t = lr
param_out = param - lr_t * ( param_out = param - lr_t * (
moment1_out / (np.sqrt(moment2_out) + epsilon) moment1_out / (np.sqrt(moment2_out) + epsilon)
...@@ -275,7 +299,7 @@ def lamb_step_sparse(inputs, attributes, height, rows, row_numel, np_grad): ...@@ -275,7 +299,7 @@ def lamb_step_sparse(inputs, attributes, height, rows, row_numel, np_grad):
update_value = np_grad[rows.index(row_id)] update_value = np_grad[rows.index(row_id)]
update_mom(row_id, update_value) update_mom(row_id, update_value)
update_param() update_param(weight_decay, always_adapt)
beta1_pow_out = beta1_pow * beta1 beta1_pow_out = beta1_pow * beta1
beta2_pow_out = beta2_pow * beta2 beta2_pow_out = beta2_pow * beta2
...@@ -307,6 +331,7 @@ class TestSparseLambOp(unittest.TestCase): ...@@ -307,6 +331,7 @@ class TestSparseLambOp(unittest.TestCase):
'beta1': beta1, 'beta1': beta1,
'beta2': beta2, 'beta2': beta2,
'weight_decay': 0.05, 'weight_decay': 0.05,
'always_adapt': False,
} }
grad_selected_rows = scope.var('Grad').get_selected_rows() grad_selected_rows = scope.var('Grad').get_selected_rows()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册