未验证 提交 4a7aa7c3 编写于 作者: T Thomas Young 提交者: GitHub

move lamb_op to phi (#44899)

上级 8537edaa
...@@ -88,6 +88,8 @@ no_amp_list = [ ...@@ -88,6 +88,8 @@ no_amp_list = [
'rmsprop', 'rmsprop',
'sgd_', 'sgd_',
'sgd', 'sgd',
'lamb_',
'lamb',
'assign_value_', 'assign_value_',
'sparse_momentum_', 'sparse_momentum_',
'sparse_momentum', 'sparse_momentum',
......
...@@ -15,15 +15,18 @@ ...@@ -15,15 +15,18 @@
#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.h" #include "paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.h"
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/optimizers/cast_with_ptr.h" #include "paddle/fluid/operators/optimizers/cast_with_ptr.h"
#include "paddle/fluid/operators/tensor_to_string.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/kernels/funcs/algorithm.h" #include "paddle/phi/kernels/funcs/algorithm.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/tensor_to_string.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using phi::funcs::FlattenToString;
using phi::funcs::ToVector;
struct ParamGradInfo { struct ParamGradInfo {
framework::Tensor *param_t{nullptr}; framework::Tensor *param_t{nullptr};
framework::Tensor *grad_t{nullptr}; framework::Tensor *grad_t{nullptr};
......
...@@ -19,12 +19,12 @@ ...@@ -19,12 +19,12 @@
#include "paddle/fluid/operators/optimizers/cast_with_ptr.h" #include "paddle/fluid/operators/optimizers/cast_with_ptr.h"
#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h" #include "paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h"
#include "paddle/fluid/operators/optimizers/multi_tensor_apply.h" #include "paddle/fluid/operators/optimizers/multi_tensor_apply.h"
#include "paddle/fluid/operators/tensor_to_string.h"
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/core/utils/data_type.h" #include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/tensor_to_string.h"
#ifdef __NVCC__ #ifdef __NVCC__
#include "cub/cub.cuh" #include "cub/cub.cuh"
...@@ -43,6 +43,8 @@ namespace operators { ...@@ -43,6 +43,8 @@ namespace operators {
template <typename T> template <typename T>
using MasterT = typename details::MPTypeTrait<T>::Type; using MasterT = typename details::MPTypeTrait<T>::Type;
using phi::funcs::FlattenToString;
using phi::funcs::ToVector;
template <typename T> template <typename T>
static void FillZeroWithPtr(T *x, size_t n, gpuStream_t stream) { static void FillZeroWithPtr(T *x, size_t n, gpuStream_t stream) {
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -12,11 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,11 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/optimizers/lamb_op.h"
#include <string> #include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/kernels/lamb_kernel.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -25,125 +29,6 @@ class LambOp : public framework::OperatorWithKernel { ...@@ -25,125 +29,6 @@ class LambOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("Param"),
true,
platform::errors::NotFound(
"Input(Param) of LambOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Grad"),
true,
platform::errors::NotFound(
"Input(Grad) of LambOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Moment1"),
true,
platform::errors::NotFound(
"Input(Moment1) of LambOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Moment2"),
true,
platform::errors::NotFound(
"Input(Moment2) of LambOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("LearningRate"),
true,
platform::errors::NotFound(
"Input(LearningRate) of LambOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Beta1Pow"),
true,
platform::errors::NotFound(
"Input(Beta1Pow) of LambOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Beta2Pow"),
true,
platform::errors::NotFound(
"Input(Beta2Pow) of LambOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("ParamOut"),
true,
platform::errors::NotFound(
"Output(ParamOut) of LambOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Moment1Out"),
true,
platform::errors::NotFound(
"Output(Moment1Out) of LambOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Moment2Out"),
true,
platform::errors::NotFound(
"Output(Moment2Out) of LambOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Beta1PowOut"),
true,
platform::errors::NotFound(
"Output(Beta1PowOut) of LambOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Beta2PowOut"),
true,
platform::errors::NotFound(
"Output(Beta2PowOut) of LambOp should not be null."));
auto lr_dims = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_NE(
phi::product(lr_dims),
0,
platform::errors::InvalidArgument(
"The number of LearningRate shall not be 0, but received %d. Maybe "
"the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function.",
phi::product(lr_dims)));
PADDLE_ENFORCE_EQ(
phi::product(lr_dims),
1,
platform::errors::InvalidArgument(
"Learning rate should have 1 dimension, but received %d.",
phi::product(lr_dims)));
auto beta1_pow_dims = ctx->GetInputDim("Beta1Pow");
PADDLE_ENFORCE_GE(phi::product(beta1_pow_dims),
1,
platform::errors::InvalidArgument(
"The size of Beta1 power accumulator should be "
"greater than 0, but received %d.",
phi::product(beta1_pow_dims)));
auto beta2_pow_dims = ctx->GetInputDim("Beta2Pow");
PADDLE_ENFORCE_GE(phi::product(beta2_pow_dims),
1,
platform::errors::InvalidArgument(
"The size of Beta2 power accumulator should be "
"greater than 0, but received %d.",
phi::product(beta2_pow_dims)));
auto param_dims = ctx->GetInputDim("Param");
if (ctx->GetInputsVarType("Grad")[0] ==
framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(
param_dims,
ctx->GetInputDim("Grad"),
platform::errors::InvalidArgument(
"Param and Grad input of LambOp should have same dimension. But "
"received Param dims: [%s], Grad dims: [%s].",
param_dims,
ctx->GetInputDim("Grad")));
}
PADDLE_ENFORCE_EQ(
param_dims,
ctx->GetInputDim("Moment1"),
platform::errors::InvalidArgument(
"Param and Moment1 input of LambOp should have same dimension. But "
"received Param dims: [%s], Moment1 dims: [%s].",
param_dims,
ctx->GetInputDim("Moment1")));
PADDLE_ENFORCE_EQ(
param_dims,
ctx->GetInputDim("Moment2"),
platform::errors::InvalidArgument(
"Param and Moment2 input of LambOp should have same dimension. But "
"received Param dims: [%s], Moment2 dims: [%s].",
param_dims,
ctx->GetInputDim("Moment2")));
ctx->SetOutputDim("ParamOut", param_dims);
ctx->SetOutputDim("Moment1Out", param_dims);
ctx->SetOutputDim("Moment2Out", param_dims);
ctx->SetOutputDim("Beta1PowOut", beta1_pow_dims);
ctx->SetOutputDim("Beta2PowOut", beta2_pow_dims);
}
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const { const framework::ExecutionContext &ctx) const {
auto input_data_type = auto input_data_type =
...@@ -246,10 +131,16 @@ learning rate, $\lambda$ the weight decay rate. ...@@ -246,10 +131,16 @@ learning rate, $\lambda$ the weight decay rate.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(lamb, ops::LambOp, ops::LambOpMaker); DECLARE_INFER_SHAPE_FUNCTOR(lamb,
REGISTER_OP_CPU_KERNEL(lamb, LambInferMetaFunctor,
ops::LambOpKernel<phi::CPUContext, float>, PD_INFER_META(phi::LambInferMeta));
ops::LambOpKernel<phi::CPUContext, double>); REGISTER_OPERATOR(
lamb,
ops::LambOp,
ops::LambOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
LambInferMetaFunctor);
/* ========================== register checkpoint ===========================*/ /* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(lamb).AddCheckpoint( REGISTER_OP_VERSION(lamb).AddCheckpoint(
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/optimizers/lamb_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
lamb,
ops::LambOpKernel<phi::GPUContext, paddle::platform::float16>,
ops::LambOpKernel<phi::GPUContext, float>,
ops::LambOpKernel<phi::GPUContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/operators/optimizers/lamb_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/device_wrapper.h" #include "paddle/fluid/platform/device/device_wrapper.h"
namespace paddle { namespace paddle {
......
...@@ -1327,6 +1327,18 @@ ...@@ -1327,6 +1327,18 @@
optional : prior_dist optional : prior_dist
backward : label_smooth_grad backward : label_smooth_grad
- api : lamb_
args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, float weight_decay, float beta1, float beta2, float epsilon, bool multi_precision)
output : Tensor(param_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(master_param_outs)
infer_meta :
func : LambInferMeta
kernel :
func : lamb {dense, dense, dense, dense, dense, dense, dense, dense, dense -> dense, dense, dense, dense, dense, dense},
lamb_sr {dense, selected_rows, dense, dense, dense, dense, dense, dense, dense -> dense, dense, dense, dense, dense, dense}
data_type : param
optional : master_param, skip_update
inplace : (param -> param_out), (moment1 -> moment1_out), (moment2 -> moment2_out), (beta1_pow -> beta1_pow_out), (beta2_pow -> beta2_pow_out), (master_param -> master_param_outs)
- api : layer_norm - api : layer_norm
args : (Tensor x, Tensor scale, Tensor bias, float epsilon, int begin_norm_axis, bool is_test) args : (Tensor x, Tensor scale, Tensor bias, float epsilon, int begin_norm_axis, bool is_test)
output : Tensor(out), Tensor(mean), Tensor(variance) output : Tensor(out), Tensor(mean), Tensor(variance)
......
...@@ -1642,6 +1642,105 @@ void InterpolateInferMeta( ...@@ -1642,6 +1642,105 @@ void InterpolateInferMeta(
} }
} }
void LambInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& learning_rate,
const MetaTensor& moment1,
const MetaTensor& moment2,
const MetaTensor& beta1_pow,
const MetaTensor& beta2_pow,
const MetaTensor& master_param,
const MetaTensor& skip_update,
float weight_decay,
float beta1,
float beta2,
float epsilon,
bool multi_precision,
MetaTensor* param_out,
MetaTensor* moment1_out,
MetaTensor* moment2_out,
MetaTensor* beta1_pow_out,
MetaTensor* beta2_pow_out,
MetaTensor* master_param_outs) {
auto lr_dims = learning_rate.dims();
PADDLE_ENFORCE_NE(
phi::product(lr_dims),
0,
phi::errors::InvalidArgument(
"The number of LearningRate shall not be 0, but received %d. Maybe "
"the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function.",
phi::product(lr_dims)));
PADDLE_ENFORCE_EQ(
phi::product(lr_dims),
1,
phi::errors::InvalidArgument(
"Learning rate should have 1 dimension, but received %d.",
phi::product(lr_dims)));
auto beta1_pow_dims = beta1_pow.dims();
PADDLE_ENFORCE_GE(phi::product(beta1_pow_dims),
1,
phi::errors::InvalidArgument(
"The size of Beta1 power accumulator should be "
"greater than 0, but received %d.",
phi::product(beta1_pow_dims)));
auto beta2_pow_dims = beta2_pow.dims();
PADDLE_ENFORCE_GE(phi::product(beta2_pow_dims),
1,
phi::errors::InvalidArgument(
"The size of Beta2 power accumulator should be "
"greater than 0, but received %d.",
phi::product(beta2_pow_dims)));
auto param_dims = param.dims();
PADDLE_ENFORCE_EQ(
param_dims,
moment1.dims(),
phi::errors::InvalidArgument(
"Param and Moment1 input of LambOp should have same dimension. But "
"received Param dims: [%s], Moment1 dims: [%s].",
param_dims,
moment1.dims()));
PADDLE_ENFORCE_EQ(
param_dims,
moment2.dims(),
errors::InvalidArgument(
"Param and Moment2 input of AdamOp should have same dimension. But "
"received Param dims: [%s], Moment2 dims: [%s].",
param_dims,
moment2.dims()));
PADDLE_ENFORCE_NOT_NULL(
param_out, errors::NotFound("The output param_out can not be nullptr"));
PADDLE_ENFORCE_NOT_NULL(
moment1_out,
errors::NotFound("The output moment1_out can not be nullptr"));
PADDLE_ENFORCE_NOT_NULL(
moment2_out,
errors::NotFound("The output moment2_out can not be nullptr"));
PADDLE_ENFORCE_NOT_NULL(
beta1_pow_out,
errors::NotFound("The output beta1_pow_out can not be nullptr"));
PADDLE_ENFORCE_NOT_NULL(
beta2_pow_out,
errors::NotFound("The output beta2_pow_out can not be nullptr"));
param_out->set_dims(param_dims);
param_out->set_dtype(param.dtype());
moment1_out->set_dims(param_dims);
moment1_out->set_dtype(moment1.dtype());
moment2_out->set_dims(param_dims);
moment2_out->set_dtype(moment2.dtype());
beta1_pow_out->set_dims(beta1_pow_dims);
beta1_pow_out->set_dtype(beta1_pow.dtype());
beta2_pow_out->set_dims(beta2_pow_dims);
beta2_pow_out->set_dtype(beta2_pow.dtype());
}
void LogspaceInferMeta(const MetaTensor& start, void LogspaceInferMeta(const MetaTensor& start,
const MetaTensor& stop, const MetaTensor& stop,
const MetaTensor& number, const MetaTensor& number,
......
...@@ -269,6 +269,27 @@ void InterpolateInferMeta( ...@@ -269,6 +269,27 @@ void InterpolateInferMeta(
MetaTensor* output, MetaTensor* output,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void LambInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& learning_rate,
const MetaTensor& moment1,
const MetaTensor& moment2,
const MetaTensor& beta1_pow,
const MetaTensor& beta2_pow,
const MetaTensor& master_param,
const MetaTensor& skip_update,
float weight_decay,
float beta1,
float beta2,
float epsilon,
bool multi_precision,
MetaTensor* param_out,
MetaTensor* moment1_out,
MetaTensor* moment2_out,
MetaTensor* beta1_pow_out,
MetaTensor* beta2_pow_out,
MetaTensor* master_param_outs);
void LogspaceInferMeta(const MetaTensor& start, void LogspaceInferMeta(const MetaTensor& start,
const MetaTensor& stop, const MetaTensor& stop,
const MetaTensor& number, const MetaTensor& number,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/lamb_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/lamb_kernel_impl.h"
PD_REGISTER_KERNEL(lamb, CPU, ALL_LAYOUT, phi::LambKernel, float, double) {}
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -18,25 +18,26 @@ limitations under the License. */ ...@@ -18,25 +18,26 @@ limitations under the License. */
#include <Eigen/Dense> #include <Eigen/Dense>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/buffer.h" #include "paddle/fluid/memory/buffer.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/operators/tensor_to_string.h" #include "paddle/phi/common/amp_type_traits.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/algorithm.h" #include "paddle/phi/kernels/funcs/algorithm.h"
#include "paddle/phi/kernels/funcs/eigen/extensions.h" #include "paddle/phi/kernels/funcs/eigen/extensions.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/squared_l2_norm.h" #include "paddle/phi/kernels/funcs/squared_l2_norm.h"
#include "paddle/phi/kernels/funcs/tensor_to_string.h"
namespace paddle { namespace phi {
namespace operators {
namespace scatter = paddle::operators::math::scatter; namespace scatter = paddle::operators::math::scatter;
template <typename T, bool IsMultiPrecision> template <typename T, bool IsMultiPrecision>
struct LambMomentREGUpdateFunctor { struct LambMomentREGUpdateFunctor {
using MT = typename std::conditional<IsMultiPrecision, using MT =
typename details::MPTypeTrait<T>::Type, typename std::conditional<IsMultiPrecision,
typename phi::dtype::MPTypeTrait<T>::Type,
T>::type; T>::type;
MT weight_decay_; MT weight_decay_;
...@@ -112,8 +113,9 @@ struct LambMomentREGUpdateFunctor { ...@@ -112,8 +113,9 @@ struct LambMomentREGUpdateFunctor {
template <typename T, bool IsMultiPrecision> template <typename T, bool IsMultiPrecision>
struct LambMomentMENUpdateFunctor { struct LambMomentMENUpdateFunctor {
using MT = typename std::conditional<IsMultiPrecision, using MT =
typename details::MPTypeTrait<T>::Type, typename std::conditional<IsMultiPrecision,
typename phi::dtype::MPTypeTrait<T>::Type,
T>::type; T>::type;
MT weight_decay_; MT weight_decay_;
...@@ -458,356 +460,4 @@ struct LambParamUpateFunctor ...@@ -458,356 +460,4 @@ struct LambParamUpateFunctor
} }
}; };
template <typename DeviceContext, typename T> } // namespace phi
class LambOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using MT = typename details::MPTypeTrait<T>::Type;
bool multi_precision = ctx.Attr<bool>("multi_precision");
if (multi_precision) {
ComputeImpl<MT, true>(ctx);
} else {
ComputeImpl<T, false>(ctx);
}
}
private:
template <typename MT, bool IsMultiPrecision>
void ComputeImpl(const framework::ExecutionContext& ctx) const {
if (!IsMultiPrecision) {
constexpr auto kIsSameType = std::is_same<T, MT>::value;
PADDLE_ENFORCE_EQ(
kIsSameType,
true,
platform::errors::InvalidArgument(
"When multi_precision=False, T and MT must be the same type."));
}
const auto* skip_update = ctx.Input<framework::LoDTensor>("SkipUpdate");
const bool* skip_update_flag = skip_update && skip_update->IsInitialized()
? skip_update->data<bool>()
: nullptr;
if (skip_update_flag && platform::is_cpu_place(skip_update->place()) &&
(*skip_update_flag)) {
return;
}
auto weight_decay = static_cast<MT>(ctx.Attr<float>("weight_decay"));
auto beta1 = static_cast<MT>(ctx.Attr<float>("beta1"));
auto beta2 = static_cast<MT>(ctx.Attr<float>("beta2"));
auto epsilon = static_cast<MT>(ctx.Attr<float>("epsilon"));
const auto& param = GET_DATA_SAFELY(
ctx.Input<framework::LoDTensor>("Param"), "Input", "Param", "Lamb");
const auto* grad_var = ctx.InputVar("Grad");
const auto& mom1 = GET_DATA_SAFELY(
ctx.Input<framework::LoDTensor>("Moment1"), "Input", "Moment1", "Lamb");
const auto& mom2 = GET_DATA_SAFELY(
ctx.Input<framework::LoDTensor>("Moment2"), "Input", "Moment2", "Lamb");
const auto& lr =
GET_DATA_SAFELY(ctx.Input<framework::LoDTensor>("LearningRate"),
"Input",
"LearningRate",
"Lamb");
const auto& beta1_pow =
GET_DATA_SAFELY(ctx.Input<framework::LoDTensor>("Beta1Pow"),
"Input",
"Beta1Pow",
"Lamb");
const auto& beta2_pow =
GET_DATA_SAFELY(ctx.Input<framework::LoDTensor>("Beta2Pow"),
"Input",
"Beta2Pow",
"Lamb");
auto& param_out =
GET_DATA_SAFELY(ctx.Output<framework::LoDTensor>("ParamOut"),
"Output",
"ParamOut",
"Lamb");
auto& mom1_out =
GET_DATA_SAFELY(ctx.Output<framework::LoDTensor>("Moment1Out"),
"Output",
"Moment1Out",
"Lamb");
auto& mom2_out =
GET_DATA_SAFELY(ctx.Output<framework::LoDTensor>("Moment2Out"),
"Output",
"Moment2Out",
"Lamb");
auto& beta1_pow_out =
GET_DATA_SAFELY(ctx.Output<framework::LoDTensor>("Beta1PowOut"),
"Output",
"Beta1PowOut",
"Lamb");
auto& beta2_pow_out =
GET_DATA_SAFELY(ctx.Output<framework::LoDTensor>("Beta2PowOut"),
"Output",
"Beta2PowOut",
"Lamb");
const auto* master_param =
IsMultiPrecision ? ctx.Input<framework::LoDTensor>("MasterParam")
: nullptr;
auto* master_param_out =
IsMultiPrecision ? ctx.Output<framework::LoDTensor>("MasterParamOut")
: nullptr;
if (IsMultiPrecision) {
PADDLE_ENFORCE_NOT_NULL(master_param,
platform::errors::InvalidArgument(
"Input(MasterParam) must be provided when "
"multi_precision=True."));
PADDLE_ENFORCE_NOT_NULL(master_param_out,
platform::errors::InvalidArgument(
"Output(MasterParamOut) must be provided "
"when multi_precision=True."));
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto numel = param.numel();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
auto trust_ratio_div =
ctx.AllocateTmpTensor<MT, DeviceContext>(param.dims(), dev_ctx);
auto* trust_ratio_div_ptr = trust_ratio_div.template data<MT>();
const void* param_ptr = param.data();
const void* master_param_ptr =
master_param ? master_param->data() : nullptr;
void* param_out_ptr = param_out.template mutable_data<T>(ctx.GetPlace());
void* master_param_out_ptr =
master_param_out
? master_param_out->template mutable_data<MT>(ctx.GetPlace())
: nullptr;
// Update moments
bool should_update_beta_pow_later = false;
const MT *beta1_pow_ptr = nullptr, *beta2_pow_ptr = nullptr;
MT *beta1_pow_out_ptr = nullptr, *beta2_pow_out_ptr = nullptr;
VLOG(10) << "Beta1Pow place: " << beta1_pow.place()
<< " , Beta2Pow place: " << beta2_pow.place();
if (grad_var->IsType<framework::LoDTensor>()) {
auto& grad = grad_var->Get<framework::LoDTensor>();
if (platform::is_gpu_place(ctx.GetPlace()) &&
beta1_pow.place() == platform::CPUPlace() &&
beta2_pow.place() == platform::CPUPlace()) {
LambMomentREGUpdateFunctor<T, IsMultiPrecision> moment_update_functor(
weight_decay,
beta1,
beta2,
epsilon,
*beta1_pow.template data<MT>(),
*beta2_pow.template data<MT>(),
mom1.template data<MT>(),
mom1_out.template mutable_data<MT>(ctx.GetPlace()),
mom2.template data<MT>(),
mom2_out.template mutable_data<MT>(ctx.GetPlace()),
grad.template data<T>(),
static_cast<const MT*>(IsMultiPrecision ? master_param_ptr
: param_ptr),
trust_ratio_div_ptr,
skip_update_flag);
for_range(moment_update_functor);
beta1_pow_out.template mutable_data<MT>(platform::CPUPlace())[0] =
beta1 * beta1_pow.template data<MT>()[0];
beta2_pow_out.template mutable_data<MT>(platform::CPUPlace())[0] =
beta2 * beta2_pow.template data<MT>()[0];
} else {
beta1_pow_ptr = beta1_pow.template data<MT>();
beta2_pow_ptr = beta2_pow.template data<MT>();
beta1_pow_out_ptr =
beta1_pow_out.template mutable_data<MT>(ctx.GetPlace());
beta2_pow_out_ptr =
beta2_pow_out.template mutable_data<MT>(ctx.GetPlace());
should_update_beta_pow_later = true;
LambMomentMENUpdateFunctor<T, IsMultiPrecision> moment_update_functor(
weight_decay,
beta1,
beta2,
epsilon,
static_cast<const MT*>(beta1_pow_ptr),
static_cast<const MT*>(beta2_pow_ptr),
mom1.template data<MT>(),
mom1_out.template mutable_data<MT>(ctx.GetPlace()),
mom2.template data<MT>(),
mom2_out.template mutable_data<MT>(ctx.GetPlace()),
grad.template data<T>(),
static_cast<const MT*>(IsMultiPrecision ? master_param_ptr
: param_ptr),
trust_ratio_div_ptr,
skip_update_flag);
for_range(moment_update_functor);
}
} else if (grad_var->IsType<phi::SelectedRows>()) {
PADDLE_ENFORCE_EQ(IsMultiPrecision,
false,
platform::errors::Unimplemented(
"SelectedRows gradient is not supported when "
"multi_precision=True."));
constexpr bool kIsSameType = std::is_same<T, MT>::value;
PADDLE_ENFORCE_EQ(kIsSameType,
true,
platform::errors::Unimplemented(
"SelectedRows gradient is not supported when "
"multi_precision=True."));
auto& grad = GET_DATA_SAFELY(
ctx.Input<phi::SelectedRows>("Grad"), "Input", "Grad", "Lamb");
if (grad.rows().size() == 0) {
VLOG(3) << "grad row size is 0!!";
return;
}
std::vector<int64_t> cpu_rows(grad.rows().begin(), grad.rows().end());
bool is_strict_sorted = true;
for (size_t i = 1; i < cpu_rows.size(); ++i) {
if (cpu_rows[i - 1] >= cpu_rows[i]) {
is_strict_sorted = false;
break;
}
}
phi::SelectedRows tmp_grad_merge;
const phi::SelectedRows* grad_merge_ptr;
if (is_strict_sorted) {
grad_merge_ptr = &grad;
} else {
// merge duplicated rows if any.
// The rows of grad_merge have been sorted inside MergeAdd functor
scatter::MergeAdd<DeviceContext, T> merge_func;
merge_func(dev_ctx, grad, &tmp_grad_merge, true);
grad_merge_ptr = &tmp_grad_merge;
}
auto& grad_merge = *grad_merge_ptr;
auto& grad_tensor = grad_merge.value();
const T* grad_data = grad_tensor.template data<T>();
auto* grad_merge_rows = &grad_merge.rows();
paddle::framework::MixVector<int64_t> mixv_grad_merge_rows(
grad_merge_rows);
const int64_t* rows = mixv_grad_merge_rows.Data(ctx.GetPlace());
auto row_numel = grad_tensor.numel() / grad_merge.rows().size();
if (platform::is_gpu_place(ctx.GetPlace()) &&
beta1_pow.place() == platform::CPUPlace() &&
beta2_pow.place() == platform::CPUPlace()) {
SparseLambMomentREGUpdateFunctor<T> moment_update_functor(
static_cast<T>(weight_decay),
static_cast<T>(beta1),
static_cast<T>(beta2),
static_cast<T>(epsilon),
*beta1_pow.template data<T>(),
*beta2_pow.template data<T>(),
mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()),
grad_data,
param.template data<T>(),
trust_ratio_div.template data<T>(),
rows,
row_numel,
grad_merge.rows().size(),
skip_update_flag);
for_range(moment_update_functor);
beta1_pow_out.template mutable_data<T>(platform::CPUPlace())[0] =
static_cast<T>(beta1) * beta1_pow.template data<T>()[0];
beta2_pow_out.template mutable_data<T>(platform::CPUPlace())[0] =
static_cast<T>(beta2) * beta2_pow.template data<T>()[0];
} else {
beta1_pow_ptr = beta1_pow.template data<MT>();
beta2_pow_ptr = beta2_pow.template data<MT>();
beta1_pow_out_ptr =
beta1_pow_out.template mutable_data<MT>(ctx.GetPlace());
beta2_pow_out_ptr =
beta2_pow_out.template mutable_data<MT>(ctx.GetPlace());
should_update_beta_pow_later = true;
SparseLambMomentMENUpdateFunctor<T> moment_update_functor(
static_cast<T>(weight_decay),
static_cast<T>(beta1),
static_cast<T>(beta2),
static_cast<T>(epsilon),
reinterpret_cast<const T*>(beta1_pow_ptr),
reinterpret_cast<const T*>(beta2_pow_ptr),
mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()),
grad_data,
param.template data<T>(),
trust_ratio_div.template data<T>(),
rows,
row_numel,
grad_merge.rows().size(),
skip_update_flag);
for_range(moment_update_functor);
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Variable type not supported by lamb_op. Expect LoDTensor or "
"SelectedRows, but got %s",
framework::ToTypeName(grad_var->Type())));
}
// Update parameter
auto p_norm_t = ctx.AllocateTmpTensor<MT, DeviceContext>({1}, dev_ctx);
auto* p_norm_ptr = p_norm_t.template data<MT>();
auto trust_ratio_div_norm_t =
ctx.AllocateTmpTensor<MT, DeviceContext>({1}, dev_ctx);
auto* trust_ratio_div_norm_ptr = trust_ratio_div_norm_t.template data<MT>();
// TODO(zengjinle): remove the following Eigen operations when
// *skip_update == true.
memory::Buffer buffer(dev_ctx.GetPlace());
phi::funcs::SquaredL2Norm(
dev_ctx,
reinterpret_cast<const MT*>(IsMultiPrecision ? master_param_ptr
: param_ptr),
p_norm_ptr,
numel,
&buffer);
phi::funcs::SquaredL2Norm(
dev_ctx, trust_ratio_div_ptr, trust_ratio_div_norm_ptr, numel, &buffer);
if (VLOG_IS_ON(1)) {
const auto& name = ctx.GetOp().Input("Param");
auto pn = ToVector(p_norm_ptr, 1, dev_ctx.GetPlace());
auto tn = ToVector(trust_ratio_div_norm_ptr, 1, dev_ctx.GetPlace());
auto dtype =
framework::DataTypeToString(framework::DataTypeTrait<T>::DataType());
VLOG(1) << "Param " << dtype << " " << name << " pn = " << pn[0]
<< " , tn = " << tn[0];
}
#define CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(__should_update_beta_pow) \
do { \
LambParamUpateFunctor<T, MT, IsMultiPrecision, __should_update_beta_pow> \
param_update_functor(lr.template data<MT>(), \
static_cast<const T*>(param_ptr), \
static_cast<const MT*>(master_param_ptr), \
p_norm_ptr, \
trust_ratio_div_ptr, \
trust_ratio_div_norm_ptr, \
static_cast<T*>(param_out_ptr), \
static_cast<MT*>(master_param_out_ptr), \
skip_update_flag); \
if (__should_update_beta_pow) { \
param_update_functor.SetBetaPows(beta1_pow_ptr, \
beta2_pow_ptr, \
beta1_pow_out_ptr, \
beta2_pow_out_ptr, \
beta1, \
beta2); \
} \
for_range(param_update_functor); \
} while (0)
if (should_update_beta_pow_later) {
CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(true);
} else {
CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(false);
}
#undef CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC
}
};
} // namespace operators
} // namespace paddle
...@@ -16,13 +16,14 @@ ...@@ -16,13 +16,14 @@
#include <sstream> #include <sstream>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/utils/string/string_helper.h"
namespace paddle { namespace phi {
namespace operators { namespace funcs {
template <typename T> template <typename T>
static const std::vector<T> &ToVector(const std::vector<T> &vec) { static const std::vector<T> &ToVector(const std::vector<T> &vec) {
...@@ -30,17 +31,15 @@ static const std::vector<T> &ToVector(const std::vector<T> &vec) { ...@@ -30,17 +31,15 @@ static const std::vector<T> &ToVector(const std::vector<T> &vec) {
} }
template <typename T> template <typename T>
static std::vector<T> ToVector(const T *x, static std::vector<T> ToVector(const T *x, size_t n, const phi::Place &place) {
size_t n,
const platform::Place &place) {
#ifdef __NVCC__ #ifdef __NVCC__
if (platform::is_gpu_place(place)) { if (paddle::platform::is_gpu_place(place)) {
using CopyT = typename std:: using CopyT = typename std::
conditional<std::is_same<T, bool>::value, uint8_t, T>::type; conditional<std::is_same<T, bool>::value, uint8_t, T>::type;
std::vector<CopyT> cpu_x(n); std::vector<CopyT> cpu_x(n);
auto *dev_ctx = static_cast<phi::GPUContext *>( auto *dev_ctx = static_cast<phi::GPUContext *>(
platform::DeviceContextPool::Instance().Get(place)); phi::DeviceContextPool::Instance().Get(place));
memory::Copy(platform::CPUPlace(), paddle::memory::Copy(phi::CPUPlace(),
cpu_x.data(), cpu_x.data(),
place, place,
x, x,
...@@ -54,7 +53,7 @@ static std::vector<T> ToVector(const T *x, ...@@ -54,7 +53,7 @@ static std::vector<T> ToVector(const T *x,
} }
template <typename T> template <typename T>
static std::vector<T> ToVector(const framework::Tensor &src) { static std::vector<T> ToVector(const DenseTensor &src) {
if (!src.IsInitialized()) { if (!src.IsInitialized()) {
return {}; return {};
} }
...@@ -64,8 +63,8 @@ static std::vector<T> ToVector(const framework::Tensor &src) { ...@@ -64,8 +63,8 @@ static std::vector<T> ToVector(const framework::Tensor &src) {
template <typename... Args> template <typename... Args>
static std::string FlattenToString(Args &&...args) { static std::string FlattenToString(Args &&...args) {
const auto &vec = ToVector(std::forward<Args>(args)...); const auto &vec = ToVector(std::forward<Args>(args)...);
return "[" + string::join_strings(vec, ',') + "]"; return "[" + paddle::string::join_strings(vec, ',') + "]";
} }
} // namespace operators } // namespace funcs
} // namespace paddle } // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/lamb_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/lamb_kernel_impl.h"
PD_REGISTER_KERNEL(lamb,
GPU,
ALL_LAYOUT,
phi::LambKernel,
phi::dtype::float16,
float,
double) {
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/lamb_functors.h"
namespace phi {
template <typename T, typename MT, typename Context, bool IsMultiPrecision>
void ComputeImpl(const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& grad,
const DenseTensor& lr,
const DenseTensor& mom1,
const DenseTensor& mom2,
const DenseTensor& beta1_pow,
const DenseTensor& beta2_pow,
const paddle::optional<DenseTensor>& master_param_opt,
const paddle::optional<DenseTensor>& skip_update_opt,
float weight_decay_f,
float beta1_f,
float beta2_f,
float epsilon_f,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* mom1_out,
DenseTensor* mom2_out,
DenseTensor* beta1_pow_out,
DenseTensor* beta2_pow_out,
DenseTensor* master_param_out);
template <typename T, typename Context>
void LambKernel(const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& grad,
const DenseTensor& learning_rate,
const DenseTensor& moment1,
const DenseTensor& moment2,
const DenseTensor& beta1_pow,
const DenseTensor& beta2_pow,
const paddle::optional<DenseTensor>& master_param,
const paddle::optional<DenseTensor>& skip_update,
float weight_decay,
float beta1,
float beta2,
float epsilon,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* moment1_out,
DenseTensor* moment2_out,
DenseTensor* beta1_pow_out,
DenseTensor* beta2_pow_out,
DenseTensor* master_param_outs) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
if (multi_precision) {
ComputeImpl<T, MT, Context, true>(dev_ctx,
param,
grad,
learning_rate,
moment1,
moment2,
beta1_pow,
beta2_pow,
master_param,
skip_update,
weight_decay,
beta1,
beta2,
epsilon,
multi_precision,
param_out,
moment1_out,
moment2_out,
beta1_pow_out,
beta2_pow_out,
master_param_outs);
} else {
ComputeImpl<T, T, Context, false>(dev_ctx,
param,
grad,
learning_rate,
moment1,
moment2,
beta1_pow,
beta2_pow,
master_param,
skip_update,
weight_decay,
beta1,
beta2,
epsilon,
multi_precision,
param_out,
moment1_out,
moment2_out,
beta1_pow_out,
beta2_pow_out,
master_param_outs);
}
}
template <typename T, typename MT, typename Context, bool IsMultiPrecision>
void ComputeImpl(const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& grad,
const DenseTensor& lr,
const DenseTensor& mom1,
const DenseTensor& mom2,
const DenseTensor& beta1_pow,
const DenseTensor& beta2_pow,
const paddle::optional<DenseTensor>& master_param_opt,
const paddle::optional<DenseTensor>& skip_update_opt,
float weight_decay_f,
float beta1_f,
float beta2_f,
float epsilon_f,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* mom1_out,
DenseTensor* mom2_out,
DenseTensor* beta1_pow_out,
DenseTensor* beta2_pow_out,
DenseTensor* master_param_out) {
if (!IsMultiPrecision) {
constexpr auto kIsSameType = std::is_same<T, MT>::value;
PADDLE_ENFORCE_EQ(
kIsSameType,
true,
phi::errors::InvalidArgument(
"When multi_precision=False, T and MT must be the same type."));
}
const auto* master_param =
IsMultiPrecision ? master_param_opt.get_ptr() : nullptr;
const auto* skip_update = skip_update_opt.get_ptr();
const bool* skip_update_flag = skip_update && skip_update->IsInitialized()
? skip_update->data<bool>()
: nullptr;
if (skip_update_flag &&
paddle::platform::is_cpu_place(skip_update->place()) &&
(*skip_update_flag)) {
return;
}
auto weight_decay = static_cast<MT>(weight_decay_f);
auto beta1 = static_cast<MT>(beta1_f);
auto beta2 = static_cast<MT>(beta2_f);
auto epsilon = static_cast<MT>(epsilon_f);
auto numel = param.numel();
phi::funcs::ForRange<Context> for_range(dev_ctx, numel);
DenseTensor trust_ratio_div;
trust_ratio_div.Resize(param.dims());
auto* trust_ratio_div_ptr = dev_ctx.template Alloc<MT>(&trust_ratio_div);
const void* param_ptr = param.data();
const void* master_param_ptr = master_param ? master_param->data() : nullptr;
void* param_out_ptr = dev_ctx.template Alloc<T>(param_out);
void* master_param_out_ptr =
master_param_out ? dev_ctx.template Alloc<MT>(master_param_out) : nullptr;
// Update moments
bool should_update_beta_pow_later = false;
const MT *beta1_pow_ptr = nullptr, *beta2_pow_ptr = nullptr;
MT *beta1_pow_out_ptr = nullptr, *beta2_pow_out_ptr = nullptr;
VLOG(10) << "Beta1Pow place: " << beta1_pow.place()
<< " , Beta2Pow place: " << beta2_pow.place();
// Diff from here
if (paddle::platform::is_gpu_place(dev_ctx.GetPlace()) &&
beta1_pow.place() == phi::CPUPlace() &&
beta2_pow.place() == phi::CPUPlace()) {
LambMomentREGUpdateFunctor<T, IsMultiPrecision> moment_update_functor(
weight_decay,
beta1,
beta2,
epsilon,
*beta1_pow.template data<MT>(),
*beta2_pow.template data<MT>(),
mom1.template data<MT>(),
dev_ctx.template Alloc<MT>(mom1_out),
mom2.template data<MT>(),
dev_ctx.template Alloc<MT>(mom2_out),
grad.template data<T>(),
static_cast<const MT*>(IsMultiPrecision ? master_param_ptr : param_ptr),
trust_ratio_div_ptr,
skip_update_flag);
for_range(moment_update_functor);
MT* beta1_pow_out_data = dev_ctx.template HostAlloc<MT>(beta1_pow_out);
beta1_pow_out_data[0] = beta1 * beta1_pow.template data<MT>()[0];
MT* beta2_pow_out_data = dev_ctx.template HostAlloc<MT>(beta2_pow_out);
beta2_pow_out_data[0] = beta2 * beta2_pow.template data<MT>()[0];
} else {
beta1_pow_ptr = beta1_pow.template data<MT>();
beta2_pow_ptr = beta2_pow.template data<MT>();
beta1_pow_out_ptr = dev_ctx.template Alloc<MT>(beta1_pow_out);
beta2_pow_out_ptr = dev_ctx.template Alloc<MT>(beta2_pow_out);
should_update_beta_pow_later = true;
LambMomentMENUpdateFunctor<T, IsMultiPrecision> moment_update_functor(
weight_decay,
beta1,
beta2,
epsilon,
static_cast<const MT*>(beta1_pow_ptr),
static_cast<const MT*>(beta2_pow_ptr),
mom1.template data<MT>(),
dev_ctx.template Alloc<MT>(mom1_out),
mom2.template data<MT>(),
dev_ctx.template Alloc<MT>(mom2_out),
grad.template data<T>(),
static_cast<const MT*>(IsMultiPrecision ? master_param_ptr : param_ptr),
trust_ratio_div_ptr,
skip_update_flag);
for_range(moment_update_functor);
}
// Same from here
// Update parameter
// The code in the following part is exactly the same as that in
// paddle/phi/kernels/selected_rows/impl/lamb_kernel_impl.h Please modify it
// together
DenseTensor p_norm_t;
p_norm_t.Resize(phi::make_ddim({1}));
auto* p_norm_ptr = dev_ctx.template Alloc<MT>(&p_norm_t);
DenseTensor trust_ratio_div_norm_t;
trust_ratio_div_norm_t.Resize(phi::make_ddim({1}));
auto* trust_ratio_div_norm_ptr =
dev_ctx.template Alloc<MT>(&trust_ratio_div_norm_t);
// TODO(zengjinle): remove the following Eigen operations when
// *skip_update == true.
paddle::memory::Buffer buffer(dev_ctx.GetPlace());
phi::funcs::SquaredL2Norm(
dev_ctx,
reinterpret_cast<const MT*>(IsMultiPrecision ? master_param_ptr
: param_ptr),
p_norm_ptr,
numel,
&buffer);
phi::funcs::SquaredL2Norm(
dev_ctx, trust_ratio_div_ptr, trust_ratio_div_norm_ptr, numel, &buffer);
if (VLOG_IS_ON(1)) {
const auto& name = "Param";
auto pn = phi::funcs::ToVector(p_norm_ptr, 1, dev_ctx.GetPlace());
auto tn =
phi::funcs::ToVector(trust_ratio_div_norm_ptr, 1, dev_ctx.GetPlace());
auto dtype = paddle::framework::DataTypeToString(
paddle::framework::DataTypeTrait<T>::DataType());
VLOG(1) << "Param " << dtype << " " << name << " pn = " << pn[0]
<< " , tn = " << tn[0];
}
#define CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(__should_update_beta_pow) \
do { \
LambParamUpateFunctor<T, MT, IsMultiPrecision, __should_update_beta_pow> \
param_update_functor(lr.template data<MT>(), \
static_cast<const T*>(param_ptr), \
static_cast<const MT*>(master_param_ptr), \
p_norm_ptr, \
trust_ratio_div_ptr, \
trust_ratio_div_norm_ptr, \
static_cast<T*>(param_out_ptr), \
static_cast<MT*>(master_param_out_ptr), \
skip_update_flag); \
if (__should_update_beta_pow) { \
param_update_functor.SetBetaPows(beta1_pow_ptr, \
beta2_pow_ptr, \
beta1_pow_out_ptr, \
beta2_pow_out_ptr, \
beta1, \
beta2); \
} \
for_range(param_update_functor); \
} while (0)
if (should_update_beta_pow_later) {
CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(true);
} else {
CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(false);
}
#undef CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void LambKernel(const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& grad,
const DenseTensor& learning_rate,
const DenseTensor& moment1,
const DenseTensor& moment2,
const DenseTensor& beta1_pow,
const DenseTensor& beta2_pow,
const paddle::optional<DenseTensor>& master_param,
const paddle::optional<DenseTensor>& skip_update,
float weight_decay,
float beta1,
float beta2,
float epsilon,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* moment1_out,
DenseTensor* moment2_out,
DenseTensor* beta1_pow_out,
DenseTensor* beta2_pow_out,
DenseTensor* master_param_outs);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/selected_rows/lamb_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/selected_rows/impl/lamb_kernel_impl.h"
PD_REGISTER_KERNEL(
lamb_sr, CPU, ALL_LAYOUT, phi::sr::LambKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/selected_rows/lamb_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/selected_rows/impl/lamb_kernel_impl.h"
PD_REGISTER_KERNEL(lamb_sr,
GPU,
ALL_LAYOUT,
phi::sr::LambKernel,
phi::dtype::float16,
float,
double) {
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/kernels/funcs/lamb_functors.h"
namespace phi {
namespace sr {
template <typename T, typename MT, typename Context, bool IsMultiPrecision>
void ComputeRowImpl(const Context& dev_ctx,
const DenseTensor& param,
const SelectedRows& grad,
const DenseTensor& lr,
const DenseTensor& mom1,
const DenseTensor& mom2,
const DenseTensor& beta1_pow,
const DenseTensor& beta2_pow,
const paddle::optional<DenseTensor>& master_param_opt,
const paddle::optional<DenseTensor>& skip_update_opt,
float weight_decay_f,
float beta1_f,
float beta2_f,
float epsilon_f,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* mom1_out,
DenseTensor* mom2_out,
DenseTensor* beta1_pow_out,
DenseTensor* beta2_pow_out,
DenseTensor* master_param_out);
template <typename T, typename Context>
void LambKernel(const Context& dev_ctx,
const DenseTensor& param,
const SelectedRows& grad,
const DenseTensor& learning_rate,
const DenseTensor& moment1,
const DenseTensor& moment2,
const DenseTensor& beta1_pow,
const DenseTensor& beta2_pow,
const paddle::optional<DenseTensor>& master_param,
const paddle::optional<DenseTensor>& skip_update,
float weight_decay,
float beta1,
float beta2,
float epsilon,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* moment1_out,
DenseTensor* moment2_out,
DenseTensor* beta1_pow_out,
DenseTensor* beta2_pow_out,
DenseTensor* master_param_outs) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
if (multi_precision) {
ComputeRowImpl<T, MT, Context, true>(dev_ctx,
param,
grad,
learning_rate,
moment1,
moment2,
beta1_pow,
beta2_pow,
master_param,
skip_update,
weight_decay,
beta1,
beta2,
epsilon,
multi_precision,
param_out,
moment1_out,
moment2_out,
beta1_pow_out,
beta2_pow_out,
master_param_outs);
} else {
ComputeRowImpl<T, T, Context, false>(dev_ctx,
param,
grad,
learning_rate,
moment1,
moment2,
beta1_pow,
beta2_pow,
master_param,
skip_update,
weight_decay,
beta1,
beta2,
epsilon,
multi_precision,
param_out,
moment1_out,
moment2_out,
beta1_pow_out,
beta2_pow_out,
master_param_outs);
}
}
template <typename T, typename MT, typename Context, bool IsMultiPrecision>
void ComputeRowImpl(const Context& dev_ctx,
const DenseTensor& param,
const SelectedRows& grad,
const DenseTensor& lr,
const DenseTensor& mom1,
const DenseTensor& mom2,
const DenseTensor& beta1_pow,
const DenseTensor& beta2_pow,
const paddle::optional<DenseTensor>& master_param_opt,
const paddle::optional<DenseTensor>& skip_update_opt,
float weight_decay_f,
float beta1_f,
float beta2_f,
float epsilon_f,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* mom1_out,
DenseTensor* mom2_out,
DenseTensor* beta1_pow_out,
DenseTensor* beta2_pow_out,
DenseTensor* master_param_out) {
if (!IsMultiPrecision) {
constexpr auto kIsSameType = std::is_same<T, MT>::value;
PADDLE_ENFORCE_EQ(
kIsSameType,
true,
phi::errors::InvalidArgument(
"When multi_precision=False, T and MT must be the same type."));
}
const auto* master_param =
IsMultiPrecision ? master_param_opt.get_ptr() : nullptr;
const auto* skip_update = skip_update_opt.get_ptr();
const bool* skip_update_flag = skip_update && skip_update->IsInitialized()
? skip_update->data<bool>()
: nullptr;
if (skip_update_flag &&
paddle::platform::is_cpu_place(skip_update->place()) &&
(*skip_update_flag)) {
return;
}
auto weight_decay = static_cast<MT>(weight_decay_f);
auto beta1 = static_cast<MT>(beta1_f);
auto beta2 = static_cast<MT>(beta2_f);
auto epsilon = static_cast<MT>(epsilon_f);
auto numel = param.numel();
phi::funcs::ForRange<Context> for_range(dev_ctx, numel);
DenseTensor trust_ratio_div;
trust_ratio_div.Resize(param.dims());
/*auto trust_ratio_div =
ctx.AllocateTmpTensor<MT, DeviceContext>(param.dims(), dev_ctx);*/
auto* trust_ratio_div_ptr = dev_ctx.template Alloc<MT>(&trust_ratio_div);
const void* param_ptr = param.data();
const void* master_param_ptr = master_param ? master_param->data() : nullptr;
void* param_out_ptr = dev_ctx.template Alloc<T>(param_out);
void* master_param_out_ptr =
master_param_out ? dev_ctx.template Alloc<MT>(master_param_out) : nullptr;
// Update moments
bool should_update_beta_pow_later = false;
const MT *beta1_pow_ptr = nullptr, *beta2_pow_ptr = nullptr;
MT *beta1_pow_out_ptr = nullptr, *beta2_pow_out_ptr = nullptr;
VLOG(10) << "Beta1Pow place: " << beta1_pow.place()
<< " , Beta2Pow place: " << beta2_pow.place();
// Diff from here
PADDLE_ENFORCE_EQ(
IsMultiPrecision,
false,
phi::errors::Unimplemented("SelectedRows gradient is not supported when "
"multi_precision=True."));
constexpr bool kIsSameType = std::is_same<T, MT>::value;
PADDLE_ENFORCE_EQ(
kIsSameType,
true,
phi::errors::Unimplemented("SelectedRows gradient is not supported when "
"multi_precision=True."));
if (grad.rows().size() == 0) {
VLOG(3) << "grad row size is 0!!";
return;
}
std::vector<int64_t> cpu_rows(grad.rows().begin(), grad.rows().end());
bool is_strict_sorted = true;
for (size_t i = 1; i < cpu_rows.size(); ++i) {
if (cpu_rows[i - 1] >= cpu_rows[i]) {
is_strict_sorted = false;
break;
}
}
phi::SelectedRows tmp_grad_merge;
const phi::SelectedRows* grad_merge_ptr;
if (is_strict_sorted) {
grad_merge_ptr = &grad;
} else {
// merge duplicated rows if any.
// The rows of grad_merge have been sorted inside MergeAdd functor
paddle::operators::math::scatter::MergeAdd<Context, T> merge_func;
merge_func(dev_ctx, grad, &tmp_grad_merge, true);
grad_merge_ptr = &tmp_grad_merge;
}
auto& grad_merge = *grad_merge_ptr;
auto& grad_tensor = grad_merge.value();
const T* grad_data = grad_tensor.template data<T>();
auto* grad_merge_rows = &grad_merge.rows();
paddle::framework::MixVector<int64_t> mixv_grad_merge_rows(grad_merge_rows);
const int64_t* rows = mixv_grad_merge_rows.Data(dev_ctx.GetPlace());
auto row_numel = grad_tensor.numel() / grad_merge.rows().size();
if (paddle::platform::is_gpu_place(dev_ctx.GetPlace()) &&
beta1_pow.place() == phi::CPUPlace() &&
beta2_pow.place() == phi::CPUPlace()) {
SparseLambMomentREGUpdateFunctor<T> moment_update_functor(
static_cast<T>(weight_decay),
static_cast<T>(beta1),
static_cast<T>(beta2),
static_cast<T>(epsilon),
*beta1_pow.template data<T>(),
*beta2_pow.template data<T>(),
mom1.template data<T>(),
dev_ctx.template Alloc<T>(mom1_out),
mom2.template data<T>(),
dev_ctx.template Alloc<T>(mom2_out),
grad_data,
param.template data<T>(),
trust_ratio_div.template data<T>(),
rows,
row_numel,
grad_merge.rows().size(),
skip_update_flag);
for_range(moment_update_functor);
T* beta1_pow_out_data = dev_ctx.template HostAlloc<T>(beta1_pow_out);
beta1_pow_out_data[0] =
static_cast<T>(beta1) * beta1_pow.template data<T>()[0];
T* beta2_pow_out_data = dev_ctx.template HostAlloc<T>(beta2_pow_out);
beta2_pow_out_data[0] =
static_cast<T>(beta2) * beta2_pow.template data<T>()[0];
} else {
beta1_pow_ptr = beta1_pow.template data<MT>();
beta2_pow_ptr = beta2_pow.template data<MT>();
beta1_pow_out_ptr = dev_ctx.template Alloc<MT>(beta1_pow_out);
beta2_pow_out_ptr = dev_ctx.template Alloc<MT>(beta2_pow_out);
should_update_beta_pow_later = true;
SparseLambMomentMENUpdateFunctor<T> moment_update_functor(
static_cast<T>(weight_decay),
static_cast<T>(beta1),
static_cast<T>(beta2),
static_cast<T>(epsilon),
reinterpret_cast<const T*>(beta1_pow_ptr),
reinterpret_cast<const T*>(beta2_pow_ptr),
mom1.template data<T>(),
dev_ctx.template Alloc<T>(mom1_out),
mom2.template data<T>(),
dev_ctx.template Alloc<T>(mom2_out),
grad_data,
param.template data<T>(),
trust_ratio_div.template data<T>(),
rows,
row_numel,
grad_merge.rows().size(),
skip_update_flag);
for_range(moment_update_functor);
}
// Same from here
// Update parameter
// The code in the following part is exactly the same as that in
// paddle/phi/kernels/impl/lamb_kernel_impl.h Please modify it together
DenseTensor p_norm_t;
p_norm_t.Resize(phi::make_ddim({1}));
auto* p_norm_ptr = dev_ctx.template Alloc<MT>(&p_norm_t);
DenseTensor trust_ratio_div_norm_t;
trust_ratio_div_norm_t.Resize(phi::make_ddim({1}));
auto* trust_ratio_div_norm_ptr =
dev_ctx.template Alloc<MT>(&trust_ratio_div_norm_t);
// TODO(zengjinle): remove the following Eigen operations when
// *skip_update == true.
paddle::memory::Buffer buffer(dev_ctx.GetPlace());
phi::funcs::SquaredL2Norm(
dev_ctx,
reinterpret_cast<const MT*>(IsMultiPrecision ? master_param_ptr
: param_ptr),
p_norm_ptr,
numel,
&buffer);
phi::funcs::SquaredL2Norm(
dev_ctx, trust_ratio_div_ptr, trust_ratio_div_norm_ptr, numel, &buffer);
if (VLOG_IS_ON(1)) {
const auto& name = "Param";
auto pn = phi::funcs::ToVector(p_norm_ptr, 1, dev_ctx.GetPlace());
auto tn =
phi::funcs::ToVector(trust_ratio_div_norm_ptr, 1, dev_ctx.GetPlace());
auto dtype = paddle::framework::DataTypeToString(
paddle::framework::DataTypeTrait<T>::DataType());
VLOG(1) << "Param " << dtype << " " << name << " pn = " << pn[0]
<< " , tn = " << tn[0];
}
#define CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(__should_update_beta_pow) \
do { \
LambParamUpateFunctor<T, MT, IsMultiPrecision, __should_update_beta_pow> \
param_update_functor(lr.template data<MT>(), \
static_cast<const T*>(param_ptr), \
static_cast<const MT*>(master_param_ptr), \
p_norm_ptr, \
trust_ratio_div_ptr, \
trust_ratio_div_norm_ptr, \
static_cast<T*>(param_out_ptr), \
static_cast<MT*>(master_param_out_ptr), \
skip_update_flag); \
if (__should_update_beta_pow) { \
param_update_functor.SetBetaPows(beta1_pow_ptr, \
beta2_pow_ptr, \
beta1_pow_out_ptr, \
beta2_pow_out_ptr, \
beta1, \
beta2); \
} \
for_range(param_update_functor); \
} while (0)
if (should_update_beta_pow_later) {
CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(true);
} else {
CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(false);
}
#undef CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC
}
} // namespace sr
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
namespace phi {
namespace sr {
template <typename T, typename Context>
void LambKernel(const Context& dev_ctx,
const DenseTensor& param,
const SelectedRows& grad,
const DenseTensor& learning_rate,
const DenseTensor& moment1,
const DenseTensor& moment2,
const DenseTensor& beta1_pow,
const DenseTensor& beta2_pow,
const paddle::optional<DenseTensor>& master_param,
const paddle::optional<DenseTensor>& skip_update,
float weight_decay,
float beta1,
float beta2,
float epsilon,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* moment1_out,
DenseTensor* moment2_out,
DenseTensor* beta1_pow_out,
DenseTensor* beta2_pow_out,
DenseTensor* master_param_outs);
} // namespace sr
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/utils/small_vector.h"
namespace phi {
KernelSignature LambOpArgumentMapping(const ArgumentMappingContext& ctx) {
paddle::small_vector<const char*> in_names = {"Param",
"Grad",
"LearningRate",
"Moment1",
"Moment2",
"Beta1Pow",
"Beta2Pow",
"MasterParam",
"SkipUpdate"};
paddle::small_vector<const char*> out_names = {"ParamOut",
"Moment1Out",
"Moment2Out",
"Beta1PowOut",
"Beta2PowOut",
"MasterParamOut"};
paddle::small_vector<const char*> attr_names;
attr_names.emplace_back("weight_decay");
attr_names.emplace_back("beta1");
attr_names.emplace_back("beta2");
attr_names.emplace_back("epsilon");
attr_names.emplace_back("multi_precision");
if (ctx.IsSelectedRowsInput("Grad")) {
return KernelSignature("lamb_sr",
std::move(in_names),
std::move(attr_names),
std::move(out_names));
} else if (ctx.IsDenseTensorInput("Grad")) {
return KernelSignature("lamb",
std::move(in_names),
std::move(attr_names),
std::move(out_names));
} else {
return KernelSignature("unregistered", {}, {}, {});
}
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(lamb, phi::LambOpArgumentMapping);
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -21,6 +21,7 @@ from ..fluid import unique_name ...@@ -21,6 +21,7 @@ from ..fluid import unique_name
from ..fluid.layer_helper import LayerHelper from ..fluid.layer_helper import LayerHelper
from paddle import _C_ops from paddle import _C_ops
from paddle.fluid.executor import global_scope from paddle.fluid.executor import global_scope
import paddle
__all__ = [] __all__ = []
...@@ -266,6 +267,13 @@ class Lamb(Optimizer): ...@@ -266,6 +267,13 @@ class Lamb(Optimizer):
master_weight = None master_weight = None
found_inf = self._get_auxiliary_var('found_inf') found_inf = self._get_auxiliary_var('found_inf')
if framework.in_dygraph_mode():
_C_ops.final_state_lamb_(param_and_grad[0], param_and_grad[1], lr,
moment1, moment2, beta1_pow_acc,
beta2_pow_acc, master_weight, found_inf,
weight_decay, self._beta1, self._beta2,
self._epsilon, find_master)
return None
if framework._non_static_mode(): if framework._non_static_mode():
_C_ops.lamb(param_and_grad[0], param_and_grad[1], lr, moment1, _C_ops.lamb(param_and_grad[0], param_and_grad[1], lr, moment1,
moment2, beta1_pow_acc, beta2_pow_acc, master_weight, moment2, beta1_pow_acc, beta2_pow_acc, master_weight,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册