diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py index 4fee2e5f12728c94a9f533c68e5dfda3f3820a71..f0745bd9690f0abb1affc13f4f7d92a5de596da3 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py @@ -88,6 +88,8 @@ no_amp_list = [ 'rmsprop', 'sgd_', 'sgd', + 'lamb_', + 'lamb', 'assign_value_', 'sparse_momentum_', 'sparse_momentum', diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cu b/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cu index e7d795ccc579c86d9dbb2b6dc898cc49af8dba98..d922b2a30cf903073544ef5d1ed9e37b6ea55bfa 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cu +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cu @@ -15,15 +15,18 @@ #include "paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.h" #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/operators/optimizers/cast_with_ptr.h" -#include "paddle/fluid/operators/tensor_to_string.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/kernels/funcs/algorithm.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/tensor_to_string.h" namespace paddle { namespace operators { +using phi::funcs::FlattenToString; +using phi::funcs::ToVector; + struct ParamGradInfo { framework::Tensor *param_t{nullptr}; framework::Tensor *grad_t{nullptr}; diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu index 53a5fa4706cc03f07f2fc7fb4ad6c26bf7c26d4b..5e6c43aa127120bb6ed8c891e4551e173738d40d 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu @@ -19,12 +19,12 @@ #include "paddle/fluid/operators/optimizers/cast_with_ptr.h" #include "paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h" #include "paddle/fluid/operators/optimizers/multi_tensor_apply.h" -#include "paddle/fluid/operators/tensor_to_string.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/string/string_helper.h" #include "paddle/phi/core/utils/data_type.h" #include "paddle/phi/kernels/funcs/aligned_vector.h" +#include "paddle/phi/kernels/funcs/tensor_to_string.h" #ifdef __NVCC__ #include "cub/cub.cuh" @@ -43,6 +43,8 @@ namespace operators { template using MasterT = typename details::MPTypeTrait::Type; +using phi::funcs::FlattenToString; +using phi::funcs::ToVector; template static void FillZeroWithPtr(T *x, size_t n, gpuStream_t stream) { diff --git a/paddle/fluid/operators/optimizers/lamb_op.cc b/paddle/fluid/operators/optimizers/lamb_op.cc index 8434da2bb0e76a389991d141d2fffab90956d962..cc3c99f9b112930f570fe18c68340cd76ad3f485 100644 --- a/paddle/fluid/operators/optimizers/lamb_op.cc +++ b/paddle/fluid/operators/optimizers/lamb_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,11 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/optimizers/lamb_op.h" - #include - +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/infermeta/multiary.h" +#include "paddle/phi/kernels/lamb_kernel.h" namespace paddle { namespace operators { @@ -25,125 +29,6 @@ class LambOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("Param"), - true, - platform::errors::NotFound( - "Input(Param) of LambOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("Grad"), - true, - platform::errors::NotFound( - "Input(Grad) of LambOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("Moment1"), - true, - platform::errors::NotFound( - "Input(Moment1) of LambOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("Moment2"), - true, - platform::errors::NotFound( - "Input(Moment2) of LambOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("LearningRate"), - true, - platform::errors::NotFound( - "Input(LearningRate) of LambOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("Beta1Pow"), - true, - platform::errors::NotFound( - "Input(Beta1Pow) of LambOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("Beta2Pow"), - true, - platform::errors::NotFound( - "Input(Beta2Pow) of LambOp should not be null.")); - - PADDLE_ENFORCE_EQ(ctx->HasOutput("ParamOut"), - true, - platform::errors::NotFound( - "Output(ParamOut) of LambOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput("Moment1Out"), - true, - platform::errors::NotFound( - "Output(Moment1Out) of LambOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput("Moment2Out"), - true, - platform::errors::NotFound( - "Output(Moment2Out) of LambOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput("Beta1PowOut"), - true, - platform::errors::NotFound( - "Output(Beta1PowOut) of LambOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput("Beta2PowOut"), - true, - platform::errors::NotFound( - "Output(Beta2PowOut) of LambOp should not be null.")); - - auto lr_dims = ctx->GetInputDim("LearningRate"); - PADDLE_ENFORCE_NE( - phi::product(lr_dims), - 0, - platform::errors::InvalidArgument( - "The number of LearningRate shall not be 0, but received %d. Maybe " - "the Input variable LearningRate has not " - "been initialized. You may need to confirm " - "if you put exe.run(startup_program) " - "after optimizer.minimize function.", - phi::product(lr_dims))); - PADDLE_ENFORCE_EQ( - phi::product(lr_dims), - 1, - platform::errors::InvalidArgument( - "Learning rate should have 1 dimension, but received %d.", - phi::product(lr_dims))); - auto beta1_pow_dims = ctx->GetInputDim("Beta1Pow"); - PADDLE_ENFORCE_GE(phi::product(beta1_pow_dims), - 1, - platform::errors::InvalidArgument( - "The size of Beta1 power accumulator should be " - "greater than 0, but received %d.", - phi::product(beta1_pow_dims))); - auto beta2_pow_dims = ctx->GetInputDim("Beta2Pow"); - PADDLE_ENFORCE_GE(phi::product(beta2_pow_dims), - 1, - platform::errors::InvalidArgument( - "The size of Beta2 power accumulator should be " - "greater than 0, but received %d.", - phi::product(beta2_pow_dims))); - - auto param_dims = ctx->GetInputDim("Param"); - if (ctx->GetInputsVarType("Grad")[0] == - framework::proto::VarType::LOD_TENSOR) { - PADDLE_ENFORCE_EQ( - param_dims, - ctx->GetInputDim("Grad"), - platform::errors::InvalidArgument( - "Param and Grad input of LambOp should have same dimension. But " - "received Param dims: [%s], Grad dims: [%s].", - param_dims, - ctx->GetInputDim("Grad"))); - } - PADDLE_ENFORCE_EQ( - param_dims, - ctx->GetInputDim("Moment1"), - platform::errors::InvalidArgument( - "Param and Moment1 input of LambOp should have same dimension. But " - "received Param dims: [%s], Moment1 dims: [%s].", - param_dims, - ctx->GetInputDim("Moment1"))); - PADDLE_ENFORCE_EQ( - param_dims, - ctx->GetInputDim("Moment2"), - platform::errors::InvalidArgument( - "Param and Moment2 input of LambOp should have same dimension. But " - "received Param dims: [%s], Moment2 dims: [%s].", - param_dims, - ctx->GetInputDim("Moment2"))); - - ctx->SetOutputDim("ParamOut", param_dims); - ctx->SetOutputDim("Moment1Out", param_dims); - ctx->SetOutputDim("Moment2Out", param_dims); - ctx->SetOutputDim("Beta1PowOut", beta1_pow_dims); - ctx->SetOutputDim("Beta2PowOut", beta2_pow_dims); - } - framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const { auto input_data_type = @@ -246,10 +131,16 @@ learning rate, $\lambda$ the weight decay rate. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(lamb, ops::LambOp, ops::LambOpMaker); -REGISTER_OP_CPU_KERNEL(lamb, - ops::LambOpKernel, - ops::LambOpKernel); +DECLARE_INFER_SHAPE_FUNCTOR(lamb, + LambInferMetaFunctor, + PD_INFER_META(phi::LambInferMeta)); +REGISTER_OPERATOR( + lamb, + ops::LambOp, + ops::LambOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker, + LambInferMetaFunctor); /* ========================== register checkpoint ===========================*/ REGISTER_OP_VERSION(lamb).AddCheckpoint( diff --git a/paddle/fluid/operators/optimizers/lamb_op.cu b/paddle/fluid/operators/optimizers/lamb_op.cu deleted file mode 100644 index 0d60979eef0bd7820aa21c7d9dc7a2e49cf90091..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/optimizers/lamb_op.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/optimizers/lamb_op.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - lamb, - ops::LambOpKernel, - ops::LambOpKernel, - ops::LambOpKernel); diff --git a/paddle/fluid/operators/optimizers/lamb_op.h b/paddle/fluid/operators/optimizers/lamb_op.h deleted file mode 100644 index 0415bb7df02acde3e47b9500b3879bd587bc6103..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/optimizers/lamb_op.h +++ /dev/null @@ -1,813 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include // for sqrt in CPU and CUDA - -#include -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/memory/buffer.h" -#include "paddle/fluid/operators/amp/fp16_type_traits.h" -#include "paddle/fluid/operators/math/selected_rows_functor.h" -#include "paddle/fluid/operators/tensor_to_string.h" -#include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/kernels/funcs/algorithm.h" -#include "paddle/phi/kernels/funcs/eigen/extensions.h" -#include "paddle/phi/kernels/funcs/squared_l2_norm.h" - -namespace paddle { -namespace operators { - -namespace scatter = paddle::operators::math::scatter; - -template -struct LambMomentREGUpdateFunctor { - using MT = typename std::conditional::Type, - T>::type; - - MT weight_decay_; - MT beta1_; - MT beta2_; - MT epsilon_; - - MT beta1_pow_; - MT* beta1_pow_out_; - MT beta2_pow_; - MT* beta2_pow_out_; - const MT* moment1_; - MT* moment1_out_; - const MT* moment2_; - MT* moment2_out_; - const T* grad_; - const MT* param_; - MT* trust_ratio_div_; - const bool* skip_update_; - - LambMomentREGUpdateFunctor(MT weight_decay, - MT beta1, - MT beta2, - MT epsilon, - MT beta1_pow, - MT beta2_pow, - const MT* mom1, - MT* mom1_out, - const MT* mom2, - MT* mom2_out, - const T* grad, - const MT* param, - MT* trust_ratio_div, - const bool* skip_update) - : weight_decay_(weight_decay), - beta1_(beta1), - beta2_(beta2), - epsilon_(epsilon), - beta1_pow_(beta1_pow), - beta2_pow_(beta2_pow), - moment1_(mom1), - moment1_out_(mom1_out), - moment2_(mom2), - moment2_out_(mom2_out), - grad_(grad), - param_(param), - trust_ratio_div_(trust_ratio_div), - skip_update_(skip_update) {} - - inline HOSTDEVICE void operator()(size_t i) const { - if (skip_update_ && *skip_update_) return; - - MT g = static_cast(grad_[i]); - MT mom1 = moment1_[i]; - MT mom2 = moment2_[i]; - MT beta1_pow = beta1_pow_; - MT beta2_pow = beta2_pow_; - MT p = param_[i]; - - mom1 = beta1_ * mom1 + (static_cast(1) - beta1_) * g; - mom2 = beta2_ * mom2 + (static_cast(1) - beta2_) * g * g; - - moment1_out_[i] = mom1; - moment2_out_[i] = mom2; - - MT mom1_unbiased = mom1 / (static_cast(1) - beta1_pow); - MT mom2_unbiased = mom2 / (static_cast(1) - beta2_pow); - trust_ratio_div_[i] = - mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) + - weight_decay_ * p; - } -}; - -template -struct LambMomentMENUpdateFunctor { - using MT = typename std::conditional::Type, - T>::type; - - MT weight_decay_; - MT beta1_; - MT beta2_; - MT epsilon_; - - const MT* beta1_pow_; - const MT* beta2_pow_; - const MT* moment1_; - MT* moment1_out_; - const MT* moment2_; - MT* moment2_out_; - const T* grad_; - const MT* param_; - MT* trust_ratio_div_; - const bool* skip_update_; - - LambMomentMENUpdateFunctor(MT weight_decay, - MT beta1, - MT beta2, - MT epsilon, - const MT* beta1_pow, - const MT* beta2_pow, - const MT* mom1, - MT* mom1_out, - const MT* mom2, - MT* mom2_out, - const T* grad, - const MT* param, - MT* trust_ratio_div, - const bool* skip_update) - : weight_decay_(weight_decay), - beta1_(beta1), - beta2_(beta2), - epsilon_(epsilon), - beta1_pow_(beta1_pow), - beta2_pow_(beta2_pow), - moment1_(mom1), - moment1_out_(mom1_out), - moment2_(mom2), - moment2_out_(mom2_out), - grad_(grad), - param_(param), - trust_ratio_div_(trust_ratio_div), - skip_update_(skip_update) {} - - inline HOSTDEVICE void operator()(size_t i) const { - if (skip_update_ && *skip_update_) return; - MT g = static_cast(grad_[i]); - MT mom1 = moment1_[i]; - MT mom2 = moment2_[i]; - MT beta1_pow = *beta1_pow_; - MT beta2_pow = *beta2_pow_; - MT p = param_[i]; - - mom1 = beta1_ * mom1 + (static_cast(1) - beta1_) * g; - mom2 = beta2_ * mom2 + (static_cast(1) - beta2_) * g * g; - - moment1_out_[i] = mom1; - moment2_out_[i] = mom2; - - MT mom1_unbiased = mom1 / (static_cast(1) - beta1_pow); - MT mom2_unbiased = mom2 / (static_cast(1) - beta2_pow); - trust_ratio_div_[i] = - mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) + - weight_decay_ * p; - } -}; - -template -struct SparseLambMomentREGUpdateFunctor { - T weight_decay_; - T beta1_; - T beta2_; - T epsilon_; - - T beta1_pow_; - T beta2_pow_; - const T* moment1_; - T* moment1_out_; - const T* moment2_; - T* moment2_out_; - const T* grad_; - const T* param_; - T* trust_ratio_div_; - - const int64_t* rows_; - int64_t row_numel_; - int64_t row_count_; - - const bool* skip_update_; - - SparseLambMomentREGUpdateFunctor(T weight_decay, - T beta1, - T beta2, - T epsilon, - T beta1_pow, - T beta2_pow, - const T* mom1, - T* mom1_out, - const T* mom2, - T* mom2_out, - const T* grad, - const T* param, - T* trust_ratio_div, - const int64_t* rows, - int64_t row_numel, - int64_t row_count, - const bool* skip_update) - : weight_decay_(weight_decay), - beta1_(beta1), - beta2_(beta2), - epsilon_(epsilon), - beta1_pow_(beta1_pow), - beta2_pow_(beta2_pow), - moment1_(mom1), - moment1_out_(mom1_out), - moment2_(mom2), - moment2_out_(mom2_out), - grad_(grad), - param_(param), - trust_ratio_div_(trust_ratio_div), - rows_(rows), - row_numel_(row_numel), - row_count_(row_count), - skip_update_(skip_update) {} - - inline HOSTDEVICE void update(size_t i, T g) const { - // The following code is same as dense - T mom1 = moment1_[i]; - T mom2 = moment2_[i]; - T beta1_pow = beta1_pow_; - T beta2_pow = beta2_pow_; - T p = param_[i]; - - mom1 = beta1_ * mom1 + (static_cast(1) - beta1_) * g; - mom2 = beta2_ * mom2 + (static_cast(1) - beta2_) * g * g; - - moment1_out_[i] = mom1; - moment2_out_[i] = mom2; - - T mom1_unbiased = mom1 / (static_cast(1) - beta1_pow); - T mom2_unbiased = mom2 / (static_cast(1) - beta2_pow); - trust_ratio_div_[i] = - mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) + - weight_decay_ * p; - } - - inline HOSTDEVICE void operator()(size_t i) const { - if (skip_update_ && *skip_update_) return; - auto row_idx = - phi::funcs::BinarySearch(rows_, row_count_, i / row_numel_); - T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] - : static_cast(0); - update(i, g); - } -}; - -template -struct SparseLambMomentMENUpdateFunctor { - T weight_decay_; - T beta1_; - T beta2_; - T epsilon_; - - const T* beta1_pow_; - const T* beta2_pow_; - const T* moment1_; - T* moment1_out_; - const T* moment2_; - T* moment2_out_; - const T* grad_; - const T* param_; - T* trust_ratio_div_; - - const int64_t* rows_; - int64_t row_numel_; - int64_t row_count_; - - const bool* skip_update_; - - SparseLambMomentMENUpdateFunctor(T weight_decay, - T beta1, - T beta2, - T epsilon, - const T* beta1_pow, - const T* beta2_pow, - const T* mom1, - T* mom1_out, - const T* mom2, - T* mom2_out, - const T* grad, - const T* param, - T* trust_ratio_div, - const int64_t* rows, - int64_t row_numel, - int64_t row_count, - const bool* skip_update) - : weight_decay_(weight_decay), - beta1_(beta1), - beta2_(beta2), - epsilon_(epsilon), - beta1_pow_(beta1_pow), - beta2_pow_(beta2_pow), - moment1_(mom1), - moment1_out_(mom1_out), - moment2_(mom2), - moment2_out_(mom2_out), - grad_(grad), - param_(param), - trust_ratio_div_(trust_ratio_div), - rows_(rows), - row_numel_(row_numel), - row_count_(row_count), - skip_update_(skip_update) {} - - inline HOSTDEVICE void update(size_t i, T g) const { - // The following code is same as dense - T mom1 = moment1_[i]; - T mom2 = moment2_[i]; - T beta1_pow = *beta1_pow_; - T beta2_pow = *beta2_pow_; - T p = param_[i]; - - mom1 = beta1_ * mom1 + (static_cast(1) - beta1_) * g; - mom2 = beta2_ * mom2 + (static_cast(1) - beta2_) * g * g; - - moment1_out_[i] = mom1; - moment2_out_[i] = mom2; - - T mom1_unbiased = mom1 / (static_cast(1) - beta1_pow); - T mom2_unbiased = mom2 / (static_cast(1) - beta2_pow); - trust_ratio_div_[i] = - mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) + - weight_decay_ * p; - } - - inline HOSTDEVICE void operator()(size_t i) const { - if (skip_update_ && *skip_update_) return; - auto row_idx = - phi::funcs::BinarySearch(rows_, row_count_, i / row_numel_); - T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] - : static_cast(0); - update(i, g); - } -}; - -template -struct LambBetaPowUpdateFunctor { - void SetBetaPows(const MT* beta1pow, - const MT* beta2pow, - MT* beta1pow_out, - MT* beta2pow_out, - MT beta1, - MT beta2) { - beta1pow_ = beta1pow; - beta2pow_ = beta2pow; - beta1pow_out_ = beta1pow_out; - beta2pow_out_ = beta2pow_out; - beta1_ = beta1; - beta2_ = beta2; - } - - HOSTDEVICE void UpdateBetaPow(size_t i) const { - if (i == 0) { - beta1pow_out_[0] = beta1pow_[0] * beta1_; - beta2pow_out_[0] = beta2pow_[0] * beta2_; - } - } - - private: - const MT* beta1pow_; - const MT* beta2pow_; - MT* beta1pow_out_; - MT* beta2pow_out_; - MT beta1_; - MT beta2_; -}; - -template -struct LambBetaPowUpdateFunctor { - void SetBetaPows(const MT* beta1pow, - const MT* beta2pow, - MT* beta1pow_out, - MT* beta2pow_out, - MT beta1, - MT beta2) {} - HOSTDEVICE void UpdateBetaPow(size_t) const {} -}; - -template -struct LambParamUpateFunctor - : public LambBetaPowUpdateFunctor { - const MT* lr_; - const T* param_; - const MT* master_param_; - const MT* param_norm_; - const MT* trust_ratio_div_; - const MT* trust_ratio_div_norm_; - T* param_out_; - MT* master_param_out_; - - const bool* skip_update_; - - LambParamUpateFunctor(const MT* lr, - const T* param, - const MT* master_param, - const MT* param_norm, - const MT* trust_ratio_div, - const MT* trust_ratio_div_norm, - T* param_out, - MT* master_param_out, - const bool* skip_update) - : lr_(lr), - param_(param), - master_param_(master_param), - param_norm_(param_norm), - trust_ratio_div_(trust_ratio_div), - trust_ratio_div_norm_(trust_ratio_div_norm), - param_out_(param_out), - master_param_out_(master_param_out), - skip_update_(skip_update) {} - - inline HOSTDEVICE void operator()(size_t i) const { - if (skip_update_ && *skip_update_) return; - MT lr = *lr_; - MT pn = Eigen::numext::sqrt(*param_norm_); - MT tn = Eigen::numext::sqrt(*trust_ratio_div_norm_); - - MT r = (pn > static_cast(0) && tn > static_cast(0)) - ? pn / tn - : static_cast(1); - lr *= r; - MT p = IsMultiPrecision ? master_param_[i] : static_cast(param_[i]); - MT param_out = p - lr * trust_ratio_div_[i]; - param_out_[i] = static_cast(param_out); - if (IsMultiPrecision) { - master_param_out_[i] = param_out; - } - this->UpdateBetaPow(i); - } -}; - -template -class LambOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - using MT = typename details::MPTypeTrait::Type; - bool multi_precision = ctx.Attr("multi_precision"); - if (multi_precision) { - ComputeImpl(ctx); - } else { - ComputeImpl(ctx); - } - } - - private: - template - void ComputeImpl(const framework::ExecutionContext& ctx) const { - if (!IsMultiPrecision) { - constexpr auto kIsSameType = std::is_same::value; - PADDLE_ENFORCE_EQ( - kIsSameType, - true, - platform::errors::InvalidArgument( - "When multi_precision=False, T and MT must be the same type.")); - } - const auto* skip_update = ctx.Input("SkipUpdate"); - const bool* skip_update_flag = skip_update && skip_update->IsInitialized() - ? skip_update->data() - : nullptr; - if (skip_update_flag && platform::is_cpu_place(skip_update->place()) && - (*skip_update_flag)) { - return; - } - - auto weight_decay = static_cast(ctx.Attr("weight_decay")); - auto beta1 = static_cast(ctx.Attr("beta1")); - auto beta2 = static_cast(ctx.Attr("beta2")); - auto epsilon = static_cast(ctx.Attr("epsilon")); - const auto& param = GET_DATA_SAFELY( - ctx.Input("Param"), "Input", "Param", "Lamb"); - const auto* grad_var = ctx.InputVar("Grad"); - const auto& mom1 = GET_DATA_SAFELY( - ctx.Input("Moment1"), "Input", "Moment1", "Lamb"); - const auto& mom2 = GET_DATA_SAFELY( - ctx.Input("Moment2"), "Input", "Moment2", "Lamb"); - const auto& lr = - GET_DATA_SAFELY(ctx.Input("LearningRate"), - "Input", - "LearningRate", - "Lamb"); - - const auto& beta1_pow = - GET_DATA_SAFELY(ctx.Input("Beta1Pow"), - "Input", - "Beta1Pow", - "Lamb"); - const auto& beta2_pow = - GET_DATA_SAFELY(ctx.Input("Beta2Pow"), - "Input", - "Beta2Pow", - "Lamb"); - - auto& param_out = - GET_DATA_SAFELY(ctx.Output("ParamOut"), - "Output", - "ParamOut", - "Lamb"); - auto& mom1_out = - GET_DATA_SAFELY(ctx.Output("Moment1Out"), - "Output", - "Moment1Out", - "Lamb"); - auto& mom2_out = - GET_DATA_SAFELY(ctx.Output("Moment2Out"), - "Output", - "Moment2Out", - "Lamb"); - auto& beta1_pow_out = - GET_DATA_SAFELY(ctx.Output("Beta1PowOut"), - "Output", - "Beta1PowOut", - "Lamb"); - auto& beta2_pow_out = - GET_DATA_SAFELY(ctx.Output("Beta2PowOut"), - "Output", - "Beta2PowOut", - "Lamb"); - const auto* master_param = - IsMultiPrecision ? ctx.Input("MasterParam") - : nullptr; - auto* master_param_out = - IsMultiPrecision ? ctx.Output("MasterParamOut") - : nullptr; - - if (IsMultiPrecision) { - PADDLE_ENFORCE_NOT_NULL(master_param, - platform::errors::InvalidArgument( - "Input(MasterParam) must be provided when " - "multi_precision=True.")); - PADDLE_ENFORCE_NOT_NULL(master_param_out, - platform::errors::InvalidArgument( - "Output(MasterParamOut) must be provided " - "when multi_precision=True.")); - } - - auto& dev_ctx = ctx.template device_context(); - auto numel = param.numel(); - platform::ForRange for_range(dev_ctx, numel); - auto trust_ratio_div = - ctx.AllocateTmpTensor(param.dims(), dev_ctx); - auto* trust_ratio_div_ptr = trust_ratio_div.template data(); - - const void* param_ptr = param.data(); - const void* master_param_ptr = - master_param ? master_param->data() : nullptr; - void* param_out_ptr = param_out.template mutable_data(ctx.GetPlace()); - void* master_param_out_ptr = - master_param_out - ? master_param_out->template mutable_data(ctx.GetPlace()) - : nullptr; - - // Update moments - bool should_update_beta_pow_later = false; - const MT *beta1_pow_ptr = nullptr, *beta2_pow_ptr = nullptr; - MT *beta1_pow_out_ptr = nullptr, *beta2_pow_out_ptr = nullptr; - VLOG(10) << "Beta1Pow place: " << beta1_pow.place() - << " , Beta2Pow place: " << beta2_pow.place(); - if (grad_var->IsType()) { - auto& grad = grad_var->Get(); - if (platform::is_gpu_place(ctx.GetPlace()) && - beta1_pow.place() == platform::CPUPlace() && - beta2_pow.place() == platform::CPUPlace()) { - LambMomentREGUpdateFunctor moment_update_functor( - weight_decay, - beta1, - beta2, - epsilon, - *beta1_pow.template data(), - *beta2_pow.template data(), - mom1.template data(), - mom1_out.template mutable_data(ctx.GetPlace()), - mom2.template data(), - mom2_out.template mutable_data(ctx.GetPlace()), - grad.template data(), - static_cast(IsMultiPrecision ? master_param_ptr - : param_ptr), - trust_ratio_div_ptr, - skip_update_flag); - for_range(moment_update_functor); - beta1_pow_out.template mutable_data(platform::CPUPlace())[0] = - beta1 * beta1_pow.template data()[0]; - beta2_pow_out.template mutable_data(platform::CPUPlace())[0] = - beta2 * beta2_pow.template data()[0]; - } else { - beta1_pow_ptr = beta1_pow.template data(); - beta2_pow_ptr = beta2_pow.template data(); - beta1_pow_out_ptr = - beta1_pow_out.template mutable_data(ctx.GetPlace()); - beta2_pow_out_ptr = - beta2_pow_out.template mutable_data(ctx.GetPlace()); - should_update_beta_pow_later = true; - LambMomentMENUpdateFunctor moment_update_functor( - weight_decay, - beta1, - beta2, - epsilon, - static_cast(beta1_pow_ptr), - static_cast(beta2_pow_ptr), - mom1.template data(), - mom1_out.template mutable_data(ctx.GetPlace()), - mom2.template data(), - mom2_out.template mutable_data(ctx.GetPlace()), - grad.template data(), - static_cast(IsMultiPrecision ? master_param_ptr - : param_ptr), - trust_ratio_div_ptr, - skip_update_flag); - for_range(moment_update_functor); - } - } else if (grad_var->IsType()) { - PADDLE_ENFORCE_EQ(IsMultiPrecision, - false, - platform::errors::Unimplemented( - "SelectedRows gradient is not supported when " - "multi_precision=True.")); - constexpr bool kIsSameType = std::is_same::value; - PADDLE_ENFORCE_EQ(kIsSameType, - true, - platform::errors::Unimplemented( - "SelectedRows gradient is not supported when " - "multi_precision=True.")); - auto& grad = GET_DATA_SAFELY( - ctx.Input("Grad"), "Input", "Grad", "Lamb"); - if (grad.rows().size() == 0) { - VLOG(3) << "grad row size is 0!!"; - return; - } - - std::vector cpu_rows(grad.rows().begin(), grad.rows().end()); - bool is_strict_sorted = true; - for (size_t i = 1; i < cpu_rows.size(); ++i) { - if (cpu_rows[i - 1] >= cpu_rows[i]) { - is_strict_sorted = false; - break; - } - } - - phi::SelectedRows tmp_grad_merge; - const phi::SelectedRows* grad_merge_ptr; - if (is_strict_sorted) { - grad_merge_ptr = &grad; - } else { - // merge duplicated rows if any. - // The rows of grad_merge have been sorted inside MergeAdd functor - scatter::MergeAdd merge_func; - merge_func(dev_ctx, grad, &tmp_grad_merge, true); - grad_merge_ptr = &tmp_grad_merge; - } - - auto& grad_merge = *grad_merge_ptr; - auto& grad_tensor = grad_merge.value(); - const T* grad_data = grad_tensor.template data(); - auto* grad_merge_rows = &grad_merge.rows(); - paddle::framework::MixVector mixv_grad_merge_rows( - grad_merge_rows); - const int64_t* rows = mixv_grad_merge_rows.Data(ctx.GetPlace()); - auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); - if (platform::is_gpu_place(ctx.GetPlace()) && - beta1_pow.place() == platform::CPUPlace() && - beta2_pow.place() == platform::CPUPlace()) { - SparseLambMomentREGUpdateFunctor moment_update_functor( - static_cast(weight_decay), - static_cast(beta1), - static_cast(beta2), - static_cast(epsilon), - *beta1_pow.template data(), - *beta2_pow.template data(), - mom1.template data(), - mom1_out.template mutable_data(ctx.GetPlace()), - mom2.template data(), - mom2_out.template mutable_data(ctx.GetPlace()), - grad_data, - param.template data(), - trust_ratio_div.template data(), - rows, - row_numel, - grad_merge.rows().size(), - skip_update_flag); - for_range(moment_update_functor); - beta1_pow_out.template mutable_data(platform::CPUPlace())[0] = - static_cast(beta1) * beta1_pow.template data()[0]; - beta2_pow_out.template mutable_data(platform::CPUPlace())[0] = - static_cast(beta2) * beta2_pow.template data()[0]; - } else { - beta1_pow_ptr = beta1_pow.template data(); - beta2_pow_ptr = beta2_pow.template data(); - beta1_pow_out_ptr = - beta1_pow_out.template mutable_data(ctx.GetPlace()); - beta2_pow_out_ptr = - beta2_pow_out.template mutable_data(ctx.GetPlace()); - should_update_beta_pow_later = true; - SparseLambMomentMENUpdateFunctor moment_update_functor( - static_cast(weight_decay), - static_cast(beta1), - static_cast(beta2), - static_cast(epsilon), - reinterpret_cast(beta1_pow_ptr), - reinterpret_cast(beta2_pow_ptr), - mom1.template data(), - mom1_out.template mutable_data(ctx.GetPlace()), - mom2.template data(), - mom2_out.template mutable_data(ctx.GetPlace()), - grad_data, - param.template data(), - trust_ratio_div.template data(), - rows, - row_numel, - grad_merge.rows().size(), - skip_update_flag); - for_range(moment_update_functor); - } - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Variable type not supported by lamb_op. Expect LoDTensor or " - "SelectedRows, but got %s", - framework::ToTypeName(grad_var->Type()))); - } - - // Update parameter - auto p_norm_t = ctx.AllocateTmpTensor({1}, dev_ctx); - auto* p_norm_ptr = p_norm_t.template data(); - - auto trust_ratio_div_norm_t = - ctx.AllocateTmpTensor({1}, dev_ctx); - auto* trust_ratio_div_norm_ptr = trust_ratio_div_norm_t.template data(); - - // TODO(zengjinle): remove the following Eigen operations when - // *skip_update == true. - memory::Buffer buffer(dev_ctx.GetPlace()); - phi::funcs::SquaredL2Norm( - dev_ctx, - reinterpret_cast(IsMultiPrecision ? master_param_ptr - : param_ptr), - p_norm_ptr, - numel, - &buffer); - phi::funcs::SquaredL2Norm( - dev_ctx, trust_ratio_div_ptr, trust_ratio_div_norm_ptr, numel, &buffer); - - if (VLOG_IS_ON(1)) { - const auto& name = ctx.GetOp().Input("Param"); - auto pn = ToVector(p_norm_ptr, 1, dev_ctx.GetPlace()); - auto tn = ToVector(trust_ratio_div_norm_ptr, 1, dev_ctx.GetPlace()); - auto dtype = - framework::DataTypeToString(framework::DataTypeTrait::DataType()); - VLOG(1) << "Param " << dtype << " " << name << " pn = " << pn[0] - << " , tn = " << tn[0]; - } - -#define CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(__should_update_beta_pow) \ - do { \ - LambParamUpateFunctor \ - param_update_functor(lr.template data(), \ - static_cast(param_ptr), \ - static_cast(master_param_ptr), \ - p_norm_ptr, \ - trust_ratio_div_ptr, \ - trust_ratio_div_norm_ptr, \ - static_cast(param_out_ptr), \ - static_cast(master_param_out_ptr), \ - skip_update_flag); \ - if (__should_update_beta_pow) { \ - param_update_functor.SetBetaPows(beta1_pow_ptr, \ - beta2_pow_ptr, \ - beta1_pow_out_ptr, \ - beta2_pow_out_ptr, \ - beta1, \ - beta2); \ - } \ - for_range(param_update_functor); \ - } while (0) - - if (should_update_beta_pow_later) { - CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(true); - } else { - CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(false); - } - -#undef CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/optimizers/lamb_op_xpu.cc b/paddle/fluid/operators/optimizers/lamb_op_xpu.cc index e0233fadb8858abcd674c097a18af45838074bb0..bfeb42a221fa77dbec16b9d9e407ec7f7f6db3bd 100644 --- a/paddle/fluid/operators/optimizers/lamb_op_xpu.cc +++ b/paddle/fluid/operators/optimizers/lamb_op_xpu.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "gflags/gflags.h" -#include "paddle/fluid/operators/optimizers/lamb_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/device_wrapper.h" namespace paddle { diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index b1ae62603d6a04ac63ea99af0937371657ca8df0..307edbefd03ca9a2717732553271feab327d2d28 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -1327,6 +1327,18 @@ optional : prior_dist backward : label_smooth_grad +- api : lamb_ + args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, float weight_decay, float beta1, float beta2, float epsilon, bool multi_precision) + output : Tensor(param_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(master_param_outs) + infer_meta : + func : LambInferMeta + kernel : + func : lamb {dense, dense, dense, dense, dense, dense, dense, dense, dense -> dense, dense, dense, dense, dense, dense}, + lamb_sr {dense, selected_rows, dense, dense, dense, dense, dense, dense, dense -> dense, dense, dense, dense, dense, dense} + data_type : param + optional : master_param, skip_update + inplace : (param -> param_out), (moment1 -> moment1_out), (moment2 -> moment2_out), (beta1_pow -> beta1_pow_out), (beta2_pow -> beta2_pow_out), (master_param -> master_param_outs) + - api : layer_norm args : (Tensor x, Tensor scale, Tensor bias, float epsilon, int begin_norm_axis, bool is_test) output : Tensor(out), Tensor(mean), Tensor(variance) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 4e0db07cc6b3ae42a4d1e5cad2d69d7906464c89..6e4f2dce35f96b6f5835d943502e8c103969a210 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -1642,6 +1642,105 @@ void InterpolateInferMeta( } } +void LambInferMeta(const MetaTensor& param, + const MetaTensor& grad, + const MetaTensor& learning_rate, + const MetaTensor& moment1, + const MetaTensor& moment2, + const MetaTensor& beta1_pow, + const MetaTensor& beta2_pow, + const MetaTensor& master_param, + const MetaTensor& skip_update, + float weight_decay, + float beta1, + float beta2, + float epsilon, + bool multi_precision, + MetaTensor* param_out, + MetaTensor* moment1_out, + MetaTensor* moment2_out, + MetaTensor* beta1_pow_out, + MetaTensor* beta2_pow_out, + MetaTensor* master_param_outs) { + auto lr_dims = learning_rate.dims(); + PADDLE_ENFORCE_NE( + phi::product(lr_dims), + 0, + phi::errors::InvalidArgument( + "The number of LearningRate shall not be 0, but received %d. Maybe " + "the Input variable LearningRate has not " + "been initialized. You may need to confirm " + "if you put exe.run(startup_program) " + "after optimizer.minimize function.", + phi::product(lr_dims))); + PADDLE_ENFORCE_EQ( + phi::product(lr_dims), + 1, + phi::errors::InvalidArgument( + "Learning rate should have 1 dimension, but received %d.", + phi::product(lr_dims))); + auto beta1_pow_dims = beta1_pow.dims(); + PADDLE_ENFORCE_GE(phi::product(beta1_pow_dims), + 1, + phi::errors::InvalidArgument( + "The size of Beta1 power accumulator should be " + "greater than 0, but received %d.", + phi::product(beta1_pow_dims))); + auto beta2_pow_dims = beta2_pow.dims(); + PADDLE_ENFORCE_GE(phi::product(beta2_pow_dims), + 1, + phi::errors::InvalidArgument( + "The size of Beta2 power accumulator should be " + "greater than 0, but received %d.", + phi::product(beta2_pow_dims))); + + auto param_dims = param.dims(); + PADDLE_ENFORCE_EQ( + param_dims, + moment1.dims(), + phi::errors::InvalidArgument( + "Param and Moment1 input of LambOp should have same dimension. But " + "received Param dims: [%s], Moment1 dims: [%s].", + param_dims, + moment1.dims())); + PADDLE_ENFORCE_EQ( + param_dims, + moment2.dims(), + errors::InvalidArgument( + "Param and Moment2 input of AdamOp should have same dimension. But " + "received Param dims: [%s], Moment2 dims: [%s].", + param_dims, + moment2.dims())); + + PADDLE_ENFORCE_NOT_NULL( + param_out, errors::NotFound("The output param_out can not be nullptr")); + PADDLE_ENFORCE_NOT_NULL( + moment1_out, + errors::NotFound("The output moment1_out can not be nullptr")); + PADDLE_ENFORCE_NOT_NULL( + moment2_out, + errors::NotFound("The output moment2_out can not be nullptr")); + PADDLE_ENFORCE_NOT_NULL( + beta1_pow_out, + errors::NotFound("The output beta1_pow_out can not be nullptr")); + PADDLE_ENFORCE_NOT_NULL( + beta2_pow_out, + errors::NotFound("The output beta2_pow_out can not be nullptr")); + + param_out->set_dims(param_dims); + param_out->set_dtype(param.dtype()); + + moment1_out->set_dims(param_dims); + moment1_out->set_dtype(moment1.dtype()); + moment2_out->set_dims(param_dims); + moment2_out->set_dtype(moment2.dtype()); + + beta1_pow_out->set_dims(beta1_pow_dims); + beta1_pow_out->set_dtype(beta1_pow.dtype()); + beta2_pow_out->set_dims(beta2_pow_dims); + beta2_pow_out->set_dtype(beta2_pow.dtype()); +} + void LogspaceInferMeta(const MetaTensor& start, const MetaTensor& stop, const MetaTensor& number, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 3bf4288cc7637a86ae80c3cd7c5301f631fa2d05..472d665050bde1c4130b9a3fdefb89a730385ce0 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -269,6 +269,27 @@ void InterpolateInferMeta( MetaTensor* output, MetaConfig config = MetaConfig()); +void LambInferMeta(const MetaTensor& param, + const MetaTensor& grad, + const MetaTensor& learning_rate, + const MetaTensor& moment1, + const MetaTensor& moment2, + const MetaTensor& beta1_pow, + const MetaTensor& beta2_pow, + const MetaTensor& master_param, + const MetaTensor& skip_update, + float weight_decay, + float beta1, + float beta2, + float epsilon, + bool multi_precision, + MetaTensor* param_out, + MetaTensor* moment1_out, + MetaTensor* moment2_out, + MetaTensor* beta1_pow_out, + MetaTensor* beta2_pow_out, + MetaTensor* master_param_outs); + void LogspaceInferMeta(const MetaTensor& start, const MetaTensor& stop, const MetaTensor& number, diff --git a/paddle/phi/kernels/cpu/lamb_kernel.cc b/paddle/phi/kernels/cpu/lamb_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..1394f8e5b910c07c5a7b5d60eda8eeac2ab402e1 --- /dev/null +++ b/paddle/phi/kernels/cpu/lamb_kernel.cc @@ -0,0 +1,20 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/lamb_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/lamb_kernel_impl.h" + +PD_REGISTER_KERNEL(lamb, CPU, ALL_LAYOUT, phi::LambKernel, float, double) {} diff --git a/paddle/phi/kernels/funcs/lamb_functors.h b/paddle/phi/kernels/funcs/lamb_functors.h new file mode 100644 index 0000000000000000000000000000000000000000..5abc86bfb777c48189813a031880a76c05a8606b --- /dev/null +++ b/paddle/phi/kernels/funcs/lamb_functors.h @@ -0,0 +1,463 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include // for sqrt in CPU and CUDA + +#include +#include + +#include "paddle/fluid/memory/buffer.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/algorithm.h" +#include "paddle/phi/kernels/funcs/eigen/extensions.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/funcs/squared_l2_norm.h" +#include "paddle/phi/kernels/funcs/tensor_to_string.h" + +namespace phi { + +namespace scatter = paddle::operators::math::scatter; + +template +struct LambMomentREGUpdateFunctor { + using MT = + typename std::conditional::Type, + T>::type; + + MT weight_decay_; + MT beta1_; + MT beta2_; + MT epsilon_; + + MT beta1_pow_; + MT* beta1_pow_out_; + MT beta2_pow_; + MT* beta2_pow_out_; + const MT* moment1_; + MT* moment1_out_; + const MT* moment2_; + MT* moment2_out_; + const T* grad_; + const MT* param_; + MT* trust_ratio_div_; + const bool* skip_update_; + + LambMomentREGUpdateFunctor(MT weight_decay, + MT beta1, + MT beta2, + MT epsilon, + MT beta1_pow, + MT beta2_pow, + const MT* mom1, + MT* mom1_out, + const MT* mom2, + MT* mom2_out, + const T* grad, + const MT* param, + MT* trust_ratio_div, + const bool* skip_update) + : weight_decay_(weight_decay), + beta1_(beta1), + beta2_(beta2), + epsilon_(epsilon), + beta1_pow_(beta1_pow), + beta2_pow_(beta2_pow), + moment1_(mom1), + moment1_out_(mom1_out), + moment2_(mom2), + moment2_out_(mom2_out), + grad_(grad), + param_(param), + trust_ratio_div_(trust_ratio_div), + skip_update_(skip_update) {} + + inline HOSTDEVICE void operator()(size_t i) const { + if (skip_update_ && *skip_update_) return; + + MT g = static_cast(grad_[i]); + MT mom1 = moment1_[i]; + MT mom2 = moment2_[i]; + MT beta1_pow = beta1_pow_; + MT beta2_pow = beta2_pow_; + MT p = param_[i]; + + mom1 = beta1_ * mom1 + (static_cast(1) - beta1_) * g; + mom2 = beta2_ * mom2 + (static_cast(1) - beta2_) * g * g; + + moment1_out_[i] = mom1; + moment2_out_[i] = mom2; + + MT mom1_unbiased = mom1 / (static_cast(1) - beta1_pow); + MT mom2_unbiased = mom2 / (static_cast(1) - beta2_pow); + trust_ratio_div_[i] = + mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) + + weight_decay_ * p; + } +}; + +template +struct LambMomentMENUpdateFunctor { + using MT = + typename std::conditional::Type, + T>::type; + + MT weight_decay_; + MT beta1_; + MT beta2_; + MT epsilon_; + + const MT* beta1_pow_; + const MT* beta2_pow_; + const MT* moment1_; + MT* moment1_out_; + const MT* moment2_; + MT* moment2_out_; + const T* grad_; + const MT* param_; + MT* trust_ratio_div_; + const bool* skip_update_; + + LambMomentMENUpdateFunctor(MT weight_decay, + MT beta1, + MT beta2, + MT epsilon, + const MT* beta1_pow, + const MT* beta2_pow, + const MT* mom1, + MT* mom1_out, + const MT* mom2, + MT* mom2_out, + const T* grad, + const MT* param, + MT* trust_ratio_div, + const bool* skip_update) + : weight_decay_(weight_decay), + beta1_(beta1), + beta2_(beta2), + epsilon_(epsilon), + beta1_pow_(beta1_pow), + beta2_pow_(beta2_pow), + moment1_(mom1), + moment1_out_(mom1_out), + moment2_(mom2), + moment2_out_(mom2_out), + grad_(grad), + param_(param), + trust_ratio_div_(trust_ratio_div), + skip_update_(skip_update) {} + + inline HOSTDEVICE void operator()(size_t i) const { + if (skip_update_ && *skip_update_) return; + MT g = static_cast(grad_[i]); + MT mom1 = moment1_[i]; + MT mom2 = moment2_[i]; + MT beta1_pow = *beta1_pow_; + MT beta2_pow = *beta2_pow_; + MT p = param_[i]; + + mom1 = beta1_ * mom1 + (static_cast(1) - beta1_) * g; + mom2 = beta2_ * mom2 + (static_cast(1) - beta2_) * g * g; + + moment1_out_[i] = mom1; + moment2_out_[i] = mom2; + + MT mom1_unbiased = mom1 / (static_cast(1) - beta1_pow); + MT mom2_unbiased = mom2 / (static_cast(1) - beta2_pow); + trust_ratio_div_[i] = + mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) + + weight_decay_ * p; + } +}; + +template +struct SparseLambMomentREGUpdateFunctor { + T weight_decay_; + T beta1_; + T beta2_; + T epsilon_; + + T beta1_pow_; + T beta2_pow_; + const T* moment1_; + T* moment1_out_; + const T* moment2_; + T* moment2_out_; + const T* grad_; + const T* param_; + T* trust_ratio_div_; + + const int64_t* rows_; + int64_t row_numel_; + int64_t row_count_; + + const bool* skip_update_; + + SparseLambMomentREGUpdateFunctor(T weight_decay, + T beta1, + T beta2, + T epsilon, + T beta1_pow, + T beta2_pow, + const T* mom1, + T* mom1_out, + const T* mom2, + T* mom2_out, + const T* grad, + const T* param, + T* trust_ratio_div, + const int64_t* rows, + int64_t row_numel, + int64_t row_count, + const bool* skip_update) + : weight_decay_(weight_decay), + beta1_(beta1), + beta2_(beta2), + epsilon_(epsilon), + beta1_pow_(beta1_pow), + beta2_pow_(beta2_pow), + moment1_(mom1), + moment1_out_(mom1_out), + moment2_(mom2), + moment2_out_(mom2_out), + grad_(grad), + param_(param), + trust_ratio_div_(trust_ratio_div), + rows_(rows), + row_numel_(row_numel), + row_count_(row_count), + skip_update_(skip_update) {} + + inline HOSTDEVICE void update(size_t i, T g) const { + // The following code is same as dense + T mom1 = moment1_[i]; + T mom2 = moment2_[i]; + T beta1_pow = beta1_pow_; + T beta2_pow = beta2_pow_; + T p = param_[i]; + + mom1 = beta1_ * mom1 + (static_cast(1) - beta1_) * g; + mom2 = beta2_ * mom2 + (static_cast(1) - beta2_) * g * g; + + moment1_out_[i] = mom1; + moment2_out_[i] = mom2; + + T mom1_unbiased = mom1 / (static_cast(1) - beta1_pow); + T mom2_unbiased = mom2 / (static_cast(1) - beta2_pow); + trust_ratio_div_[i] = + mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) + + weight_decay_ * p; + } + + inline HOSTDEVICE void operator()(size_t i) const { + if (skip_update_ && *skip_update_) return; + auto row_idx = + phi::funcs::BinarySearch(rows_, row_count_, i / row_numel_); + T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] + : static_cast(0); + update(i, g); + } +}; + +template +struct SparseLambMomentMENUpdateFunctor { + T weight_decay_; + T beta1_; + T beta2_; + T epsilon_; + + const T* beta1_pow_; + const T* beta2_pow_; + const T* moment1_; + T* moment1_out_; + const T* moment2_; + T* moment2_out_; + const T* grad_; + const T* param_; + T* trust_ratio_div_; + + const int64_t* rows_; + int64_t row_numel_; + int64_t row_count_; + + const bool* skip_update_; + + SparseLambMomentMENUpdateFunctor(T weight_decay, + T beta1, + T beta2, + T epsilon, + const T* beta1_pow, + const T* beta2_pow, + const T* mom1, + T* mom1_out, + const T* mom2, + T* mom2_out, + const T* grad, + const T* param, + T* trust_ratio_div, + const int64_t* rows, + int64_t row_numel, + int64_t row_count, + const bool* skip_update) + : weight_decay_(weight_decay), + beta1_(beta1), + beta2_(beta2), + epsilon_(epsilon), + beta1_pow_(beta1_pow), + beta2_pow_(beta2_pow), + moment1_(mom1), + moment1_out_(mom1_out), + moment2_(mom2), + moment2_out_(mom2_out), + grad_(grad), + param_(param), + trust_ratio_div_(trust_ratio_div), + rows_(rows), + row_numel_(row_numel), + row_count_(row_count), + skip_update_(skip_update) {} + + inline HOSTDEVICE void update(size_t i, T g) const { + // The following code is same as dense + T mom1 = moment1_[i]; + T mom2 = moment2_[i]; + T beta1_pow = *beta1_pow_; + T beta2_pow = *beta2_pow_; + T p = param_[i]; + + mom1 = beta1_ * mom1 + (static_cast(1) - beta1_) * g; + mom2 = beta2_ * mom2 + (static_cast(1) - beta2_) * g * g; + + moment1_out_[i] = mom1; + moment2_out_[i] = mom2; + + T mom1_unbiased = mom1 / (static_cast(1) - beta1_pow); + T mom2_unbiased = mom2 / (static_cast(1) - beta2_pow); + trust_ratio_div_[i] = + mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) + + weight_decay_ * p; + } + + inline HOSTDEVICE void operator()(size_t i) const { + if (skip_update_ && *skip_update_) return; + auto row_idx = + phi::funcs::BinarySearch(rows_, row_count_, i / row_numel_); + T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] + : static_cast(0); + update(i, g); + } +}; + +template +struct LambBetaPowUpdateFunctor { + void SetBetaPows(const MT* beta1pow, + const MT* beta2pow, + MT* beta1pow_out, + MT* beta2pow_out, + MT beta1, + MT beta2) { + beta1pow_ = beta1pow; + beta2pow_ = beta2pow; + beta1pow_out_ = beta1pow_out; + beta2pow_out_ = beta2pow_out; + beta1_ = beta1; + beta2_ = beta2; + } + + HOSTDEVICE void UpdateBetaPow(size_t i) const { + if (i == 0) { + beta1pow_out_[0] = beta1pow_[0] * beta1_; + beta2pow_out_[0] = beta2pow_[0] * beta2_; + } + } + + private: + const MT* beta1pow_; + const MT* beta2pow_; + MT* beta1pow_out_; + MT* beta2pow_out_; + MT beta1_; + MT beta2_; +}; + +template +struct LambBetaPowUpdateFunctor { + void SetBetaPows(const MT* beta1pow, + const MT* beta2pow, + MT* beta1pow_out, + MT* beta2pow_out, + MT beta1, + MT beta2) {} + HOSTDEVICE void UpdateBetaPow(size_t) const {} +}; + +template +struct LambParamUpateFunctor + : public LambBetaPowUpdateFunctor { + const MT* lr_; + const T* param_; + const MT* master_param_; + const MT* param_norm_; + const MT* trust_ratio_div_; + const MT* trust_ratio_div_norm_; + T* param_out_; + MT* master_param_out_; + + const bool* skip_update_; + + LambParamUpateFunctor(const MT* lr, + const T* param, + const MT* master_param, + const MT* param_norm, + const MT* trust_ratio_div, + const MT* trust_ratio_div_norm, + T* param_out, + MT* master_param_out, + const bool* skip_update) + : lr_(lr), + param_(param), + master_param_(master_param), + param_norm_(param_norm), + trust_ratio_div_(trust_ratio_div), + trust_ratio_div_norm_(trust_ratio_div_norm), + param_out_(param_out), + master_param_out_(master_param_out), + skip_update_(skip_update) {} + + inline HOSTDEVICE void operator()(size_t i) const { + if (skip_update_ && *skip_update_) return; + MT lr = *lr_; + MT pn = Eigen::numext::sqrt(*param_norm_); + MT tn = Eigen::numext::sqrt(*trust_ratio_div_norm_); + + MT r = (pn > static_cast(0) && tn > static_cast(0)) + ? pn / tn + : static_cast(1); + lr *= r; + MT p = IsMultiPrecision ? master_param_[i] : static_cast(param_[i]); + MT param_out = p - lr * trust_ratio_div_[i]; + param_out_[i] = static_cast(param_out); + if (IsMultiPrecision) { + master_param_out_[i] = param_out; + } + this->UpdateBetaPow(i); + } +}; + +} // namespace phi diff --git a/paddle/fluid/operators/tensor_to_string.h b/paddle/phi/kernels/funcs/tensor_to_string.h similarity index 66% rename from paddle/fluid/operators/tensor_to_string.h rename to paddle/phi/kernels/funcs/tensor_to_string.h index ef8a041fc5adcf74839b8769da09771ff68e0608..2f1fb574930f3e0f19c517bee3d93005d8ab9417 100644 --- a/paddle/fluid/operators/tensor_to_string.h +++ b/paddle/phi/kernels/funcs/tensor_to_string.h @@ -16,13 +16,14 @@ #include -#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/utils/string/string_helper.h" -namespace paddle { -namespace operators { +namespace phi { +namespace funcs { template static const std::vector &ToVector(const std::vector &vec) { @@ -30,22 +31,20 @@ static const std::vector &ToVector(const std::vector &vec) { } template -static std::vector ToVector(const T *x, - size_t n, - const platform::Place &place) { +static std::vector ToVector(const T *x, size_t n, const phi::Place &place) { #ifdef __NVCC__ - if (platform::is_gpu_place(place)) { + if (paddle::platform::is_gpu_place(place)) { using CopyT = typename std:: conditional::value, uint8_t, T>::type; std::vector cpu_x(n); auto *dev_ctx = static_cast( - platform::DeviceContextPool::Instance().Get(place)); - memory::Copy(platform::CPUPlace(), - cpu_x.data(), - place, - x, - n * sizeof(T), - dev_ctx->stream()); + phi::DeviceContextPool::Instance().Get(place)); + paddle::memory::Copy(phi::CPUPlace(), + cpu_x.data(), + place, + x, + n * sizeof(T), + dev_ctx->stream()); dev_ctx->Wait(); return std::vector(cpu_x.data(), cpu_x.data() + n); } @@ -54,7 +53,7 @@ static std::vector ToVector(const T *x, } template -static std::vector ToVector(const framework::Tensor &src) { +static std::vector ToVector(const DenseTensor &src) { if (!src.IsInitialized()) { return {}; } @@ -64,8 +63,8 @@ static std::vector ToVector(const framework::Tensor &src) { template static std::string FlattenToString(Args &&...args) { const auto &vec = ToVector(std::forward(args)...); - return "[" + string::join_strings(vec, ',') + "]"; + return "[" + paddle::string::join_strings(vec, ',') + "]"; } -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/lamb_kernel.cu b/paddle/phi/kernels/gpu/lamb_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..0b11b94fdb725f28192cd55ff17331ea214f6af3 --- /dev/null +++ b/paddle/phi/kernels/gpu/lamb_kernel.cu @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/lamb_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/lamb_kernel_impl.h" + +PD_REGISTER_KERNEL(lamb, + GPU, + ALL_LAYOUT, + phi::LambKernel, + phi::dtype::float16, + float, + double) { + kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); +} diff --git a/paddle/phi/kernels/impl/lamb_kernel_impl.h b/paddle/phi/kernels/impl/lamb_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..f3a76c6a7f1dd52745e509d1f31b77898b935d55 --- /dev/null +++ b/paddle/phi/kernels/impl/lamb_kernel_impl.h @@ -0,0 +1,296 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/lamb_functors.h" + +namespace phi { + +template +void ComputeImpl(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& grad, + const DenseTensor& lr, + const DenseTensor& mom1, + const DenseTensor& mom2, + const DenseTensor& beta1_pow, + const DenseTensor& beta2_pow, + const paddle::optional& master_param_opt, + const paddle::optional& skip_update_opt, + float weight_decay_f, + float beta1_f, + float beta2_f, + float epsilon_f, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* mom1_out, + DenseTensor* mom2_out, + DenseTensor* beta1_pow_out, + DenseTensor* beta2_pow_out, + DenseTensor* master_param_out); + +template +void LambKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& grad, + const DenseTensor& learning_rate, + const DenseTensor& moment1, + const DenseTensor& moment2, + const DenseTensor& beta1_pow, + const DenseTensor& beta2_pow, + const paddle::optional& master_param, + const paddle::optional& skip_update, + float weight_decay, + float beta1, + float beta2, + float epsilon, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* moment1_out, + DenseTensor* moment2_out, + DenseTensor* beta1_pow_out, + DenseTensor* beta2_pow_out, + DenseTensor* master_param_outs) { + using MT = typename phi::dtype::MPTypeTrait::Type; + if (multi_precision) { + ComputeImpl(dev_ctx, + param, + grad, + learning_rate, + moment1, + moment2, + beta1_pow, + beta2_pow, + master_param, + skip_update, + weight_decay, + beta1, + beta2, + epsilon, + multi_precision, + param_out, + moment1_out, + moment2_out, + beta1_pow_out, + beta2_pow_out, + master_param_outs); + } else { + ComputeImpl(dev_ctx, + param, + grad, + learning_rate, + moment1, + moment2, + beta1_pow, + beta2_pow, + master_param, + skip_update, + weight_decay, + beta1, + beta2, + epsilon, + multi_precision, + param_out, + moment1_out, + moment2_out, + beta1_pow_out, + beta2_pow_out, + master_param_outs); + } +} + +template +void ComputeImpl(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& grad, + const DenseTensor& lr, + const DenseTensor& mom1, + const DenseTensor& mom2, + const DenseTensor& beta1_pow, + const DenseTensor& beta2_pow, + const paddle::optional& master_param_opt, + const paddle::optional& skip_update_opt, + float weight_decay_f, + float beta1_f, + float beta2_f, + float epsilon_f, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* mom1_out, + DenseTensor* mom2_out, + DenseTensor* beta1_pow_out, + DenseTensor* beta2_pow_out, + DenseTensor* master_param_out) { + if (!IsMultiPrecision) { + constexpr auto kIsSameType = std::is_same::value; + PADDLE_ENFORCE_EQ( + kIsSameType, + true, + phi::errors::InvalidArgument( + "When multi_precision=False, T and MT must be the same type.")); + } + + const auto* master_param = + IsMultiPrecision ? master_param_opt.get_ptr() : nullptr; + const auto* skip_update = skip_update_opt.get_ptr(); + const bool* skip_update_flag = skip_update && skip_update->IsInitialized() + ? skip_update->data() + : nullptr; + if (skip_update_flag && + paddle::platform::is_cpu_place(skip_update->place()) && + (*skip_update_flag)) { + return; + } + + auto weight_decay = static_cast(weight_decay_f); + auto beta1 = static_cast(beta1_f); + auto beta2 = static_cast(beta2_f); + auto epsilon = static_cast(epsilon_f); + auto numel = param.numel(); + phi::funcs::ForRange for_range(dev_ctx, numel); + DenseTensor trust_ratio_div; + trust_ratio_div.Resize(param.dims()); + auto* trust_ratio_div_ptr = dev_ctx.template Alloc(&trust_ratio_div); + + const void* param_ptr = param.data(); + const void* master_param_ptr = master_param ? master_param->data() : nullptr; + void* param_out_ptr = dev_ctx.template Alloc(param_out); + void* master_param_out_ptr = + master_param_out ? dev_ctx.template Alloc(master_param_out) : nullptr; + // Update moments + bool should_update_beta_pow_later = false; + const MT *beta1_pow_ptr = nullptr, *beta2_pow_ptr = nullptr; + MT *beta1_pow_out_ptr = nullptr, *beta2_pow_out_ptr = nullptr; + VLOG(10) << "Beta1Pow place: " << beta1_pow.place() + << " , Beta2Pow place: " << beta2_pow.place(); + // Diff from here + + if (paddle::platform::is_gpu_place(dev_ctx.GetPlace()) && + beta1_pow.place() == phi::CPUPlace() && + beta2_pow.place() == phi::CPUPlace()) { + LambMomentREGUpdateFunctor moment_update_functor( + weight_decay, + beta1, + beta2, + epsilon, + *beta1_pow.template data(), + *beta2_pow.template data(), + mom1.template data(), + dev_ctx.template Alloc(mom1_out), + mom2.template data(), + dev_ctx.template Alloc(mom2_out), + grad.template data(), + static_cast(IsMultiPrecision ? master_param_ptr : param_ptr), + trust_ratio_div_ptr, + skip_update_flag); + for_range(moment_update_functor); + MT* beta1_pow_out_data = dev_ctx.template HostAlloc(beta1_pow_out); + beta1_pow_out_data[0] = beta1 * beta1_pow.template data()[0]; + MT* beta2_pow_out_data = dev_ctx.template HostAlloc(beta2_pow_out); + beta2_pow_out_data[0] = beta2 * beta2_pow.template data()[0]; + } else { + beta1_pow_ptr = beta1_pow.template data(); + beta2_pow_ptr = beta2_pow.template data(); + beta1_pow_out_ptr = dev_ctx.template Alloc(beta1_pow_out); + beta2_pow_out_ptr = dev_ctx.template Alloc(beta2_pow_out); + should_update_beta_pow_later = true; + LambMomentMENUpdateFunctor moment_update_functor( + weight_decay, + beta1, + beta2, + epsilon, + static_cast(beta1_pow_ptr), + static_cast(beta2_pow_ptr), + mom1.template data(), + dev_ctx.template Alloc(mom1_out), + mom2.template data(), + dev_ctx.template Alloc(mom2_out), + grad.template data(), + static_cast(IsMultiPrecision ? master_param_ptr : param_ptr), + trust_ratio_div_ptr, + skip_update_flag); + for_range(moment_update_functor); + } + + // Same from here + // Update parameter + // The code in the following part is exactly the same as that in + // paddle/phi/kernels/selected_rows/impl/lamb_kernel_impl.h Please modify it + // together + DenseTensor p_norm_t; + p_norm_t.Resize(phi::make_ddim({1})); + auto* p_norm_ptr = dev_ctx.template Alloc(&p_norm_t); + + DenseTensor trust_ratio_div_norm_t; + trust_ratio_div_norm_t.Resize(phi::make_ddim({1})); + auto* trust_ratio_div_norm_ptr = + dev_ctx.template Alloc(&trust_ratio_div_norm_t); + + // TODO(zengjinle): remove the following Eigen operations when + // *skip_update == true. + paddle::memory::Buffer buffer(dev_ctx.GetPlace()); + phi::funcs::SquaredL2Norm( + dev_ctx, + reinterpret_cast(IsMultiPrecision ? master_param_ptr + : param_ptr), + p_norm_ptr, + numel, + &buffer); + phi::funcs::SquaredL2Norm( + dev_ctx, trust_ratio_div_ptr, trust_ratio_div_norm_ptr, numel, &buffer); + + if (VLOG_IS_ON(1)) { + const auto& name = "Param"; + auto pn = phi::funcs::ToVector(p_norm_ptr, 1, dev_ctx.GetPlace()); + auto tn = + phi::funcs::ToVector(trust_ratio_div_norm_ptr, 1, dev_ctx.GetPlace()); + auto dtype = paddle::framework::DataTypeToString( + paddle::framework::DataTypeTrait::DataType()); + VLOG(1) << "Param " << dtype << " " << name << " pn = " << pn[0] + << " , tn = " << tn[0]; + } + +#define CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(__should_update_beta_pow) \ + do { \ + LambParamUpateFunctor \ + param_update_functor(lr.template data(), \ + static_cast(param_ptr), \ + static_cast(master_param_ptr), \ + p_norm_ptr, \ + trust_ratio_div_ptr, \ + trust_ratio_div_norm_ptr, \ + static_cast(param_out_ptr), \ + static_cast(master_param_out_ptr), \ + skip_update_flag); \ + if (__should_update_beta_pow) { \ + param_update_functor.SetBetaPows(beta1_pow_ptr, \ + beta2_pow_ptr, \ + beta1_pow_out_ptr, \ + beta2_pow_out_ptr, \ + beta1, \ + beta2); \ + } \ + for_range(param_update_functor); \ + } while (0) + + if (should_update_beta_pow_later) { + CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(true); + } else { + CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(false); + } + +#undef CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC +} + +} // namespace phi diff --git a/paddle/phi/kernels/lamb_kernel.h b/paddle/phi/kernels/lamb_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..f69948453d9b62c3a0f28fd841f527ae6dbd7bda --- /dev/null +++ b/paddle/phi/kernels/lamb_kernel.h @@ -0,0 +1,44 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void LambKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& grad, + const DenseTensor& learning_rate, + const DenseTensor& moment1, + const DenseTensor& moment2, + const DenseTensor& beta1_pow, + const DenseTensor& beta2_pow, + const paddle::optional& master_param, + const paddle::optional& skip_update, + float weight_decay, + float beta1, + float beta2, + float epsilon, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* moment1_out, + DenseTensor* moment2_out, + DenseTensor* beta1_pow_out, + DenseTensor* beta2_pow_out, + DenseTensor* master_param_outs); + +} // namespace phi diff --git a/paddle/phi/kernels/selected_rows/cpu/lamb_kernel.cc b/paddle/phi/kernels/selected_rows/cpu/lamb_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..e30c0cf9704776ff00c3a59167b8bbcd40e482d6 --- /dev/null +++ b/paddle/phi/kernels/selected_rows/cpu/lamb_kernel.cc @@ -0,0 +1,21 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/selected_rows/lamb_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/selected_rows/impl/lamb_kernel_impl.h" + +PD_REGISTER_KERNEL( + lamb_sr, CPU, ALL_LAYOUT, phi::sr::LambKernel, float, double) {} diff --git a/paddle/phi/kernels/selected_rows/gpu/lamb_kernel.cu b/paddle/phi/kernels/selected_rows/gpu/lamb_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..b76d116f7f63ff4c5ddc97fe0e35024f3b1fb9d0 --- /dev/null +++ b/paddle/phi/kernels/selected_rows/gpu/lamb_kernel.cu @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/selected_rows/lamb_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/selected_rows/impl/lamb_kernel_impl.h" + +PD_REGISTER_KERNEL(lamb_sr, + GPU, + ALL_LAYOUT, + phi::sr::LambKernel, + phi::dtype::float16, + float, + double) { + kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); +} diff --git a/paddle/phi/kernels/selected_rows/impl/lamb_kernel_impl.h b/paddle/phi/kernels/selected_rows/impl/lamb_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..5623d0dbdbe6967744124c338303f1e342d3679c --- /dev/null +++ b/paddle/phi/kernels/selected_rows/impl/lamb_kernel_impl.h @@ -0,0 +1,351 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/selected_rows.h" +#include "paddle/phi/kernels/funcs/lamb_functors.h" + +namespace phi { +namespace sr { + +template +void ComputeRowImpl(const Context& dev_ctx, + const DenseTensor& param, + const SelectedRows& grad, + const DenseTensor& lr, + const DenseTensor& mom1, + const DenseTensor& mom2, + const DenseTensor& beta1_pow, + const DenseTensor& beta2_pow, + const paddle::optional& master_param_opt, + const paddle::optional& skip_update_opt, + float weight_decay_f, + float beta1_f, + float beta2_f, + float epsilon_f, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* mom1_out, + DenseTensor* mom2_out, + DenseTensor* beta1_pow_out, + DenseTensor* beta2_pow_out, + DenseTensor* master_param_out); + +template +void LambKernel(const Context& dev_ctx, + const DenseTensor& param, + const SelectedRows& grad, + const DenseTensor& learning_rate, + const DenseTensor& moment1, + const DenseTensor& moment2, + const DenseTensor& beta1_pow, + const DenseTensor& beta2_pow, + const paddle::optional& master_param, + const paddle::optional& skip_update, + float weight_decay, + float beta1, + float beta2, + float epsilon, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* moment1_out, + DenseTensor* moment2_out, + DenseTensor* beta1_pow_out, + DenseTensor* beta2_pow_out, + DenseTensor* master_param_outs) { + using MT = typename phi::dtype::MPTypeTrait::Type; + if (multi_precision) { + ComputeRowImpl(dev_ctx, + param, + grad, + learning_rate, + moment1, + moment2, + beta1_pow, + beta2_pow, + master_param, + skip_update, + weight_decay, + beta1, + beta2, + epsilon, + multi_precision, + param_out, + moment1_out, + moment2_out, + beta1_pow_out, + beta2_pow_out, + master_param_outs); + } else { + ComputeRowImpl(dev_ctx, + param, + grad, + learning_rate, + moment1, + moment2, + beta1_pow, + beta2_pow, + master_param, + skip_update, + weight_decay, + beta1, + beta2, + epsilon, + multi_precision, + param_out, + moment1_out, + moment2_out, + beta1_pow_out, + beta2_pow_out, + master_param_outs); + } +} + +template +void ComputeRowImpl(const Context& dev_ctx, + const DenseTensor& param, + const SelectedRows& grad, + const DenseTensor& lr, + const DenseTensor& mom1, + const DenseTensor& mom2, + const DenseTensor& beta1_pow, + const DenseTensor& beta2_pow, + const paddle::optional& master_param_opt, + const paddle::optional& skip_update_opt, + float weight_decay_f, + float beta1_f, + float beta2_f, + float epsilon_f, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* mom1_out, + DenseTensor* mom2_out, + DenseTensor* beta1_pow_out, + DenseTensor* beta2_pow_out, + DenseTensor* master_param_out) { + if (!IsMultiPrecision) { + constexpr auto kIsSameType = std::is_same::value; + PADDLE_ENFORCE_EQ( + kIsSameType, + true, + phi::errors::InvalidArgument( + "When multi_precision=False, T and MT must be the same type.")); + } + + const auto* master_param = + IsMultiPrecision ? master_param_opt.get_ptr() : nullptr; + const auto* skip_update = skip_update_opt.get_ptr(); + const bool* skip_update_flag = skip_update && skip_update->IsInitialized() + ? skip_update->data() + : nullptr; + if (skip_update_flag && + paddle::platform::is_cpu_place(skip_update->place()) && + (*skip_update_flag)) { + return; + } + + auto weight_decay = static_cast(weight_decay_f); + auto beta1 = static_cast(beta1_f); + auto beta2 = static_cast(beta2_f); + auto epsilon = static_cast(epsilon_f); + auto numel = param.numel(); + phi::funcs::ForRange for_range(dev_ctx, numel); + DenseTensor trust_ratio_div; + trust_ratio_div.Resize(param.dims()); + /*auto trust_ratio_div = + ctx.AllocateTmpTensor(param.dims(), dev_ctx);*/ + auto* trust_ratio_div_ptr = dev_ctx.template Alloc(&trust_ratio_div); + + const void* param_ptr = param.data(); + const void* master_param_ptr = master_param ? master_param->data() : nullptr; + void* param_out_ptr = dev_ctx.template Alloc(param_out); + void* master_param_out_ptr = + master_param_out ? dev_ctx.template Alloc(master_param_out) : nullptr; + // Update moments + bool should_update_beta_pow_later = false; + const MT *beta1_pow_ptr = nullptr, *beta2_pow_ptr = nullptr; + MT *beta1_pow_out_ptr = nullptr, *beta2_pow_out_ptr = nullptr; + VLOG(10) << "Beta1Pow place: " << beta1_pow.place() + << " , Beta2Pow place: " << beta2_pow.place(); + // Diff from here + PADDLE_ENFORCE_EQ( + IsMultiPrecision, + false, + phi::errors::Unimplemented("SelectedRows gradient is not supported when " + "multi_precision=True.")); + constexpr bool kIsSameType = std::is_same::value; + PADDLE_ENFORCE_EQ( + kIsSameType, + true, + phi::errors::Unimplemented("SelectedRows gradient is not supported when " + "multi_precision=True.")); + if (grad.rows().size() == 0) { + VLOG(3) << "grad row size is 0!!"; + return; + } + + std::vector cpu_rows(grad.rows().begin(), grad.rows().end()); + bool is_strict_sorted = true; + for (size_t i = 1; i < cpu_rows.size(); ++i) { + if (cpu_rows[i - 1] >= cpu_rows[i]) { + is_strict_sorted = false; + break; + } + } + + phi::SelectedRows tmp_grad_merge; + const phi::SelectedRows* grad_merge_ptr; + if (is_strict_sorted) { + grad_merge_ptr = &grad; + } else { + // merge duplicated rows if any. + // The rows of grad_merge have been sorted inside MergeAdd functor + paddle::operators::math::scatter::MergeAdd merge_func; + merge_func(dev_ctx, grad, &tmp_grad_merge, true); + grad_merge_ptr = &tmp_grad_merge; + } + + auto& grad_merge = *grad_merge_ptr; + auto& grad_tensor = grad_merge.value(); + const T* grad_data = grad_tensor.template data(); + auto* grad_merge_rows = &grad_merge.rows(); + paddle::framework::MixVector mixv_grad_merge_rows(grad_merge_rows); + const int64_t* rows = mixv_grad_merge_rows.Data(dev_ctx.GetPlace()); + auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); + if (paddle::platform::is_gpu_place(dev_ctx.GetPlace()) && + beta1_pow.place() == phi::CPUPlace() && + beta2_pow.place() == phi::CPUPlace()) { + SparseLambMomentREGUpdateFunctor moment_update_functor( + static_cast(weight_decay), + static_cast(beta1), + static_cast(beta2), + static_cast(epsilon), + *beta1_pow.template data(), + *beta2_pow.template data(), + mom1.template data(), + dev_ctx.template Alloc(mom1_out), + mom2.template data(), + dev_ctx.template Alloc(mom2_out), + grad_data, + param.template data(), + trust_ratio_div.template data(), + rows, + row_numel, + grad_merge.rows().size(), + skip_update_flag); + for_range(moment_update_functor); + T* beta1_pow_out_data = dev_ctx.template HostAlloc(beta1_pow_out); + beta1_pow_out_data[0] = + static_cast(beta1) * beta1_pow.template data()[0]; + T* beta2_pow_out_data = dev_ctx.template HostAlloc(beta2_pow_out); + beta2_pow_out_data[0] = + static_cast(beta2) * beta2_pow.template data()[0]; + } else { + beta1_pow_ptr = beta1_pow.template data(); + beta2_pow_ptr = beta2_pow.template data(); + beta1_pow_out_ptr = dev_ctx.template Alloc(beta1_pow_out); + beta2_pow_out_ptr = dev_ctx.template Alloc(beta2_pow_out); + should_update_beta_pow_later = true; + SparseLambMomentMENUpdateFunctor moment_update_functor( + static_cast(weight_decay), + static_cast(beta1), + static_cast(beta2), + static_cast(epsilon), + reinterpret_cast(beta1_pow_ptr), + reinterpret_cast(beta2_pow_ptr), + mom1.template data(), + dev_ctx.template Alloc(mom1_out), + mom2.template data(), + dev_ctx.template Alloc(mom2_out), + grad_data, + param.template data(), + trust_ratio_div.template data(), + rows, + row_numel, + grad_merge.rows().size(), + skip_update_flag); + for_range(moment_update_functor); + } + // Same from here + // Update parameter + // The code in the following part is exactly the same as that in + // paddle/phi/kernels/impl/lamb_kernel_impl.h Please modify it together + DenseTensor p_norm_t; + p_norm_t.Resize(phi::make_ddim({1})); + auto* p_norm_ptr = dev_ctx.template Alloc(&p_norm_t); + + DenseTensor trust_ratio_div_norm_t; + trust_ratio_div_norm_t.Resize(phi::make_ddim({1})); + auto* trust_ratio_div_norm_ptr = + dev_ctx.template Alloc(&trust_ratio_div_norm_t); + + // TODO(zengjinle): remove the following Eigen operations when + // *skip_update == true. + paddle::memory::Buffer buffer(dev_ctx.GetPlace()); + phi::funcs::SquaredL2Norm( + dev_ctx, + reinterpret_cast(IsMultiPrecision ? master_param_ptr + : param_ptr), + p_norm_ptr, + numel, + &buffer); + phi::funcs::SquaredL2Norm( + dev_ctx, trust_ratio_div_ptr, trust_ratio_div_norm_ptr, numel, &buffer); + + if (VLOG_IS_ON(1)) { + const auto& name = "Param"; + auto pn = phi::funcs::ToVector(p_norm_ptr, 1, dev_ctx.GetPlace()); + auto tn = + phi::funcs::ToVector(trust_ratio_div_norm_ptr, 1, dev_ctx.GetPlace()); + auto dtype = paddle::framework::DataTypeToString( + paddle::framework::DataTypeTrait::DataType()); + VLOG(1) << "Param " << dtype << " " << name << " pn = " << pn[0] + << " , tn = " << tn[0]; + } + +#define CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(__should_update_beta_pow) \ + do { \ + LambParamUpateFunctor \ + param_update_functor(lr.template data(), \ + static_cast(param_ptr), \ + static_cast(master_param_ptr), \ + p_norm_ptr, \ + trust_ratio_div_ptr, \ + trust_ratio_div_norm_ptr, \ + static_cast(param_out_ptr), \ + static_cast(master_param_out_ptr), \ + skip_update_flag); \ + if (__should_update_beta_pow) { \ + param_update_functor.SetBetaPows(beta1_pow_ptr, \ + beta2_pow_ptr, \ + beta1_pow_out_ptr, \ + beta2_pow_out_ptr, \ + beta1, \ + beta2); \ + } \ + for_range(param_update_functor); \ + } while (0) + + if (should_update_beta_pow_later) { + CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(true); + } else { + CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(false); + } + +#undef CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC +} + +} // namespace sr +} // namespace phi diff --git a/paddle/phi/kernels/selected_rows/lamb_kernel.h b/paddle/phi/kernels/selected_rows/lamb_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..306f1ca0ff79bce523783d1459cd9803252f0b69 --- /dev/null +++ b/paddle/phi/kernels/selected_rows/lamb_kernel.h @@ -0,0 +1,46 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/selected_rows.h" + +namespace phi { +namespace sr { + +template +void LambKernel(const Context& dev_ctx, + const DenseTensor& param, + const SelectedRows& grad, + const DenseTensor& learning_rate, + const DenseTensor& moment1, + const DenseTensor& moment2, + const DenseTensor& beta1_pow, + const DenseTensor& beta2_pow, + const paddle::optional& master_param, + const paddle::optional& skip_update, + float weight_decay, + float beta1, + float beta2, + float epsilon, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* moment1_out, + DenseTensor* moment2_out, + DenseTensor* beta1_pow_out, + DenseTensor* beta2_pow_out, + DenseTensor* master_param_outs); + +} // namespace sr +} // namespace phi diff --git a/paddle/phi/ops/compat/lamb_sig.cc b/paddle/phi/ops/compat/lamb_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..a59ae6155c1832114c4e06161312d755c66909cd --- /dev/null +++ b/paddle/phi/ops/compat/lamb_sig.cc @@ -0,0 +1,62 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include "paddle/phi/core/compat/op_utils.h" +#include "paddle/utils/small_vector.h" + +namespace phi { + +KernelSignature LambOpArgumentMapping(const ArgumentMappingContext& ctx) { + paddle::small_vector in_names = {"Param", + "Grad", + "LearningRate", + "Moment1", + "Moment2", + "Beta1Pow", + "Beta2Pow", + "MasterParam", + "SkipUpdate"}; + paddle::small_vector out_names = {"ParamOut", + "Moment1Out", + "Moment2Out", + "Beta1PowOut", + "Beta2PowOut", + "MasterParamOut"}; + paddle::small_vector attr_names; + + attr_names.emplace_back("weight_decay"); + attr_names.emplace_back("beta1"); + attr_names.emplace_back("beta2"); + attr_names.emplace_back("epsilon"); + attr_names.emplace_back("multi_precision"); + + if (ctx.IsSelectedRowsInput("Grad")) { + return KernelSignature("lamb_sr", + std::move(in_names), + std::move(attr_names), + std::move(out_names)); + } else if (ctx.IsDenseTensorInput("Grad")) { + return KernelSignature("lamb", + std::move(in_names), + std::move(attr_names), + std::move(out_names)); + } else { + return KernelSignature("unregistered", {}, {}, {}); + } +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(lamb, phi::LambOpArgumentMapping); diff --git a/python/paddle/optimizer/lamb.py b/python/paddle/optimizer/lamb.py index 29233e6ced0a2d32d1eaad1083277f03679da01b..5a5f52bb3ef3d4dc606087ee8e5a3e0f26b6e4d7 100644 --- a/python/paddle/optimizer/lamb.py +++ b/python/paddle/optimizer/lamb.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ from ..fluid import unique_name from ..fluid.layer_helper import LayerHelper from paddle import _C_ops from paddle.fluid.executor import global_scope +import paddle __all__ = [] @@ -266,6 +267,13 @@ class Lamb(Optimizer): master_weight = None found_inf = self._get_auxiliary_var('found_inf') + if framework.in_dygraph_mode(): + _C_ops.final_state_lamb_(param_and_grad[0], param_and_grad[1], lr, + moment1, moment2, beta1_pow_acc, + beta2_pow_acc, master_weight, found_inf, + weight_decay, self._beta1, self._beta2, + self._epsilon, find_master) + return None if framework._non_static_mode(): _C_ops.lamb(param_and_grad[0], param_and_grad[1], lr, moment1, moment2, beta1_pow_acc, beta2_pow_acc, master_weight,