From 189e0d44eaa3ef7833d1f7ed351ebcbc3113b83a Mon Sep 17 00:00:00 2001 From: wangzhen38 <41941775+wangzhen38@users.noreply.github.com> Date: Wed, 12 Apr 2023 11:01:45 +0800 Subject: [PATCH] Patch del (#52754) * [DO NOT MERGE] adadelta lr support * [DO NOT MERGE] gpu support * [test] follow torch * fix acc update order * for ci * [bug fix] update master para * [bug fix] update test * [bug fix] for ci test * for ci * fix xpu * [adadelta fix] del fluid head file * for ci * del notes --- .../phi/kernels/impl/adadelta_kernel_impl.h | 34 +++++++------------ 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/paddle/phi/kernels/impl/adadelta_kernel_impl.h b/paddle/phi/kernels/impl/adadelta_kernel_impl.h index c432c72d832..18fcd953d65 100644 --- a/paddle/phi/kernels/impl/adadelta_kernel_impl.h +++ b/paddle/phi/kernels/impl/adadelta_kernel_impl.h @@ -13,10 +13,6 @@ // limitations under the License. #pragma once -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/eigen/eigen_function.h" - #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/kernels/adadelta_kernel.h" #include "paddle/phi/kernels/funcs/eigen/common.h" @@ -67,26 +63,20 @@ void AdadeltaKernel(const Context& dev_ctx, -(((eigen_avg_squared_update + epsilon_).sqrt()) / ((eigen_avg_squared_grad_out + epsilon_).sqrt()) * eigen_grad_cast); Eigen::DSizes m_dsize(avg_squared_update_out->numel()); - if (paddle::platform::is_cpu_place(dev_ctx.GetPlace())) { - auto* lr = learning_rate.data(); + auto lr = EigenVector::Flatten(learning_rate); + if (multi_precision) { + auto eigen_master_param_out = + EigenVector::Flatten(*master_param_outs); + auto eigen_master_param = EigenVector::Flatten(*master_param); + + eigen_master_param_out.device(place) = + eigen_master_param + lr.broadcast(m_dsize) * update; eigen_param_out.device(place) = - eigen_param + lr[0] * update.template cast(); + (eigen_param.template cast() + lr.broadcast(m_dsize) * update) + .template cast(); } else { - auto lr = EigenVector::Flatten(learning_rate); - if (multi_precision) { - auto eigen_master_param_out = - EigenVector::Flatten(*master_param_outs); - auto eigen_master_param = EigenVector::Flatten(*master_param); - - eigen_master_param_out.device(place) = - eigen_master_param + lr.broadcast(m_dsize) * update; - eigen_param_out.device(place) = (eigen_param.template cast() + - lr.broadcast(m_dsize) * update) - .template cast(); - } else { - eigen_param_out.device(place) = - eigen_param + (lr.broadcast(m_dsize) * update).template cast(); - } + eigen_param_out.device(place) = + eigen_param + (lr.broadcast(m_dsize) * update).template cast(); } eigen_avg_squared_update_out.device(place) = rho_ * eigen_avg_squared_update + (1 - rho_) * update.square(); -- GitLab