From d432b10d8aa5beb0e8576b8f9811048af98519bc Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Tue, 5 Dec 2017 13:45:17 +0800 Subject: [PATCH] Update cuda kernel and doc. --- paddle/operators/momentum_op.cc | 8 ++++++-- paddle/operators/momentum_op.cu | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/paddle/operators/momentum_op.cc b/paddle/operators/momentum_op.cc index fde253b0b38..2ab48fedecf 100644 --- a/paddle/operators/momentum_op.cc +++ b/paddle/operators/momentum_op.cc @@ -71,8 +71,12 @@ class MomentumOpMaker : public framework::OpProtoAndCheckerMaker { "(Tensor, default Tensor) " "Input learning rate"); - AddOutput("ParamOut", "(Tensor) Output updated parameter"); - AddOutput("VelocityOut", "(Tensor) Output updated velocity"); + AddOutput("ParamOut", + "(Tensor) This output is updated parameter. " + "It shared memory with Input(Param)."); + AddOutput("VelocityOut", + "(Tensor) This output is updated velocity. " + "It shared memory with Input(Velocity)."); AddAttr("mu", "(float) Momentum coefficient"); AddAttr("use_nesterov", diff --git a/paddle/operators/momentum_op.cu b/paddle/operators/momentum_op.cu index d856df40027..be0c8ea0717 100644 --- a/paddle/operators/momentum_op.cu +++ b/paddle/operators/momentum_op.cu @@ -29,7 +29,7 @@ __global__ void MomentumKernel(const T* p, const T* g, const T* v, T g_val = g[i]; T v_new = v[i] * mu + g_val; v_out[i] = v_new; - p_out[i] = p[i] - g_val * lr + v_new * mu * lr; + p_out[i] = p[i] - (g_val - v_new * mu) * lr; } } else { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; -- GitLab