Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
5380a547
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5380a547
编写于
10月 20, 2017
作者:
K
kavyasrinet
提交者:
GitHub
10月 20, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Adding Nesterov Momentum (#4948)
上级
23785584
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
59 addition
and
6 deletion
+59
-6
paddle/operators/momentum_op.cc
paddle/operators/momentum_op.cc
+7
-2
paddle/operators/momentum_op.h
paddle/operators/momentum_op.h
+8
-1
python/paddle/v2/framework/tests/test_momentum_op.py
python/paddle/v2/framework/tests/test_momentum_op.py
+43
-2
python/paddle/v2/framework/tests/test_rmsprop_op.py
python/paddle/v2/framework/tests/test_rmsprop_op.py
+1
-1
未找到文件。
paddle/operators/momentum_op.cc
浏览文件 @
5380a547
...
@@ -75,12 +75,17 @@ class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -75,12 +75,17 @@ class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput
(
"VelocityOut"
,
"(Tensor) Output updated velocity"
);
AddOutput
(
"VelocityOut"
,
"(Tensor) Output updated velocity"
);
AddAttr
<
float
>
(
"mu"
,
"(float) Momentum coefficient"
);
AddAttr
<
float
>
(
"mu"
,
"(float) Momentum coefficient"
);
AddAttr
<
bool
>
(
"useNesterov"
,
"(bool) Use Nesterov Momentum"
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Momentum Algorithm (momentum).
Momentum Algorithm
with a flag for Nestrov Moemntum
(momentum).
velocity = mu * velocity + gradient
velocity = mu * velocity + gradient
param = param - learning_rate * velocity
if (use_nesterov):
param = param - gradient * learning_rate + mu * velocity * learning_rate
else:
param = param - learning_rate * velocity
)DOC"
);
)DOC"
);
}
}
...
...
paddle/operators/momentum_op.h
浏览文件 @
5380a547
...
@@ -34,6 +34,7 @@ class MomentumOpKernel : public framework::OpKernel<T> {
...
@@ -34,6 +34,7 @@ class MomentumOpKernel : public framework::OpKernel<T> {
velocity_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
velocity_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
float
mu
=
ctx
.
Attr
<
float
>
(
"mu"
);
float
mu
=
ctx
.
Attr
<
float
>
(
"mu"
);
bool
use_nesterov
=
ctx
.
Attr
<
bool
>
(
"useNesterov"
);
auto
p_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
param_out
);
auto
p_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
param_out
);
auto
v_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
velocity_out
);
auto
v_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
velocity_out
);
...
@@ -46,8 +47,14 @@ class MomentumOpKernel : public framework::OpKernel<T> {
...
@@ -46,8 +47,14 @@ class MomentumOpKernel : public framework::OpKernel<T> {
auto
place
=
ctx
.
GetEigenDevice
<
Place
>
();
auto
place
=
ctx
.
GetEigenDevice
<
Place
>
();
Eigen
::
DSizes
<
int
,
1
>
grad_dsize
(
grad
->
numel
());
Eigen
::
DSizes
<
int
,
1
>
grad_dsize
(
grad
->
numel
());
v_out
.
device
(
place
)
=
v
*
mu
+
g
;
v_out
.
device
(
place
)
=
v
*
mu
+
g
;
p_out
.
device
(
place
)
=
p
-
lr
.
broadcast
(
grad_dsize
)
*
v_out
;
if
(
use_nesterov
)
{
p_out
.
device
(
place
)
=
p
-
g
*
lr
.
broadcast
(
grad_dsize
)
+
v_out
*
mu
*
lr
.
broadcast
(
grad_dsize
);
}
else
{
p_out
.
device
(
place
)
=
p
-
lr
.
broadcast
(
grad_dsize
)
*
v_out
;
}
}
}
};
};
...
...
python/paddle/v2/framework/tests/test_momentum_op.py
浏览文件 @
5380a547
...
@@ -3,7 +3,7 @@ import numpy as np
...
@@ -3,7 +3,7 @@ import numpy as np
from
op_test
import
OpTest
from
op_test
import
OpTest
class
TestMomentumOp
(
OpTest
):
class
TestMomentumOp
1
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"momentum"
self
.
op_type
=
"momentum"
...
@@ -12,6 +12,7 @@ class TestMomentumOp(OpTest):
...
@@ -12,6 +12,7 @@ class TestMomentumOp(OpTest):
velocity
=
np
.
zeros
((
123
,
321
)).
astype
(
"float32"
)
velocity
=
np
.
zeros
((
123
,
321
)).
astype
(
"float32"
)
learning_rate
=
np
.
array
([
0.001
]).
astype
(
"float32"
)
learning_rate
=
np
.
array
([
0.001
]).
astype
(
"float32"
)
mu
=
0.0001
mu
=
0.0001
use_nesterov
=
False
self
.
inputs
=
{
self
.
inputs
=
{
'Param'
:
param
,
'Param'
:
param
,
...
@@ -23,7 +24,47 @@ class TestMomentumOp(OpTest):
...
@@ -23,7 +24,47 @@ class TestMomentumOp(OpTest):
self
.
attrs
=
{
'mu'
:
mu
}
self
.
attrs
=
{
'mu'
:
mu
}
velocity_out
=
mu
*
velocity
+
grad
velocity_out
=
mu
*
velocity
+
grad
param_out
=
param
-
learning_rate
*
velocity_out
if
use_nesterov
:
param_out
=
param
-
grad
*
learning_rate
+
\
velocity_out
*
mu
*
learning_rate
else
:
param_out
=
param
-
learning_rate
*
velocity_out
self
.
outputs
=
{
'ParamOut'
:
param_out
,
'VelocityOut'
:
velocity_out
}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestMomentumOp2
(
OpTest
):
'''Test Momentum with defaukt values for attributes
'''
def
setUp
(
self
):
self
.
op_type
=
"momentum"
param
=
np
.
random
.
random
((
123
,
321
)).
astype
(
"float32"
)
grad
=
np
.
random
.
random
((
123
,
321
)).
astype
(
"float32"
)
velocity
=
np
.
zeros
((
123
,
321
)).
astype
(
"float32"
)
learning_rate
=
np
.
array
([
0.001
]).
astype
(
"float32"
)
mu
=
0.0001
use_nesterov
=
True
self
.
inputs
=
{
'Param'
:
param
,
'Grad'
:
grad
,
'Velocity'
:
velocity
,
'LearningRate'
:
learning_rate
}
self
.
attrs
=
{
'mu'
:
mu
,
'useNesterov'
:
use_nesterov
}
velocity_out
=
mu
*
velocity
+
grad
if
use_nesterov
:
param_out
=
param
-
grad
*
learning_rate
+
\
velocity_out
*
mu
*
learning_rate
else
:
param_out
=
param
-
learning_rate
*
velocity_out
self
.
outputs
=
{
'ParamOut'
:
param_out
,
'VelocityOut'
:
velocity_out
}
self
.
outputs
=
{
'ParamOut'
:
param_out
,
'VelocityOut'
:
velocity_out
}
...
...
python/paddle/v2/framework/tests/test_rmsprop_op.py
浏览文件 @
5380a547
...
@@ -46,7 +46,7 @@ class TestRmspropOp1(OpTest):
...
@@ -46,7 +46,7 @@ class TestRmspropOp1(OpTest):
class
TestRmspropOp2
(
OpTest
):
class
TestRmspropOp2
(
OpTest
):
'''Test RMSProp with defau
k
t values for attributes
'''Test RMSProp with defau
l
t values for attributes
'''
'''
def
setUp
(
self
):
def
setUp
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录