未验证 提交 873a50ce 编写于 作者: Q qingqing01 提交者: GitHub

Fix serious bug in nesterov momentum optimizer. (#12231)

* Fix serious bug in nesterov momentum optimizer.
上级 b42ced8e
......@@ -98,7 +98,7 @@ The update equations are as follows:
$$
velocity = mu * velocity + gradient \\
if (use\_nesterov): \\
param = param - gradient * learning\_rate + mu * velocity * learning\_rate \\
param = param - (gradient + mu * velocity) * learning\_rate \\
else: \\
param = param - learning\_rate * velocity. \\
$$
......
......@@ -30,7 +30,7 @@ __global__ void MomentumKernel(const T* p, const T* g, const T* v,
T g_val = g[i];
T v_new = v[i] * mu + g_val;
v_out[i] = v_new;
p_out[i] = p[i] - (g_val - v_new * mu) * lr;
p_out[i] = p[i] - (g_val + v_new * mu) * lr;
}
} else {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
......
......@@ -46,7 +46,7 @@ class MomentumOpKernel : public framework::OpKernel<T> {
v_out = v * mu + g;
if (use_nesterov) {
p_out = p - (g - v_out * mu) * lr[0];
p_out = p - (g + v_out * mu) * lr[0];
} else {
p_out = p - lr[0] * v_out;
}
......
......@@ -166,7 +166,8 @@ def fc(input,
param_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for learnable
parameters/weights of this layer.
bias_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for the bias
of this layer. If it is set to None, no bias will be added to the output units.
of this layer. If it is set to False, no bias will be added to the output units.
If it is set to None, the bias is initialized zero. Default: None.
act (str, default None): Activation to be applied to the output of this layer.
is_test(bool): A flag indicating whether execution is in test phase.
use_mkldnn(bool): Use mkldnn kernel or not, it is valid only when the mkldnn
......
......@@ -324,7 +324,7 @@ class MomentumOptimizer(Optimizer):
& if (use\_nesterov):
&\quad param = param - gradient * learning\_rate + mu * velocity * learning\_rate
&\quad param = param - (gradient + mu * velocity) * learning\_rate
& else:
......
......@@ -39,7 +39,7 @@ class TestMomentumOp1(OpTest):
velocity_out = mu * velocity + grad
if use_nesterov:
param_out = param - grad * learning_rate + \
param_out = param - grad * learning_rate - \
velocity_out * mu * learning_rate
else:
param_out = param - learning_rate * velocity_out
......@@ -75,7 +75,7 @@ class TestMomentumOp2(OpTest):
velocity_out = mu * velocity + grad
if use_nesterov:
param_out = param - grad * learning_rate + \
param_out = param - grad * learning_rate - \
velocity_out * mu * learning_rate
else:
param_out = param - learning_rate * velocity_out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册