diff --git a/paddle/operators/momentum_op.h b/paddle/operators/momentum_op.h index 2d919573d201b8a511f657cd187fe6cc0344069a..da69532ea58bad8d3908770d82dbcc6e6b108fce 100644 --- a/paddle/operators/momentum_op.h +++ b/paddle/operators/momentum_op.h @@ -44,15 +44,11 @@ class MomentumOpKernel : public framework::OpKernel { auto g = framework::EigenVector::Flatten(*grad); auto* lr = learning_rate->data(); - auto place = ctx.GetEigenDevice(); - - Eigen::DSizes grad_dsize(grad->numel()); - - v_out.device(place) = v * mu + g; + v_out = v * mu + g; if (use_nesterov) { - p_out.device(place) = p - (g - v_out * mu) * lr[0]; + p_out = p - (g - v_out * mu) * lr[0]; } else { - p_out.device(place) = p - lr[0] * v_out; + p_out = p - lr[0] * v_out; } } };