Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
94855f4a
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
94855f4a
编写于
10月 04, 2017
作者:
K
Kavya Srinet
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fixed changes proposed in the review
上级
163d2871
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
77 addition
and
35 deletion
+77
-35
paddle/operators/rmsprop_op.cc
paddle/operators/rmsprop_op.cc
+48
-21
paddle/operators/rmsprop_op.h
paddle/operators/rmsprop_op.h
+13
-6
python/paddle/v2/framework/tests/test_rmsprop_op.py
python/paddle/v2/framework/tests/test_rmsprop_op.py
+16
-8
未找到文件。
paddle/operators/rmsprop_op.cc
浏览文件 @
94855f4a
...
@@ -25,25 +25,32 @@ class RmspropOp : public framework::OperatorWithKernel {
...
@@ -25,25 +25,32 @@ class RmspropOp : public framework::OperatorWithKernel {
void
InferShape
(
framework
::
InferShapeContextBase
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContextBase
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Param"
),
"Input(Param) of RmspropOp should not be null."
);
"Input(Param) of RmspropOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"MeanSquare"
),
"Input(MeanSquare) of RmspropOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"LearningRate"
),
"Input(LearningRate) of RmspropOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grad"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Grad"
),
"Input(Grad) of RmspropOp should not be null."
);
"Input(Grad) of RmspropOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Moment"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Moment"
),
"Input(Moment) of RmspropOp should not be null."
);
"Input(Moment) of RmspropOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"LearningRate"
),
"Input(LearningRate) of RmspropOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ParamOut"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ParamOut"
),
"Output(param_out) of RmspropOp should not be null."
);
"Output(param_out) of RmspropOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"MomentOut"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"MomentOut"
),
"Output(moment_out) of RmspropOp should not be null."
);
"Output(Momentum_out) of RmspropOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"MeanSquareOut"
),
"Output(MeanSquareOut) of RmspropOp should not be null."
);
auto
param_dim
=
ctx
->
GetInputDim
(
"Param"
);
auto
param_dim
=
ctx
->
GetInputDim
(
"Param"
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
param_dim
,
ctx
->
GetInputDim
(
"Grad"
),
param_dim
,
ctx
->
GetInputDim
(
"Grad"
),
"Param and grad input of RmspropOp should have the same dimension."
);
"Param and grad input of RmspropOp should have the same dimension."
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
param_dim
,
ctx
->
GetInputDim
(
"Moment"
),
param_dim
,
ctx
->
GetInputDim
(
"Moment"
),
"Param and Momentum input of RmspropOp "
"Param and moment input of RmspropOp should have the same dimension."
);
"should have the same dimension."
);
PADDLE_ENFORCE_EQ
(
param_dim
,
ctx
->
GetInputDim
(
"MeanSquare"
),
"Param and Momentum input of RmspropOp "
"should have the same dimension."
);
auto
lr_dim
=
ctx
->
GetInputDim
(
"LearningRate"
);
auto
lr_dim
=
ctx
->
GetInputDim
(
"LearningRate"
);
PADDLE_ENFORCE_EQ
(
framework
::
product
(
lr_dim
),
1
,
PADDLE_ENFORCE_EQ
(
framework
::
product
(
lr_dim
),
1
,
...
@@ -51,6 +58,7 @@ class RmspropOp : public framework::OperatorWithKernel {
...
@@ -51,6 +58,7 @@ class RmspropOp : public framework::OperatorWithKernel {
ctx
->
SetOutputDim
(
"ParamOut"
,
param_dim
);
ctx
->
SetOutputDim
(
"ParamOut"
,
param_dim
);
ctx
->
SetOutputDim
(
"MomentOut"
,
param_dim
);
ctx
->
SetOutputDim
(
"MomentOut"
,
param_dim
);
ctx
->
SetOutputDim
(
"MeanSquareOut"
,
param_dim
);
}
}
};
};
...
@@ -59,27 +67,46 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -59,27 +67,46 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
RmspropOpMaker
(
framework
::
OpProto
*
proto
,
RmspropOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"Param"
,
"Input parameter"
);
AddInput
(
"Param"
,
AddInput
(
"Grad"
,
"Input gradient"
);
"(Tensor, default Tensor<float>) "
AddInput
(
"Moment"
,
"Second moment"
);
"Input parameter value that has to be updated"
);
AddInput
(
"LearningRate"
,
"Learning Rate"
);
AddInput
(
"MeanSquare"
,
"(Tensor, default Tensor<float>)"
AddOutput
(
"ParamOut"
,
"Output parameter"
);
" The mean square value that gets updated"
);
AddOutput
(
"MomentOut"
,
"Output second moment"
);
AddInput
(
"LearningRate"
,
"(Tensor, default Tensor<float>) "
AddAttr
<
float
>
(
"epsilon"
,
"Constant for numerical stability"
);
"The learning rate should be a tensor of size 1"
);
AddAttr
<
float
>
(
"decayRate"
,
"Decay rate for moving average of gradients"
);
AddInput
(
"Grad"
,
"(Tensor, default Tensor<float>) "
"Input gradient of the parameter"
);
AddInput
(
"Moment"
,
"(Tensor, default Tensor<float>) The moment that gets updated"
);
AddOutput
(
"ParamOut"
,
"(Tensor) Output updated parameter value"
);
AddOutput
(
"MomentOut"
,
"(Tensor) Output updated moment"
);
AddOutput
(
"MeanSquareOut"
,
"(Tensor) Output Mean squared updated value"
);
AddAttr
<
float
>
(
"epsilon"
,
"(float, default 1e-10) Constant "
"for numerical stability."
)
.
SetDefault
(
1e-10
);
AddAttr
<
float
>
(
"decay"
,
"(float, default 0.9) "
"Discounting factor for coming gradient."
)
.
SetDefault
(
0.9
);
AddAttr
<
float
>
(
"momentum"
,
"(float, default 0.0) Constant value"
)
.
SetDefault
(
0.0
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
RMSprop
RMSprop
MomentOut = decayRate * Moment + (1 - decayRate) * Grad * Grad
MeanSquareOut = decay * MeanSquare + (1 - decay) * Grad * Grad
ParamOut = Param - LearningRate * Grad / (sqrt(MomentOut) + epsilon)
MomentOut = momentum * Moment +
LearningRate * Grad / sqrt(MeanSquareOut + epsilon)
ParamOut = Param - MomentOut
The original slide
(
Slide 29 of
The original slide
s that proposed RMSprop:
Slide 29 of
http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
does not have the epsilon attribute. It is added here for numerical stability
to avoid division by zero.
)DOC"
);
)DOC"
);
}
}
...
...
paddle/operators/rmsprop_op.h
浏览文件 @
94855f4a
...
@@ -30,23 +30,30 @@ class RmspropOpKernel : public framework::OpKernel<T> {
...
@@ -30,23 +30,30 @@ class RmspropOpKernel : public framework::OpKernel<T> {
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
param_out
=
ctx
.
Output
<
Tensor
>
(
"ParamOut"
);
auto
param_out
=
ctx
.
Output
<
Tensor
>
(
"ParamOut"
);
auto
moment_out
=
ctx
.
Output
<
Tensor
>
(
"MomentOut"
);
auto
moment_out
=
ctx
.
Output
<
Tensor
>
(
"MomentOut"
);
auto
mean_square_out
=
ctx
.
Output
<
Tensor
>
(
"MeanSquareOut"
);
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
moment_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
moment_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
mean_square_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
float
epsilon
=
ctx
.
Attr
<
float
>
(
"epsilon"
);
float
epsilon
=
ctx
.
Attr
<
float
>
(
"epsilon"
);
float
decay
=
ctx
.
Attr
<
float
>
(
"decayRate"
);
float
rho
=
ctx
.
Attr
<
float
>
(
"decay"
);
float
momentum
=
ctx
.
Attr
<
float
>
(
"momentum"
);
auto
p
=
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
Tensor
>
(
"Param"
));
auto
p
=
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
Tensor
>
(
"Param"
));
auto
g
=
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
Tensor
>
(
"Grad"
));
auto
ms
=
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
Tensor
>
(
"MeanSquare"
));
auto
m
=
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
Tensor
>
(
"Moment"
));
float
lr
=
ctx
.
Input
<
Tensor
>
(
"LearningRate"
)
->
data
<
float
>
()[
0
];
float
lr
=
ctx
.
Input
<
Tensor
>
(
"LearningRate"
)
->
data
<
float
>
()[
0
];
auto
g
=
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
Tensor
>
(
"Grad"
));
auto
mom
=
EigenVector
<
T
>::
Flatten
(
*
ctx
.
Input
<
Tensor
>
(
"Moment"
));
auto
p_out
=
EigenVector
<
T
>::
Flatten
(
*
param_out
);
auto
p_out
=
EigenVector
<
T
>::
Flatten
(
*
param_out
);
auto
m_out
=
EigenVector
<
T
>::
Flatten
(
*
moment_out
);
auto
mom_out
=
EigenVector
<
T
>::
Flatten
(
*
moment_out
);
auto
ms_out
=
EigenVector
<
T
>::
Flatten
(
*
mean_square_out
);
auto
place
=
ctx
.
GetEigenDevice
<
Place
>
();
auto
place
=
ctx
.
GetEigenDevice
<
Place
>
();
m_out
.
device
(
place
)
=
decay
*
m
+
(
1
-
decay
)
*
g
*
g
;
ms_out
.
device
(
place
)
=
rho
*
ms
+
(
1
-
rho
)
*
g
*
g
;
p_out
.
device
(
place
)
=
p
-
lr
*
g
/
(
m_out
.
sqrt
()
+
epsilon
);
mom_out
.
device
(
place
)
=
momentum
*
mom
+
lr
*
g
/
(
ms_out
+
epsilon
).
sqrt
();
p_out
.
device
(
place
)
=
p
-
mom_out
;
}
}
};
};
...
...
python/paddle/v2/framework/tests/test_rmsprop_op.py
浏览文件 @
94855f4a
...
@@ -8,27 +8,35 @@ class TestRmspropOp(OpTest):
...
@@ -8,27 +8,35 @@ class TestRmspropOp(OpTest):
self
.
op_type
=
"rmsprop"
self
.
op_type
=
"rmsprop"
param
=
np
.
random
.
random
((
123
,
321
)).
astype
(
"float32"
)
param
=
np
.
random
.
random
((
123
,
321
)).
astype
(
"float32"
)
mean_square
=
np
.
random
.
random
((
123
,
321
)).
astype
(
"float32"
)
learning_rate
=
np
.
array
([
0.01
]).
astype
(
"float32"
)
grad
=
np
.
random
.
random
((
123
,
321
)).
astype
(
"float32"
)
grad
=
np
.
random
.
random
((
123
,
321
)).
astype
(
"float32"
)
moment
=
np
.
zeros
((
123
,
321
)).
astype
(
"float32"
)
moment
=
np
.
zeros
((
123
,
321
)).
astype
(
"float32"
)
learning_rate
=
np
.
array
([
0.01
]).
astype
(
"float32"
)
epsilon
=
1e-6
epsilon
=
1e-6
decay_rate
=
0.9
decay
=
0.9
momentum
=
0.0
self
.
inputs
=
{
self
.
inputs
=
{
'Param'
:
param
,
'Param'
:
param
,
'MeanSquare'
:
mean_square
,
'LearningRate'
:
learning_rate
,
'Grad'
:
grad
,
'Grad'
:
grad
,
'Moment'
:
moment
,
'Moment'
:
moment
,
'LearningRate'
:
learning_rate
}
}
self
.
attrs
=
{
'epsilon'
:
epsilon
,
'decay
Rate'
:
decay_rate
}
self
.
attrs
=
{
'epsilon'
:
epsilon
,
'decay
'
:
decay
,
'momentum'
:
momentum
}
moment_out
=
decay_rate
*
moment
+
(
1
-
decay_rate
)
*
grad
*
grad
ms_out
=
decay
*
mean_square
+
(
1
-
decay
)
*
grad
*
grad
param_out
=
param
-
learning_rate
*
grad
/
(
np
.
sqrt
(
moment_out
)
+
moment_out
=
momentum
*
moment
+
\
epsilon
)
learning_rate
*
grad
/
np
.
sqrt
(
ms_out
+
epsilon
)
param_out
=
param
-
moment_out
self
.
outputs
=
{
'ParamOut'
:
param_out
,
'MomentOut'
:
moment_out
}
self
.
outputs
=
{
'ParamOut'
:
param_out
,
'MomentOut'
:
moment_out
,
'MeanSquareOut'
:
ms_out
}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
self
.
check_output
()
self
.
check_output
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录