提交 f52cdaa0 编写于 作者: K Kavya Srinet

Updated RMSProp to have learning rate as an input and work with GPU

上级 03363041
...@@ -32,6 +32,8 @@ class RmspropOpKernel : public framework::OpKernel<T> { ...@@ -32,6 +32,8 @@ class RmspropOpKernel : public framework::OpKernel<T> {
auto* moment_out = ctx.Output<Tensor>("MomentOut"); auto* moment_out = ctx.Output<Tensor>("MomentOut");
auto* mean_square_out = ctx.Output<Tensor>("MeanSquareOut"); auto* mean_square_out = ctx.Output<Tensor>("MeanSquareOut");
auto grad = ctx.Input<Tensor>("Grad");
param_out->mutable_data<T>(ctx.GetPlace()); param_out->mutable_data<T>(ctx.GetPlace());
moment_out->mutable_data<T>(ctx.GetPlace()); moment_out->mutable_data<T>(ctx.GetPlace());
mean_square_out->mutable_data<T>(ctx.GetPlace()); mean_square_out->mutable_data<T>(ctx.GetPlace());
...@@ -42,8 +44,8 @@ class RmspropOpKernel : public framework::OpKernel<T> { ...@@ -42,8 +44,8 @@ class RmspropOpKernel : public framework::OpKernel<T> {
auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param")); auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param"));
auto ms = EigenVector<T>::Flatten(*ctx.Input<Tensor>("MeanSquare")); auto ms = EigenVector<T>::Flatten(*ctx.Input<Tensor>("MeanSquare"));
float lr = ctx.Input<Tensor>("LearningRate")->data<float>()[0]; auto lr = EigenVector<T>::Flatten(*ctx.Input<Tensor>("LearningRate"));
auto g = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Grad")); auto g = EigenVector<T>::Flatten(*grad);
auto mom = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Moment")); auto mom = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Moment"));
auto p_out = EigenVector<T>::Flatten(*param_out); auto p_out = EigenVector<T>::Flatten(*param_out);
...@@ -51,8 +53,12 @@ class RmspropOpKernel : public framework::OpKernel<T> { ...@@ -51,8 +53,12 @@ class RmspropOpKernel : public framework::OpKernel<T> {
auto ms_out = EigenVector<T>::Flatten(*mean_square_out); auto ms_out = EigenVector<T>::Flatten(*mean_square_out);
auto place = ctx.GetEigenDevice<Place>(); auto place = ctx.GetEigenDevice<Place>();
Eigen::DSizes<int, 1> grad_dsize(grad->numel());
ms_out.device(place) = rho * ms + (1 - rho) * g * g; ms_out.device(place) = rho * ms + (1 - rho) * g * g;
mom_out.device(place) = momentum * mom + lr * g / (ms_out + epsilon).sqrt(); mom_out.device(place) =
momentum * mom +
lr.broadcast(grad_dsize) * g / (ms_out + epsilon).sqrt();
p_out.device(place) = p - mom_out; p_out.device(place) = p - mom_out;
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册