未验证 提交 11e78eba 编写于 作者: G guofei 提交者: GitHub

Modify the calculation logic of LambOptimizer (#29313)

* Modify the calculation logic of LambOptimizer
上级 c5ffad12
...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/optimizers/lamb_op.h" #include "paddle/fluid/operators/optimizers/lamb_op.h"
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -21,7 +23,7 @@ class LambOp : public framework::OperatorWithKernel { ...@@ -21,7 +23,7 @@ class LambOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Param"), true, PADDLE_ENFORCE_EQ(ctx->HasInput("Param"), true,
platform::errors::NotFound( platform::errors::NotFound(
"Input(Param) of LambOp should not be null.")); "Input(Param) of LambOp should not be null."));
...@@ -53,6 +55,12 @@ class LambOp : public framework::OperatorWithKernel { ...@@ -53,6 +55,12 @@ class LambOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx->HasOutput("Moment2Out"), true, PADDLE_ENFORCE_EQ(ctx->HasOutput("Moment2Out"), true,
platform::errors::NotFound( platform::errors::NotFound(
"Output(Moment2Out) of LambOp should not be null.")); "Output(Moment2Out) of LambOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Beta1PowOut"), true,
platform::errors::NotFound(
"Output(Beta1PowOut) of LambOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Beta2PowOut"), true,
platform::errors::NotFound(
"Output(Beta2PowOut) of LambOp should not be null."));
auto lr_dims = ctx->GetInputDim("LearningRate"); auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
...@@ -108,14 +116,26 @@ class LambOp : public framework::OperatorWithKernel { ...@@ -108,14 +116,26 @@ class LambOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ParamOut", param_dims); ctx->SetOutputDim("ParamOut", param_dims);
ctx->SetOutputDim("Moment1Out", param_dims); ctx->SetOutputDim("Moment1Out", param_dims);
ctx->SetOutputDim("Moment2Out", param_dims); ctx->SetOutputDim("Moment2Out", param_dims);
ctx->SetOutputDim("Beta1PowOut", beta1_pow_dims);
ctx->SetOutputDim("Beta2PowOut", beta2_pow_dims);
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext &ctx) const {
auto input_data_type = auto input_data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "Param"); OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const {
if (var_name == "Beta1Pow" || var_name == "Beta2Pow") {
return expected_kernel_type;
} else {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
}
}; };
class LambOpMaker : public framework::OpProtoAndCheckerMaker { class LambOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -136,6 +156,10 @@ class LambOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -136,6 +156,10 @@ class LambOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("ParamOut", "(Tensor) Output parameter."); AddOutput("ParamOut", "(Tensor) Output parameter.");
AddOutput("Moment1Out", "(Tensor) Output first moment."); AddOutput("Moment1Out", "(Tensor) Output first moment.");
AddOutput("Moment2Out", "(Tensor) Output second moment."); AddOutput("Moment2Out", "(Tensor) Output second moment.");
AddOutput("Beta1PowOut", "(Tensor) Output beta1 power accumulator")
.AsDispensable();
AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator")
.AsDispensable();
AddAttr<float>("weight_decay", "(float) Weight decay rate."); AddAttr<float>("weight_decay", "(float) Weight decay rate.");
AddAttr<float>("beta1", AddAttr<float>("beta1",
"(float, default 0.9) The exponential decay rate for the " "(float, default 0.9) The exponential decay rate for the "
...@@ -164,6 +188,10 @@ m_t &= \beta_1 m_{t - 1}+ (1 - \beta_1)g_t \\ ...@@ -164,6 +188,10 @@ m_t &= \beta_1 m_{t - 1}+ (1 - \beta_1)g_t \\
v_t &= \beta_2 v_{t - 1} + (1 - \beta_2)g_t^2 \\ v_t &= \beta_2 v_{t - 1} + (1 - \beta_2)g_t^2 \\
m_t &= \frac{m_t}{\beta_1^t} \\
v_t &= \frac{v_t}{\beta_2^t} \\
r_t &= \frac{m_t}{\sqrt{v_t}+\epsilon} \\ r_t &= \frac{m_t}{\sqrt{v_t}+\epsilon} \\
w_t &= w_{t-1} -\eta_t \frac{\left \| w_{t-1}\right \|}{\left \| r_t + \lambda w_{t-1}\right \|} (r_t + \lambda w_{t-1}) w_t &= w_{t-1} -\eta_t \frac{\left \| w_{t-1}\right \|}{\left \| r_t + \lambda w_{t-1}\right \|} (r_t + \lambda w_{t-1})
...@@ -183,3 +211,15 @@ REGISTER_OP_WITHOUT_GRADIENT(lamb, ops::LambOp, ops::LambOpMaker); ...@@ -183,3 +211,15 @@ REGISTER_OP_WITHOUT_GRADIENT(lamb, ops::LambOp, ops::LambOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
lamb, ops::LambOpKernel<paddle::platform::CPUDeviceContext, float>, lamb, ops::LambOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::LambOpKernel<paddle::platform::CPUDeviceContext, double>); ops::LambOpKernel<paddle::platform::CPUDeviceContext, double>);
/* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(lamb)
.AddCheckpoint(
R"ROC(Upgrade lamb, add two new outputs [Beta1PowOut] and [Beta2PowOut].)ROC",
paddle::framework::compatible::OpVersionDesc()
.NewInput("Beta1PowOut",
"The Output beta1 power accumulator. 'Beta1PowOut' is "
"dispensable.")
.NewInput("Beta2PowOut",
"The Output beta2 power accumulator. 'Beta2PowOut' is "
"dispensable."));
...@@ -27,14 +27,81 @@ namespace operators { ...@@ -27,14 +27,81 @@ namespace operators {
namespace scatter = paddle::operators::math::scatter; namespace scatter = paddle::operators::math::scatter;
template <typename T> template <typename T>
struct LambMomentUpdateFunctor { struct LambMomentREGUpdateFunctor {
T weight_decay_;
T beta1_;
T beta2_;
T epsilon_;
T beta1_pow_;
T* beta1_pow_out_;
T beta2_pow_;
T* beta2_pow_out_;
const T* moment1_;
T* moment1_out_;
const T* moment2_;
T* moment2_out_;
const T* grad_;
const T* param_;
T* trust_ratio_div_;
LambMomentREGUpdateFunctor(T weight_decay, T beta1, T beta2, T epsilon,
T beta1_pow, T* beta1_pow_out, T beta2_pow,
T* beta2_pow_out, const T* mom1, T* mom1_out,
const T* mom2, T* mom2_out, const T* grad,
const T* param, T* trust_ratio_div)
: weight_decay_(weight_decay),
beta1_(beta1),
beta2_(beta2),
epsilon_(epsilon),
beta1_pow_(beta1_pow),
beta1_pow_out_(beta1_pow_out),
beta2_pow_(beta2_pow),
beta2_pow_out_(beta2_pow_out),
moment1_(mom1),
moment1_out_(mom1_out),
moment2_(mom2),
moment2_out_(mom2_out),
grad_(grad),
param_(param),
trust_ratio_div_(trust_ratio_div) {}
inline HOSTDEVICE void operator()(size_t i) const {
T g = grad_[i];
T mom1 = moment1_[i];
T mom2 = moment2_[i];
T beta1_pow = beta1_pow_;
T beta2_pow = beta2_pow_;
T p = param_[i];
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
moment1_out_[i] = mom1;
moment2_out_[i] = mom2;
T mom1_unbiased = mom1 / (1 - beta1_pow);
T mom2_unbiased = mom2 / (1 - beta2_pow);
trust_ratio_div_[i] =
mom1_unbiased / (sqrt(mom2_unbiased) + epsilon_) + weight_decay_ * p;
if (beta1_pow_out_ && beta2_pow_out_) {
beta1_pow_out_[0] = beta1_pow * beta1_;
beta2_pow_out_[0] = beta2_pow * beta2_;
}
}
};
template <typename T>
struct LambMomentMENUpdateFunctor {
T weight_decay_; T weight_decay_;
T beta1_; T beta1_;
T beta2_; T beta2_;
T epsilon_; T epsilon_;
const T* beta1_pow_; const T* beta1_pow_;
T* beta1_pow_out_;
const T* beta2_pow_; const T* beta2_pow_;
T* beta2_pow_out_;
const T* moment1_; const T* moment1_;
T* moment1_out_; T* moment1_out_;
const T* moment2_; const T* moment2_;
...@@ -43,16 +110,20 @@ struct LambMomentUpdateFunctor { ...@@ -43,16 +110,20 @@ struct LambMomentUpdateFunctor {
const T* param_; const T* param_;
T* trust_ratio_div_; T* trust_ratio_div_;
LambMomentUpdateFunctor(T weight_decay, T beta1, T beta2, T epsilon, LambMomentMENUpdateFunctor(T weight_decay, T beta1, T beta2, T epsilon,
const T* beta1_pow, const T* beta2_pow, const T* mom1, const T* beta1_pow, T* beta1_pow_out,
T* mom1_out, const T* mom2, T* mom2_out, const T* beta2_pow, T* beta2_pow_out,
const T* grad, const T* param, T* trust_ratio_div) const T* mom1, T* mom1_out, const T* mom2,
T* mom2_out, const T* grad, const T* param,
T* trust_ratio_div)
: weight_decay_(weight_decay), : weight_decay_(weight_decay),
beta1_(beta1), beta1_(beta1),
beta2_(beta2), beta2_(beta2),
epsilon_(epsilon), epsilon_(epsilon),
beta1_pow_(beta1_pow), beta1_pow_(beta1_pow),
beta1_pow_out_(beta1_pow_out),
beta2_pow_(beta2_pow), beta2_pow_(beta2_pow),
beta2_pow_out_(beta2_pow_out),
moment1_(mom1), moment1_(mom1),
moment1_out_(mom1_out), moment1_out_(mom1_out),
moment2_(mom2), moment2_(mom2),
...@@ -65,6 +136,8 @@ struct LambMomentUpdateFunctor { ...@@ -65,6 +136,8 @@ struct LambMomentUpdateFunctor {
T g = grad_[i]; T g = grad_[i];
T mom1 = moment1_[i]; T mom1 = moment1_[i];
T mom2 = moment2_[i]; T mom2 = moment2_[i];
T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_;
T p = param_[i]; T p = param_[i];
mom1 = beta1_ * mom1 + (1 - beta1_) * g; mom1 = beta1_ * mom1 + (1 - beta1_) * g;
...@@ -72,19 +145,110 @@ struct LambMomentUpdateFunctor { ...@@ -72,19 +145,110 @@ struct LambMomentUpdateFunctor {
moment1_out_[i] = mom1; moment1_out_[i] = mom1;
moment2_out_[i] = mom2; moment2_out_[i] = mom2;
trust_ratio_div_[i] = mom1 / (sqrt(mom2) + epsilon_) + weight_decay_ * p;
T mom1_unbiased = mom1 / (1 - beta1_pow);
T mom2_unbiased = mom2 / (1 - beta2_pow);
trust_ratio_div_[i] =
mom1_unbiased / (sqrt(mom2_unbiased) + epsilon_) + weight_decay_ * p;
if (beta1_pow_out_ && beta2_pow_out_) {
beta1_pow_out_[0] = beta1_pow * beta1_;
beta2_pow_out_[0] = beta2_pow * beta2_;
}
} }
}; };
template <typename T> template <typename T>
struct SparseLambMomentUpdateFunctor { struct SparseLambMomentREGUpdateFunctor {
T weight_decay_;
T beta1_;
T beta2_;
T epsilon_;
T beta1_pow_;
T* beta1_pow_out_;
T beta2_pow_;
T* beta2_pow_out_;
const T* moment1_;
T* moment1_out_;
const T* moment2_;
T* moment2_out_;
const T* grad_;
const T* param_;
T* trust_ratio_div_;
const int64_t* rows_;
int64_t row_numel_;
int64_t row_count_;
SparseLambMomentREGUpdateFunctor(T weight_decay, T beta1, T beta2, T epsilon,
T beta1_pow, T* beta1_pow_out, T beta2_pow,
T* beta2_pow_out, const T* mom1, T* mom1_out,
const T* mom2, T* mom2_out, const T* grad,
const T* param, T* trust_ratio_div,
const int64_t* rows, int64_t row_numel,
int64_t row_count)
: weight_decay_(weight_decay),
beta1_(beta1),
beta2_(beta2),
epsilon_(epsilon),
beta1_pow_(beta1_pow),
beta1_pow_out_(beta1_pow_out),
beta2_pow_(beta2_pow),
beta2_pow_out_(beta2_pow_out),
moment1_(mom1),
moment1_out_(mom1_out),
moment2_(mom2),
moment2_out_(mom2_out),
grad_(grad),
param_(param),
trust_ratio_div_(trust_ratio_div),
rows_(rows),
row_numel_(row_numel),
row_count_(row_count) {}
inline HOSTDEVICE void update(size_t i, T g) const {
// The following code is same as dense
T mom1 = moment1_[i];
T mom2 = moment2_[i];
T beta1_pow = beta1_pow_;
T beta2_pow = beta2_pow_;
T p = param_[i];
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
moment1_out_[i] = mom1;
moment2_out_[i] = mom2;
T mom1_unbiased = mom1 / (1 - beta1_pow);
T mom2_unbiased = mom2 / (1 - beta2_pow);
trust_ratio_div_[i] =
mom1_unbiased / (sqrt(mom2_unbiased) + epsilon_) + weight_decay_ * p;
if (beta1_pow_out_ && beta1_pow_out_) {
beta1_pow_out_[0] = beta1_pow * beta1_;
beta2_pow_out_[0] = beta2_pow * beta2_;
}
}
inline HOSTDEVICE void operator()(size_t i) const {
auto row_idx =
math::BinarySearch<int64_t>(rows_, row_count_, i / row_numel_);
T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0;
update(i, g);
}
};
template <typename T>
struct SparseLambMomentMENUpdateFunctor {
T weight_decay_; T weight_decay_;
T beta1_; T beta1_;
T beta2_; T beta2_;
T epsilon_; T epsilon_;
const T* beta1_pow_; const T* beta1_pow_;
T* beta1_pow_out_;
const T* beta2_pow_; const T* beta2_pow_;
T* beta2_pow_out_;
const T* moment1_; const T* moment1_;
T* moment1_out_; T* moment1_out_;
const T* moment2_; const T* moment2_;
...@@ -97,8 +261,9 @@ struct SparseLambMomentUpdateFunctor { ...@@ -97,8 +261,9 @@ struct SparseLambMomentUpdateFunctor {
int64_t row_numel_; int64_t row_numel_;
int64_t row_count_; int64_t row_count_;
SparseLambMomentUpdateFunctor(T weight_decay, T beta1, T beta2, T epsilon, SparseLambMomentMENUpdateFunctor(T weight_decay, T beta1, T beta2, T epsilon,
const T* beta1_pow, const T* beta2_pow, const T* beta1_pow, T* beta1_pow_out,
const T* beta2_pow, T* beta2_pow_out,
const T* mom1, T* mom1_out, const T* mom2, const T* mom1, T* mom1_out, const T* mom2,
T* mom2_out, const T* grad, const T* param, T* mom2_out, const T* grad, const T* param,
T* trust_ratio_div, const int64_t* rows, T* trust_ratio_div, const int64_t* rows,
...@@ -108,7 +273,9 @@ struct SparseLambMomentUpdateFunctor { ...@@ -108,7 +273,9 @@ struct SparseLambMomentUpdateFunctor {
beta2_(beta2), beta2_(beta2),
epsilon_(epsilon), epsilon_(epsilon),
beta1_pow_(beta1_pow), beta1_pow_(beta1_pow),
beta1_pow_out_(beta1_pow_out),
beta2_pow_(beta2_pow), beta2_pow_(beta2_pow),
beta2_pow_out_(beta2_pow_out),
moment1_(mom1), moment1_(mom1),
moment1_out_(mom1_out), moment1_out_(mom1_out),
moment2_(mom2), moment2_(mom2),
...@@ -124,6 +291,8 @@ struct SparseLambMomentUpdateFunctor { ...@@ -124,6 +291,8 @@ struct SparseLambMomentUpdateFunctor {
// The following code is same as dense // The following code is same as dense
T mom1 = moment1_[i]; T mom1 = moment1_[i];
T mom2 = moment2_[i]; T mom2 = moment2_[i];
T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_;
T p = param_[i]; T p = param_[i];
mom1 = beta1_ * mom1 + (1 - beta1_) * g; mom1 = beta1_ * mom1 + (1 - beta1_) * g;
...@@ -131,7 +300,15 @@ struct SparseLambMomentUpdateFunctor { ...@@ -131,7 +300,15 @@ struct SparseLambMomentUpdateFunctor {
moment1_out_[i] = mom1; moment1_out_[i] = mom1;
moment2_out_[i] = mom2; moment2_out_[i] = mom2;
trust_ratio_div_[i] = mom1 / (sqrt(mom2) + epsilon_) + weight_decay_ * p;
T mom1_unbiased = mom1 / (1 - beta1_pow);
T mom2_unbiased = mom2 / (1 - beta2_pow);
trust_ratio_div_[i] =
mom1_unbiased / (sqrt(mom2_unbiased) + epsilon_) + weight_decay_ * p;
if (beta1_pow_out_ && beta1_pow_out_) {
beta1_pow_out_[0] = beta1_pow * beta1_;
beta2_pow_out_[0] = beta2_pow * beta2_;
}
} }
inline HOSTDEVICE void operator()(size_t i) const { inline HOSTDEVICE void operator()(size_t i) const {
...@@ -211,6 +388,10 @@ class LambOpKernel : public framework::OpKernel<T> { ...@@ -211,6 +388,10 @@ class LambOpKernel : public framework::OpKernel<T> {
"Output", "Moment1Out", "Lamb"); "Output", "Moment1Out", "Lamb");
auto& mom2_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("Moment2Out"), auto& mom2_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("Moment2Out"),
"Output", "Moment2Out", "Lamb"); "Output", "Moment2Out", "Lamb");
auto& beta1_pow_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("Beta1PowOut"),
"Output", "Beta1PowOut", "Lamb");
auto& beta2_pow_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("Beta2PowOut"),
"Output", "Beta2PowOut", "Lamb");
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, param.numel()); platform::ForRange<DeviceContext> for_range(dev_ctx, param.numel());
...@@ -220,16 +401,37 @@ class LambOpKernel : public framework::OpKernel<T> { ...@@ -220,16 +401,37 @@ class LambOpKernel : public framework::OpKernel<T> {
// Update moments // Update moments
if (grad_var->IsType<framework::LoDTensor>()) { if (grad_var->IsType<framework::LoDTensor>()) {
auto& grad = *ctx.Input<LoDTensor>("Grad"); auto& grad = *ctx.Input<LoDTensor>("Grad");
if (platform::is_gpu_place(ctx.GetPlace()) &&
LambMomentUpdateFunctor<T> moment_update_functor( beta1_pow.place() == platform::CPUPlace() &&
beta2_pow.place() == platform::CPUPlace()) {
LambMomentREGUpdateFunctor<T> moment_update_functor(
weight_decay, beta1, beta2, epsilon, *beta1_pow.template data<T>(),
nullptr, *beta2_pow.template data<T>(), nullptr,
mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()),
grad.template data<T>(), param.template data<T>(),
trust_ratio_div.template data<T>());
for_range(moment_update_functor);
beta1_pow_out.template mutable_data<T>(platform::CPUPlace())[0] =
beta1 * beta1_pow.template data<T>()[0];
beta2_pow_out.template mutable_data<T>(platform::CPUPlace())[0] =
beta2 * beta2_pow.template data<T>()[0];
} else {
LambMomentMENUpdateFunctor<T> moment_update_functor(
weight_decay, beta1, beta2, epsilon, beta1_pow.template data<T>(), weight_decay, beta1, beta2, epsilon, beta1_pow.template data<T>(),
beta2_pow.template data<T>(), mom1.template data<T>(), beta1_pow_out.template mutable_data<T>(ctx.GetPlace()),
beta2_pow.template data<T>(),
beta2_pow_out.template mutable_data<T>(ctx.GetPlace()),
mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()), mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(), mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()), mom2_out.template mutable_data<T>(ctx.GetPlace()),
grad.template data<T>(), param.template data<T>(), grad.template data<T>(), param.template data<T>(),
trust_ratio_div.template data<T>()); trust_ratio_div.template data<T>());
for_range(moment_update_functor); for_range(moment_update_functor);
}
} else if (grad_var->IsType<framework::SelectedRows>()) { } else if (grad_var->IsType<framework::SelectedRows>()) {
auto& grad = GET_DATA_SAFELY(ctx.Input<framework::SelectedRows>("Grad"), auto& grad = GET_DATA_SAFELY(ctx.Input<framework::SelectedRows>("Grad"),
"Input", "Grad", "Lamb"); "Input", "Grad", "Lamb");
...@@ -264,16 +466,37 @@ class LambOpKernel : public framework::OpKernel<T> { ...@@ -264,16 +466,37 @@ class LambOpKernel : public framework::OpKernel<T> {
const T* grad_data = grad_tensor.template data<T>(); const T* grad_data = grad_tensor.template data<T>();
const int64_t* rows = grad_merge.rows().Data(ctx.GetPlace()); const int64_t* rows = grad_merge.rows().Data(ctx.GetPlace());
auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); auto row_numel = grad_tensor.numel() / grad_merge.rows().size();
if (platform::is_gpu_place(ctx.GetPlace()) &&
SparseLambMomentUpdateFunctor<T> moment_update_functor( beta1_pow.place() == platform::CPUPlace() &&
beta2_pow.place() == platform::CPUPlace()) {
SparseLambMomentREGUpdateFunctor<T> moment_update_functor(
weight_decay, beta1, beta2, epsilon, *beta1_pow.template data<T>(),
nullptr, *beta2_pow.template data<T>(), nullptr,
mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()), grad_data,
param.template data<T>(), trust_ratio_div.template data<T>(), rows,
row_numel, grad_merge.rows().size());
for_range(moment_update_functor);
beta1_pow_out.template mutable_data<T>(platform::CPUPlace())[0] =
beta1 * beta1_pow.template data<T>()[0];
beta2_pow_out.template mutable_data<T>(platform::CPUPlace())[0] =
beta2 * beta2_pow.template data<T>()[0];
} else {
SparseLambMomentMENUpdateFunctor<T> moment_update_functor(
weight_decay, beta1, beta2, epsilon, beta1_pow.template data<T>(), weight_decay, beta1, beta2, epsilon, beta1_pow.template data<T>(),
beta2_pow.template data<T>(), mom1.template data<T>(), beta1_pow_out.template mutable_data<T>(ctx.GetPlace()),
beta2_pow.template data<T>(),
beta2_pow_out.template mutable_data<T>(ctx.GetPlace()),
mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()), mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(), mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()), grad_data, mom2_out.template mutable_data<T>(ctx.GetPlace()), grad_data,
param.template data<T>(), trust_ratio_div.template data<T>(), rows, param.template data<T>(), trust_ratio_div.template data<T>(), rows,
row_numel, grad_merge.rows().size()); row_numel, grad_merge.rows().size());
for_range(moment_update_functor); for_range(moment_update_functor);
}
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Variable type not supported by lamb_op. Expect LoDTensor or " "Variable type not supported by lamb_op. Expect LoDTensor or "
...@@ -296,7 +519,6 @@ class LambOpKernel : public framework::OpKernel<T> { ...@@ -296,7 +519,6 @@ class LambOpKernel : public framework::OpKernel<T> {
auto* place = dev_ctx.eigen_device(); auto* place = dev_ctx.eigen_device();
p_norm.device(*place) = p.square().sum().sqrt(); p_norm.device(*place) = p.square().sum().sqrt();
trust_ratio_div_norm.device(*place) = t.square().sum().sqrt(); trust_ratio_div_norm.device(*place) = t.square().sum().sqrt();
LambParamUpateFunctor<T> param_update_functor( LambParamUpateFunctor<T> param_update_functor(
lr.template data<T>(), param.template data<T>(), lr.template data<T>(), param.template data<T>(),
p_norm_t.template data<T>(), trust_ratio_div.template data<T>(), p_norm_t.template data<T>(), trust_ratio_div.template data<T>(),
......
...@@ -89,6 +89,8 @@ std::map<std::string, std::set<std::string>> op_outs_map = { ...@@ -89,6 +89,8 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"generate_proposals_v2", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}}, {"generate_proposals_v2", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}},
{"momentum", {"ParamOut", "VelocityOut"}}, {"momentum", {"ParamOut", "VelocityOut"}},
{"rnn", {"DropoutState", "Reserve", "Out", "State"}}, {"rnn", {"DropoutState", "Reserve", "Out", "State"}},
{"lamb",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
}; };
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are // NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
...@@ -136,6 +138,8 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = { ...@@ -136,6 +138,8 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"update_loss_scaling", {"update_loss_scaling",
{"Out", "LossScaling", "OutGoodSteps", "OutBadSteps"}}, {"Out", "LossScaling", "OutGoodSteps", "OutBadSteps"}},
{"moving_average_abs_max_scale", {"OutScale", "OutAccum", "OutState"}}, {"moving_average_abs_max_scale", {"OutScale", "OutAccum", "OutState"}},
{"lamb",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
{"rnn", {"DropoutState"}}, {"rnn", {"DropoutState"}},
}; };
......
...@@ -2983,6 +2983,10 @@ class LambOptimizer(AdamOptimizer): ...@@ -2983,6 +2983,10 @@ class LambOptimizer(AdamOptimizer):
v_t &= \\beta_2 v_{t - 1} + (1 - \\beta_2)g_t^2 v_t &= \\beta_2 v_{t - 1} + (1 - \\beta_2)g_t^2
m_t &= \\frac{m_t}{\\beta_1^t}
v_t &= \\frac{v_t}{\\beta_2^t}
r_t &= \\frac{m_t}{\\sqrt{v_t}+\\epsilon} r_t &= \\frac{m_t}{\\sqrt{v_t}+\\epsilon}
w_t &= w_{t-1} -\\eta_t \\frac{\\left \| w_{t-1}\\right \|}{\\left \| r_t + \\lambda w_{t-1}\\right \|} (r_t + \\lambda w_{t-1}) w_t &= w_{t-1} -\\eta_t \\frac{\\left \| w_{t-1}\\right \|}{\\left \| r_t + \\lambda w_{t-1}\\right \|} (r_t + \\lambda w_{t-1})
...@@ -3010,8 +3014,9 @@ class LambOptimizer(AdamOptimizer): ...@@ -3010,8 +3014,9 @@ class LambOptimizer(AdamOptimizer):
Default None, meaning there is no regularization. Default None, meaning there is no regularization.
grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
some derived class of ``GradientClipBase`` . There are three cliping strategies some derived class of ``GradientClipBase`` . There are three cliping strategies
( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , ( :ref:`api_paddle_fluid_clip_ClipGradByGlobalNorm` , :ref:`api_paddle_fluid_clip_ClipGradByNorm` ,
:ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping. :ref:`api_paddle_fluid_clip_ClipGradByValue` ). If you want better convergence, it is recommended
to use :ref:`api_paddle_fluid_clip_ClipGradByGlobalNorm` . Default None, meaning there is no gradient clipping.
exclude_from_weight_decay_fn (function|None): Exclude a parameter from weight exclude_from_weight_decay_fn (function|None): Exclude a parameter from weight
decay when **exclude_from_weight_decay_fn(parameter)** returns true. decay when **exclude_from_weight_decay_fn(parameter)** returns true.
Default None. Default None.
...@@ -3036,7 +3041,6 @@ class LambOptimizer(AdamOptimizer): ...@@ -3036,7 +3041,6 @@ class LambOptimizer(AdamOptimizer):
""" """
_moment1_acc_str = "moment1" _moment1_acc_str = "moment1"
_moment2_acc_str = "moment2" _moment2_acc_str = "moment2"
# these two not used in op temporarily
_beta1_pow_acc_str = "beta1_pow_acc" _beta1_pow_acc_str = "beta1_pow_acc"
_beta2_pow_acc_str = "beta2_pow_acc" _beta2_pow_acc_str = "beta2_pow_acc"
...@@ -3087,6 +3091,16 @@ class LambOptimizer(AdamOptimizer): ...@@ -3087,6 +3091,16 @@ class LambOptimizer(AdamOptimizer):
weight_decay = 0.0 weight_decay = 0.0
else: else:
weight_decay = self._weight_decay weight_decay = self._weight_decay
lr = self._create_param_lr(param_and_grad)
if framework.in_dygraph_mode():
_, _, _, _, _ = core.ops.lamb(
param_and_grad[0], param_and_grad[1], lr, moment1, moment2,
beta1_pow_acc, beta2_pow_acc, param_and_grad[0], moment1,
moment2, beta1_pow_acc, beta2_pow_acc, 'beta1', self._beta1,
'beta2', self._beta2, 'epsilon', self._epsilon, 'weight_decay',
weight_decay)
return None
# create the lamb optimize op # create the lamb optimize op
lamb_op = block.append_op( lamb_op = block.append_op(
...@@ -3094,7 +3108,7 @@ class LambOptimizer(AdamOptimizer): ...@@ -3094,7 +3108,7 @@ class LambOptimizer(AdamOptimizer):
inputs={ inputs={
"Param": param_and_grad[0], "Param": param_and_grad[0],
"Grad": param_and_grad[1], "Grad": param_and_grad[1],
"LearningRate": self._create_param_lr(param_and_grad), "LearningRate": lr,
"Moment1": moment1, "Moment1": moment1,
"Moment2": moment2, "Moment2": moment2,
"Beta1Pow": beta1_pow_acc, "Beta1Pow": beta1_pow_acc,
...@@ -3103,7 +3117,9 @@ class LambOptimizer(AdamOptimizer): ...@@ -3103,7 +3117,9 @@ class LambOptimizer(AdamOptimizer):
outputs={ outputs={
"ParamOut": param_and_grad[0], "ParamOut": param_and_grad[0],
"Moment1Out": moment1, "Moment1Out": moment1,
"Moment2Out": moment2 "Moment2Out": moment2,
"Beta1PowOut": beta1_pow_acc,
"Beta2PowOut": beta2_pow_acc
}, },
attrs={ attrs={
"beta1": self._beta1, "beta1": self._beta1,
......
...@@ -23,7 +23,7 @@ import itertools ...@@ -23,7 +23,7 @@ import itertools
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.optimizer import MomentumOptimizer, LarsMomentumOptimizer, AdagradOptimizer, AdamaxOptimizer, DpsgdOptimizer, DecayedAdagradOptimizer, AdadeltaOptimizer, RMSPropOptimizer, FtrlOptimizer, LambOptimizer from paddle.fluid.optimizer import MomentumOptimizer, LarsMomentumOptimizer, AdagradOptimizer, AdamaxOptimizer, DpsgdOptimizer, DecayedAdagradOptimizer, AdadeltaOptimizer, RMSPropOptimizer, FtrlOptimizer
from paddle.fluid.optimizer import ModelAverage, DGCMomentumOptimizer, ExponentialMovingAverage, PipelineOptimizer, LookaheadOptimizer, RecomputeOptimizer from paddle.fluid.optimizer import ModelAverage, DGCMomentumOptimizer, ExponentialMovingAverage, PipelineOptimizer, LookaheadOptimizer, RecomputeOptimizer
from paddle.fluid.dygraph import Linear from paddle.fluid.dygraph import Linear
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
...@@ -702,14 +702,14 @@ def exclude_fn(param): ...@@ -702,14 +702,14 @@ def exclude_fn(param):
class TestImperativeLambOptimizer(TestImperativeOptimizerBase): class TestImperativeLambOptimizer(TestImperativeOptimizerBase):
def get_optimizer_dygraph(self, parameter_list): def get_optimizer_dygraph(self, parameter_list):
optimizer = LambOptimizer( optimizer = paddle.optimizer.Lamb(
learning_rate=0.002, learning_rate=0.002,
exclude_from_weight_decay_fn=exclude_fn, exclude_from_weight_decay_fn=exclude_fn,
parameter_list=parameter_list) parameters=parameter_list)
return optimizer return optimizer
def get_optimizer(self): def get_optimizer(self):
optimizer = LambOptimizer( optimizer = paddle.optimizer.Lamb(
learning_rate=0.002, exclude_from_weight_decay_fn=exclude_fn) learning_rate=0.002, exclude_from_weight_decay_fn=exclude_fn)
return optimizer return optimizer
......
...@@ -17,9 +17,13 @@ from __future__ import print_function ...@@ -17,9 +17,13 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle
import paddle.fluid as fluid
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
paddle.enable_static()
class TestLambOp1(OpTest): class TestLambOp1(OpTest):
def set_attrs(self): def set_attrs(self):
...@@ -41,8 +45,8 @@ class TestLambOp1(OpTest): ...@@ -41,8 +45,8 @@ class TestLambOp1(OpTest):
learning_rate = 0.001 learning_rate = 0.001
self.set_attrs() self.set_attrs()
beta1_pow = self.attrs['beta1']**10 beta1_pow = self.attrs['beta1']
beta2_pow = self.attrs['beta2']**10 beta2_pow = self.attrs['beta2']
self.inputs = { self.inputs = {
'Param': param, 'Param': param,
...@@ -55,13 +59,15 @@ class TestLambOp1(OpTest): ...@@ -55,13 +59,15 @@ class TestLambOp1(OpTest):
} }
param_out, moment1_out, \ param_out, moment1_out, moment2_out, \
moment2_out = lamb_step(self.inputs, self.attrs) beta1_pow_out, beta2_pow_out = lamb_step(self.inputs, self.attrs)
self.outputs = { self.outputs = {
'Moment1Out': moment1_out, 'Moment1Out': moment1_out,
'Moment2Out': moment2_out, 'Moment2Out': moment2_out,
'ParamOut': param_out 'ParamOut': param_out,
'Beta1PowOut': beta1_pow_out,
'Beta2PowOut': beta2_pow_out
} }
def test_check_output(self): def test_check_output(self):
...@@ -89,14 +95,16 @@ class TestLambOpMultipleSteps(TestLambOp1): ...@@ -89,14 +95,16 @@ class TestLambOpMultipleSteps(TestLambOp1):
self.num_steps = 10 self.num_steps = 10
def test_check_output(self): def test_check_output(self):
for _ in range(self.num_steps): for i in range(self.num_steps):
param_out, moment1_out, \ param_out, moment1_out, moment2_out, \
moment2_out = lamb_step(self.inputs, self.attrs) beta1_pow_out, beta2_pow_out = lamb_step(self.inputs, self.attrs)
self.outputs = { self.outputs = {
'Moment1Out': moment1_out, 'Moment1Out': moment1_out,
'Moment2Out': moment2_out, 'Moment2Out': moment2_out,
'ParamOut': param_out 'ParamOut': param_out,
'Beta1PowOut': beta1_pow_out,
'Beta2PowOut': beta2_pow_out
} }
# Verify output for this step # Verify output for this step
...@@ -108,8 +116,8 @@ class TestLambOpMultipleSteps(TestLambOp1): ...@@ -108,8 +116,8 @@ class TestLambOpMultipleSteps(TestLambOp1):
self.inputs['Moment2'] = moment2_out self.inputs['Moment2'] = moment2_out
# Update powers of Beta1 and Beta2 for next time step # Update powers of Beta1 and Beta2 for next time step
self.inputs['Beta1Pow'] *= self.attrs['beta1'] self.inputs['Beta1Pow'] = beta1_pow_out
self.inputs['Beta2Pow'] *= self.attrs['beta1'] self.inputs['Beta2Pow'] = beta2_pow_out
# Randomize gradient for next step # Randomize gradient for next step
self.inputs['Grad'] = np.random.uniform( self.inputs['Grad'] = np.random.uniform(
...@@ -140,14 +148,21 @@ def lamb_step(inputs, attributes): ...@@ -140,14 +148,21 @@ def lamb_step(inputs, attributes):
moment1_out = beta1 * moment1 + (1 - beta1) * grad moment1_out = beta1 * moment1 + (1 - beta1) * grad
moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad) moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad)
moment1_unbiased = moment1_out / (1 - beta1_pow)
moment2_unbiased = moment2_out / (1 - beta2_pow)
r_1 = np.linalg.norm(param) r_1 = np.linalg.norm(param)
r_2 = np.linalg.norm(moment1_out / (np.sqrt(moment2_out) + epsilon) + r_2 = np.linalg.norm(moment1_unbiased / (np.sqrt(moment2_unbiased) + epsilon
weight_decay * param) ) + weight_decay * param)
lr_t = lr * r_1 / r_2 lr_t = lr * r_1 / r_2
param_out = param - lr_t * (moment1_out / (np.sqrt(moment2_out) + epsilon) + param_out = param - lr_t * (moment1_unbiased / (
weight_decay * param) np.sqrt(moment2_unbiased) + epsilon) + weight_decay * param)
return param_out, moment1_out, moment2_out
beta1_pow_out = beta1_pow * beta1
beta2_pow_out = beta2_pow * beta2
return param_out, moment1_out, moment2_out, beta1_pow_out, beta2_pow_out
def lamb_step_sparse(inputs, attributes, height, rows, row_numel, np_grad): def lamb_step_sparse(inputs, attributes, height, rows, row_numel, np_grad):
...@@ -174,6 +189,8 @@ def lamb_step_sparse(inputs, attributes, height, rows, row_numel, np_grad): ...@@ -174,6 +189,8 @@ def lamb_step_sparse(inputs, attributes, height, rows, row_numel, np_grad):
moment1_out = np.zeros(shape=[height, row_numel]) moment1_out = np.zeros(shape=[height, row_numel])
moment2_out = np.zeros(shape=[height, row_numel]) moment2_out = np.zeros(shape=[height, row_numel])
param_out = np.zeros(shape=[height, row_numel]) param_out = np.zeros(shape=[height, row_numel])
moment1_unbiased = np.zeros(shape=[height, row_numel])
moment2_unbiased = np.zeros(shape=[height, row_numel])
def update_mom(row_id, update_value): def update_mom(row_id, update_value):
moment1_out[row_id] = beta1 * moment1[row_id] + (1 - beta1 moment1_out[row_id] = beta1 * moment1[row_id] + (1 - beta1
...@@ -202,8 +219,10 @@ def lamb_step_sparse(inputs, attributes, height, rows, row_numel, np_grad): ...@@ -202,8 +219,10 @@ def lamb_step_sparse(inputs, attributes, height, rows, row_numel, np_grad):
update_mom(row_id, update_value) update_mom(row_id, update_value)
update_param() update_param()
beta1_pow_out = beta1_pow * beta1
beta2_pow_out = beta2_pow * beta2
return param_out, moment1_out, moment2_out return param_out, moment1_out, moment2_out, beta1_pow_out, beta2_pow_out
class TestSparseLambOp(unittest.TestCase): class TestSparseLambOp(unittest.TestCase):
...@@ -221,8 +240,8 @@ class TestSparseLambOp(unittest.TestCase): ...@@ -221,8 +240,8 @@ class TestSparseLambOp(unittest.TestCase):
"Param": np.full((height, row_numel), 5.0).astype("float32"), "Param": np.full((height, row_numel), 5.0).astype("float32"),
"Moment1": np.full((height, row_numel), 5.0).astype("float32"), "Moment1": np.full((height, row_numel), 5.0).astype("float32"),
"Moment2": np.full((height, row_numel), 5.0).astype("float32"), "Moment2": np.full((height, row_numel), 5.0).astype("float32"),
'Beta1Pow': np.array([beta1**10]).astype("float32"), 'Beta1Pow': np.array([beta1]).astype("float32"),
'Beta2Pow': np.array([beta2**10]).astype("float32"), 'Beta2Pow': np.array([beta2]).astype("float32"),
"LearningRate": np.full((1), 2.0).astype("float32") "LearningRate": np.full((1), 2.0).astype("float32")
} }
self.init_output = np.full((height, row_numel), 0.0).astype("float32") self.init_output = np.full((height, row_numel), 0.0).astype("float32")
...@@ -245,12 +264,14 @@ class TestSparseLambOp(unittest.TestCase): ...@@ -245,12 +264,14 @@ class TestSparseLambOp(unittest.TestCase):
self.sparse_inputs = ["Grad"] self.sparse_inputs = ["Grad"]
param_out, mom1, mom2 = lamb_step_sparse( param_out, mom1, mom2, beta1_pow_out, beta2_pow_out = lamb_step_sparse(
self.dense_inputs, self.attrs, height, rows, row_numel, np_array) self.dense_inputs, self.attrs, height, rows, row_numel, np_array)
self.outputs = { self.outputs = {
"ParamOut": param_out, "ParamOut": param_out,
"Moment1Out": mom1, "Moment1Out": mom1,
"Moment2Out": mom2 "Moment2Out": mom2,
'Beta1PowOut': beta1_pow_out,
'Beta2PowOut': beta2_pow_out
} }
def check_with_place(self, place): def check_with_place(self, place):
......
...@@ -19,34 +19,140 @@ import numpy as np ...@@ -19,34 +19,140 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.op import Operator from paddle.fluid.op import Operator
import paddle.fluid as fluid
import paddle import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
class LAMBOptimizer(paddle.optimizer.Lamb):
def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, fluid.framework.Block)
block.program._use_lamb = True
m = moment1 = self._get_accumulator(self._moment1_acc_str,
param_and_grad[0])
v = self._get_accumulator(self._moment2_acc_str, param_and_grad[0])
beta_1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
param_and_grad[0])
beta_2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
param_and_grad[0])
beta_1 = layers.fill_constant(
dtype='float32', shape=[1], value=self._beta1, name='lamb_beta_1')
beta_2 = layers.fill_constant(
dtype='float32', shape=[1], value=self._beta2, name='lamb_beta_2')
epsilon = layers.fill_constant(
dtype='float32', shape=[1], value=self._epsilon, name='epsilon')
one = paddle.ones(shape=[1]).astype('float32')
zero = paddle.zeros(shape=[1]).astype('float32')
next_m = paddle.multiply(m, beta_1) + paddle.multiply(param_and_grad[1],
one - beta_1)
next_v = paddle.multiply(v, beta_2) + paddle.multiply(
paddle.pow(param_and_grad[1], 2), one - beta_2)
beta1_correction = one - beta_1_pow_acc
beta2_correction = one - beta_2_pow_acc
next_m_unbiased = next_m / beta1_correction
next_v_unbiased = next_v / beta2_correction
update = next_m_unbiased / (paddle.sqrt(next_v_unbiased) + epsilon)
if self._exclude_from_weight_decay_fn is not None and self._exclude_from_weight_decay_fn(
param_and_grad[0]):
self._lamb_weight_decay = 0.0
update += self._lamb_weight_decay * param_and_grad[0]
w_norm = paddle.norm(param_and_grad[0], p=2)
g_norm = paddle.norm(update, p=2)
learning_rate = self._create_param_lr(param_and_grad)
ratio = paddle.where(
paddle.greater_than(w_norm, zero),
paddle.where(
paddle.greater_than(g_norm, zero), (w_norm / g_norm), one), one)
update_with_lr = ratio * learning_rate * update
next_param = param_and_grad[0] - update_with_lr
beta_1_pow_acc *= beta_1
beta_2_pow_acc *= beta_2
paddle.assign(next_m, m)
paddle.assign(next_v, v)
paddle.assign(next_param, param_and_grad[0])
return None
class TestLambOpV2(unittest.TestCase): class TestLambOpV2(unittest.TestCase):
def test_lamb_op(self): def test_lamb_op(self):
shape = [2, 4, 8, 8]
data = paddle.to_tensor(np.random.random(size=shape).astype("float32"))
conv = paddle.nn.Conv2D(4, 6, (3, 3))
data = conv(data)
loss = paddle.mean(data)
opt = paddle.optimizer.Lamb(
learning_rate=1e-5, epsilon=1e-8, parameters=conv.parameters())
loss.backward()
opt.minimize(loss)
assert loss.numpy() is not None
class TestLambOpWithCombinedOp(unittest.TestCase):
def test_lamb_op_with_multi_steps(self):
paddle.enable_static() paddle.enable_static()
def _build_static_model(main, startup, seed=100):
with fluid.program_guard(main, startup):
main.random_seed = seed
startup.random_seed = seed
x = fluid.layers.data(name='X', shape=[13], dtype='float32')
y = fluid.layers.data(name='Y', shape=[1], dtype='float32')
prediction = fluid.layers.fc(input=x, size=1, act=None)
loss = fluid.layers.square_error_cost(input=prediction, label=y)
avg_loss = fluid.layers.mean(loss)
return avg_loss
place = fluid.CPUPlace() place = fluid.CPUPlace()
shape = [2, 3, 8, 8] num_steps = 10
exe = fluid.Executor(place)
train_prog = fluid.Program() for i in range(num_steps):
feed_x = np.random.random(size=(10, 13)).astype('float32')
feed_y = np.random.random(size=(10, 1)).astype('float32')
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
avg_loss = _build_static_model(main_program, startup_program)
lamb_kernel = paddle.optimizer.Lamb(learning_rate=0.2)
lamb_kernel.minimize(avg_loss)
executor = fluid.Executor(place)
executor.run(startup_program)
output = executor.run(program=main_program,
feed={'X': feed_x,
'Y': feed_y},
fetch_list=[avg_loss.name])
main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(train_prog, startup): with fluid.program_guard(main, startup):
with fluid.unique_name.guard(): loss = _build_static_model(main, startup)
data = fluid.data(name="data", shape=shape) lamb = LAMBOptimizer(learning_rate=0.2)
conv = fluid.layers.conv2d(data, 8, 3) lamb.minimize(loss)
loss = fluid.layers.reduce_mean(conv)
beta1 = 0.85
beta2 = 0.95
betas = [beta1, beta2]
opt = paddle.optimizer.Lamb(
learning_rate=1e-5, beta1=beta1, beta2=beta2, epsilon=1e-8)
opt.minimize(loss)
exe = fluid.Executor(place)
exe.run(startup) exe.run(startup)
data_np = np.random.random(shape).astype('float32') out = exe.run(program=main,
rets = exe.run(train_prog, feed={"data": data_np}, fetch_list=[loss]) feed={'X': feed_x,
assert rets[0] is not None 'Y': feed_y},
fetch_list=[loss.name])
self.assertTrue(np.allclose(out, output))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -37,6 +37,10 @@ class Lamb(Optimizer): ...@@ -37,6 +37,10 @@ class Lamb(Optimizer):
v_t &= \\beta_2 v_{t - 1} + (1 - \\beta_2)g_t^2 v_t &= \\beta_2 v_{t - 1} + (1 - \\beta_2)g_t^2
m_t &= \\frac{m_t}{\\beta_1^t}
v_t &= \\frac{v_t}{\\beta_2^t}
r_t &= \\frac{m_t}{\\sqrt{v_t}+\\epsilon} r_t &= \\frac{m_t}{\\sqrt{v_t}+\\epsilon}
w_t &= w_{t-1} -\\eta_t \\frac{\\left \| w_{t-1}\\right \|}{\\left \| r_t + \\lambda w_{t-1}\\right \|} (r_t + \\lambda w_{t-1}) w_t &= w_{t-1} -\\eta_t \\frac{\\left \| w_{t-1}\\right \|}{\\left \| r_t + \\lambda w_{t-1}\\right \|} (r_t + \\lambda w_{t-1})
...@@ -59,8 +63,9 @@ class Lamb(Optimizer): ...@@ -59,8 +63,9 @@ class Lamb(Optimizer):
The default value is None in static mode, at this time all parameters will be updated. The default value is None in static mode, at this time all parameters will be updated.
grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
some derived class of ``GradientClipBase`` . There are three cliping strategies some derived class of ``GradientClipBase`` . There are three cliping strategies
( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , ( :ref:`api_paddle_fluid_clip_ClipGradByGlobalNorm` , :ref:`api_paddle_fluid_clip_ClipGradByNorm` ,
:ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping. :ref:`api_paddle_fluid_clip_ClipGradByValue` ). If you want better convergence, it is recommended
to use :ref:`api_paddle_fluid_clip_ClipGradByGlobalNorm` . Default None, meaning there is no gradient clipping.
name(str|None): For detailed information, please refer to name(str|None): For detailed information, please refer to
:ref:`api_guide_Name` . Usually name is no need to set and None by default. :ref:`api_guide_Name` . Usually name is no need to set and None by default.
Examples: Examples:
...@@ -81,7 +86,6 @@ class Lamb(Optimizer): ...@@ -81,7 +86,6 @@ class Lamb(Optimizer):
""" """
_moment1_acc_str = "moment1" _moment1_acc_str = "moment1"
_moment2_acc_str = "moment2" _moment2_acc_str = "moment2"
# these two not used in op temporarily
_beta1_pow_acc_str = "beta1_pow_acc" _beta1_pow_acc_str = "beta1_pow_acc"
_beta2_pow_acc_str = "beta2_pow_acc" _beta2_pow_acc_str = "beta2_pow_acc"
...@@ -93,6 +97,7 @@ class Lamb(Optimizer): ...@@ -93,6 +97,7 @@ class Lamb(Optimizer):
epsilon=1e-6, epsilon=1e-6,
parameters=None, parameters=None,
grad_clip=None, grad_clip=None,
exclude_from_weight_decay_fn=None,
name=None): name=None):
assert learning_rate is not None assert learning_rate is not None
assert beta1 is not None assert beta1 is not None
...@@ -109,6 +114,7 @@ class Lamb(Optimizer): ...@@ -109,6 +114,7 @@ class Lamb(Optimizer):
self._beta2 = beta2 self._beta2 = beta2
self._epsilon = epsilon self._epsilon = epsilon
self._lamb_weight_decay = lamb_weight_decay self._lamb_weight_decay = lamb_weight_decay
self._exclude_from_weight_decay_fn = exclude_from_weight_decay_fn
def _create_accumulators(self, block, parameters): def _create_accumulators(self, block, parameters):
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
...@@ -145,34 +151,51 @@ class Lamb(Optimizer): ...@@ -145,34 +151,51 @@ class Lamb(Optimizer):
beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str, beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
param_and_grad[0]) param_and_grad[0])
if param_and_grad[0].need_clip: if self._exclude_from_weight_decay_fn is not None \
and self._exclude_from_weight_decay_fn(param_and_grad[0]):
weight_decay = 0.0 weight_decay = 0.0
else: else:
weight_decay = self._lamb_weight_decay weight_decay = self._lamb_weight_decay
lr = self._create_param_lr(param_and_grad)
if framework.in_dygraph_mode():
_, _, _, _, _ = core.ops.lamb(
param_and_grad[0], param_and_grad[1], lr, moment1, moment2,
beta1_pow_acc, beta2_pow_acc, param_and_grad[0], moment1,
moment2, beta1_pow_acc, beta2_pow_acc, 'beta1', self._beta1,
'beta2', self._beta2, 'epsilon', self._epsilon, 'weight_decay',
weight_decay)
return None
# create the lamb optimize op # create the lamb optimize op
lamb_op = block.append_op( inputs = {
type=self.type,
inputs={
"Param": param_and_grad[0], "Param": param_and_grad[0],
"Grad": param_and_grad[1], "Grad": param_and_grad[1],
"LearningRate": self._create_param_lr(param_and_grad), "LearningRate": lr,
"Moment1": moment1, "Moment1": moment1,
"Moment2": moment2, "Moment2": moment2,
"Beta1Pow": beta1_pow_acc, "Beta1Pow": beta1_pow_acc,
"Beta2Pow": beta2_pow_acc "Beta2Pow": beta2_pow_acc
}, }
outputs={ outputs = {
"ParamOut": param_and_grad[0], "ParamOut": param_and_grad[0],
"Moment1Out": moment1, "Moment1Out": moment1,
"Moment2Out": moment2 "Moment2Out": moment2,
}, "Beta1PowOut": beta1_pow_acc,
attrs={ "Beta2PowOut": beta2_pow_acc
}
attrs = {
"beta1": self._beta1, "beta1": self._beta1,
"beta2": self._beta2, "beta2": self._beta2,
"epsilon": self._epsilon, "epsilon": self._epsilon,
"weight_decay": weight_decay "weight_decay": weight_decay
}, }
lamb_op = block.append_op(
type=self.type,
inputs=inputs,
outputs=outputs,
attrs=attrs,
stop_gradient=True) stop_gradient=True)
return lamb_op return lamb_op
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册