Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
873a50ce
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
873a50ce
编写于
7月 20, 2018
作者:
Q
qingqing01
提交者:
GitHub
7月 20, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix serious bug in nesterov momentum optimizer. (#12231)
* Fix serious bug in nesterov momentum optimizer.
上级
b42ced8e
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
8 addition
and
7 deletion
+8
-7
paddle/fluid/operators/momentum_op.cc
paddle/fluid/operators/momentum_op.cc
+1
-1
paddle/fluid/operators/momentum_op.cu
paddle/fluid/operators/momentum_op.cu
+1
-1
paddle/fluid/operators/momentum_op.h
paddle/fluid/operators/momentum_op.h
+1
-1
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+2
-1
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+1
-1
python/paddle/fluid/tests/unittests/test_momentum_op.py
python/paddle/fluid/tests/unittests/test_momentum_op.py
+2
-2
未找到文件。
paddle/fluid/operators/momentum_op.cc
浏览文件 @
873a50ce
...
...
@@ -98,7 +98,7 @@ The update equations are as follows:
$$
velocity = mu * velocity + gradient \\
if (use\_nesterov): \\
param = param -
gradient * learning\_rate + mu * velocity
* learning\_rate \\
param = param -
(gradient + mu * velocity)
* learning\_rate \\
else: \\
param = param - learning\_rate * velocity. \\
$$
...
...
paddle/fluid/operators/momentum_op.cu
浏览文件 @
873a50ce
...
...
@@ -30,7 +30,7 @@ __global__ void MomentumKernel(const T* p, const T* g, const T* v,
T
g_val
=
g
[
i
];
T
v_new
=
v
[
i
]
*
mu
+
g_val
;
v_out
[
i
]
=
v_new
;
p_out
[
i
]
=
p
[
i
]
-
(
g_val
-
v_new
*
mu
)
*
lr
;
p_out
[
i
]
=
p
[
i
]
-
(
g_val
+
v_new
*
mu
)
*
lr
;
}
}
else
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
num
;
...
...
paddle/fluid/operators/momentum_op.h
浏览文件 @
873a50ce
...
...
@@ -46,7 +46,7 @@ class MomentumOpKernel : public framework::OpKernel<T> {
v_out
=
v
*
mu
+
g
;
if
(
use_nesterov
)
{
p_out
=
p
-
(
g
-
v_out
*
mu
)
*
lr
[
0
];
p_out
=
p
-
(
g
+
v_out
*
mu
)
*
lr
[
0
];
}
else
{
p_out
=
p
-
lr
[
0
]
*
v_out
;
}
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
873a50ce
...
...
@@ -166,7 +166,8 @@ def fc(input,
param_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for learnable
parameters/weights of this layer.
bias_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for the bias
of this layer. If it is set to None, no bias will be added to the output units.
of this layer. If it is set to False, no bias will be added to the output units.
If it is set to None, the bias is initialized zero. Default: None.
act (str, default None): Activation to be applied to the output of this layer.
is_test(bool): A flag indicating whether execution is in test phase.
use_mkldnn(bool): Use mkldnn kernel or not, it is valid only when the mkldnn
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
873a50ce
...
...
@@ -324,7 +324,7 @@ class MomentumOptimizer(Optimizer):
& if (use\_nesterov):
&\quad param = param -
gradient * learning\_rate + mu * velocity
* learning\_rate
&\quad param = param -
(gradient + mu * velocity)
* learning\_rate
& else:
...
...
python/paddle/fluid/tests/unittests/test_momentum_op.py
浏览文件 @
873a50ce
...
...
@@ -39,7 +39,7 @@ class TestMomentumOp1(OpTest):
velocity_out
=
mu
*
velocity
+
grad
if
use_nesterov
:
param_out
=
param
-
grad
*
learning_rate
+
\
param_out
=
param
-
grad
*
learning_rate
-
\
velocity_out
*
mu
*
learning_rate
else
:
param_out
=
param
-
learning_rate
*
velocity_out
...
...
@@ -75,7 +75,7 @@ class TestMomentumOp2(OpTest):
velocity_out
=
mu
*
velocity
+
grad
if
use_nesterov
:
param_out
=
param
-
grad
*
learning_rate
+
\
param_out
=
param
-
grad
*
learning_rate
-
\
velocity_out
*
mu
*
learning_rate
else
:
param_out
=
param
-
learning_rate
*
velocity_out
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录