提交 d432b10d 编写于 作者: D dangqingqing

Update cuda kernel and doc.

上级 e03b574e
......@@ -71,8 +71,12 @@ class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor, default Tensor<float>) "
"Input learning rate");
AddOutput("ParamOut", "(Tensor) Output updated parameter");
AddOutput("VelocityOut", "(Tensor) Output updated velocity");
AddOutput("ParamOut",
"(Tensor) This output is updated parameter. "
"It shared memory with Input(Param).");
AddOutput("VelocityOut",
"(Tensor) This output is updated velocity. "
"It shared memory with Input(Velocity).");
AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<bool>("use_nesterov",
......
......@@ -29,7 +29,7 @@ __global__ void MomentumKernel(const T* p, const T* g, const T* v,
T g_val = g[i];
T v_new = v[i] * mu + g_val;
v_out[i] = v_new;
p_out[i] = p[i] - g_val * lr + v_new * mu * lr;
p_out[i] = p[i] - (g_val - v_new * mu) * lr;
}
} else {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册