提交 8ebc31d9 编写于 作者: Q qiaolongfei

optimize the dsize

上级 775c6024
......@@ -33,10 +33,10 @@ class SGDOpKernel : public framework::OpKernel<T> {
auto p = framework::EigenVector<T>::Flatten(*param);
auto g = framework::EigenVector<T>::Flatten(*grad);
auto o = framework::EigenVector<T>::Flatten(*param_out);
auto lr = framework::EigenVector<T>::From(*learning_rate);
auto lr = framework::EigenVector<T>::Flatten(*learning_rate);
auto place = ctx.GetEigenDevice<Place>();
Eigen::DSizes<int, 2> grad_dsize(grad->dims()[0], grad->dims()[1]);
Eigen::DSizes<int, 1> grad_dsize(grad->numel());
o.device(place) = p - lr.broadcast(grad_dsize) * g;
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册