提交 ee7b3ed0 编写于 作者: Q qiaolongfei

use EigenScalar to get learning_rate from GPU device

上级 15b35f9a
...@@ -23,6 +23,9 @@ using Tensor = framework::Tensor; ...@@ -23,6 +23,9 @@ using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class SGDOpKernel : public framework::OpKernel<T> { class SGDOpKernel : public framework::OpKernel<T> {
...@@ -31,13 +34,14 @@ class SGDOpKernel : public framework::OpKernel<T> { ...@@ -31,13 +34,14 @@ class SGDOpKernel : public framework::OpKernel<T> {
auto param = ctx.Input<Tensor>("Param"); auto param = ctx.Input<Tensor>("Param");
auto grad = ctx.Input<Tensor>("Grad"); auto grad = ctx.Input<Tensor>("Grad");
auto param_out = ctx.Output<Tensor>("ParamOut"); auto param_out = ctx.Output<Tensor>("ParamOut");
float lr = ctx.Input<Tensor>("LearningRate")->data<float>()[0]; auto learning_rate = ctx.Input<Tensor>("LearningRate");
param_out->mutable_data<T>(ctx.GetPlace()); param_out->mutable_data<T>(ctx.GetPlace());
auto p = EigenVector<T>::Flatten(*param); auto p = EigenVector<T>::Flatten(*param);
auto g = EigenVector<T>::Flatten(*grad); auto g = EigenVector<T>::Flatten(*grad);
auto o = EigenVector<T>::Flatten(*param_out); auto o = EigenVector<T>::Flatten(*param_out);
auto lr = EigenScalar<T>::From(*learning_rate);
auto place = ctx.GetEigenDevice<Place>(); auto place = ctx.GetEigenDevice<Place>();
o.device(place) = p - lr * g; o.device(place) = p - lr * g;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册