提交 e498e1fc 编写于 作者: T Tomasz Patejko 提交者: Tao Luo

Adam operator optimized with Eigen (#10229)

* Some changes for Adam profiling

* Adam optimization: initial Eigen optimization

* Eigen Adam: flavour of adam can be chosen

* Eigen Adam used for CPU by default. Plain Adam used for GPU

* Eigen Adam: missing call to the Eigen functor added

* Eigen Adam: revert changes in benchmarks

* Eigen Adam: typo corrected
上级 0ecc6fa8
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <math.h> // for sqrt in CPU and CUDA #include <math.h> // for sqrt in CPU and CUDA
#include <Eigen/Dense>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
...@@ -24,8 +25,14 @@ namespace operators { ...@@ -24,8 +25,14 @@ namespace operators {
namespace scatter = paddle::operators::math::scatter; namespace scatter = paddle::operators::math::scatter;
struct GPUAdam;
struct CPUAdam;
template <typename T, typename Flavour>
struct AdamFunctor;
template <typename T> template <typename T>
struct AdamFunctor { struct AdamFunctor<T, GPUAdam> {
T beta1_; T beta1_;
T beta2_; T beta2_;
T epsilon_; T epsilon_;
...@@ -71,6 +78,7 @@ struct AdamFunctor { ...@@ -71,6 +78,7 @@ struct AdamFunctor {
// Calculation // Calculation
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow); lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
mom1 = beta1_ * mom1 + (1 - beta1_) * g; mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
...@@ -82,6 +90,71 @@ struct AdamFunctor { ...@@ -82,6 +90,71 @@ struct AdamFunctor {
} }
}; };
template <typename T>
struct AdamFunctor<T, CPUAdam> {
T beta1_;
T beta2_;
T epsilon_;
const T* beta1_pow_;
const T* beta2_pow_;
const T* moment1_;
T* moment1_out_;
const T* moment2_;
T* moment2_out_;
const T* lr_;
const T* grad_;
const T* param_;
T* param_out_;
AdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
const T* beta2_pow, const T* mom1, T* mom1_out, const T* mom2,
T* mom2_out, const T* lr, const T* grad, const T* param,
T* param_out)
: beta1_(beta1),
beta2_(beta2),
epsilon_(epsilon),
beta1_pow_(beta1_pow),
beta2_pow_(beta2_pow),
moment1_(mom1),
moment1_out_(mom1_out),
moment2_(mom2),
moment2_out_(mom2_out),
lr_(lr),
grad_(grad),
param_(param),
param_out_(param_out) {}
void operator()(size_t numel) const {
Eigen::Map<const Eigen::Array<T, 1, Eigen::Dynamic>> g{
grad_, static_cast<Eigen::Index>(numel)};
Eigen::Map<const Eigen::Array<T, 1, Eigen::Dynamic>> mom1{
moment1_, static_cast<Eigen::Index>(numel)};
Eigen::Map<const Eigen::Array<T, 1, Eigen::Dynamic>> mom2{
moment2_, static_cast<Eigen::Index>(numel)};
Eigen::Map<const Eigen::Array<T, 1, Eigen::Dynamic>> param{
param_, static_cast<Eigen::Index>(numel)};
Eigen::Map<Eigen::Array<T, 1, Eigen::Dynamic>> param_out{
param_out_, static_cast<Eigen::Index>(numel)};
Eigen::Map<Eigen::Array<T, 1, Eigen::Dynamic>> moment1_out{
moment1_out_, static_cast<Eigen::Index>(numel)};
Eigen::Map<Eigen::Array<T, 1, Eigen::Dynamic>> moment2_out{
moment2_out_, static_cast<Eigen::Index>(numel)};
T lr = *lr_;
T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_;
// Calculation
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
moment1_out = beta1_ * mom1 + (1 - beta1_) * g;
moment2_out = beta2_ * mom2 + (1 - beta2_) * g * g;
param_out = param - lr * (moment1_out / (moment2_out.sqrt() + epsilon_));
}
};
template <typename T> template <typename T>
struct SparseAdamFunctor { struct SparseAdamFunctor {
T beta1_; T beta1_;
...@@ -134,6 +207,7 @@ struct SparseAdamFunctor { ...@@ -134,6 +207,7 @@ struct SparseAdamFunctor {
T p = param_[rows_[i] * row_numel_ + j]; T p = param_[rows_[i] * row_numel_ + j];
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow); lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
mom1 = beta1_ * mom1 + (1 - beta1_) * g; mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
...@@ -177,19 +251,34 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -177,19 +251,34 @@ class AdamOpKernel : public framework::OpKernel<T> {
if (grad_var->IsType<framework::LoDTensor>()) { if (grad_var->IsType<framework::LoDTensor>()) {
auto& grad = Ref(ctx.Input<LoDTensor>("Grad"), "Must set Grad"); auto& grad = Ref(ctx.Input<LoDTensor>("Grad"), "Must set Grad");
AdamFunctor<T> functor(
beta1, beta2, epsilon, beta1_pow.template data<T>(), if (platform::is_cpu_place(ctx.GetPlace())) {
beta2_pow.template data<T>(), mom1.template data<T>(), AdamFunctor<T, CPUAdam> functor(
mom1_out.template mutable_data<T>(ctx.GetPlace()), beta1, beta2, epsilon, beta1_pow.template data<T>(),
mom2.template data<T>(), beta2_pow.template data<T>(), mom1.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()), mom1_out.template mutable_data<T>(ctx.GetPlace()),
lr.template data<T>(), grad.template data<T>(), mom2.template data<T>(),
param.template data<T>(), mom2_out.template mutable_data<T>(ctx.GetPlace()),
param_out.template mutable_data<T>(ctx.GetPlace())); lr.template data<T>(), grad.template data<T>(),
platform::ForRange<DeviceContext> for_range( param.template data<T>(),
static_cast<const DeviceContext&>(ctx.device_context()), param_out.template mutable_data<T>(ctx.GetPlace()));
param.numel()); functor(param.numel());
for_range(functor); } else if (platform::is_gpu_place(ctx.GetPlace())) {
AdamFunctor<T, GPUAdam> functor(
beta1, beta2, epsilon, beta1_pow.template data<T>(),
beta2_pow.template data<T>(), mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()),
lr.template data<T>(), grad.template data<T>(),
param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()));
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()),
param.numel());
for_range(functor);
}
} else if (grad_var->IsType<framework::SelectedRows>()) { } else if (grad_var->IsType<framework::SelectedRows>()) {
auto& grad = auto& grad =
Ref(ctx.Input<framework::SelectedRows>("Grad"), "Must set Grad"); Ref(ctx.Input<framework::SelectedRows>("Grad"), "Must set Grad");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册