Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
4dda18a8
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4dda18a8
编写于
10月 15, 2021
作者:
Z
Zeng Jinle
提交者:
GitHub
10月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix momentum ops (#36452)
上级
8566cc98
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
41 addition
and
35 deletion
+41
-35
paddle/fluid/operators/optimizers/momentum_op.h
paddle/fluid/operators/optimizers/momentum_op.h
+35
-32
python/paddle/fluid/tests/unittests/test_merged_momentum_op.py
...n/paddle/fluid/tests/unittests/test_merged_momentum_op.py
+6
-3
未找到文件。
paddle/fluid/operators/optimizers/momentum_op.h
浏览文件 @
4dda18a8
...
@@ -173,14 +173,15 @@ class CPUDenseMomentumFunctor {
...
@@ -173,14 +173,15 @@ class CPUDenseMomentumFunctor {
}
}
};
};
template
<
typename
T
,
typename
MT
,
typename
UpdateMethod
>
template
<
typename
T
,
typename
MT
,
RegularizationType
kRegType
,
typename
UpdateMethod
>
class
DenseMomentumFunctor
;
class
DenseMomentumFunctor
;
// NOTE(dzh) for performance.
// NOTE(dzh) for performance.
// avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two
// avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two
// functor.
// functor.
template
<
typename
T
,
typename
MT
>
template
<
typename
T
,
typename
MT
,
RegularizationType
kRegType
>
class
DenseMomentumFunctor
<
T
,
MT
,
UseNesterov
>
{
class
DenseMomentumFunctor
<
T
,
MT
,
kRegType
,
UseNesterov
>
{
private:
private:
const
T
*
param_
;
const
T
*
param_
;
const
T
*
grad_
;
const
T
*
grad_
;
...
@@ -193,7 +194,6 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
...
@@ -193,7 +194,6 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
T
*
param_out_
;
T
*
param_out_
;
MT
*
velocity_out_
;
MT
*
velocity_out_
;
MT
*
master_param_out_
;
MT
*
master_param_out_
;
const
RegularizationType
regularization_flag_
;
const
MT
regularization_coeff_
;
const
MT
regularization_coeff_
;
public:
public:
...
@@ -201,7 +201,6 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
...
@@ -201,7 +201,6 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
const
MultiPrecisionType
<
MT
>*
learning_rate
,
const
MultiPrecisionType
<
MT
>*
learning_rate
,
const
MT
*
master_param
,
const
MT
mu
,
const
MT
*
master_param
,
const
MT
mu
,
const
MT
rescale_grad
,
const
int64_t
num
,
const
MT
rescale_grad
,
const
int64_t
num
,
const
RegularizationType
regularization_flag
,
const
MT
regularization_coeff
,
T
*
param_out
,
const
MT
regularization_coeff
,
T
*
param_out
,
MT
*
velocity_out
,
MT
*
master_param_out
)
MT
*
velocity_out
,
MT
*
master_param_out
)
:
param_
(
param
),
:
param_
(
param
),
...
@@ -215,7 +214,6 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
...
@@ -215,7 +214,6 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
param_out_
(
param_out
),
param_out_
(
param_out
),
velocity_out_
(
velocity_out
),
velocity_out_
(
velocity_out
),
master_param_out_
(
master_param_out
),
master_param_out_
(
master_param_out
),
regularization_flag_
(
regularization_flag
),
regularization_coeff_
(
regularization_coeff
)
{}
regularization_coeff_
(
regularization_coeff
)
{}
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
const
{
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
const
{
// put memory access in register
// put memory access in register
...
@@ -225,9 +223,9 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
...
@@ -225,9 +223,9 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
const
MT
lr
=
static_cast
<
MT
>
(
lr_
[
0
]);
const
MT
lr
=
static_cast
<
MT
>
(
lr_
[
0
]);
const
MT
velocity
=
velocity_
[
i
];
const
MT
velocity
=
velocity_
[
i
];
grad
=
regularization_flag_
==
RegularizationType
::
kL2DECAY
if
(
kRegType
==
RegularizationType
::
kL2DECAY
)
{
?
grad
+
regularization_coeff_
*
param
grad
+=
regularization_coeff_
*
param
;
:
grad
;
}
MT
velocity_out
=
velocity
*
mu_
+
grad
;
MT
velocity_out
=
velocity
*
mu_
+
grad
;
MT
param_out
=
param
-
(
grad
+
velocity_out
*
mu_
)
*
lr
;
MT
param_out
=
param
-
(
grad
+
velocity_out
*
mu_
)
*
lr
;
...
@@ -240,8 +238,8 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
...
@@ -240,8 +238,8 @@ class DenseMomentumFunctor<T, MT, UseNesterov> {
}
}
};
};
template
<
typename
T
,
typename
MT
>
template
<
typename
T
,
typename
MT
,
RegularizationType
kRegType
>
class
DenseMomentumFunctor
<
T
,
MT
,
NoNesterov
>
{
class
DenseMomentumFunctor
<
T
,
MT
,
kRegType
,
NoNesterov
>
{
private:
private:
const
T
*
param_
;
const
T
*
param_
;
const
T
*
grad_
;
const
T
*
grad_
;
...
@@ -254,7 +252,6 @@ class DenseMomentumFunctor<T, MT, NoNesterov> {
...
@@ -254,7 +252,6 @@ class DenseMomentumFunctor<T, MT, NoNesterov> {
T
*
param_out_
;
T
*
param_out_
;
MT
*
velocity_out_
;
MT
*
velocity_out_
;
MT
*
master_param_out_
;
MT
*
master_param_out_
;
const
RegularizationType
regularization_flag_
;
const
MT
regularization_coeff_
;
const
MT
regularization_coeff_
;
public:
public:
...
@@ -262,7 +259,6 @@ class DenseMomentumFunctor<T, MT, NoNesterov> {
...
@@ -262,7 +259,6 @@ class DenseMomentumFunctor<T, MT, NoNesterov> {
const
MultiPrecisionType
<
MT
>*
learning_rate
,
const
MultiPrecisionType
<
MT
>*
learning_rate
,
const
MT
*
master_param
,
const
MT
mu
,
const
MT
*
master_param
,
const
MT
mu
,
const
MT
rescale_grad
,
const
int64_t
num
,
const
MT
rescale_grad
,
const
int64_t
num
,
const
RegularizationType
regularization_flag
,
const
MT
regularization_coeff
,
T
*
param_out
,
const
MT
regularization_coeff
,
T
*
param_out
,
MT
*
velocity_out
,
MT
*
master_param_out
)
MT
*
velocity_out
,
MT
*
master_param_out
)
:
param_
(
param
),
:
param_
(
param
),
...
@@ -276,7 +272,6 @@ class DenseMomentumFunctor<T, MT, NoNesterov> {
...
@@ -276,7 +272,6 @@ class DenseMomentumFunctor<T, MT, NoNesterov> {
param_out_
(
param_out
),
param_out_
(
param_out
),
velocity_out_
(
velocity_out
),
velocity_out_
(
velocity_out
),
master_param_out_
(
master_param_out
),
master_param_out_
(
master_param_out
),
regularization_flag_
(
regularization_flag
),
regularization_coeff_
(
regularization_coeff
)
{}
regularization_coeff_
(
regularization_coeff
)
{}
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
const
{
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
const
{
// put memory access in register
// put memory access in register
...
@@ -286,9 +281,9 @@ class DenseMomentumFunctor<T, MT, NoNesterov> {
...
@@ -286,9 +281,9 @@ class DenseMomentumFunctor<T, MT, NoNesterov> {
const
MT
lr
=
static_cast
<
MT
>
(
lr_
[
0
]);
const
MT
lr
=
static_cast
<
MT
>
(
lr_
[
0
]);
const
MT
velocity
=
velocity_
[
i
];
const
MT
velocity
=
velocity_
[
i
];
grad
=
regularization_flag_
==
RegularizationType
::
kL2DECAY
if
(
kRegType
==
RegularizationType
::
kL2DECAY
)
{
?
grad
+
regularization_coeff_
*
param
grad
+=
regularization_coeff_
*
param
;
:
grad
;
}
MT
velocity_out
=
velocity
*
mu_
+
grad
;
MT
velocity_out
=
velocity
*
mu_
+
grad
;
MT
param_out
=
param
-
lr
*
velocity_out
;
MT
param_out
=
param
-
lr
*
velocity_out
;
...
@@ -522,23 +517,31 @@ class MomentumOpKernel : public framework::OpKernel<T> {
...
@@ -522,23 +517,31 @@ class MomentumOpKernel : public framework::OpKernel<T> {
platform
::
ForRange
<
DeviceContext
>
for_range
(
platform
::
ForRange
<
DeviceContext
>
for_range
(
static_cast
<
const
DeviceContext
&>
(
ctx
.
device_context
()),
static_cast
<
const
DeviceContext
&>
(
ctx
.
device_context
()),
param
->
numel
());
param
->
numel
());
if
(
use_nesterov
)
{
#define PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(__nesterov, __reg_type) \
DenseMomentumFunctor
<
T
,
MT
,
UseNesterov
>
functor
(
DenseMomentumFunctor<T, MT, __reg_type, __nesterov> functor( \
param
->
data
<
T
>
(),
grad
->
data
<
T
>
(),
velocity
->
data
<
MT
>
(),
param->data<T>(), grad->data<T>(), velocity->data<MT>(), \
learning_rate
->
data
<
MPDType
>
(),
master_in_data
,
mu
,
rescale_grad
,
learning_rate->data<MPDType>(), master_in_data, mu, rescale_grad, \
param
->
numel
(),
regularization_flag
,
regularization_coeff
,
param->numel(), regularization_coeff, \
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
param_out->mutable_data<T>(ctx.GetPlace()), \
velocity_out
->
mutable_data
<
MT
>
(
ctx
.
GetPlace
()),
master_out_data
);
velocity_out->mutable_data<MT>(ctx.GetPlace()), master_out_data); \
for_range
(
functor
);
for_range(functor);
if
(
use_nesterov
)
{
if
(
regularization_flag
==
RegularizationType
::
kL2DECAY
)
{
PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL
(
UseNesterov
,
RegularizationType
::
kL2DECAY
);
}
else
{
PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL
(
UseNesterov
,
RegularizationType
::
kNONE
);
}
}
else
{
}
else
{
DenseMomentumFunctor
<
T
,
MT
,
NoNesterov
>
functor
(
if
(
regularization_flag
==
RegularizationType
::
kL2DECAY
)
{
param
->
data
<
T
>
(),
grad
->
data
<
T
>
(),
velocity
->
data
<
MT
>
()
,
PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL
(
NoNesterov
,
learning_rate
->
data
<
MPDType
>
(),
master_in_data
,
mu
,
rescale_grad
,
RegularizationType
::
kL2DECAY
);
param
->
numel
(),
regularization_flag
,
regularization_coeff
,
}
else
{
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
())
,
PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL
(
NoNesterov
,
velocity_out
->
mutable_data
<
MT
>
(
ctx
.
GetPlace
()),
master_out_data
);
RegularizationType
::
kNONE
);
for_range
(
functor
);
}
}
}
}
}
...
...
python/paddle/fluid/tests/unittests/test_merged_momentum_op.py
浏览文件 @
4dda18a8
...
@@ -102,7 +102,7 @@ def run_momentum_op(params,
...
@@ -102,7 +102,7 @@ def run_momentum_op(params,
'Param'
:
p
,
'Param'
:
p
,
'Grad'
:
g
,
'Grad'
:
g
,
'Velocity'
:
v
,
'Velocity'
:
v
,
'LearningRate'
:
lr_var
'LearningRate'
:
lr_var
,
}
}
outputs
=
{
'ParamOut'
:
p
,
'VelocityOut'
:
v
}
outputs
=
{
'ParamOut'
:
p
,
'VelocityOut'
:
v
}
if
multi_precision
:
if
multi_precision
:
...
@@ -115,7 +115,7 @@ def run_momentum_op(params,
...
@@ -115,7 +115,7 @@ def run_momentum_op(params,
'Param'
:
param_vars
,
'Param'
:
param_vars
,
'Grad'
:
grad_vars
,
'Grad'
:
grad_vars
,
'Velocity'
:
velocity_vars
,
'Velocity'
:
velocity_vars
,
'LearningRate'
:
lr_var
'LearningRate'
:
lr_var
,
}
}
outputs
=
{
'ParamOut'
:
param_vars
,
'VelocityOut'
:
velocity_vars
}
outputs
=
{
'ParamOut'
:
param_vars
,
'VelocityOut'
:
velocity_vars
}
if
multi_precision
:
if
multi_precision
:
...
@@ -176,7 +176,10 @@ class TestMergedMomentum(unittest.TestCase):
...
@@ -176,7 +176,10 @@ class TestMergedMomentum(unittest.TestCase):
outs2
=
run_op
(
False
)
outs2
=
run_op
(
False
)
self
.
assertEqual
(
len
(
outs1
),
len
(
outs2
))
self
.
assertEqual
(
len
(
outs1
),
len
(
outs2
))
for
i
,
(
out1
,
out2
)
in
enumerate
(
zip
(
outs1
,
outs2
)):
for
i
,
(
out1
,
out2
)
in
enumerate
(
zip
(
outs1
,
outs2
)):
self
.
assertTrue
(
np
.
allclose
(
out1
,
out2
,
atol
=
1e-7
))
if
isinstance
(
place
,
paddle
.
CUDAPlace
):
self
.
assertTrue
(
np
.
array_equal
(
out1
,
out2
))
else
:
self
.
assertTrue
(
np
.
allclose
(
out1
,
out2
,
atol
=
1e-7
))
def
get_places
(
self
):
def
get_places
(
self
):
places
=
[
paddle
.
CPUPlace
()]
places
=
[
paddle
.
CPUPlace
()]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录