未验证 提交 3672480b 编写于 作者: S sneaxiy 提交者: GitHub

fix lamb beta1pow beta2pow update (#38518)

上级 72a41e50
......@@ -52,19 +52,16 @@ struct LambMomentREGUpdateFunctor {
const bool* skip_update_;
LambMomentREGUpdateFunctor(MT weight_decay, MT beta1, MT beta2, MT epsilon,
MT beta1_pow, MT* beta1_pow_out, MT beta2_pow,
MT* beta2_pow_out, const MT* mom1, MT* mom1_out,
const MT* mom2, MT* mom2_out, const T* grad,
const MT* param, MT* trust_ratio_div,
const bool* skip_update)
MT beta1_pow, MT beta2_pow, const MT* mom1,
MT* mom1_out, const MT* mom2, MT* mom2_out,
const T* grad, const MT* param,
MT* trust_ratio_div, const bool* skip_update)
: weight_decay_(weight_decay),
beta1_(beta1),
beta2_(beta2),
epsilon_(epsilon),
beta1_pow_(beta1_pow),
beta1_pow_out_(beta1_pow_out),
beta2_pow_(beta2_pow),
beta2_pow_out_(beta2_pow_out),
moment1_(mom1),
moment1_out_(mom1_out),
moment2_(mom2),
......@@ -95,10 +92,6 @@ struct LambMomentREGUpdateFunctor {
trust_ratio_div_[i] =
mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) +
weight_decay_ * p;
if (beta1_pow_out_ && beta2_pow_out_) {
beta1_pow_out_[0] = beta1_pow * beta1_;
beta2_pow_out_[0] = beta2_pow * beta2_;
}
}
};
......@@ -113,9 +106,7 @@ struct LambMomentMENUpdateFunctor {
MT epsilon_;
const MT* beta1_pow_;
MT* beta1_pow_out_;
const MT* beta2_pow_;
MT* beta2_pow_out_;
const MT* moment1_;
MT* moment1_out_;
const MT* moment2_;
......@@ -126,8 +117,7 @@ struct LambMomentMENUpdateFunctor {
const bool* skip_update_;
LambMomentMENUpdateFunctor(MT weight_decay, MT beta1, MT beta2, MT epsilon,
const MT* beta1_pow, MT* beta1_pow_out,
const MT* beta2_pow, MT* beta2_pow_out,
const MT* beta1_pow, const MT* beta2_pow,
const MT* mom1, MT* mom1_out, const MT* mom2,
MT* mom2_out, const T* grad, const MT* param,
MT* trust_ratio_div, const bool* skip_update)
......@@ -136,9 +126,7 @@ struct LambMomentMENUpdateFunctor {
beta2_(beta2),
epsilon_(epsilon),
beta1_pow_(beta1_pow),
beta1_pow_out_(beta1_pow_out),
beta2_pow_(beta2_pow),
beta2_pow_out_(beta2_pow_out),
moment1_(mom1),
moment1_out_(mom1_out),
moment2_(mom2),
......@@ -168,10 +156,6 @@ struct LambMomentMENUpdateFunctor {
trust_ratio_div_[i] =
mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) +
weight_decay_ * p;
if (beta1_pow_out_ && beta2_pow_out_) {
beta1_pow_out_[0] = beta1_pow * beta1_;
beta2_pow_out_[0] = beta2_pow * beta2_;
}
}
};
......@@ -183,9 +167,7 @@ struct SparseLambMomentREGUpdateFunctor {
T epsilon_;
T beta1_pow_;
T* beta1_pow_out_;
T beta2_pow_;
T* beta2_pow_out_;
const T* moment1_;
T* moment1_out_;
const T* moment2_;
......@@ -201,20 +183,18 @@ struct SparseLambMomentREGUpdateFunctor {
const bool* skip_update_;
SparseLambMomentREGUpdateFunctor(T weight_decay, T beta1, T beta2, T epsilon,
T beta1_pow, T* beta1_pow_out, T beta2_pow,
T* beta2_pow_out, const T* mom1, T* mom1_out,
const T* mom2, T* mom2_out, const T* grad,
const T* param, T* trust_ratio_div,
const int64_t* rows, int64_t row_numel,
int64_t row_count, const bool* skip_update)
T beta1_pow, T beta2_pow, const T* mom1,
T* mom1_out, const T* mom2, T* mom2_out,
const T* grad, const T* param,
T* trust_ratio_div, const int64_t* rows,
int64_t row_numel, int64_t row_count,
const bool* skip_update)
: weight_decay_(weight_decay),
beta1_(beta1),
beta2_(beta2),
epsilon_(epsilon),
beta1_pow_(beta1_pow),
beta1_pow_out_(beta1_pow_out),
beta2_pow_(beta2_pow),
beta2_pow_out_(beta2_pow_out),
moment1_(mom1),
moment1_out_(mom1_out),
moment2_(mom2),
......@@ -246,10 +226,6 @@ struct SparseLambMomentREGUpdateFunctor {
trust_ratio_div_[i] =
mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) +
weight_decay_ * p;
if (beta1_pow_out_ && beta1_pow_out_) {
beta1_pow_out_[0] = beta1_pow * beta1_;
beta2_pow_out_[0] = beta2_pow * beta2_;
}
}
inline HOSTDEVICE void operator()(size_t i) const {
......@@ -270,9 +246,7 @@ struct SparseLambMomentMENUpdateFunctor {
T epsilon_;
const T* beta1_pow_;
T* beta1_pow_out_;
const T* beta2_pow_;
T* beta2_pow_out_;
const T* moment1_;
T* moment1_out_;
const T* moment2_;
......@@ -288,8 +262,7 @@ struct SparseLambMomentMENUpdateFunctor {
const bool* skip_update_;
SparseLambMomentMENUpdateFunctor(T weight_decay, T beta1, T beta2, T epsilon,
const T* beta1_pow, T* beta1_pow_out,
const T* beta2_pow, T* beta2_pow_out,
const T* beta1_pow, const T* beta2_pow,
const T* mom1, T* mom1_out, const T* mom2,
T* mom2_out, const T* grad, const T* param,
T* trust_ratio_div, const int64_t* rows,
......@@ -300,9 +273,7 @@ struct SparseLambMomentMENUpdateFunctor {
beta2_(beta2),
epsilon_(epsilon),
beta1_pow_(beta1_pow),
beta1_pow_out_(beta1_pow_out),
beta2_pow_(beta2_pow),
beta2_pow_out_(beta2_pow_out),
moment1_(mom1),
moment1_out_(mom1_out),
moment2_(mom2),
......@@ -334,10 +305,6 @@ struct SparseLambMomentMENUpdateFunctor {
trust_ratio_div_[i] =
mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) +
weight_decay_ * p;
if (beta1_pow_out_ && beta1_pow_out_) {
beta1_pow_out_[0] = beta1_pow * beta1_;
beta2_pow_out_[0] = beta2_pow * beta2_;
}
}
inline HOSTDEVICE void operator()(size_t i) const {
......@@ -350,11 +317,44 @@ struct SparseLambMomentMENUpdateFunctor {
}
};
template <typename T, bool IsMultiPrecision>
struct LambParamUpateFunctor {
using MT = typename std::conditional<
IsMultiPrecision, typename details::MPTypeTrait<T>::Type, T>::type;
template <typename MT, bool NeedUpdateBetaPow /*=true*/>
struct LambBetaPowUpdateFunctor {
void SetBetaPows(const MT* beta1pow, const MT* beta2pow, MT* beta1pow_out,
MT* beta2pow_out, MT beta1, MT beta2) {
beta1pow_ = beta1pow;
beta2pow_ = beta2pow;
beta1pow_out_ = beta1pow_out;
beta2pow_out_ = beta2pow_out;
beta1_ = beta1;
beta2_ = beta2;
}
HOSTDEVICE void UpdateBetaPow(size_t i) const {
if (i == 0) {
beta1pow_out_[0] = beta1pow_[0] * beta1_;
beta2pow_out_[0] = beta2pow_[0] * beta2_;
}
}
private:
const MT* beta1pow_;
const MT* beta2pow_;
MT* beta1pow_out_;
MT* beta2pow_out_;
MT beta1_;
MT beta2_;
};
template <typename MT>
struct LambBetaPowUpdateFunctor<MT, /*NeedUpdateBetaPow=*/false> {
void SetBetaPows(const MT* beta1pow, const MT* beta2pow, MT* beta1pow_out,
MT* beta2pow_out, MT beta1, MT beta2) {}
HOSTDEVICE void UpdateBetaPow(size_t) const {}
};
template <typename T, typename MT, bool IsMultiPrecision, bool UpdateBetaPow>
struct LambParamUpateFunctor
: public LambBetaPowUpdateFunctor<MT, UpdateBetaPow> {
const MT* lr_;
const T* param_;
const MT* master_param_;
......@@ -396,6 +396,7 @@ struct LambParamUpateFunctor {
if (IsMultiPrecision) {
master_param_out_[i] = param_out;
}
this->UpdateBetaPow(i);
}
};
......@@ -501,6 +502,11 @@ class LambOpKernel : public framework::OpKernel<T> {
: nullptr;
// Update moments
bool should_update_beta_pow_later = false;
const MT *beta1_pow_ptr = nullptr, *beta2_pow_ptr = nullptr;
MT *beta1_pow_out_ptr = nullptr, *beta2_pow_out_ptr = nullptr;
VLOG(10) << "Beta1Pow place: " << beta1_pow.place()
<< " , Beta2Pow place: " << beta2_pow.place();
if (grad_var->IsType<framework::LoDTensor>()) {
auto& grad = grad_var->Get<framework::LoDTensor>();
if (platform::is_gpu_place(ctx.GetPlace()) &&
......@@ -508,8 +514,7 @@ class LambOpKernel : public framework::OpKernel<T> {
beta2_pow.place() == platform::CPUPlace()) {
LambMomentREGUpdateFunctor<T, IsMultiPrecision> moment_update_functor(
weight_decay, beta1, beta2, epsilon, *beta1_pow.template data<MT>(),
nullptr, *beta2_pow.template data<MT>(), nullptr,
mom1.template data<MT>(),
*beta2_pow.template data<MT>(), mom1.template data<MT>(),
mom1_out.template mutable_data<MT>(ctx.GetPlace()),
mom2.template data<MT>(),
mom2_out.template mutable_data<MT>(ctx.GetPlace()),
......@@ -523,12 +528,17 @@ class LambOpKernel : public framework::OpKernel<T> {
beta2_pow_out.template mutable_data<MT>(platform::CPUPlace())[0] =
beta2 * beta2_pow.template data<MT>()[0];
} else {
beta1_pow_ptr = beta1_pow.template data<MT>();
beta2_pow_ptr = beta2_pow.template data<MT>();
beta1_pow_out_ptr =
beta1_pow_out.template mutable_data<MT>(ctx.GetPlace());
beta2_pow_out_ptr =
beta2_pow_out.template mutable_data<MT>(ctx.GetPlace());
should_update_beta_pow_later = true;
LambMomentMENUpdateFunctor<T, IsMultiPrecision> moment_update_functor(
weight_decay, beta1, beta2, epsilon, beta1_pow.template data<MT>(),
beta1_pow_out.template mutable_data<MT>(ctx.GetPlace()),
beta2_pow.template data<MT>(),
beta2_pow_out.template mutable_data<MT>(ctx.GetPlace()),
mom1.template data<MT>(),
weight_decay, beta1, beta2, epsilon,
static_cast<const MT*>(beta1_pow_ptr),
static_cast<const MT*>(beta2_pow_ptr), mom1.template data<MT>(),
mom1_out.template mutable_data<MT>(ctx.GetPlace()),
mom2.template data<MT>(),
mom2_out.template mutable_data<MT>(ctx.GetPlace()),
......@@ -542,7 +552,12 @@ class LambOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(IsMultiPrecision, false,
platform::errors::Unimplemented(
"SelectedRows gradient is not supported when "
"multi_precision=True"));
"multi_precision=True."));
constexpr bool kIsSameType = std::is_same<T, MT>::value;
PADDLE_ENFORCE_EQ(kIsSameType, true,
platform::errors::Unimplemented(
"SelectedRows gradient is not supported when "
"multi_precision=True."));
auto& grad = GET_DATA_SAFELY(ctx.Input<framework::SelectedRows>("Grad"),
"Input", "Grad", "Lamb");
if (grad.rows().size() == 0) {
......@@ -582,8 +597,8 @@ class LambOpKernel : public framework::OpKernel<T> {
SparseLambMomentREGUpdateFunctor<T> moment_update_functor(
static_cast<T>(weight_decay), static_cast<T>(beta1),
static_cast<T>(beta2), static_cast<T>(epsilon),
*beta1_pow.template data<T>(), nullptr,
*beta2_pow.template data<T>(), nullptr, mom1.template data<T>(),
*beta1_pow.template data<T>(), *beta2_pow.template data<T>(),
mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()), grad_data,
......@@ -595,14 +610,18 @@ class LambOpKernel : public framework::OpKernel<T> {
beta2_pow_out.template mutable_data<T>(platform::CPUPlace())[0] =
static_cast<T>(beta2) * beta2_pow.template data<T>()[0];
} else {
beta1_pow_ptr = beta1_pow.template data<MT>();
beta2_pow_ptr = beta2_pow.template data<MT>();
beta1_pow_out_ptr =
beta1_pow_out.template mutable_data<MT>(ctx.GetPlace());
beta2_pow_out_ptr =
beta2_pow_out.template mutable_data<MT>(ctx.GetPlace());
should_update_beta_pow_later = true;
SparseLambMomentMENUpdateFunctor<T> moment_update_functor(
static_cast<T>(weight_decay), static_cast<T>(beta1),
static_cast<T>(beta2), static_cast<T>(epsilon),
beta1_pow.template data<T>(),
beta1_pow_out.template mutable_data<T>(ctx.GetPlace()),
beta2_pow.template data<T>(),
beta2_pow_out.template mutable_data<T>(ctx.GetPlace()),
mom1.template data<T>(),
reinterpret_cast<const T*>(beta1_pow_ptr),
reinterpret_cast<const T*>(beta2_pow_ptr), mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()), grad_data,
......@@ -639,14 +658,31 @@ class LambOpKernel : public framework::OpKernel<T> {
}
trust_ratio_div_norm.device(*place) = t.square().sum().sqrt();
LambParamUpateFunctor<T, IsMultiPrecision> param_update_functor(
lr.template data<MT>(), static_cast<const T*>(param_ptr),
static_cast<const MT*>(master_param_ptr), p_norm_t.template data<MT>(),
trust_ratio_div.template data<MT>(),
trust_ratio_div_norm_t.template data<MT>(),
static_cast<T*>(param_out_ptr), static_cast<MT*>(master_param_out_ptr),
skip_update_flag);
for_range(param_update_functor);
#define CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(__should_update_beta_pow) \
do { \
LambParamUpateFunctor<T, MT, IsMultiPrecision, __should_update_beta_pow> \
param_update_functor( \
lr.template data<MT>(), static_cast<const T*>(param_ptr), \
static_cast<const MT*>(master_param_ptr), \
p_norm_t.template data<MT>(), trust_ratio_div.template data<MT>(), \
trust_ratio_div_norm_t.template data<MT>(), \
static_cast<T*>(param_out_ptr), \
static_cast<MT*>(master_param_out_ptr), skip_update_flag); \
if (__should_update_beta_pow) { \
param_update_functor.SetBetaPows(beta1_pow_ptr, beta2_pow_ptr, \
beta1_pow_out_ptr, beta2_pow_out_ptr, \
beta1, beta2); \
} \
for_range(param_update_functor); \
} while (0)
if (should_update_beta_pow_later) {
CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(true);
} else {
CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(false);
}
#undef CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册