diff --git a/paddle/fluid/operators/optimizers/adam_op.cu b/paddle/fluid/operators/optimizers/adam_op.cu index 57231e1135a6a922800810ed1515dc79c316e176..3b9cf159f1b6b15dbe8237016aa0f1eb80a0f283 100644 --- a/paddle/fluid/operators/optimizers/adam_op.cu +++ b/paddle/fluid/operators/optimizers/adam_op.cu @@ -29,20 +29,18 @@ __global__ void AdamKernelREG(MT beta1, MT beta2, MT epsilon, MT beta1_pow_, MT beta1_pow = beta1_pow_; MT beta2_pow = beta2_pow_; - lr *= sqrt(static_cast(1.0) - beta2_pow) / - (static_cast(1.0) - beta1_pow); - int id = blockIdx.x * blockDim.x + threadIdx.x; for (; id < ndim; id += gridDim.x * blockDim.x) { MT p = master_param ? master_param[id] : static_cast(param[id]); MT g = static_cast(grad[id]); - MT mom1 = moment1[id]; - MT mom2 = moment2[id]; + MT mom1 = static_cast(moment1[id]); + MT mom2 = static_cast(moment2[id]); mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - p -= lr * (mom1 / - (sqrt(mom2) + epsilon * sqrt(static_cast(1.0) - beta2_pow))); + + MT denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; + p += (mom1 / denom) * (-(lr / (static_cast(1.0) - beta1_pow))); moment1_out[id] = mom1; moment2_out[id] = mom2; @@ -65,9 +63,6 @@ __global__ void AdamKernelMEM(MT beta1, MT beta2, MT epsilon, MT beta1_pow = *beta1_pow_; MT beta2_pow = *beta2_pow_; - lr *= sqrt(static_cast(1.0) - beta2_pow) / - (static_cast(1.0) - beta1_pow); - int id = blockIdx.x * blockDim.x + threadIdx.x; for (; id < ndim; id += gridDim.x * blockDim.x) { @@ -77,8 +72,9 @@ __global__ void AdamKernelMEM(MT beta1, MT beta2, MT epsilon, MT mom2 = static_cast(moment2[id]); mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - p -= lr * (mom1 / - (sqrt(mom2) + epsilon * sqrt(static_cast(1.0) - beta2_pow))); + + MT denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; + p += (mom1 / denom) * (-(lr / (static_cast(1.0) - beta1_pow))); moment1_out[id] = mom1; moment2_out[id] = mom2; @@ -105,8 +101,6 @@ __global__ void SparseAdamCUDAKernelREG( int64_t row_numel, int64_t row_count, bool lazy_mode, int ndim) { int id = blockIdx.x * blockDim.x + threadIdx.x; MT lr = *lr_; - lr *= sqrt(static_cast(1.0) - beta2_pow) / - (static_cast(1.0) - beta1_pow); for (; id < ndim; id += blockDim.x * gridDim.x) { auto row_idx = @@ -122,8 +116,10 @@ __global__ void SparseAdamCUDAKernelREG( : static_cast(0); mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - p -= lr * (mom1 / (sqrt(mom2) + - epsilon * sqrt(static_cast(1.0) - beta2_pow))); + + MT denom = + (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; + p += (mom1 / denom) * (-(lr / (static_cast(1.0) - beta1_pow))); // Write back to global memory mom1_out_[id] = mom1; diff --git a/paddle/fluid/operators/optimizers/merged_adam_op.cc b/paddle/fluid/operators/optimizers/merged_adam_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..11c047305c44a0c6bcf901af25f16c8c0d9934e9 --- /dev/null +++ b/paddle/fluid/operators/optimizers/merged_adam_op.cc @@ -0,0 +1,138 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/optimizers/merged_adam_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class MergedAdamOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override {} + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto param_dtype = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "Param"); + return framework::OpKernelType(param_dtype, ctx.GetPlace()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override { + if (var_name == "Beta1Pow" || var_name == "Beta2Pow" || + var_name == "SkipUpdate") { + return expected_kernel_type; + } else { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } + } +}; + +class MergedAdamOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Param", "(Tensor, default Tensor) Input parameter") + .AsDuplicable(); + AddInput("Grad", "(Tensor, default Tensor) Input gradient") + .AsDuplicable(); + AddInput("LearningRate", "(Tensor, default Tensor) Learning rate") + .AsDuplicable(); + AddInput("Moment1", "(Tensor, default Tensor) Input first moment") + .AsDuplicable(); + AddInput("Moment2", "(Tensor, default Tensor) Input second moment") + .AsDuplicable(); + AddInput("Beta1Pow", + "(Tensor, default Tensor) Input beta1 power accumulator") + .AsDuplicable(); + AddInput("Beta2Pow", + "(Tensor, default Tensor) Input beta2 power accumulator") + .AsDuplicable(); + AddInput("MasterParam", "FP32 master weight for AMP.") + .AsDispensable() + .AsDuplicable(); + + AddOutput("ParamOut", "(Tensor) Output parameter").AsDuplicable(); + AddOutput("Moment1Out", "(Tensor) Output first moment").AsDuplicable(); + AddOutput("Moment2Out", "(Tensor) Output second moment").AsDuplicable(); + AddOutput("Beta1PowOut", "(Tensor) Output beta1 power accumulator") + .AsDuplicable(); + AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator") + .AsDuplicable(); + AddOutput("MasterParamOut", + "The updated FP32 master weight for AMP. " + "It shared memory with Input(MasterParam).") + .AsDispensable() + .AsDuplicable(); + + AddAttr("beta1", + "(float, default 0.9) " + "Exponential decay rate for the " + "first moment estimates.") + .SetDefault(0.9f); + AddAttr("beta2", + "(float, default 0.999) " + "exponential decay rate for the " + "second moment estimates.") + .SetDefault(0.999f); + AddAttr("epsilon", + "(float, default 1.0e-8) " + "Constant for numerical stability") + .SetDefault(1.0e-8f); + AddAttr("multi_precision", + "(bool, default false) " + "Whether to use multi-precision during weight updating.") + .SetDefault(false); + // TODO(zhiqiu): We could set Beta1PowOut and Beta2PowOut + // as dispensable since they are not used when use_global_beta_pow is true. + AddAttr("use_global_beta_pow", + "(bool, default false) " + "Whether to use global beta_pow for whole model instead of " + "creating beta_pow for each parameter.") + .SetDefault(false); + + AddComment(R"DOC( +Adam Optimizer. +This implements the Adam optimizer from Section 2 of the Adam +paper : https://arxiv.org/abs/1412.6980. +Adam is a first-order gradient-based optimization method based on +adaptive estimates of lower-order moments. +Adam updates: +$$ +moment\_1\_out = \beta_1 * moment\_1 + (1 - \beta_1) * grad \\ +moment\_2_\out = \beta_2 * moment\_2 + (1 - \beta_2) * grad * grad \\ +learning\_rate = learning\_rate * + \frac{\sqrt{1 - \beta_{2\_pow}}}{1 - \beta_{1\_pow}} \\ +param\_out = param - learning\_rate * \frac{moment\_1}{\sqrt{moment\_2} + \epsilon} +$$ +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(merged_adam, ops::MergedAdamOp, + ops::MergedAdamOpMaker); +REGISTER_OP_WITHOUT_GRADIENT(merged_adamw, ops::MergedAdamOp, + ops::MergedAdamOpMaker); + +REGISTER_OP_CPU_KERNEL( + merged_adam, + ops::MergedAdamOpKernel, + ops::MergedAdamOpKernel); diff --git a/paddle/fluid/operators/optimizers/merged_adam_op.cu b/paddle/fluid/operators/optimizers/merged_adam_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..2523fb9e5c680615c8e2e9193526de8776cf2b3c --- /dev/null +++ b/paddle/fluid/operators/optimizers/merged_adam_op.cu @@ -0,0 +1,191 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/operators/optimizers/merged_adam_op.h" + +namespace paddle { +namespace operators { + +template +__global__ void AdamKernelREG(MT beta1, MT beta2, MT epsilon, MT beta1_pow_, + MT beta2_pow_, const MT* moment1, MT* moment1_out, + const MT* moment2, MT* moment2_out, const MT* lr_, + const T* grad, const T* param, T* param_out, + const MT* master_param, MT* master_param_out, + int ndim) { + MT lr = *lr_; + MT beta1_pow = beta1_pow_; + MT beta2_pow = beta2_pow_; + + int id = blockIdx.x * blockDim.x + threadIdx.x; + + for (; id < ndim; id += gridDim.x * blockDim.x) { + MT p = master_param ? master_param[id] : static_cast(param[id]); + MT g = static_cast(grad[id]); + MT mom1 = static_cast(moment1[id]); + MT mom2 = static_cast(moment2[id]); + mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; + mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; + + MT denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; + p += (mom1 / denom) * (-(lr / (static_cast(1.0) - beta1_pow))); + + moment1_out[id] = mom1; + moment2_out[id] = mom2; + param_out[id] = static_cast(p); + if (master_param_out) { + master_param_out[id] = p; + } + } +} + +template +__global__ void AdamKernelMEM(MT beta1, MT beta2, MT epsilon, + const MT* beta1_pow_, const MT* beta2_pow_, + const MT* moment1, MT* moment1_out, + const MT* moment2, MT* moment2_out, const MT* lr_, + const T* grad, const T* param, T* param_out, + const MT* master_param, MT* master_param_out, + int ndim) { + MT lr = *lr_; + MT beta1_pow = *beta1_pow_; + MT beta2_pow = *beta2_pow_; + + int id = blockIdx.x * blockDim.x + threadIdx.x; + + for (; id < ndim; id += gridDim.x * blockDim.x) { + MT p = master_param ? master_param[id] : static_cast(param[id]); + MT g = static_cast(grad[id]); + MT mom1 = static_cast(moment1[id]); + MT mom2 = static_cast(moment2[id]); + mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; + mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; + + MT denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; + p += (mom1 / denom) * (-(lr / (static_cast(1.0) - beta1_pow))); + + moment1_out[id] = mom1; + moment2_out[id] = mom2; + param_out[id] = static_cast(p); + if (master_param_out) { + master_param_out[id] = p; + } + } +} + +template +__global__ void UpdateBetaPow(T beta1, T beta2, const T* beta1_pow_, + const T* beta2_pow_, T* beta1_pow_out, + T* beta2_pow_out) { + *beta1_pow_out = beta1 * beta1_pow_[0]; + *beta2_pow_out = beta2 * beta2_pow_[0]; +} + +template +class MergedAdamOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + using MPDType = typename details::MPTypeTrait::Type; + + auto param = ctx.MultiInput("Param"); + auto grad = ctx.MultiInput("Grad"); + auto lr = ctx.MultiInput("LearningRate"); + auto mom1 = ctx.MultiInput("Moment1"); + auto mom2 = ctx.MultiInput("Moment2"); + auto beta1_pow = ctx.MultiInput("Beta1Pow"); + auto beta2_pow = ctx.MultiInput("Beta2Pow"); + + auto param_out = ctx.MultiOutput("ParamOut"); + auto mom1_out = ctx.MultiOutput("Moment1Out"); + auto mom2_out = ctx.MultiOutput("Moment2Out"); + auto beta1_pow_out = ctx.MultiOutput("Beta1PowOut"); + auto beta2_pow_out = ctx.MultiOutput("Beta2PowOut"); + + MPDType beta1 = static_cast(ctx.Attr("beta1")); + MPDType beta2 = static_cast(ctx.Attr("beta2")); + MPDType epsilon = static_cast(ctx.Attr("epsilon")); + bool use_global_beta_pow = ctx.Attr("use_global_beta_pow"); + VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; + + const bool multi_precision = ctx.Attr("multi_precision"); + auto master_param = ctx.MultiInput("MasterParam"); + auto master_param_out = + ctx.MultiOutput("MasterParamOut"); + + auto& dev_ctx = ctx.template device_context(); + + size_t param_num = param.size(); + for (size_t idx = 0; idx < param_num; idx++) { + const MPDType* master_in_data = + multi_precision ? master_param[idx]->data() : nullptr; + MPDType* master_out_data = + multi_precision + ? master_param_out[idx]->mutable_data(ctx.GetPlace()) + : nullptr; + + // update param and moment + int threads = 512; + int blocks = (param[idx]->numel() + threads - 1) / threads; + + if (beta1_pow[idx]->place() == platform::CPUPlace() && + beta2_pow[idx]->place() == platform::CPUPlace()) { + // Compute with betapow in REG + AdamKernelREG<<>>( + beta1, beta2, epsilon, *beta1_pow[idx]->data(), + *beta2_pow[idx]->data(), mom1[idx]->data(), + mom1_out[idx]->mutable_data(ctx.GetPlace()), + mom2[idx]->data(), + mom2_out[idx]->mutable_data(ctx.GetPlace()), + lr[idx]->data(), grad[idx]->data(), + param[idx]->data(), + param_out[idx]->mutable_data(ctx.GetPlace()), master_in_data, + master_out_data, param[idx]->numel()); + if (!use_global_beta_pow) { + // Cpu update + beta1_pow_out[idx]->mutable_data(platform::CPUPlace())[0] = + beta1 * beta1_pow[idx]->data()[0]; + beta2_pow_out[idx]->mutable_data(platform::CPUPlace())[0] = + beta2 * beta2_pow[idx]->data()[0]; + } + } else { + AdamKernelMEM<<>>( + beta1, beta2, epsilon, beta1_pow[idx]->data(), + beta2_pow[idx]->data(), mom1[idx]->data(), + mom1_out[idx]->mutable_data(ctx.GetPlace()), + mom2[idx]->data(), + mom2_out[idx]->mutable_data(ctx.GetPlace()), + lr[idx]->data(), grad[idx]->data(), + param[idx]->data(), + param_out[idx]->mutable_data(ctx.GetPlace()), master_in_data, + master_out_data, param[idx]->numel()); + if (!use_global_beta_pow) { + // Update with gpu + UpdateBetaPow<<<1, 32, 0, dev_ctx.stream()>>>( + beta1, beta2, beta1_pow[idx]->data(), + beta2_pow[idx]->data(), + beta1_pow_out[idx]->mutable_data(ctx.GetPlace()), + beta2_pow_out[idx]->mutable_data(ctx.GetPlace())); + } + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(merged_adam, ops::MergedAdamOpCUDAKernel, + ops::MergedAdamOpCUDAKernel, + ops::MergedAdamOpCUDAKernel); diff --git a/paddle/fluid/operators/optimizers/merged_adam_op.h b/paddle/fluid/operators/optimizers/merged_adam_op.h new file mode 100644 index 0000000000000000000000000000000000000000..c9417158fe772817b0c50b3eb2f4183a5f094380 --- /dev/null +++ b/paddle/fluid/operators/optimizers/merged_adam_op.h @@ -0,0 +1,104 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/optimizers/adam_op.h" + +namespace paddle { +namespace operators { + +namespace scatter = paddle::operators::math::scatter; + +template +class MergedAdamOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto param = ctx.MultiInput("Param"); + size_t n = param.size(); + auto grad = ctx.MultiInput("Grad"); + PADDLE_ENFORCE_EQ(n, grad.size(), + platform::errors::InvalidArgument( + "The size of Input(Grad) must be equal to " + "Input(Param), but got the size of Input(Grad) " + "is %d, the size of Input(Param) is %d.", + grad.size(), n)); + auto lr = ctx.MultiInput("LearningRate"); + PADDLE_ENFORCE_EQ( + n, lr.size(), + platform::errors::InvalidArgument( + "The size of Input(LearningRate) must be equal to " + "Input(Param), but got the size of Input(LearningRate) " + "is %d, the size of Input(Param) is %d.", + lr.size(), n)); + auto mom1 = ctx.MultiInput("Moment1"); + PADDLE_ENFORCE_EQ(n, mom1.size(), + platform::errors::InvalidArgument( + "The size of Input(Moment1) must be equal to " + "Input(Param), but got the size of Input(Moment1) " + "is %d, the size of Input(Param) is %d.", + mom1.size(), n)); + auto mom2 = ctx.MultiInput("Moment2"); + PADDLE_ENFORCE_EQ(n, mom2.size(), + platform::errors::InvalidArgument( + "The size of Input(Moment2) must be equal to " + "Input(Param), but got the size of Input(Moment2) " + "is %d, the size of Input(Param) is %d.", + mom2.size(), n)); + auto beta1_pow = ctx.MultiInput("Beta1Pow"); + PADDLE_ENFORCE_EQ(n, beta1_pow.size(), + platform::errors::InvalidArgument( + "The size of Input(Beta1Pow) must be equal to " + "Input(Param), but got the size of Input(Beta1Pow) " + "is %d, the size of Input(Param) is %d.", + beta1_pow.size(), n)); + auto beta2_pow = ctx.MultiInput("Beta2Pow"); + PADDLE_ENFORCE_EQ(n, beta2_pow.size(), + platform::errors::InvalidArgument( + "The size of Input(Beta2Pow) must be equal to " + "Input(Param), but got the size of Input(Beta2Pow) " + "is %d, the size of Input(Param) is %d.", + beta2_pow.size(), n)); + + auto param_out = ctx.MultiOutput("ParamOut"); + auto mom1_out = ctx.MultiOutput("Moment1Out"); + auto mom2_out = ctx.MultiOutput("Moment2Out"); + auto beta1_pow_out = ctx.MultiOutput("Beta1PowOut"); + auto beta2_pow_out = ctx.MultiOutput("Beta2PowOut"); + + T beta1 = static_cast(ctx.Attr("beta1")); + T beta2 = static_cast(ctx.Attr("beta2")); + T epsilon = static_cast(ctx.Attr("epsilon")); + bool use_global_beta_pow = ctx.Attr("use_global_beta_pow"); + VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; + + size_t param_num = param.size(); + for (size_t idx = 0; idx < param_num; idx++) { + AdamFunctor functor( + beta1, beta2, epsilon, beta1_pow[idx]->data(), + beta2_pow[idx]->data(), mom1[idx]->data(), + mom1_out[idx]->mutable_data(ctx.GetPlace()), mom2[idx]->data(), + mom2_out[idx]->mutable_data(ctx.GetPlace()), lr[idx]->data(), + grad[idx]->data(), param[idx]->data(), + param_out[idx]->mutable_data(ctx.GetPlace())); + functor(param[idx]->numel()); + if (!use_global_beta_pow) { + beta1_pow_out[idx]->mutable_data(ctx.GetPlace())[0] = + beta1 * beta1_pow[idx]->data()[0]; + beta2_pow_out[idx]->mutable_data(ctx.GetPlace())[0] = + beta2 * beta2_pow[idx]->data()[0]; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h index e14b836bf083024e73b058663fe3d5c4d90cc53d..f83997843f433d6b150631469d807de636492227 100644 --- a/paddle/fluid/pybind/op_function_generator.h +++ b/paddle/fluid/pybind/op_function_generator.h @@ -71,6 +71,9 @@ std::map> op_ins_map = { {"adam", {"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow", "Beta2Pow", "MasterParam"}}, + {"merged_adam", + {"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow", + "Beta2Pow", "MasterParam"}}, {"adamw", {"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow", "Beta2Pow", "MasterParam"}}, @@ -123,6 +126,9 @@ std::map> op_outs_map = { {"adam", {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", "MasterParamOut"}}, + {"merged_adam", + {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", + "MasterParamOut"}}, {"adamw", {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", "MasterParamOut"}}, @@ -148,6 +154,9 @@ std::map> op_passing_outs_map = { {"adam", {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", "MasterParamOut"}}, + {"merged_adam", + {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", + "MasterParamOut"}}, {"adamw", {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", "MasterParamOut"}}, diff --git a/python/paddle/fluid/tests/unittests/test_adam_op.py b/python/paddle/fluid/tests/unittests/test_adam_op.py index 70109164960a33d732063405f6dc5afbf54984dc..a06f0d390e517d6434b5232c3eb3c5d9b0115150 100644 --- a/python/paddle/fluid/tests/unittests/test_adam_op.py +++ b/python/paddle/fluid/tests/unittests/test_adam_op.py @@ -1011,5 +1011,186 @@ class TestAdamOpV2Group(TestAdamOpV2): adam.clear_gradients() +class TestMultiTensorAdam(unittest.TestCase): + def _adam_optimize_dygraph(self, + place, + use_param_attr=False, + use_param_group=False, + use_amp=False, + use_multi_tensor=False): + paddle.disable_static() + paddle.seed(10) + paddle.set_device(place) + + input = paddle.randn((5, 5)) + + weight_attr = paddle.ParamAttr( + learning_rate=0.5, + regularizer=paddle.regularizer.L2Decay(1.0), + trainable=True) + if use_param_attr: + model = paddle.nn.Linear(5, 5, weight_attr) + else: + model = paddle.nn.Linear(5, 5) + + if not use_param_group: + optimizer = paddle.optimizer.Adam( + parameters=model.parameters(), + use_multi_tensor=use_multi_tensor, + multi_precision=use_amp) + else: + optimizer = paddle.optimizer.Adam( + parameters=[{ + 'params': model.parameters(), + 'weight_decay': 0.001, + 'beta1': 0.1, + 'beta2': 0.99 + }], + use_multi_tensor=use_multi_tensor, + multi_precision=use_amp) + + for idx in range(2): + if place == 'gpu' and use_amp == True: + model = paddle.amp.decorate(models=model, level='O2') + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + + if place == 'gpu' and use_amp == True: + with paddle.amp.auto_cast(level='O2'): + output = model(input) + loss = paddle.mean(output) + scaled = scaler.scale(loss) + scaled.backward() + scaler.step(optimizer) + optimizer.clear_grad() + else: + output = model(input) + loss = paddle.mean(output) + loss.backward() + optimizer.step() + optimizer.clear_grad() + + return output, model.parameters() + + def _adam_optimize_static(self, + place, + use_amp=False, + use_multi_tensor=False): + paddle.enable_static() + paddle.seed(10) + np.random.seed(10) + if place == 'cpu': + use_amp = False + exe = paddle.static.Executor(place=place) + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + optimizer = paddle.optimizer.Adam( + multi_precision=use_amp, use_multi_tensor=use_multi_tensor) + if use_amp: + optimizer = paddle.static.amp.decorate( + optimizer, + init_loss_scaling=128.0, + use_dynamic_loss_scaling=True, + use_pure_fp16=True, + use_fp16_guard=False) + with paddle.static.program_guard(train_program, startup_program): + if use_amp: + data = paddle.static.data( + shape=[2, 2], name='X', dtype='float16') + else: + data = paddle.static.data( + shape=[2, 2], name='X', dtype='float32') + hidden = paddle.static.nn.fc(x=data, size=10) + loss = paddle.fluid.layers.mean(hidden) + optimizer.minimize(loss) + exe.run(startup_program) + if use_amp: + optimizer.amp_init(place=place, scope=paddle.static.global_scope()) + x = np.random.random(size=(2, 2)).astype('float16') + else: + x = np.random.random(size=(2, 2)).astype('float32') + out = [] + for idx in range(5): + loss_data, = exe.run(train_program, + feed={"X": x}, + fetch_list=[loss.name]) + out.append(loss_data) + return out + + def _get_places(self): + places = ['cpu'] + if paddle.is_compiled_with_cuda(): + places.append('gpu') + return places + + def _check_with_place_amp(self, place, use_amp): + # test dygraph mode + output_dygraph1, params_dygraph1 = self._adam_optimize_dygraph( + place=place, use_amp=use_amp, use_multi_tensor=True) + output_dygraph2, params_dygraph2 = self._adam_optimize_dygraph( + place=place, use_amp=use_amp, use_multi_tensor=False) + self.assertEqual( + np.allclose( + output_dygraph1, output_dygraph2, rtol=1e-05), True) + for idx in range(len(params_dygraph1)): + self.assertEqual( + np.allclose( + params_dygraph1[idx], params_dygraph2[idx], rtol=1e-05), + True) + # test static mode + output_static1 = self._adam_optimize_static( + place=place, use_amp=use_amp, use_multi_tensor=True) + output_static2 = self._adam_optimize_static( + place=place, use_amp=use_amp, use_multi_tensor=False) + for idx in range(len(output_static1)): + self.assertEqual( + np.allclose( + output_static1[idx], output_static2[idx], rtol=1e-05), + True) + + def _check_with_param_arrt(self, place, use_amp): + output1, params1 = self._adam_optimize_dygraph( + place=place, + use_amp=use_amp, + use_param_attr=True, + use_multi_tensor=True) + output2, params2 = self._adam_optimize_dygraph( + place=place, + use_amp=use_amp, + use_param_attr=True, + use_multi_tensor=False) + + self.assertEqual(np.allclose(output1, output2, rtol=1e-05), True) + for idx in range(len(params1)): + self.assertEqual( + np.allclose( + params1[idx], params2[idx], rtol=1e-05), True) + + def _check_with_param_group(self, place, use_amp): + output1, params1 = self._adam_optimize_dygraph( + place=place, + use_amp=use_amp, + use_param_group=True, + use_multi_tensor=True) + output2, params2 = self._adam_optimize_dygraph( + place=place, + use_amp=use_amp, + use_param_group=True, + use_multi_tensor=False) + + self.assertEqual(np.allclose(output1, output2, rtol=1e-05), True) + for idx in range(len(params1)): + self.assertEqual( + np.allclose( + params1[idx], params2[idx], rtol=1e-05), True) + + def test_main(self): + for place in self._get_places(): + use_amp_list = [True, False] + for use_amp in use_amp_list: + self._check_with_place_amp(place, use_amp) + self._check_with_param_arrt(place, use_amp) + self._check_with_param_group(place, use_amp) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_merged_adam_op.py b/python/paddle/fluid/tests/unittests/test_merged_adam_op.py new file mode 100644 index 0000000000000000000000000000000000000000..f515a9f95b109333e4c4b48f69e4c453ca8eb176 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_merged_adam_op.py @@ -0,0 +1,157 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import numpy as np +from paddle import _C_ops + + +def run_adam_op(params, + grads, + lrs, + moment1s, + moment2s, + beta1_pows, + beta2_pows, + master_params, + epsilon, + beta1, + beta2, + place, + multi_precision=False, + use_merged=False): + assert len(params) == len(grads) + assert len(params) == len(lrs) + assert len(params) == len(moment1s) + assert len(params) == len(moment2s) + assert len(params) == len(beta1_pows) + assert len(params) == len(beta1_pows) + assert len(params) == len(master_params) + paddle.disable_static() + paddle.set_device(place) + + param_vars = [paddle.fluid.dygraph.to_variable(p) for p in params] + grad_vars = [paddle.fluid.dygraph.to_variable(g) for g in grads] + lr_vars = [paddle.fluid.dygraph.to_variable(l) for l in lrs] + moment1_vars = [paddle.fluid.dygraph.to_variable(m) for m in moment1s] + moment2_vars = [paddle.fluid.dygraph.to_variable(m) for m in moment2s] + beta1_pow_vars = [paddle.fluid.dygraph.to_variable(b) for b in beta1_pows] + beta2_pow_vars = [paddle.fluid.dygraph.to_variable(b) for b in beta2_pows] + master_param_vars = [ + paddle.fluid.dygraph.to_variable(m_p) for m_p in master_params + ] + + if not use_merged: + for i in range(len(param_vars)): + _, _, _, _, _, _ = _C_ops.adam( + param_vars[i], grad_vars[i], lr_vars[i], moment1_vars[i], + moment2_vars[i], beta1_pow_vars[i], beta2_pow_vars[i], + master_param_vars[i], param_vars[i], moment1_vars[i], + moment2_vars[i], beta1_pow_vars[i], beta2_pow_vars[i], + master_param_vars[i], 'epsilon', epsilon, 'beta1', beta1, + 'beta2', beta2, 'multi_precision', multi_precision) + else: + _, _, _, _, _, _ = _C_ops.merged_adam( + param_vars, grad_vars, lr_vars, moment1_vars, moment2_vars, + beta1_pow_vars, beta2_pow_vars, master_param_vars, param_vars, + moment1_vars, moment2_vars, beta1_pow_vars, beta2_pow_vars, + master_param_vars, 'epsilon', epsilon, 'beta1', beta1, 'beta2', + beta2, 'multi_precision', multi_precision) + + outputs = { + 'ParamOut': param_vars, + 'Moment1Out': moment1_vars, + 'Moment2Out': moment2_vars, + 'Beta1PowOut': beta1_pow_vars, + 'Beta2PowOut': beta2_pow_vars, + 'MasterParamOut': master_param_vars + } + + return outputs + + +class TestMergedAdam(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self.shapes = [[3, 4], [2, 7], [5, 6], [7, 8]] + self.seed = 10 + + def gen_rand_data(self, shapes, dtype): + return [np.random.random(s).astype(dtype) for s in shapes] + + def prepare_data(self, shapes, multi_precision, seed, place): + np.random.seed(seed) + mp_dtype = np.float32 + dtype = np.float16 if multi_precision and place == 'gpu' else np.float32 + params = self.gen_rand_data(shapes, dtype) + grads = self.gen_rand_data(shapes, dtype) + lrs = self.gen_rand_data([[1], [1], [1], [1]], mp_dtype) + moment1s = self.gen_rand_data(shapes, mp_dtype) + moment2s = self.gen_rand_data(shapes, mp_dtype) + beta1_pows = self.gen_rand_data([[1], [1], [1], [1]], mp_dtype) + beta2_pows = self.gen_rand_data([[1], [1], [1], [1]], mp_dtype) + master_params = [p.astype(mp_dtype) for p in params] + return params, grads, lrs, moment1s, moment2s, beta1_pows, beta2_pows, master_params + + def check_with_place(self, place, multi_precision): + params, grads, lrs, moment1s, moment2s, beta1_pows, beta2_pows, master_params = self.prepare_data( + self.shapes, multi_precision, self.seed, place) + + def run_op(use_merged): + return run_adam_op( + params=params, + grads=grads, + lrs=lrs, + moment1s=moment1s, + moment2s=moment2s, + beta1_pows=beta1_pows, + beta2_pows=beta2_pows, + master_params=master_params, + epsilon=0.9, + beta1=0.9, + beta2=0.99, + place=place, + multi_precision=multi_precision, + use_merged=use_merged) + + outs1 = run_op(True) + outs2 = run_op(False) + self.assertEqual(len(outs1), len(outs2)) + + for key in outs1.keys(): + value1 = outs1[key] + value2 = outs2[key] + for i in range(len(value1)): + if place == 'gpu': + self.assertTrue(np.array_equal(value1[i], value2[i])) + else: + self.assertTrue( + np.allclose( + value1[i], value2[i], atol=1e-7)) + + def get_places(self): + places = ['cpu'] + if paddle.is_compiled_with_cuda(): + places.append('gpu') + return places + + def test_main(self): + for multi_precision in [False, True]: + for place in self.get_places(): + self.check_with_place(place, multi_precision) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/optimizer/adam.py b/python/paddle/optimizer/adam.py index cc28eead522d49fd4588f6252b78e890cb3fa271..8134c9f71b6699c2491e1fe84101f3a4ecbd56db 100644 --- a/python/paddle/optimizer/adam.py +++ b/python/paddle/optimizer/adam.py @@ -92,6 +92,7 @@ class Adam(Optimizer): different semantics with the original Adam algorithm and may lead to different result. The default value is False. multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false. + use_multi_tensor (bool, optional): Whether to use multi-tensor strategy to update all parameters at once . Default is false. name (str, optional): Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. The default value is None. @@ -172,6 +173,7 @@ class Adam(Optimizer): grad_clip=None, lazy_mode=False, multi_precision=False, + use_multi_tensor=False, name=None): assert learning_rate is not None assert beta1 is not None @@ -209,6 +211,24 @@ class Adam(Optimizer): 'lazy_mode': lazy_mode, } + self._use_multi_tensor = use_multi_tensor + if self._use_multi_tensor: + self._param_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []} + self._moment1_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []} + self._moment2_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []} + self._beta1_pow_acc_dict = { + 'FP32_LODTensor': [], + 'FP16_LODTensor': [] + } + self._beta2_pow_acc_dict = { + 'FP32_LODTensor': [], + 'FP16_LODTensor': [] + } + self._master_weight_dict = { + 'FP32_LODTensor': None, + 'FP16_LODTensor': [] + } + def _create_master_weight(self, param): if param.name in self._master_weights: var = self._master_weights[param.name] @@ -436,6 +456,157 @@ class Adam(Optimizer): self._apply_optimize( loss=None, startup_program=None, params_grads=params_grads) + def _multi_tensor_init(self, target_block, parameters): + """ + All parameters used for optimizer (such as: parameters, master_weight, velocity_acc for momentum) calculations are grouped into a python list by data type (float16, float32). + This function will be overridden in the corresponding optimizer file. + Args: + target_block: the block in which the loss tensor is present + parameters: list of parameter tensors for the optimizer + """ + self._create_accumulators(target_block, parameters) + for param in parameters: + moment1 = self._get_accumulator(self._moment1_acc_str, param) + moment2 = self._get_accumulator(self._moment2_acc_str, param) + beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str, + param) + beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str, + param) + + if param.dtype == paddle.float32: + self._param_dict['FP32_LODTensor'].append(param) + self._moment1_dict['FP32_LODTensor'].append(moment1) + self._moment2_dict['FP32_LODTensor'].append(moment2) + self._beta1_pow_acc_dict['FP32_LODTensor'].append(beta1_pow_acc) + self._beta2_pow_acc_dict['FP32_LODTensor'].append(beta2_pow_acc) + elif param.dtype == paddle.float16: + self._param_dict['FP16_LODTensor'].append(param) + self._moment1_dict['FP16_LODTensor'].append(moment1) + self._moment2_dict['FP16_LODTensor'].append(moment2) + self._beta1_pow_acc_dict['FP16_LODTensor'].append(beta1_pow_acc) + self._beta2_pow_acc_dict['FP16_LODTensor'].append(beta2_pow_acc) + if self._multi_precision: + self._master_weight_dict['FP16_LODTensor'].append( + self._master_weights[param.name]) + else: + self._master_weight_dict['FP16_LODTensor'] = None + else: + raise ValueError( + "Now multi_tensor_momentum only support fp32 and fp16 parameters and grad is LOD_TENSOR." + ) + + def _append_optimize_multi_tensor_op(self, target_block, + parameters_and_grads): + """ + For Multi Tensor, append optimize merged_operator to block. + """ + assert isinstance(target_block, framework.Block) + + grad_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []} + lr_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []} + + if isinstance(parameters_and_grads, list): + for param_and_grad in parameters_and_grads: + if param_and_grad[1] is None: + continue + if param_and_grad[0].stop_gradient is False: + if param_and_grad[ + 0].dtype == paddle.float32 and param_and_grad[ + 1].type == core.VarDesc.VarType.LOD_TENSOR: + grad_dict['FP32_LODTensor'].append(param_and_grad[1]) + lr = self._create_param_lr(param_and_grad) + lr_dict['FP32_LODTensor'].append(lr) + elif param_and_grad[ + 0].dtype == paddle.float16 and param_and_grad[ + 1].type == core.VarDesc.VarType.LOD_TENSOR: + grad_dict['FP16_LODTensor'].append(param_and_grad[1]) + lr = self._create_param_lr(param_and_grad) + lr_dict['FP16_LODTensor'].append(lr) + else: + for param_and_grad in parameters_and_grads['params']: + if param_and_grad[1] is None: + continue + if param_and_grad[0].stop_gradient is False: + param_grad_dict = dict() + param_grad_dict['params'] = param_and_grad + param_grad_dict.update({ + k: v + for k, v in parameters_and_grads.items() + if k != 'params' + }) + param_and_grad = self._update_param_group(param_grad_dict) + if param_and_grad[ + 0].dtype == paddle.float32 and param_and_grad[ + 1].type == core.VarDesc.VarType.LOD_TENSOR: + grad_dict['FP32_LODTensor'].append(param_and_grad[1]) + lr = self._create_param_lr(param_and_grad) + lr_dict['FP32_LODTensor'].append(lr) + elif param_and_grad[ + 0].dtype == paddle.float16 and param_and_grad[ + 1].type == core.VarDesc.VarType.LOD_TENSOR: + grad_dict['FP16_LODTensor'].append(param_and_grad[1]) + lr = self._create_param_lr(param_and_grad) + lr_dict['FP16_LODTensor'].append(lr) + + multi_tensor_list = ['FP32_LODTensor', 'FP16_LODTensor'] + for key in multi_tensor_list: + if len(self._param_dict[key]) > 0: + if key == 'FP32_LODTensor': + self._multi_precision = False + + _beta1 = self._beta1 if not isinstance( + self._beta1, Variable) else self._beta1.numpy().item(0) + _beta2 = self._beta2 if not isinstance( + self._beta2, Variable) else self._beta2.numpy().item(0) + + if framework.in_dygraph_mode(): + _, _, _, _, _, _ = _C_ops.merged_adam( + self._param_dict[key], grad_dict[key], lr_dict[key], + self._moment1_dict[key], self._moment2_dict[key], + self._beta1_pow_acc_dict[key], + self._beta2_pow_acc_dict[key], + self._master_weight_dict[key], self._param_dict[key], + self._moment1_dict[key], self._moment2_dict[key], + self._beta1_pow_acc_dict[key], + self._beta2_pow_acc_dict[key], + self._master_weight_dict[key], 'epsilon', self._epsilon, + 'beta1', _beta1, 'beta2', _beta2, 'multi_precision', + self._multi_precision) + else: + inputs = { + "Param": self._param_dict[key], + "Grad": grad_dict[key], + "LearningRate": lr_dict[key], + "Moment1": self._moment1_dict[key], + "Moment2": self._moment2_dict[key], + "Beta1Pow": self._beta1_pow_acc_dict[key], + "Beta2Pow": self._beta2_pow_acc_dict[key] + } + outputs = { + "ParamOut": self._param_dict[key], + "Moment1Out": self._moment1_dict[key], + "Moment2Out": self._moment2_dict[key], + "Beta1PowOut": self._beta1_pow_acc_dict[key], + "Beta2PowOut": self._beta2_pow_acc_dict[key] + } + attrs = { + "epsilon": self._epsilon, + "beta1": _beta1, + "beta2": _beta2 + } + if self._multi_precision: + inputs["MasterParam"] = self._master_weight_dict[key] + outputs["MasterParamOut"] = self._master_weight_dict[ + key] + attrs["multi_precision"] = self._multi_precision + target_block.append_op( + type="merged_adam", + inputs=inputs, + outputs=outputs, + attrs=attrs, + stop_gradient=True) + return None + def _update_param_group(self, parameters): self._beta1 = parameters.get('beta1', self._default_dict['beta1']) self._beta2 = parameters.get('beta2', self._default_dict['beta2']) diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index a711d98df6fa17cfd1c577d70334c771a0f35a18..3fc70449d15c970a57ab77d4f155dad9bce60854 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -218,7 +218,7 @@ class Optimizer(object): self._param_groups = self._parameter_list # NOTE: Multi Tensor: Pass in all parameters and gradients to the op kernel of the Optimizer at one time for updating for dygraph mode. - # Optimizer support list: [ paddle.optimizer.Momentum ]. + # Optimizer support list: [ paddle.optimizer.Momentum, paddle.optimizer.Adam]. self._use_multi_tensor = None self._param_dict = {'FP32_LODTensor': [], 'FP16_LODTensor': []} @@ -684,8 +684,10 @@ class Optimizer(object): self._create_global_learning_rate() - # NOTE: Multi Tensor support [ Momentum ] for dygraph mode - if self._use_multi_tensor and self.__class__.__name__ in ['Momentum']: + # NOTE: Multi Tensor support [ Momentum, Adam ] for dygraph mode + if self._use_multi_tensor and self.__class__.__name__ in [ + 'Momentum', 'Adam' + ]: if len(self._param_dict['FP32_LODTensor']) == 0 and len( self._param_dict['FP16_LODTensor']) == 0: if isinstance(parameters_and_grads, list):