Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
3672480b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3672480b
编写于
12月 29, 2021
作者:
S
sneaxiy
提交者:
GitHub
12月 29, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix lamb beta1pow beta2pow update (#38518)
上级
72a41e50
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
108 addition
and
72 deletion
+108
-72
paddle/fluid/operators/optimizers/lamb_op.h
paddle/fluid/operators/optimizers/lamb_op.h
+108
-72
未找到文件。
paddle/fluid/operators/optimizers/lamb_op.h
浏览文件 @
3672480b
...
...
@@ -52,19 +52,16 @@ struct LambMomentREGUpdateFunctor {
const
bool
*
skip_update_
;
LambMomentREGUpdateFunctor
(
MT
weight_decay
,
MT
beta1
,
MT
beta2
,
MT
epsilon
,
MT
beta1_pow
,
MT
*
beta1_pow_out
,
MT
beta2_pow
,
MT
*
beta2_pow_out
,
const
MT
*
mom1
,
MT
*
mom1_out
,
const
MT
*
mom2
,
MT
*
mom2_out
,
const
T
*
grad
,
const
MT
*
param
,
MT
*
trust_ratio_div
,
const
bool
*
skip_update
)
MT
beta1_pow
,
MT
beta2_pow
,
const
MT
*
mom1
,
MT
*
mom1_out
,
const
MT
*
mom2
,
MT
*
mom2_out
,
const
T
*
grad
,
const
MT
*
param
,
MT
*
trust_ratio_div
,
const
bool
*
skip_update
)
:
weight_decay_
(
weight_decay
),
beta1_
(
beta1
),
beta2_
(
beta2
),
epsilon_
(
epsilon
),
beta1_pow_
(
beta1_pow
),
beta1_pow_out_
(
beta1_pow_out
),
beta2_pow_
(
beta2_pow
),
beta2_pow_out_
(
beta2_pow_out
),
moment1_
(
mom1
),
moment1_out_
(
mom1_out
),
moment2_
(
mom2
),
...
...
@@ -95,10 +92,6 @@ struct LambMomentREGUpdateFunctor {
trust_ratio_div_
[
i
]
=
mom1_unbiased
/
(
Eigen
::
numext
::
sqrt
(
mom2_unbiased
)
+
epsilon_
)
+
weight_decay_
*
p
;
if
(
beta1_pow_out_
&&
beta2_pow_out_
)
{
beta1_pow_out_
[
0
]
=
beta1_pow
*
beta1_
;
beta2_pow_out_
[
0
]
=
beta2_pow
*
beta2_
;
}
}
};
...
...
@@ -113,9 +106,7 @@ struct LambMomentMENUpdateFunctor {
MT
epsilon_
;
const
MT
*
beta1_pow_
;
MT
*
beta1_pow_out_
;
const
MT
*
beta2_pow_
;
MT
*
beta2_pow_out_
;
const
MT
*
moment1_
;
MT
*
moment1_out_
;
const
MT
*
moment2_
;
...
...
@@ -126,8 +117,7 @@ struct LambMomentMENUpdateFunctor {
const
bool
*
skip_update_
;
LambMomentMENUpdateFunctor
(
MT
weight_decay
,
MT
beta1
,
MT
beta2
,
MT
epsilon
,
const
MT
*
beta1_pow
,
MT
*
beta1_pow_out
,
const
MT
*
beta2_pow
,
MT
*
beta2_pow_out
,
const
MT
*
beta1_pow
,
const
MT
*
beta2_pow
,
const
MT
*
mom1
,
MT
*
mom1_out
,
const
MT
*
mom2
,
MT
*
mom2_out
,
const
T
*
grad
,
const
MT
*
param
,
MT
*
trust_ratio_div
,
const
bool
*
skip_update
)
...
...
@@ -136,9 +126,7 @@ struct LambMomentMENUpdateFunctor {
beta2_
(
beta2
),
epsilon_
(
epsilon
),
beta1_pow_
(
beta1_pow
),
beta1_pow_out_
(
beta1_pow_out
),
beta2_pow_
(
beta2_pow
),
beta2_pow_out_
(
beta2_pow_out
),
moment1_
(
mom1
),
moment1_out_
(
mom1_out
),
moment2_
(
mom2
),
...
...
@@ -168,10 +156,6 @@ struct LambMomentMENUpdateFunctor {
trust_ratio_div_
[
i
]
=
mom1_unbiased
/
(
Eigen
::
numext
::
sqrt
(
mom2_unbiased
)
+
epsilon_
)
+
weight_decay_
*
p
;
if
(
beta1_pow_out_
&&
beta2_pow_out_
)
{
beta1_pow_out_
[
0
]
=
beta1_pow
*
beta1_
;
beta2_pow_out_
[
0
]
=
beta2_pow
*
beta2_
;
}
}
};
...
...
@@ -183,9 +167,7 @@ struct SparseLambMomentREGUpdateFunctor {
T
epsilon_
;
T
beta1_pow_
;
T
*
beta1_pow_out_
;
T
beta2_pow_
;
T
*
beta2_pow_out_
;
const
T
*
moment1_
;
T
*
moment1_out_
;
const
T
*
moment2_
;
...
...
@@ -201,20 +183,18 @@ struct SparseLambMomentREGUpdateFunctor {
const
bool
*
skip_update_
;
SparseLambMomentREGUpdateFunctor
(
T
weight_decay
,
T
beta1
,
T
beta2
,
T
epsilon
,
T
beta1_pow
,
T
*
beta1_pow_out
,
T
beta2_pow
,
T
*
beta2_pow_out
,
const
T
*
mom1
,
T
*
mom1
_out
,
const
T
*
mom2
,
T
*
mom2_out
,
const
T
*
grad
,
const
T
*
param
,
T
*
trust_ratio_div
,
const
int64_t
*
rows
,
int64_t
row_numel
,
int64_t
row_count
,
const
bool
*
skip_update
)
T
beta1_pow
,
T
beta2_pow
,
const
T
*
mom1
,
T
*
mom1_out
,
const
T
*
mom2
,
T
*
mom2
_out
,
const
T
*
grad
,
const
T
*
param
,
T
*
trust_ratio_div
,
const
int64_t
*
rows
,
int64_t
row_numel
,
int64_t
row_count
,
const
bool
*
skip_update
)
:
weight_decay_
(
weight_decay
),
beta1_
(
beta1
),
beta2_
(
beta2
),
epsilon_
(
epsilon
),
beta1_pow_
(
beta1_pow
),
beta1_pow_out_
(
beta1_pow_out
),
beta2_pow_
(
beta2_pow
),
beta2_pow_out_
(
beta2_pow_out
),
moment1_
(
mom1
),
moment1_out_
(
mom1_out
),
moment2_
(
mom2
),
...
...
@@ -246,10 +226,6 @@ struct SparseLambMomentREGUpdateFunctor {
trust_ratio_div_
[
i
]
=
mom1_unbiased
/
(
Eigen
::
numext
::
sqrt
(
mom2_unbiased
)
+
epsilon_
)
+
weight_decay_
*
p
;
if
(
beta1_pow_out_
&&
beta1_pow_out_
)
{
beta1_pow_out_
[
0
]
=
beta1_pow
*
beta1_
;
beta2_pow_out_
[
0
]
=
beta2_pow
*
beta2_
;
}
}
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
const
{
...
...
@@ -270,9 +246,7 @@ struct SparseLambMomentMENUpdateFunctor {
T
epsilon_
;
const
T
*
beta1_pow_
;
T
*
beta1_pow_out_
;
const
T
*
beta2_pow_
;
T
*
beta2_pow_out_
;
const
T
*
moment1_
;
T
*
moment1_out_
;
const
T
*
moment2_
;
...
...
@@ -288,8 +262,7 @@ struct SparseLambMomentMENUpdateFunctor {
const
bool
*
skip_update_
;
SparseLambMomentMENUpdateFunctor
(
T
weight_decay
,
T
beta1
,
T
beta2
,
T
epsilon
,
const
T
*
beta1_pow
,
T
*
beta1_pow_out
,
const
T
*
beta2_pow
,
T
*
beta2_pow_out
,
const
T
*
beta1_pow
,
const
T
*
beta2_pow
,
const
T
*
mom1
,
T
*
mom1_out
,
const
T
*
mom2
,
T
*
mom2_out
,
const
T
*
grad
,
const
T
*
param
,
T
*
trust_ratio_div
,
const
int64_t
*
rows
,
...
...
@@ -300,9 +273,7 @@ struct SparseLambMomentMENUpdateFunctor {
beta2_
(
beta2
),
epsilon_
(
epsilon
),
beta1_pow_
(
beta1_pow
),
beta1_pow_out_
(
beta1_pow_out
),
beta2_pow_
(
beta2_pow
),
beta2_pow_out_
(
beta2_pow_out
),
moment1_
(
mom1
),
moment1_out_
(
mom1_out
),
moment2_
(
mom2
),
...
...
@@ -334,10 +305,6 @@ struct SparseLambMomentMENUpdateFunctor {
trust_ratio_div_
[
i
]
=
mom1_unbiased
/
(
Eigen
::
numext
::
sqrt
(
mom2_unbiased
)
+
epsilon_
)
+
weight_decay_
*
p
;
if
(
beta1_pow_out_
&&
beta1_pow_out_
)
{
beta1_pow_out_
[
0
]
=
beta1_pow
*
beta1_
;
beta2_pow_out_
[
0
]
=
beta2_pow
*
beta2_
;
}
}
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
const
{
...
...
@@ -350,11 +317,44 @@ struct SparseLambMomentMENUpdateFunctor {
}
};
template
<
typename
T
,
bool
IsMultiPrecision
>
struct
LambParamUpateFunctor
{
using
MT
=
typename
std
::
conditional
<
IsMultiPrecision
,
typename
details
::
MPTypeTrait
<
T
>::
Type
,
T
>::
type
;
template
<
typename
MT
,
bool
NeedUpdateBetaPow
/*=true*/
>
struct
LambBetaPowUpdateFunctor
{
void
SetBetaPows
(
const
MT
*
beta1pow
,
const
MT
*
beta2pow
,
MT
*
beta1pow_out
,
MT
*
beta2pow_out
,
MT
beta1
,
MT
beta2
)
{
beta1pow_
=
beta1pow
;
beta2pow_
=
beta2pow
;
beta1pow_out_
=
beta1pow_out
;
beta2pow_out_
=
beta2pow_out
;
beta1_
=
beta1
;
beta2_
=
beta2
;
}
HOSTDEVICE
void
UpdateBetaPow
(
size_t
i
)
const
{
if
(
i
==
0
)
{
beta1pow_out_
[
0
]
=
beta1pow_
[
0
]
*
beta1_
;
beta2pow_out_
[
0
]
=
beta2pow_
[
0
]
*
beta2_
;
}
}
private:
const
MT
*
beta1pow_
;
const
MT
*
beta2pow_
;
MT
*
beta1pow_out_
;
MT
*
beta2pow_out_
;
MT
beta1_
;
MT
beta2_
;
};
template
<
typename
MT
>
struct
LambBetaPowUpdateFunctor
<
MT
,
/*NeedUpdateBetaPow=*/
false
>
{
void
SetBetaPows
(
const
MT
*
beta1pow
,
const
MT
*
beta2pow
,
MT
*
beta1pow_out
,
MT
*
beta2pow_out
,
MT
beta1
,
MT
beta2
)
{}
HOSTDEVICE
void
UpdateBetaPow
(
size_t
)
const
{}
};
template
<
typename
T
,
typename
MT
,
bool
IsMultiPrecision
,
bool
UpdateBetaPow
>
struct
LambParamUpateFunctor
:
public
LambBetaPowUpdateFunctor
<
MT
,
UpdateBetaPow
>
{
const
MT
*
lr_
;
const
T
*
param_
;
const
MT
*
master_param_
;
...
...
@@ -396,6 +396,7 @@ struct LambParamUpateFunctor {
if
(
IsMultiPrecision
)
{
master_param_out_
[
i
]
=
param_out
;
}
this
->
UpdateBetaPow
(
i
);
}
};
...
...
@@ -501,6 +502,11 @@ class LambOpKernel : public framework::OpKernel<T> {
:
nullptr
;
// Update moments
bool
should_update_beta_pow_later
=
false
;
const
MT
*
beta1_pow_ptr
=
nullptr
,
*
beta2_pow_ptr
=
nullptr
;
MT
*
beta1_pow_out_ptr
=
nullptr
,
*
beta2_pow_out_ptr
=
nullptr
;
VLOG
(
10
)
<<
"Beta1Pow place: "
<<
beta1_pow
.
place
()
<<
" , Beta2Pow place: "
<<
beta2_pow
.
place
();
if
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
&
grad
=
grad_var
->
Get
<
framework
::
LoDTensor
>
();
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
())
&&
...
...
@@ -508,8 +514,7 @@ class LambOpKernel : public framework::OpKernel<T> {
beta2_pow
.
place
()
==
platform
::
CPUPlace
())
{
LambMomentREGUpdateFunctor
<
T
,
IsMultiPrecision
>
moment_update_functor
(
weight_decay
,
beta1
,
beta2
,
epsilon
,
*
beta1_pow
.
template
data
<
MT
>(),
nullptr
,
*
beta2_pow
.
template
data
<
MT
>(),
nullptr
,
mom1
.
template
data
<
MT
>(),
*
beta2_pow
.
template
data
<
MT
>(),
mom1
.
template
data
<
MT
>(),
mom1_out
.
template
mutable_data
<
MT
>(
ctx
.
GetPlace
()),
mom2
.
template
data
<
MT
>(),
mom2_out
.
template
mutable_data
<
MT
>(
ctx
.
GetPlace
()),
...
...
@@ -523,12 +528,17 @@ class LambOpKernel : public framework::OpKernel<T> {
beta2_pow_out
.
template
mutable_data
<
MT
>(
platform
::
CPUPlace
())[
0
]
=
beta2
*
beta2_pow
.
template
data
<
MT
>()[
0
];
}
else
{
beta1_pow_ptr
=
beta1_pow
.
template
data
<
MT
>();
beta2_pow_ptr
=
beta2_pow
.
template
data
<
MT
>();
beta1_pow_out_ptr
=
beta1_pow_out
.
template
mutable_data
<
MT
>(
ctx
.
GetPlace
());
beta2_pow_out_ptr
=
beta2_pow_out
.
template
mutable_data
<
MT
>(
ctx
.
GetPlace
());
should_update_beta_pow_later
=
true
;
LambMomentMENUpdateFunctor
<
T
,
IsMultiPrecision
>
moment_update_functor
(
weight_decay
,
beta1
,
beta2
,
epsilon
,
beta1_pow
.
template
data
<
MT
>(),
beta1_pow_out
.
template
mutable_data
<
MT
>(
ctx
.
GetPlace
()),
beta2_pow
.
template
data
<
MT
>(),
beta2_pow_out
.
template
mutable_data
<
MT
>(
ctx
.
GetPlace
()),
mom1
.
template
data
<
MT
>(),
weight_decay
,
beta1
,
beta2
,
epsilon
,
static_cast
<
const
MT
*>
(
beta1_pow_ptr
),
static_cast
<
const
MT
*>
(
beta2_pow_ptr
),
mom1
.
template
data
<
MT
>(),
mom1_out
.
template
mutable_data
<
MT
>(
ctx
.
GetPlace
()),
mom2
.
template
data
<
MT
>(),
mom2_out
.
template
mutable_data
<
MT
>(
ctx
.
GetPlace
()),
...
...
@@ -542,7 +552,12 @@ class LambOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ
(
IsMultiPrecision
,
false
,
platform
::
errors
::
Unimplemented
(
"SelectedRows gradient is not supported when "
"multi_precision=True"
));
"multi_precision=True."
));
constexpr
bool
kIsSameType
=
std
::
is_same
<
T
,
MT
>::
value
;
PADDLE_ENFORCE_EQ
(
kIsSameType
,
true
,
platform
::
errors
::
Unimplemented
(
"SelectedRows gradient is not supported when "
"multi_precision=True."
));
auto
&
grad
=
GET_DATA_SAFELY
(
ctx
.
Input
<
framework
::
SelectedRows
>
(
"Grad"
),
"Input"
,
"Grad"
,
"Lamb"
);
if
(
grad
.
rows
().
size
()
==
0
)
{
...
...
@@ -582,8 +597,8 @@ class LambOpKernel : public framework::OpKernel<T> {
SparseLambMomentREGUpdateFunctor
<
T
>
moment_update_functor
(
static_cast
<
T
>
(
weight_decay
),
static_cast
<
T
>
(
beta1
),
static_cast
<
T
>
(
beta2
),
static_cast
<
T
>
(
epsilon
),
*
beta1_pow
.
template
data
<
T
>(),
nullptr
,
*
beta2_pow
.
template
data
<
T
>(),
nullptr
,
mom1
.
template
data
<
T
>(),
*
beta1_pow
.
template
data
<
T
>(),
*
beta2_pow
.
template
data
<
T
>()
,
mom1
.
template
data
<
T
>(),
mom1_out
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
mom2
.
template
data
<
T
>(),
mom2_out
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
grad_data
,
...
...
@@ -595,14 +610,18 @@ class LambOpKernel : public framework::OpKernel<T> {
beta2_pow_out
.
template
mutable_data
<
T
>(
platform
::
CPUPlace
())[
0
]
=
static_cast
<
T
>
(
beta2
)
*
beta2_pow
.
template
data
<
T
>()[
0
];
}
else
{
beta1_pow_ptr
=
beta1_pow
.
template
data
<
MT
>();
beta2_pow_ptr
=
beta2_pow
.
template
data
<
MT
>();
beta1_pow_out_ptr
=
beta1_pow_out
.
template
mutable_data
<
MT
>(
ctx
.
GetPlace
());
beta2_pow_out_ptr
=
beta2_pow_out
.
template
mutable_data
<
MT
>(
ctx
.
GetPlace
());
should_update_beta_pow_later
=
true
;
SparseLambMomentMENUpdateFunctor
<
T
>
moment_update_functor
(
static_cast
<
T
>
(
weight_decay
),
static_cast
<
T
>
(
beta1
),
static_cast
<
T
>
(
beta2
),
static_cast
<
T
>
(
epsilon
),
beta1_pow
.
template
data
<
T
>(),
beta1_pow_out
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
beta2_pow
.
template
data
<
T
>(),
beta2_pow_out
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
mom1
.
template
data
<
T
>(),
reinterpret_cast
<
const
T
*>
(
beta1_pow_ptr
),
reinterpret_cast
<
const
T
*>
(
beta2_pow_ptr
),
mom1
.
template
data
<
T
>(),
mom1_out
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
mom2
.
template
data
<
T
>(),
mom2_out
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
grad_data
,
...
...
@@ -639,14 +658,31 @@ class LambOpKernel : public framework::OpKernel<T> {
}
trust_ratio_div_norm
.
device
(
*
place
)
=
t
.
square
().
sum
().
sqrt
();
LambParamUpateFunctor
<
T
,
IsMultiPrecision
>
param_update_functor
(
lr
.
template
data
<
MT
>(),
static_cast
<
const
T
*>
(
param_ptr
),
static_cast
<
const
MT
*>
(
master_param_ptr
),
p_norm_t
.
template
data
<
MT
>(),
trust_ratio_div
.
template
data
<
MT
>(),
trust_ratio_div_norm_t
.
template
data
<
MT
>(),
static_cast
<
T
*>
(
param_out_ptr
),
static_cast
<
MT
*>
(
master_param_out_ptr
),
skip_update_flag
);
for_range
(
param_update_functor
);
#define CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(__should_update_beta_pow) \
do { \
LambParamUpateFunctor<T, MT, IsMultiPrecision, __should_update_beta_pow> \
param_update_functor( \
lr.template data<MT>(), static_cast<const T*>(param_ptr), \
static_cast<const MT*>(master_param_ptr), \
p_norm_t.template data<MT>(), trust_ratio_div.template data<MT>(), \
trust_ratio_div_norm_t.template data<MT>(), \
static_cast<T*>(param_out_ptr), \
static_cast<MT*>(master_param_out_ptr), skip_update_flag); \
if (__should_update_beta_pow) { \
param_update_functor.SetBetaPows(beta1_pow_ptr, beta2_pow_ptr, \
beta1_pow_out_ptr, beta2_pow_out_ptr, \
beta1, beta2); \
} \
for_range(param_update_functor); \
} while (0)
if
(
should_update_beta_pow_later
)
{
CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC
(
true
);
}
else
{
CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC
(
false
);
}
#undef CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录