From 683f152aea1fc7ddc6cd12a0d7a1764a6184a87a Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Fri, 29 Apr 2022 16:58:44 +0800 Subject: [PATCH] [OP]Fix adamw not registered into AllKernels (#42391) --- paddle/fluid/operators/optimizers/adam_op.cc | 167 +----------------- paddle/fluid/operators/optimizers/adam_op.h | 149 ++++++++++++++++ paddle/fluid/operators/optimizers/adamw_op.cc | 58 ++++++ 3 files changed, 209 insertions(+), 165 deletions(-) create mode 100644 paddle/fluid/operators/optimizers/adam_op.h create mode 100644 paddle/fluid/operators/optimizers/adamw_op.cc diff --git a/paddle/fluid/operators/optimizers/adam_op.cc b/paddle/fluid/operators/optimizers/adam_op.cc index 8225dc8e07..36e54d741a 100644 --- a/paddle/fluid/operators/optimizers/adam_op.cc +++ b/paddle/fluid/operators/optimizers/adam_op.cc @@ -12,168 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/operators/optimizers/adam_op.h" #include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/infermeta/multiary.h" -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -class AdamOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const { - auto input_data_type = - OperatorWithKernel::IndicateVarDataType(ctx, "Param"); - return framework::OpKernelType(input_data_type, ctx.GetPlace()); - } - - framework::OpKernelType GetKernelTypeForVar( - const std::string &var_name, const framework::Tensor &tensor, - const framework::OpKernelType &expected_kernel_type) const { - if (var_name == "Beta1Pow" || var_name == "Beta2Pow" || - var_name == "SkipUpdate") { - return expected_kernel_type; - } else { - return framework::OpKernelType(expected_kernel_type.data_type_, - tensor.place(), tensor.layout()); - } - } -}; - -class AdamOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("Param", "(Tensor) Input parameter"); - AddInput("Grad", "(Tensor) Input gradient"); - AddInput("LearningRate", "(Tensor) Learning rate"); - AddInput("Moment1", "(Tensor) Input first moment"); - AddInput("Moment2", "(Tensor) Input second moment"); - AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator"); - AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator"); - - AddInput("Beta1Tensor", - "(Tensor, optional) If provided, Adam will use this " - "as beta1, this has a higher priority than attr(beta1), the " - "shape of this tensor MUST BE [1].") - .AsDispensable(); - AddInput("Beta2Tensor", - "(Tensor, optional) If provided, Adam will use this " - "as beta2, this has a higher priority than attr(beta2), the " - "shape of this tensor MUST BE [1].") - .AsDispensable(); - AddInput("EpsilonTensor", - "(Tensor, optional) If provided, Adam will use this " - "as epsilon, this has a higher priority than attr(epsilon), the " - "shape of this tensor MUST BE [1].") - .AsDispensable(); - AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable(); - AddInput("SkipUpdate", "(Tensor, optional), Skip the update or not.") - .AsDispensable(); - - AddOutput("ParamOut", "(Tensor) Output parameter"); - AddOutput("Moment1Out", "(Tensor) Output first moment"); - AddOutput("Moment2Out", "(Tensor) Output second moment"); - AddOutput("Beta1PowOut", "(Tensor) Output beta1 power accumulator"); - AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator"); - AddOutput("MasterParamOut", - "The updated FP32 master weight for AMP. " - "It shared memory with Input(MasterParam).") - .AsDispensable(); - - AddAttr("beta1", - "(float, default 0.9) " - "Exponential decay rate for the " - "first moment estimates.") - .SetDefault(0.9f); - AddAttr("beta2", - "(float, default 0.999) " - "exponential decay rate for the " - "second moment estimates.") - .SetDefault(0.999f); - AddAttr("epsilon", - "(float, default 1.0e-8) " - "Constant for numerical stability") - .SetDefault(1.0e-8f); - AddAttr( - "lazy_mode", - "(bool, default false) " - "only update the parameter that has gradient in sparse update") - .SetDefault(false); - AddAttr("min_row_size_to_use_multithread", - "(int64_t, default 0) " - "when not zero, if param row size is larger then " - "min_row_size_to_use_multithread and " - "inner_op_parallelism is larger then 0, sparse update " - "will run in multithread mode") - .SetDefault(1000); - AddAttr("multi_precision", - "(bool, default false) " - "Whether to use multi-precision during weight updating.") - .SetDefault(false); - // TODO(zhiqiu): We could set Beta1PowOut and Beta2PowOut - // as dispensable since they are not used when use_global_beta_pow is true. - AddAttr("use_global_beta_pow", - "(bool, default false) " - "Whether to use global beta_pow for whole model instead of " - "creating beta_pow for each parameter.") - .SetDefault(false); - - AddComment(R"DOC( -Adam Optimizer. - -This implements the Adam optimizer from Section 2 of the Adam -paper : https://arxiv.org/abs/1412.6980. -Adam is a first-order gradient-based optimization method based on -adaptive estimates of lower-order moments. - -Adam updates: - -$$ -moment\_1\_out = \beta_1 * moment\_1 + (1 - \beta_1) * grad \\ -moment\_2_\out = \beta_2 * moment\_2 + (1 - \beta_2) * grad * grad \\ -learning\_rate = learning\_rate * - \frac{\sqrt{1 - \beta_{2\_pow}}}{1 - \beta_{1\_pow}} \\ -param\_out = param - learning\_rate * \frac{moment\_1}{\sqrt{moment\_2} + \epsilon} -$$ - -)DOC"); - } -}; - -class AdamWOp : public AdamOp { - using AdamOp::AdamOp; -}; - -class AdamWOpMaker : public AdamOpMaker { - public: - void Make() { - AdamOpMaker::Make(); - AddAttr("lr_ratio", - "(float, default 1.0) " - "layerwise learning rate decay") - .SetDefault(1.0f); - AddAttr("coeff", - "(float, default 0.01) " - "coeff of the weight decay") - .SetDefault(0.01f); - AddAttr("with_decay", - "(bool, default false) " - "whether to do weight decay") - .SetDefault(false); - } -}; - -} // namespace operators -} // namespace paddle - namespace ops = paddle::operators; DECLARE_INFER_SHAPE_FUNCTOR(adam, AdamInferMetaFunctor, @@ -185,14 +30,6 @@ REGISTER_OPERATOR( paddle::framework::EmptyGradOpMaker, AdamInferMetaFunctor); -DECLARE_INFER_SHAPE_FUNCTOR(adamw, AdamwInferMetaFunctor, - PD_INFER_META(phi::AdamwInferMeta)); -REGISTER_OPERATOR( - adamw, ops::AdamWOp, ops::AdamWOpMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker, - AdamwInferMetaFunctor); - REGISTER_OP_VERSION(adam) .AddCheckpoint( R"ROC( diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h new file mode 100644 index 0000000000..31feaa8102 --- /dev/null +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -0,0 +1,149 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class AdamOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const { + auto input_data_type = + OperatorWithKernel::IndicateVarDataType(ctx, "Param"); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const { + if (var_name == "Beta1Pow" || var_name == "Beta2Pow" || + var_name == "SkipUpdate") { + return expected_kernel_type; + } else { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } + } +}; + +class AdamOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Param", "(Tensor) Input parameter"); + AddInput("Grad", "(Tensor) Input gradient"); + AddInput("LearningRate", "(Tensor) Learning rate"); + AddInput("Moment1", "(Tensor) Input first moment"); + AddInput("Moment2", "(Tensor) Input second moment"); + AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator"); + AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator"); + + AddInput("Beta1Tensor", + "(Tensor, optional) If provided, Adam will use this " + "as beta1, this has a higher priority than attr(beta1), the " + "shape of this tensor MUST BE [1].") + .AsDispensable(); + AddInput("Beta2Tensor", + "(Tensor, optional) If provided, Adam will use this " + "as beta2, this has a higher priority than attr(beta2), the " + "shape of this tensor MUST BE [1].") + .AsDispensable(); + AddInput("EpsilonTensor", + "(Tensor, optional) If provided, Adam will use this " + "as epsilon, this has a higher priority than attr(epsilon), the " + "shape of this tensor MUST BE [1].") + .AsDispensable(); + AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable(); + AddInput("SkipUpdate", "(Tensor, optional), Skip the update or not.") + .AsDispensable(); + + AddOutput("ParamOut", "(Tensor) Output parameter"); + AddOutput("Moment1Out", "(Tensor) Output first moment"); + AddOutput("Moment2Out", "(Tensor) Output second moment"); + AddOutput("Beta1PowOut", "(Tensor) Output beta1 power accumulator"); + AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator"); + AddOutput("MasterParamOut", + "The updated FP32 master weight for AMP. " + "It shared memory with Input(MasterParam).") + .AsDispensable(); + + AddAttr("beta1", + "(float, default 0.9) " + "Exponential decay rate for the " + "first moment estimates.") + .SetDefault(0.9f); + AddAttr("beta2", + "(float, default 0.999) " + "exponential decay rate for the " + "second moment estimates.") + .SetDefault(0.999f); + AddAttr("epsilon", + "(float, default 1.0e-8) " + "Constant for numerical stability") + .SetDefault(1.0e-8f); + AddAttr( + "lazy_mode", + "(bool, default false) " + "only update the parameter that has gradient in sparse update") + .SetDefault(false); + AddAttr("min_row_size_to_use_multithread", + "(int64_t, default 0) " + "when not zero, if param row size is larger then " + "min_row_size_to_use_multithread and " + "inner_op_parallelism is larger then 0, sparse update " + "will run in multithread mode") + .SetDefault(1000); + AddAttr("multi_precision", + "(bool, default false) " + "Whether to use multi-precision during weight updating.") + .SetDefault(false); + // TODO(zhiqiu): We could set Beta1PowOut and Beta2PowOut + // as dispensable since they are not used when use_global_beta_pow is true. + AddAttr("use_global_beta_pow", + "(bool, default false) " + "Whether to use global beta_pow for whole model instead of " + "creating beta_pow for each parameter.") + .SetDefault(false); + + AddComment(R"DOC( +Adam Optimizer. + +This implements the Adam optimizer from Section 2 of the Adam +paper : https://arxiv.org/abs/1412.6980. +Adam is a first-order gradient-based optimization method based on +adaptive estimates of lower-order moments. + +Adam updates: + +$$ +moment\_1\_out = \beta_1 * moment\_1 + (1 - \beta_1) * grad \\ +moment\_2_\out = \beta_2 * moment\_2 + (1 - \beta_2) * grad * grad \\ +learning\_rate = learning\_rate * + \frac{\sqrt{1 - \beta_{2\_pow}}}{1 - \beta_{1\_pow}} \\ +param\_out = param - learning\_rate * \frac{moment\_1}{\sqrt{moment\_2} + \epsilon} +$$ + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/optimizers/adamw_op.cc b/paddle/fluid/operators/optimizers/adamw_op.cc new file mode 100644 index 0000000000..e2670625d4 --- /dev/null +++ b/paddle/fluid/operators/optimizers/adamw_op.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/optimizers/adam_op.h" + +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/multiary.h" + +namespace paddle { +namespace operators { + +class AdamWOp : public AdamOp { + using AdamOp::AdamOp; +}; + +class AdamWOpMaker : public AdamOpMaker { + public: + void Make() { + AdamOpMaker::Make(); + AddAttr("lr_ratio", + "(float, default 1.0) " + "layerwise learning rate decay") + .SetDefault(1.0f); + AddAttr("coeff", + "(float, default 0.01) " + "coeff of the weight decay") + .SetDefault(0.01f); + AddAttr("with_decay", + "(bool, default false) " + "whether to do weight decay") + .SetDefault(false); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +DECLARE_INFER_SHAPE_FUNCTOR(adamw, AdamwInferMetaFunctor, + PD_INFER_META(phi::AdamwInferMeta)); +REGISTER_OPERATOR( + adamw, ops::AdamWOp, ops::AdamWOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker, + AdamwInferMetaFunctor); -- GitLab