From 873a50ce35478d76e10a0b58214755f323043a80 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Fri, 20 Jul 2018 17:26:02 +0800 Subject: [PATCH] Fix serious bug in nesterov momentum optimizer. (#12231) * Fix serious bug in nesterov momentum optimizer. --- paddle/fluid/operators/momentum_op.cc | 2 +- paddle/fluid/operators/momentum_op.cu | 2 +- paddle/fluid/operators/momentum_op.h | 2 +- python/paddle/fluid/layers/nn.py | 3 ++- python/paddle/fluid/optimizer.py | 2 +- python/paddle/fluid/tests/unittests/test_momentum_op.py | 4 ++-- 6 files changed, 8 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/momentum_op.cc b/paddle/fluid/operators/momentum_op.cc index dcd73e3c3e4..5f43c581081 100644 --- a/paddle/fluid/operators/momentum_op.cc +++ b/paddle/fluid/operators/momentum_op.cc @@ -98,7 +98,7 @@ The update equations are as follows: $$ velocity = mu * velocity + gradient \\ if (use\_nesterov): \\ - param = param - gradient * learning\_rate + mu * velocity * learning\_rate \\ + param = param - (gradient + mu * velocity) * learning\_rate \\ else: \\ param = param - learning\_rate * velocity. \\ $$ diff --git a/paddle/fluid/operators/momentum_op.cu b/paddle/fluid/operators/momentum_op.cu index 5eb9d995024..a3932db1f3a 100644 --- a/paddle/fluid/operators/momentum_op.cu +++ b/paddle/fluid/operators/momentum_op.cu @@ -30,7 +30,7 @@ __global__ void MomentumKernel(const T* p, const T* g, const T* v, T g_val = g[i]; T v_new = v[i] * mu + g_val; v_out[i] = v_new; - p_out[i] = p[i] - (g_val - v_new * mu) * lr; + p_out[i] = p[i] - (g_val + v_new * mu) * lr; } } else { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; diff --git a/paddle/fluid/operators/momentum_op.h b/paddle/fluid/operators/momentum_op.h index 04a1929b84a..264726040fb 100644 --- a/paddle/fluid/operators/momentum_op.h +++ b/paddle/fluid/operators/momentum_op.h @@ -46,7 +46,7 @@ class MomentumOpKernel : public framework::OpKernel { v_out = v * mu + g; if (use_nesterov) { - p_out = p - (g - v_out * mu) * lr[0]; + p_out = p - (g + v_out * mu) * lr[0]; } else { p_out = p - lr[0] * v_out; } diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 56124663929..ab40d0c217f 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -166,7 +166,8 @@ def fc(input, param_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for learnable parameters/weights of this layer. bias_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for the bias - of this layer. If it is set to None, no bias will be added to the output units. + of this layer. If it is set to False, no bias will be added to the output units. + If it is set to None, the bias is initialized zero. Default: None. act (str, default None): Activation to be applied to the output of this layer. is_test(bool): A flag indicating whether execution is in test phase. use_mkldnn(bool): Use mkldnn kernel or not, it is valid only when the mkldnn diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 7fc8e106fb4..3fe99f55011 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -324,7 +324,7 @@ class MomentumOptimizer(Optimizer): & if (use\_nesterov): - &\quad param = param - gradient * learning\_rate + mu * velocity * learning\_rate + &\quad param = param - (gradient + mu * velocity) * learning\_rate & else: diff --git a/python/paddle/fluid/tests/unittests/test_momentum_op.py b/python/paddle/fluid/tests/unittests/test_momentum_op.py index aaea9c18092..c75d3bd276a 100644 --- a/python/paddle/fluid/tests/unittests/test_momentum_op.py +++ b/python/paddle/fluid/tests/unittests/test_momentum_op.py @@ -39,7 +39,7 @@ class TestMomentumOp1(OpTest): velocity_out = mu * velocity + grad if use_nesterov: - param_out = param - grad * learning_rate + \ + param_out = param - grad * learning_rate - \ velocity_out * mu * learning_rate else: param_out = param - learning_rate * velocity_out @@ -75,7 +75,7 @@ class TestMomentumOp2(OpTest): velocity_out = mu * velocity + grad if use_nesterov: - param_out = param - grad * learning_rate + \ + param_out = param - grad * learning_rate - \ velocity_out * mu * learning_rate else: param_out = param - learning_rate * velocity_out -- GitLab