diff --git a/paddle/operators/momentum_op.cc b/paddle/operators/momentum_op.cc index fde253b0b38b3be42842cc3e3d612d24df874a2c..2ab48fedecf0cce95dcf4d0593dcd4b30bc1f505 100644 --- a/paddle/operators/momentum_op.cc +++ b/paddle/operators/momentum_op.cc @@ -71,8 +71,12 @@ class MomentumOpMaker : public framework::OpProtoAndCheckerMaker { "(Tensor, default Tensor) " "Input learning rate"); - AddOutput("ParamOut", "(Tensor) Output updated parameter"); - AddOutput("VelocityOut", "(Tensor) Output updated velocity"); + AddOutput("ParamOut", + "(Tensor) This output is updated parameter. " + "It shared memory with Input(Param)."); + AddOutput("VelocityOut", + "(Tensor) This output is updated velocity. " + "It shared memory with Input(Velocity)."); AddAttr("mu", "(float) Momentum coefficient"); AddAttr("use_nesterov", diff --git a/paddle/operators/momentum_op.cu b/paddle/operators/momentum_op.cu index d856df400275e150728922d5e698e4ee4d66209f..be0c8ea071716b75eaeddab209a52b3d5f2f7e16 100644 --- a/paddle/operators/momentum_op.cu +++ b/paddle/operators/momentum_op.cu @@ -29,7 +29,7 @@ __global__ void MomentumKernel(const T* p, const T* g, const T* v, T g_val = g[i]; T v_new = v[i] * mu + g_val; v_out[i] = v_new; - p_out[i] = p[i] - g_val * lr + v_new * mu * lr; + p_out[i] = p[i] - (g_val - v_new * mu) * lr; } } else { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;