未验证 提交 d06bbb12 编写于 作者: Z zhangchao 提交者: GitHub

Merge pull request #7203 from peterzhang2029/fix_adagrad

Fix adagrad op by removing broadcast of Eigen.
...@@ -47,8 +47,7 @@ class AdagradOpKernel : public framework::OpKernel<T> { ...@@ -47,8 +47,7 @@ class AdagradOpKernel : public framework::OpKernel<T> {
*ctx.Input<framework::Tensor>("Grad")); *ctx.Input<framework::Tensor>("Grad"));
auto moment = framework::EigenVector<T>::Flatten( auto moment = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Moment")); *ctx.Input<framework::Tensor>("Moment"));
auto lr = framework::EigenVector<T>::Flatten( auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate");
*ctx.Input<framework::Tensor>("LearningRate"));
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor); auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor); auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor);
...@@ -56,8 +55,16 @@ class AdagradOpKernel : public framework::OpKernel<T> { ...@@ -56,8 +55,16 @@ class AdagradOpKernel : public framework::OpKernel<T> {
moment_out.device(*place) = moment + grad * grad; moment_out.device(*place) = moment + grad * grad;
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel()); Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
param_out.device(*place) = if (platform::is_cpu_place(ctx.GetPlace())) {
param - lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon); auto* lr = learning_rate->data<T>();
param_out.device(*place) =
param - lr[0] * grad / (moment_out.sqrt() + epsilon);
} else {
auto lr = framework::EigenVector<T>::Flatten(*learning_rate);
param_out.device(*place) =
param -
lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon);
}
} else if (grad_var->IsType<framework::SelectedRows>()) { } else if (grad_var->IsType<framework::SelectedRows>()) {
auto* param_tensor = ctx.Input<framework::Tensor>("Param"); auto* param_tensor = ctx.Input<framework::Tensor>("Param");
PADDLE_ENFORCE_EQ(param_tensor, param_out_tensor); PADDLE_ENFORCE_EQ(param_tensor, param_out_tensor);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册