提交 d432b10d 编写于 作者: D dangqingqing

Update cuda kernel and doc.

上级 e03b574e
...@@ -71,8 +71,12 @@ class MomentumOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -71,8 +71,12 @@ class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor, default Tensor<float>) " "(Tensor, default Tensor<float>) "
"Input learning rate"); "Input learning rate");
AddOutput("ParamOut", "(Tensor) Output updated parameter"); AddOutput("ParamOut",
AddOutput("VelocityOut", "(Tensor) Output updated velocity"); "(Tensor) This output is updated parameter. "
"It shared memory with Input(Param).");
AddOutput("VelocityOut",
"(Tensor) This output is updated velocity. "
"It shared memory with Input(Velocity).");
AddAttr<float>("mu", "(float) Momentum coefficient"); AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<bool>("use_nesterov", AddAttr<bool>("use_nesterov",
......
...@@ -29,7 +29,7 @@ __global__ void MomentumKernel(const T* p, const T* g, const T* v, ...@@ -29,7 +29,7 @@ __global__ void MomentumKernel(const T* p, const T* g, const T* v,
T g_val = g[i]; T g_val = g[i];
T v_new = v[i] * mu + g_val; T v_new = v[i] * mu + g_val;
v_out[i] = v_new; v_out[i] = v_new;
p_out[i] = p[i] - g_val * lr + v_new * mu * lr; p_out[i] = p[i] - (g_val - v_new * mu) * lr;
} }
} else { } else {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册