Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
8ff35506
P
Paddle
项目概览
PaddlePaddle
/
Paddle
接近 2 年 前同步成功
通知
2323
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8ff35506
编写于
11月 23, 2020
作者:
F
furnace
提交者:
GitHub
11月 23, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor momentum op to combine weight (#27414)
* refactor momentum op to combine weight_decay (scale op and sum op)
上级
bd1d6d3b
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
646 addition
and
135 deletion
+646
-135
paddle/fluid/operators/optimizers/momentum_op.cc
paddle/fluid/operators/optimizers/momentum_op.cc
+20
-0
paddle/fluid/operators/optimizers/momentum_op.h
paddle/fluid/operators/optimizers/momentum_op.h
+208
-116
python/paddle/fluid/contrib/__init__.py
python/paddle/fluid/contrib/__init__.py
+2
-0
python/paddle/fluid/contrib/optimizer.py
python/paddle/fluid/contrib/optimizer.py
+175
-0
python/paddle/fluid/tests/unittests/test_momentum_op.py
python/paddle/fluid/tests/unittests/test_momentum_op.py
+241
-19
未找到文件。
paddle/fluid/operators/optimizers/momentum_op.cc
浏览文件 @
8ff35506
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/optimizers/momentum_op.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -61,6 +62,12 @@ void MomentumOpMaker::Make() {
"(bool, default false) "
"Use Nesterov Momentum"
)
.
SetDefault
(
false
);
AddAttr
<
std
::
string
>
(
"regularization_method"
,
"(string) regularization_method, right now only support l2decay or none"
)
.
SetDefault
(
""
);
AddAttr
<
float
>
(
"regularization_coeff"
,
"(float) regularization_coeff"
)
.
SetDefault
(
0
);
AddComment
(
R"DOC(
Momentum Optimizer.
...
...
@@ -90,3 +97,16 @@ REGISTER_OPERATOR(
REGISTER_OP_CPU_KERNEL
(
momentum
,
ops
::
MomentumOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
MomentumOpKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_VERSION
(
momentum
)
.
AddCheckpoint
(
R"ROC(
Upgrade momentum add 2 attributes [regularization_method, regularization_coeff].
)ROC"
,
paddle
::
framework
::
compatible
::
OpVersionDesc
()
.
NewAttr
(
"regularization_method"
,
"(string) regularization_method, right now only support "
"l2decay or none"
,
std
::
string
(
""
))
.
NewAttr
(
"regularization_coeff"
,
"(float) regularization_coeff"
,
0.0
f
));
paddle/fluid/operators/optimizers/momentum_op.h
浏览文件 @
8ff35506
...
...
@@ -29,6 +29,12 @@ using framework::SelectedRows;
struct
NoNesterov
;
struct
UseNesterov
;
enum
class
RegularizationType
{
kNONE
=
0
,
kL1DECAY
=
1
,
// do not need support right now
kL2DECAY
=
2
,
};
class
MomentumOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
;
...
...
@@ -113,43 +119,60 @@ class MomentumOp : public framework::OperatorWithKernel {
template
<
typename
T
>
class
CPUDenseMomentumFunctor
{
private:
const
Tensor
*
param
;
const
Tensor
*
grad
;
const
Tensor
*
velocity
;
const
Tensor
*
learning_rate
;
const
T
mu
;
const
T
use_nesterov
;
Tensor
*
param_out
;
Tensor
*
velocity_out
;
const
Tensor
*
param_
;
const
Tensor
*
grad_
;
const
Tensor
*
velocity_
;
const
Tensor
*
learning_rate_
;
const
T
mu_
;
const
T
use_nesterov_
;
RegularizationType
regularization_flag_
;
const
T
regularization_coeff_
;
Tensor
*
param_out_
;
Tensor
*
velocity_out_
;
public:
CPUDenseMomentumFunctor
(
const
Tensor
*
param
,
const
Tensor
*
grad
,
const
Tensor
*
velocity
,
const
Tensor
*
learning_rate
,
const
T
mu
,
const
bool
use_nesterov
,
Tensor
*
param_out
,
Tensor
*
velocity_out
)
:
param
(
param
),
grad
(
grad
),
velocity
(
velocity
),
learning_rate
(
learning_rate
),
mu
(
mu
),
use_nesterov
(
use_nesterov
),
param_out
(
param_out
),
velocity_out
(
velocity_out
)
{}
RegularizationType
regularization_flag
,
const
T
regularization_coeff
,
Tensor
*
param_out
,
Tensor
*
velocity_out
)
:
param_
(
param
),
grad_
(
grad
),
velocity_
(
velocity
),
learning_rate_
(
learning_rate
),
mu_
(
mu
),
use_nesterov_
(
use_nesterov
),
regularization_flag_
(
regularization_flag
),
regularization_coeff_
(
regularization_coeff
),
param_out_
(
param_out
),
velocity_out_
(
velocity_out
)
{}
inline
void
operator
()()
{
auto
p_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
param_out
);
auto
v_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
velocity_out
);
auto
p
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
param
);
auto
v
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
velocity
);
auto
g
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
grad
);
auto
*
lr
=
learning_rate
->
data
<
T
>
();
v_out
=
v
*
mu
+
g
;
if
(
use_nesterov
)
{
p_out
=
p
-
(
g
+
v_out
*
mu
)
*
lr
[
0
];
auto
param_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
param_out_
);
auto
velocity_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
velocity_out_
);
auto
param
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
param_
);
auto
velocity
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
velocity_
);
auto
grad
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
grad_
);
auto
*
lr
=
learning_rate_
->
data
<
T
>
();
if
(
regularization_flag_
==
RegularizationType
::
kL2DECAY
)
{
velocity_out
=
velocity
*
mu_
+
param
*
regularization_coeff_
+
grad
;
if
(
use_nesterov_
)
{
param_out
=
param
-
(
param
*
regularization_coeff_
+
grad
+
velocity_out
*
mu_
)
*
lr
[
0
];
}
else
{
param_out
=
param
-
lr
[
0
]
*
velocity_out
;
}
}
else
{
p_out
=
p
-
lr
[
0
]
*
v_out
;
velocity_out
=
velocity
*
mu_
+
grad
;
if
(
use_nesterov_
)
{
param_out
=
param
-
(
grad
+
velocity_out
*
mu_
)
*
lr
[
0
];
}
else
{
param_out
=
param
-
lr
[
0
]
*
velocity_out
;
}
}
}
};
...
...
@@ -163,76 +186,100 @@ class DenseMomentumFunctor;
template
<
typename
T
>
class
DenseMomentumFunctor
<
T
,
UseNesterov
>
{
private:
const
T
*
p_
;
const
T
*
g_
;
const
T
*
v_
;
const
T
*
p
aram
_
;
const
T
*
g
rad
_
;
const
T
*
v
elocity
_
;
const
T
*
lr_
;
const
T
mu_
;
const
int64_t
num_
;
T
*
p_out_
;
T
*
v_out_
;
T
*
param_out_
;
T
*
velocity_out_
;
RegularizationType
regularization_flag_
;
const
T
regularization_coeff_
;
public:
DenseMomentumFunctor
(
const
T
*
p
,
const
T
*
g
,
const
T
*
v
,
DenseMomentumFunctor
(
const
T
*
p
aram
,
const
T
*
grad
,
const
T
*
velocity
,
const
T
*
learning_rate
,
const
T
mu
,
const
int64_t
num
,
T
*
p_out
,
T
*
v_out
)
:
p_
(
p
),
g_
(
g
),
v_
(
v
),
RegularizationType
regularization_flag
,
const
T
regularization_coeff
,
T
*
param_out
,
T
*
velocity_out
)
:
param_
(
param
),
grad_
(
grad
),
velocity_
(
velocity
),
lr_
(
learning_rate
),
mu_
(
mu
),
num_
(
num
),
p_out_
(
p_out
),
v_out_
(
v_out
)
{}
param_out_
(
param_out
),
velocity_out_
(
velocity_out
),
regularization_flag_
(
regularization_flag
),
regularization_coeff_
(
regularization_coeff
)
{}
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
const
{
// put memory access in register
const
T
p
=
p
_
[
i
];
const
T
g
=
g
_
[
i
];
const
T
p
aram
=
param
_
[
i
];
T
grad
=
grad
_
[
i
];
const
T
lr
=
lr_
[
0
];
const
T
v
=
v_
[
i
];
T
v_out
=
v
*
mu_
+
g
;
T
p_out
=
p
-
(
g
+
v_out
*
mu_
)
*
lr
;
const
T
velocity
=
velocity_
[
i
];
grad
=
regularization_flag_
==
RegularizationType
::
kL2DECAY
?
grad
+
regularization_coeff_
*
param
:
grad
;
T
velocity_out
=
velocity
*
mu_
+
grad
;
T
param_out
=
param
-
(
grad
+
velocity_out
*
mu_
)
*
lr
;
// write reigster to memory
v
_out_
[
i
]
=
v
_out
;
p
_out_
[
i
]
=
p
_out
;
v
elocity_out_
[
i
]
=
velocity
_out
;
p
aram_out_
[
i
]
=
param
_out
;
}
};
template
<
typename
T
>
class
DenseMomentumFunctor
<
T
,
NoNesterov
>
{
private:
const
T
*
p_
;
const
T
*
g_
;
const
T
*
v_
;
const
T
*
p
aram
_
;
const
T
*
g
rad
_
;
const
T
*
v
elocity
_
;
const
T
*
lr_
;
const
T
mu_
;
const
int64_t
num_
;
T
*
p_out_
;
T
*
v_out_
;
T
*
param_out_
;
T
*
velocity_out_
;
RegularizationType
regularization_flag_
;
const
T
regularization_coeff_
;
public:
DenseMomentumFunctor
(
const
T
*
p
,
const
T
*
g
,
const
T
*
v
,
DenseMomentumFunctor
(
const
T
*
p
aram
,
const
T
*
grad
,
const
T
*
velocity
,
const
T
*
learning_rate
,
const
T
mu
,
const
int64_t
num
,
T
*
p_out
,
T
*
v_out
)
:
p_
(
p
),
g_
(
g
),
v_
(
v
),
RegularizationType
regularization_flag
,
const
T
regularization_coeff
,
T
*
param_out
,
T
*
velocity_out
)
:
param_
(
param
),
grad_
(
grad
),
velocity_
(
velocity
),
lr_
(
learning_rate
),
mu_
(
mu
),
num_
(
num
),
p_out_
(
p_out
),
v_out_
(
v_out
)
{}
param_out_
(
param_out
),
velocity_out_
(
velocity_out
),
regularization_flag_
(
regularization_flag
),
regularization_coeff_
(
regularization_coeff
)
{}
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
const
{
// put memory access in register
const
T
p
=
p
_
[
i
];
const
T
g
=
g
_
[
i
];
const
T
p
aram
=
param
_
[
i
];
T
grad
=
grad
_
[
i
];
const
T
lr
=
lr_
[
0
];
const
T
v
=
v_
[
i
];
T
v_out
=
v
*
mu_
+
g
;
T
p_out
=
p
-
lr
*
v_out
;
const
T
velocity
=
velocity_
[
i
];
grad
=
regularization_flag_
==
RegularizationType
::
kL2DECAY
?
grad
+
regularization_coeff_
*
param
:
grad
;
T
velocity_out
=
velocity
*
mu_
+
grad
;
T
param_out
=
param
-
lr
*
velocity_out
;
// write reigster to memory
v
_out_
[
i
]
=
v
_out
;
p
_out_
[
i
]
=
p
_out
;
v
elocity_out_
[
i
]
=
velocity
_out
;
p
aram_out_
[
i
]
=
param
_out
;
}
};
...
...
@@ -242,92 +289,116 @@ class SparseMomentumFunctor;
template
<
typename
T
>
class
SparseMomentumFunctor
<
T
,
UseNesterov
>
{
private:
const
T
*
p_
;
const
T
*
g_
;
const
T
*
v_
;
const
T
*
p
aram
_
;
const
T
*
g
rad
_
;
const
T
*
v
elocity
_
;
const
T
*
lr_
;
const
T
mu_
;
const
int64_t
*
rows_
;
const
int64_t
row_numel_
;
const
int64_t
row_height_
;
T
*
p_out_
;
T
*
v_out_
;
T
*
param_out_
;
T
*
velocity_out_
;
RegularizationType
regularization_flag_
;
const
T
regularization_coeff_
;
public:
SparseMomentumFunctor
(
const
T
*
p
,
const
T
*
g
,
const
T
*
v
,
const
T
*
lr
,
const
T
mu
,
const
int64_t
*
rows
,
int64_t
row_numel
,
int64_t
row_height
,
T
*
p_out
,
T
*
v_out
)
:
p_
(
p
),
g_
(
g
),
v_
(
v
),
SparseMomentumFunctor
(
const
T
*
param
,
const
T
*
grad
,
const
T
*
velocity
,
const
T
*
lr
,
const
T
mu
,
const
int64_t
*
rows
,
int64_t
row_numel
,
int64_t
row_height
,
RegularizationType
regularization_flag
,
const
T
regularization_coeff
,
T
*
param_out
,
T
*
velocity_out
)
:
param_
(
param
),
grad_
(
grad
),
velocity_
(
velocity
),
lr_
(
lr
),
mu_
(
mu
),
rows_
(
rows
),
row_numel_
(
row_numel
),
row_height_
(
row_height
),
p_out_
(
p_out
),
v_out_
(
v_out
)
{}
param_out_
(
param_out
),
velocity_out_
(
velocity_out
),
regularization_flag_
(
regularization_flag
),
regularization_coeff_
(
regularization_coeff
)
{}
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
{
auto
row_idx
=
math
::
BinarySearch
<
int64_t
>
(
rows_
,
row_height_
,
i
/
row_numel_
);
T
g
=
row_idx
>=
0
?
g
_
[
row_idx
*
row_numel_
+
i
%
row_numel_
]
:
static_cast
<
T
>
(
0
);
T
g
rad
=
row_idx
>=
0
?
grad
_
[
row_idx
*
row_numel_
+
i
%
row_numel_
]
:
static_cast
<
T
>
(
0
);
// put memory access in register
const
T
p
=
p
_
[
i
];
const
T
p
aram
=
param
_
[
i
];
const
T
lr
=
lr_
[
0
];
const
T
v
=
v_
[
i
];
T
v_out
=
v
*
mu_
+
g
;
T
p_out
=
p
-
(
g
+
v_out
*
mu_
)
*
lr
;
const
T
velocity
=
velocity_
[
i
];
grad
=
regularization_flag_
==
RegularizationType
::
kL2DECAY
?
grad
+
regularization_coeff_
*
param
:
grad
;
T
velocity_out
=
velocity
*
mu_
+
grad
;
T
param_out
=
param
-
(
grad
+
velocity_out
*
mu_
)
*
lr
;
// write reigster to memory
v
_out_
[
i
]
=
v
_out
;
p
_out_
[
i
]
=
p
_out
;
v
elocity_out_
[
i
]
=
velocity
_out
;
p
aram_out_
[
i
]
=
param
_out
;
}
};
template
<
typename
T
>
class
SparseMomentumFunctor
<
T
,
NoNesterov
>
{
private:
const
T
*
p_
;
const
T
*
g_
;
const
T
*
v_
;
const
T
*
p
aram
_
;
const
T
*
g
rad
_
;
const
T
*
v
elocity
_
;
const
T
*
lr_
;
const
T
mu_
;
const
int64_t
*
rows_
;
const
int64_t
row_numel_
;
const
int64_t
row_height_
;
T
*
p_out_
;
T
*
v_out_
;
T
*
param_out_
;
T
*
velocity_out_
;
RegularizationType
regularization_flag_
;
const
T
regularization_coeff_
;
public:
SparseMomentumFunctor
(
const
T
*
p
,
const
T
*
g
,
const
T
*
v
,
const
T
*
lr
,
const
T
mu
,
const
int64_t
*
rows
,
int64_t
row_numel
,
int64_t
row_height
,
T
*
p_out
,
T
*
v_out
)
:
p_
(
p
),
g_
(
g
),
v_
(
v
),
SparseMomentumFunctor
(
const
T
*
param
,
const
T
*
grad
,
const
T
*
velocity
,
const
T
*
lr
,
const
T
mu
,
const
int64_t
*
rows
,
int64_t
row_numel
,
int64_t
row_height
,
RegularizationType
regularization_flag
,
const
T
regularization_coeff
,
T
*
param_out
,
T
*
velocity_out
)
:
param_
(
param
),
grad_
(
grad
),
velocity_
(
velocity
),
lr_
(
lr
),
mu_
(
mu
),
rows_
(
rows
),
row_numel_
(
row_numel
),
row_height_
(
row_height
),
p_out_
(
p_out
),
v_out_
(
v_out
)
{}
param_out_
(
param_out
),
velocity_out_
(
velocity_out
),
regularization_flag_
(
regularization_flag
),
regularization_coeff_
(
regularization_coeff
)
{}
inline
HOSTDEVICE
void
operator
()(
size_t
i
)
{
auto
row_idx
=
math
::
BinarySearch
<
int64_t
>
(
rows_
,
row_height_
,
i
/
row_numel_
);
T
g
=
row_idx
>=
0
?
g
_
[
row_idx
*
row_numel_
+
i
%
row_numel_
]
:
static_cast
<
T
>
(
0
);
T
g
rad
=
row_idx
>=
0
?
grad
_
[
row_idx
*
row_numel_
+
i
%
row_numel_
]
:
static_cast
<
T
>
(
0
);
// put memory access in register
const
T
p
=
p
_
[
i
];
const
T
p
aram
=
param
_
[
i
];
const
T
lr
=
lr_
[
0
];
const
T
v
=
v_
[
i
];
T
v_out
=
v
*
mu_
+
g
;
T
p_out
=
p
-
v_out
*
lr
;
const
T
velocity
=
velocity_
[
i
];
grad
=
regularization_flag_
==
RegularizationType
::
kL2DECAY
?
grad
+
regularization_coeff_
*
param
:
grad
;
T
velocity_out
=
velocity
*
mu_
+
grad
;
T
param_out
=
param
-
velocity_out
*
lr
;
// write reigster to memory
v
_out_
[
i
]
=
v
_out
;
p
_out_
[
i
]
=
p
_out
;
v
elocity_out_
[
i
]
=
velocity
_out
;
p
aram_out_
[
i
]
=
param
_out
;
}
};
...
...
@@ -335,6 +406,24 @@ template <typename DeviceContext, typename T>
class
MomentumOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
std
::
string
regularization_method
=
ctx
.
Attr
<
std
::
string
>
(
"regularization_method"
);
if
(
regularization_method
!=
""
||
!
regularization_method
.
empty
())
{
PADDLE_ENFORCE_EQ
(
"l2_decay"
,
regularization_method
,
platform
::
errors
::
InvalidArgument
(
"if regularization_method is not null, "
"it should be l2_decay, but received %s"
,
regularization_method
));
}
T
regularization_coeff
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"regularization_coeff"
));
RegularizationType
regularization_flag
{
RegularizationType
::
kNONE
};
// disable regularization
if
(
regularization_method
==
"l2_decay"
)
{
regularization_flag
=
RegularizationType
::
kL2DECAY
;
}
T
mu
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"mu"
));
bool
use_nesterov
=
ctx
.
Attr
<
bool
>
(
"use_nesterov"
);
...
...
@@ -343,6 +432,7 @@ class MomentumOpKernel : public framework::OpKernel<T> {
auto
param_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
*
velocity
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Velocity"
);
auto
velocity_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"VelocityOut"
);
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
velocity_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
...
...
@@ -350,9 +440,9 @@ class MomentumOpKernel : public framework::OpKernel<T> {
if
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
grad
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Grad"
);
if
(
platform
::
is_cpu_place
(
ctx
.
GetPlace
()))
{
CPUDenseMomentumFunctor
<
T
>
functor
(
param
,
grad
,
velocity
,
learning_rate
,
mu
,
use_nesterov
,
param_out
,
velocity_out
);
CPUDenseMomentumFunctor
<
T
>
functor
(
param
,
grad
,
velocity
,
learning_rate
,
mu
,
use_nesterov
,
regularization_flag
,
regularization_coeff
,
param_out
,
velocity_out
);
functor
();
}
else
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
platform
::
ForRange
<
DeviceContext
>
for_range
(
...
...
@@ -361,16 +451,16 @@ class MomentumOpKernel : public framework::OpKernel<T> {
if
(
use_nesterov
)
{
DenseMomentumFunctor
<
T
,
UseNesterov
>
functor
(
param
->
data
<
T
>
(),
grad
->
data
<
T
>
(),
velocity
->
data
<
T
>
(),
learning_rate
->
data
<
T
>
(),
mu
,
param
->
numel
(),
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
learning_rate
->
data
<
T
>
(),
mu
,
param
->
numel
(),
regularization_flag
,
regularization_coeff
,
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
velocity_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
for_range
(
functor
);
}
else
{
DenseMomentumFunctor
<
T
,
NoNesterov
>
functor
(
param
->
data
<
T
>
(),
grad
->
data
<
T
>
(),
velocity
->
data
<
T
>
(),
learning_rate
->
data
<
T
>
(),
mu
,
param
->
numel
(),
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
learning_rate
->
data
<
T
>
(),
mu
,
param
->
numel
(),
regularization_flag
,
regularization_coeff
,
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
velocity_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
for_range
(
functor
);
}
...
...
@@ -403,6 +493,7 @@ class MomentumOpKernel : public framework::OpKernel<T> {
param
->
data
<
T
>
(),
merged_grad
->
value
().
data
<
T
>
(),
velocity
->
data
<
T
>
(),
learning_rate
->
data
<
T
>
(),
mu
,
rows
,
row_numel
,
static_cast
<
int64_t
>
(
merged_grad
->
rows
().
size
()),
regularization_flag
,
regularization_coeff
,
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
velocity_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
for_range
(
functor
);
...
...
@@ -412,6 +503,7 @@ class MomentumOpKernel : public framework::OpKernel<T> {
param
->
data
<
T
>
(),
merged_grad
->
value
().
data
<
T
>
(),
velocity
->
data
<
T
>
(),
learning_rate
->
data
<
T
>
(),
mu
,
rows
,
row_numel
,
static_cast
<
int64_t
>
(
merged_grad
->
rows
().
size
()),
regularization_flag
,
regularization_coeff
,
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
velocity_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
for_range
(
functor
);
...
...
python/paddle/fluid/contrib/__init__.py
浏览文件 @
8ff35506
...
...
@@ -35,6 +35,7 @@ from . import mixed_precision
from
.mixed_precision
import
*
from
.
import
layers
from
.layers
import
*
from
.
import
optimizer
__all__
=
[]
__all__
+=
decoder
.
__all__
...
...
@@ -46,3 +47,4 @@ __all__ += utils.__all__
__all__
+=
extend_optimizer
.
__all__
__all__
+=
[
'mixed_precision'
]
__all__
+=
layers
.
__all__
__all__
+=
optimizer
.
__all__
python/paddle/fluid/contrib/optimizer.py
0 → 100644
浏览文件 @
8ff35506
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.fluid.optimizer
import
Optimizer
from
paddle.fluid.regularizer
import
L1DecayRegularizer
from
paddle.fluid.regularizer
import
L2DecayRegularizer
from
paddle.fluid.regularizer
import
append_regularization_ops
from
paddle.fluid
import
framework
from
paddle.fluid
import
core
from
paddle.fluid.framework
import
program_guard
from
paddle.fluid.clip
import
append_gradient_clip_ops
__all__
=
[
'Momentum'
]
class
Momentum
(
Optimizer
):
"""
Simple Momentum optimizer with velocity state
This optimizer has a flag for Nestrov Momentum.
The update equations are as follows:
.. math::
& velocity = mu * velocity + gradient
& if (use\_nesterov):
&\quad param = param - (gradient + mu * velocity) * learning\_rate
& else:
&\quad param = param - learning\_rate * velocity
Parameters:
learning_rate (float|Variable): The learning rate used to update parameters.
\
Can be a float value or a Variable with one float value as data element.
momentum (float): Momentum factor
parameter_list (Iterable, optional): Iterable of ``Variable`` names to update to minimize ``loss``.
\
This parameter is required in dygraph mode.
\
The default value is None in static mode, at this time all parameters will be updated.
use_nesterov (bool, optional): Enables Nesterov momentum, default is false.
regularization (WeightDecayRegularizer, optional): The strategy of regularization. There are two method:
\
:ref:`api_fluid_regularizer_L1Decay` , :ref:`api_fluid_regularizer_L2Decay` . If a parameter has set
\
regularizer using :ref:`api_fluid_ParamAttr` already, the regularization setting here in optimizer will be
\
ignored for this parameter. Otherwise, the regularization setting here in optimizer will take effect.
\
Default None, meaning there is no regularization.
grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
some derived class of ``GradientClipBase`` . There are three cliping strategies
( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
:ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
name (str, optional): This parameter is used by developers to print debugging information.
\
For details, please refer to :ref:`api_guide_Name`. Default is None.
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
import numpy as np
paddle.enable_static()
place = fluid.CPUPlace()
main = fluid.Program()
with fluid.program_guard(main):
x = paddle.static.data(name='x', shape=[1, 13], dtype='float32')
y = paddle.static.data(name='y', shape=[1], dtype='float32')
linear = paddle.nn.Linear(13, 1)
y_predict = linear(x)
cost = paddle.nn.functional.square_error_cost(input=y_predict, label=y)
avg_cost = paddle.mean(cost)
moment_optimizer = fluid.contrib.optimizer.Momentum(learning_rate=0.001, momentum=0.9)
moment_optimizer.minimize(avg_cost)
fetch_list = [avg_cost]
train_reader = paddle.batch(
paddle.dataset.uci_housing.train(), batch_size=1)
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
exe = fluid.Executor(place)
exe.run(paddle.static.default_startup_program())
for data in train_reader():
exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list)
"""
_velocity_acc_str
=
"velocity"
def
__init__
(
self
,
learning_rate
,
momentum
,
parameter_list
=
None
,
use_nesterov
=
False
,
regularization
=
None
,
grad_clip
=
None
,
name
=
None
):
assert
learning_rate
is
not
None
assert
momentum
is
not
None
predicate
=
lambda
regular
:
isinstance
(
regular
,
L2DecayRegularizer
)
py_regular
=
None
if
predicate
(
regularization
)
else
regularization
super
(
Momentum
,
self
).
__init__
(
learning_rate
=
learning_rate
,
parameter_list
=
parameter_list
,
regularization
=
py_regular
,
grad_clip
=
grad_clip
,
name
=
name
)
self
.
type
=
"momentum"
self
.
_momentum
=
momentum
self
.
_use_nesterov
=
bool
(
use_nesterov
)
self
.
_regularization_method
=
""
self
.
_regularization_coeff
=
0
if
(
isinstance
(
regularization
,
L2DecayRegularizer
)):
self
.
_regularization_method
=
"l2_decay"
self
.
_regularization_coeff
=
regularization
.
_regularization_coeff
def
_create_accumulators
(
self
,
block
,
parameters
):
assert
isinstance
(
block
,
framework
.
Block
)
for
p
in
parameters
:
self
.
_add_accumulator
(
self
.
_velocity_acc_str
,
p
)
def
_append_optimize_op
(
self
,
block
,
param_and_grad
):
assert
isinstance
(
block
,
framework
.
Block
)
velocity_acc
=
self
.
_get_accumulator
(
self
.
_velocity_acc_str
,
param_and_grad
[
0
])
lr
=
self
.
_create_param_lr
(
param_and_grad
)
if
framework
.
in_dygraph_mode
():
_
,
_
=
core
.
ops
.
momentum
(
param_and_grad
[
0
],
param_and_grad
[
1
],
velocity_acc
,
lr
,
param_and_grad
[
0
],
velocity_acc
,
'mu'
,
self
.
_momentum
,
'use_nesterov'
,
self
.
_use_nesterov
,
'regularization_method'
,
self
.
_regularization_method
,
'regularization_coeff'
,
self
.
_regularization_coeff
)
return
None
attrs
=
{
"mu"
:
self
.
_momentum
,
"use_nesterov"
:
self
.
_use_nesterov
,
"regularization_method"
:
self
.
_regularization_method
,
"regularization_coeff"
:
self
.
_regularization_coeff
}
inputs
=
{
"Param"
:
[
param_and_grad
[
0
]],
"Grad"
:
[
param_and_grad
[
1
]],
"Velocity"
:
[
velocity_acc
],
"LearningRate"
:
[
lr
]
}
outputs
=
{
"ParamOut"
:
[
param_and_grad
[
0
]],
"VelocityOut"
:
[
velocity_acc
]
}
# create the momentum optimize op
momentum_op
=
block
.
append_op
(
type
=
self
.
type
,
inputs
=
inputs
,
outputs
=
outputs
,
attrs
=
attrs
,
stop_gradient
=
True
)
return
momentum_op
python/paddle/fluid/tests/unittests/test_momentum_op.py
浏览文件 @
8ff35506
...
...
@@ -23,6 +23,33 @@ import paddle
import
paddle.fluid
as
fluid
def
calculate_momentum_by_numpy
(
param
,
grad
,
mu
,
velocity
,
use_nesterov
,
learning_rate
,
regularization_method
=
None
,
regularization_coeff
=
1.0
):
if
regularization_method
==
"l2_decay"
:
grad
=
grad
+
regularization_coeff
*
param
velocity_out
=
mu
*
velocity
+
grad
if
use_nesterov
:
param_out
=
param
-
(
grad
+
velocity_out
*
mu
)
*
learning_rate
else
:
param_out
=
param
-
learning_rate
*
velocity_out
else
:
velocity_out
=
mu
*
velocity
+
grad
if
use_nesterov
:
param_out
=
param
-
grad
*
learning_rate
-
\
velocity_out
*
mu
*
learning_rate
else
:
param_out
=
param
-
learning_rate
*
velocity_out
return
param_out
,
velocity_out
class
TestMomentumOp1
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"momentum"
...
...
@@ -45,12 +72,13 @@ class TestMomentumOp1(OpTest):
self
.
attrs
=
{
'mu'
:
mu
}
velocity_out
=
mu
*
velocity
+
grad
if
use_nesterov
:
param_out
=
param
-
grad
*
learning_rate
-
\
velocity_out
*
mu
*
learning_rate
else
:
param_out
=
param
-
learning_rate
*
velocity_out
param_out
,
velocity_out
=
calculate_momentum_by_numpy
(
param
=
param
,
grad
=
grad
,
mu
=
mu
,
velocity
=
velocity
,
use_nesterov
=
use_nesterov
,
learning_rate
=
learning_rate
)
self
.
outputs
=
{
'ParamOut'
:
param_out
,
'VelocityOut'
:
velocity_out
}
...
...
@@ -92,12 +120,13 @@ class TestMomentumOp2(OpTest):
self
.
attrs
=
{
'mu'
:
mu
,
'use_nesterov'
:
use_nesterov
}
velocity_out
=
mu
*
velocity
+
grad
if
use_nesterov
:
param_out
=
param
-
grad
*
learning_rate
-
\
velocity_out
*
mu
*
learning_rate
else
:
param_out
=
param
-
learning_rate
*
velocity_out
param_out
,
velocity_out
=
calculate_momentum_by_numpy
(
param
=
param
,
grad
=
grad
,
mu
=
mu
,
velocity
=
velocity
,
use_nesterov
=
use_nesterov
,
learning_rate
=
learning_rate
)
self
.
outputs
=
{
'ParamOut'
:
param_out
,
'VelocityOut'
:
velocity_out
}
...
...
@@ -141,12 +170,15 @@ class TestLarsMomentumOp(OpTest):
self
.
outputs
=
{
'ParamOut'
:
param_out
,
'VelocityOut'
:
velocity_out
}
def
test_check_output
(
self
):
paddle
.
enable_static
()
self
.
check_output
()
class
TestSparseMomentumOp
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
use_nesterov
=
False
self
.
regularization_method
=
""
self
.
regularization_coeff
=
1.0
def
check_with_place
(
self
,
place
):
self
.
init_kernel
()
...
...
@@ -157,6 +189,8 @@ class TestSparseMomentumOp(unittest.TestCase):
row_numel
=
12
mu
=
1.0
use_nesterov
=
self
.
use_nesterov
regularization_method
=
self
.
regularization_method
regularization_coeff
=
self
.
regularization_coeff
# create and initialize Param Variable
param
=
scope
.
var
(
'Param'
).
get_tensor
()
...
...
@@ -198,7 +232,9 @@ class TestSparseMomentumOp(unittest.TestCase):
VelocityOut
=
'VelocityOut'
,
LearningRate
=
'LearningRate'
,
mu
=
mu
,
use_nesterov
=
use_nesterov
)
use_nesterov
=
use_nesterov
,
regularization_method
=
regularization_method
,
regularization_coeff
=
regularization_coeff
)
op
.
run
(
scope
,
place
)
# get and compare result
...
...
@@ -210,13 +246,19 @@ class TestSparseMomentumOp(unittest.TestCase):
_grad_np_array
=
np
.
full
((
height
,
row_numel
),
0.0
).
astype
(
"float32"
)
for
i
in
range
(
len
(
rows
)):
_grad_np_array
[
rows
[
i
]]
=
grad_np_array
[
i
]
_velocity_out
=
mu
*
velocity_np_array
+
_grad_np_array
_param
=
param_array
if
use_nesterov
:
_param_out
=
_param
-
(
_grad_np_array
+
_velocity_out
*
mu
)
*
lr_array
else
:
_param_out
=
_param
-
lr_array
*
_velocity_out
_param_out
,
_velocity_out
=
calculate_momentum_by_numpy
(
param
=
_param
,
grad
=
_grad_np_array
,
mu
=
mu
,
velocity
=
velocity_np_array
,
use_nesterov
=
use_nesterov
,
learning_rate
=
lr_array
,
regularization_method
=
regularization_method
,
regularization_coeff
=
regularization_coeff
)
self
.
assertTrue
((
_velocity_out
==
velocity_out_np_array
).
all
())
self
.
assertTrue
((
_param_out
==
param_out_np_array
).
all
())
...
...
@@ -251,6 +293,8 @@ class TestMomentumV2(unittest.TestCase):
adam
.
clear_gradients
()
def
test_momentum
(
self
):
paddle
.
enable_static
()
place
=
fluid
.
CPUPlace
()
main
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
):
...
...
@@ -279,5 +323,183 @@ class TestMomentumV2(unittest.TestCase):
self
.
assertRaises
(
ValueError
,
paddle
.
optimizer
.
Momentum
,
momentum
=
None
)
class
TestMomentumOpWithDecay
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"momentum"
self
.
dtype
=
np
.
float32
self
.
use_nesterov
=
True
self
.
regularization_method
=
'l2_decay'
self
.
regularization_coeff
=
0.9
self
.
init_config
()
param
=
np
.
random
.
random
((
123
,
321
)).
astype
(
self
.
dtype
)
grad
=
np
.
random
.
random
((
123
,
321
)).
astype
(
self
.
dtype
)
velocity
=
np
.
zeros
((
123
,
321
)).
astype
(
self
.
dtype
)
learning_rate
=
np
.
array
([
0.001
]).
astype
(
self
.
dtype
)
mu
=
0.0001
use_nesterov
=
self
.
use_nesterov
regularization_method
=
self
.
regularization_method
regularization_coeff
=
self
.
regularization_coeff
self
.
inputs
=
{
'Param'
:
param
,
'Grad'
:
grad
,
'Velocity'
:
velocity
,
'LearningRate'
:
learning_rate
}
self
.
attrs
=
{
'mu'
:
mu
,
'use_nesterov'
:
use_nesterov
,
'regularization_method'
:
regularization_method
,
'regularization_coeff'
:
regularization_coeff
}
grad
=
grad
+
regularization_coeff
*
param
param_out
,
velocity_out
=
calculate_momentum_by_numpy
(
param
=
param
,
grad
=
grad
,
mu
=
mu
,
velocity
=
velocity
,
use_nesterov
=
use_nesterov
,
learning_rate
=
learning_rate
)
self
.
outputs
=
{
'ParamOut'
:
param_out
,
'VelocityOut'
:
velocity_out
}
def
init_config
(
self
):
pass
def
test_check_output
(
self
):
paddle
.
enable_static
()
self
.
check_output
()
class
TestMomentumOpWithDecayFP16
(
TestMomentumOpWithDecay
):
def
init_config
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
paddle
.
enable_static
()
self
.
check_output
(
atol
=
1e-3
)
class
TestMomentumOpWithDecay2
(
TestMomentumOpWithDecay
):
def
init_config
(
self
):
self
.
use_nesterov
=
False
class
TestSparseMomentumOpWithDecay
(
TestSparseMomentumOp
):
def
setUp
(
self
):
self
.
use_nesterov
=
False
self
.
regularization_method
=
'l2_decay'
self
.
regularization_coeff
=
0.9
class
TestSparseMomentumOpWithDecay2
(
TestSparseMomentumOpWithDecay
):
def
init_kernel
(
self
):
self
.
use_nesterov
=
True
class
TestMomentumOpWithDecayAPI
(
unittest
.
TestCase
):
def
_test_momentum_dygraph_common
(
self
,
regularization
):
paddle
.
disable_static
()
inp
=
np
.
random
.
uniform
(
-
0.1
,
0.1
,
[
10
,
10
]).
astype
(
"float32"
)
linear
=
paddle
.
nn
.
Linear
(
10
,
10
)
inp
=
paddle
.
to_tensor
(
inp
)
out
=
linear
(
inp
)
loss
=
paddle
.
mean
(
out
)
# This can be any optimizer supported by dygraph.
momentum
=
paddle
.
fluid
.
contrib
.
optimizer
.
Momentum
(
learning_rate
=
0.01
,
momentum
=
0.9
,
parameter_list
=
linear
.
parameters
(),
regularization
=
regularization
)
momentum
.
minimize
(
loss
)
def
test_momentum_dygraph_1
(
self
):
self
.
_test_momentum_dygraph_common
(
regularization
=
paddle
.
fluid
.
regularizer
.
L2Decay
(
regularization_coeff
=
0.1
))
def
test_momentum_static
(
self
):
paddle
.
enable_static
()
place
=
fluid
.
CPUPlace
()
main
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
13
],
dtype
=
'float32'
)
y
=
fluid
.
layers
.
data
(
name
=
'y'
,
shape
=
[
1
],
dtype
=
'float32'
)
y_predict
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
1
,
act
=
None
)
cost
=
fluid
.
layers
.
square_error_cost
(
input
=
y_predict
,
label
=
y
)
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
momentum_optimizer
=
paddle
.
fluid
.
contrib
.
optimizer
.
Momentum
(
learning_rate
=
0.1
,
momentum
=
0.9
)
momentum_optimizer
.
minimize
(
avg_cost
)
fetch_list
=
[
avg_cost
]
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
uci_housing
.
train
(),
batch_size
=
1
)
feeder
=
fluid
.
DataFeeder
(
place
=
place
,
feed_list
=
[
x
,
y
])
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
for
data
in
train_reader
():
exe
.
run
(
main
,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
fetch_list
)
class
TestMomentumOpVsMomentumOpWithDecayAPI
(
unittest
.
TestCase
):
def
__update_params
(
self
,
momentum
,
linear
):
for
i
in
range
(
10
):
inp
=
paddle
.
full
(
shape
=
[
2
,
2
],
fill_value
=
i
,
dtype
=
'float32'
).
astype
(
"float32"
)
inp
=
paddle
.
to_tensor
(
inp
)
out
=
linear
(
inp
)
loss
=
paddle
.
mean
(
out
)
loss
.
backward
()
momentum
.
minimize
(
loss
)
def
__test_vs
(
self
,
place
=
fluid
.
CPUPlace
()):
paddle
.
disable_static
(
place
=
place
)
linear_old
=
paddle
.
nn
.
Linear
(
2
,
2
,
weight_attr
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
2.0
),
bias_attr
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
2.0
))
momentum_old
=
paddle
.
fluid
.
optimizer
.
Momentum
(
learning_rate
=
0.01
,
momentum
=
0.9
,
parameter_list
=
linear_old
.
parameters
(),
regularization
=
paddle
.
fluid
.
regularizer
.
L2Decay
(
regularization_coeff
=
0.1
))
self
.
__update_params
(
momentum
=
momentum_old
,
linear
=
linear_old
)
linear_new
=
paddle
.
nn
.
Linear
(
2
,
2
,
weight_attr
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
2.0
),
bias_attr
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
2.0
))
momentum_new
=
paddle
.
fluid
.
contrib
.
optimizer
.
Momentum
(
learning_rate
=
0.01
,
momentum
=
0.9
,
parameter_list
=
linear_new
.
parameters
(),
regularization
=
paddle
.
fluid
.
regularizer
.
L2Decay
(
regularization_coeff
=
0.1
))
self
.
__update_params
(
momentum
=
momentum_new
,
linear
=
linear_new
)
self
.
assertEqual
(
(
linear_old
.
weight
.
numpy
()
==
linear_new
.
weight
.
numpy
()).
all
(),
True
,
'the param weight updated by two Momentum optimizers should equal'
)
def
test_vs
(
self
,
place
=
fluid
.
CPUPlace
()):
places
=
[
fluid
.
CPUPlace
()]
if
paddle
.
fluid
.
core
.
is_compiled_with_cuda
():
places
.
append
(
fluid
.
CUDAPlace
(
0
))
for
place
in
places
:
self
.
__test_vs
(
place
=
place
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录