diff --git a/paddle/operators/sgd_op.h b/paddle/operators/sgd_op.h index 2ee21ef8f93ae68bfdc3e012f82ea806dc90b6dc..4b2d214618e5c7c15695bd66604139d805255c47 100644 --- a/paddle/operators/sgd_op.h +++ b/paddle/operators/sgd_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include "glog/logging.h" +#include "paddle/framework/eigen.h" #include "paddle/framework/operator.h" namespace paddle { @@ -30,8 +31,10 @@ public: param_out->mutable_data(ctx.GetPlace()); - param_out->flat().device(*(ctx.GetEigenDevice())) = - param.flat() - lr * grad.flat(); + framework::EigenVector::Flatten(*param_out) + .device(*(ctx.GetEigenDevice())) = + framework::EigenVector::Flatten(param) - + lr * framework::EigenVector::Flatten(grad); } };