From 7f7dfccf20347eb9f0600b15a6472c32f1c34c4b Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Fri, 8 Jan 2021 20:35:18 +0800 Subject: [PATCH] Support pure fp16 training for AMP API. (#29544) * add cast ops before and after unsupported fp16 ops. * Keep partial net in FP32 pattern. * Support check_finite_and_unscale and update_loss_scaling for FP16 calculation mode. * Add fp16 support for adam op. * add multi precision attr for adam. * Fix the bug of test_multi_precision_fp16_train UT. * Code format for CI. * Fix the redefine error about MPTypeTrait on windows. * fix bugs of the _create_accumulators func in Momentum. * fix bug when inserting post cast op. * Add the update_loss_scaling op in allow_set of UnusedVarCheck. * Update for ci coverage. * Add some doc for OptimizerWithMixedPrecision. * Fix the code style. * Imporve the doc of `amp_init`. * Change for fp16 testing if users have the infer program defined in separate way. --- paddle/fluid/framework/unused_var_check.cc | 3 +- .../amp/check_finite_and_unscale_op.cu | 30 +- paddle/fluid/operators/amp/fp16_type_traits.h | 37 +++ .../operators/amp/update_loss_scaling_op.cc | 6 +- .../operators/amp/update_loss_scaling_op.cu | 5 +- .../operators/amp/update_loss_scaling_op.h | 34 ++- paddle/fluid/operators/optimizers/adam_op.cc | 20 ++ paddle/fluid/operators/optimizers/adam_op.cu | 257 +++++++++++------- paddle/fluid/operators/optimizers/adam_op.h | 76 +++--- .../fluid/operators/optimizers/momentum_op.cc | 3 +- .../fluid/operators/optimizers/momentum_op.h | 12 +- .../fluid/contrib/mixed_precision/amp_nn.py | 13 +- .../contrib/mixed_precision/decorator.py | 181 ++++++++++-- .../contrib/mixed_precision/fp16_lists.py | 14 +- .../contrib/mixed_precision/fp16_utils.py | 224 +++++++++++---- .../tests/test_multi_precision_fp16_train.py | 95 +++---- python/paddle/optimizer/adam.py | 111 ++++++-- python/paddle/optimizer/adamw.py | 5 +- python/paddle/optimizer/momentum.py | 30 +- 19 files changed, 815 insertions(+), 341 deletions(-) create mode 100644 paddle/fluid/operators/amp/fp16_type_traits.h diff --git a/paddle/fluid/framework/unused_var_check.cc b/paddle/fluid/framework/unused_var_check.cc index ac455b9ffd..dc20632824 100644 --- a/paddle/fluid/framework/unused_var_check.cc +++ b/paddle/fluid/framework/unused_var_check.cc @@ -73,7 +73,8 @@ static const std::unordered_set &GetOpWithUnusedVarAllowSet() { "fused_batch_norm_act", // 2 "fused_batch_norm_act_grad", // 2 "data_norm", // 0 - "data_norm_grad", // 0); + "data_norm_grad", // 0 + "update_loss_scaling", // 0 }); return *allow_set; } diff --git a/paddle/fluid/operators/amp/check_finite_and_unscale_op.cu b/paddle/fluid/operators/amp/check_finite_and_unscale_op.cu index 6b60d989d2..e28a3c1b6d 100644 --- a/paddle/fluid/operators/amp/check_finite_and_unscale_op.cu +++ b/paddle/fluid/operators/amp/check_finite_and_unscale_op.cu @@ -15,6 +15,8 @@ limitations under the License. */ #include #include "paddle/fluid/operators/amp/check_finite_and_unscale_op.h" +#include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -25,15 +27,16 @@ __global__ void InverseAndMemset(const T* s, T* o, bool* found_inf) { *found_inf = false; } -template -__global__ void CheckFiniteAndUnscale(const T* in, const T* scale, int num, +template +__global__ void CheckFiniteAndUnscale(const T* in, const MT* scale, int num, bool* found_inf, T* out) { const int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx < num) { - T val = in[idx] * (*scale); - out[idx] = val; - if (!isfinite(val)) { + MT val = static_cast(in[idx]) * (*scale); + T narrow_val = static_cast(val); + out[idx] = narrow_val; + if (!isfinite(narrow_val)) { *found_inf = true; } } @@ -41,6 +44,8 @@ __global__ void CheckFiniteAndUnscale(const T* in, const T* scale, int num, template class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel { + using MPDType = typename details::MPTypeTrait::Type; + public: void Compute(const framework::ExecutionContext& ctx) const { auto& dev_ctx = ctx.template device_context(); @@ -49,14 +54,15 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel { auto outs = ctx.MultiOutput("Out"); auto* found_inf = ctx.Output("FoundInfinite"); - const T* scale_data = scale->data(); + const MPDType* scale_data = scale->data(); bool* found_inf_data = found_inf->mutable_data(dev_ctx.GetPlace()); framework::Tensor inverse_scale = - ctx.AllocateTmpTensor({1}, dev_ctx); - T* inverse_scale_v = inverse_scale.template data(); + ctx.AllocateTmpTensor({1}, + dev_ctx); + MPDType* inverse_scale_v = inverse_scale.template data(); - InverseAndMemset<<<1, 1, 0, dev_ctx.stream()>>>( + InverseAndMemset<<<1, 1, 0, dev_ctx.stream()>>>( scale_data, inverse_scale_v, found_inf_data); for (size_t i = 0; i < xs.size(); ++i) { @@ -69,7 +75,7 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel { int block = 1024; int grid = (num + block - 1) / block; VLOG(3) << "launch kernel"; - CheckFiniteAndUnscale<<>>( + CheckFiniteAndUnscale<<>>( x_data, inverse_scale_v, num, found_inf_data, out_data); VLOG(3) << "finish kernel"; } @@ -79,6 +85,8 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(check_finite_and_unscale, ops::CheckFiniteAndUnscaleGpuKernel, - ops::CheckFiniteAndUnscaleGpuKernel); + ops::CheckFiniteAndUnscaleGpuKernel, + ops::CheckFiniteAndUnscaleGpuKernel); diff --git a/paddle/fluid/operators/amp/fp16_type_traits.h b/paddle/fluid/operators/amp/fp16_type_traits.h new file mode 100644 index 0000000000..f7aa0de975 --- /dev/null +++ b/paddle/fluid/operators/amp/fp16_type_traits.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { +namespace details { + +template +class MPTypeTrait { + public: + using Type = T; +}; + +template <> +class MPTypeTrait { + public: + using Type = float; +}; + +} // namespace details +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/amp/update_loss_scaling_op.cc b/paddle/fluid/operators/amp/update_loss_scaling_op.cc index e4d9042151..1ac286ef4a 100644 --- a/paddle/fluid/operators/amp/update_loss_scaling_op.cc +++ b/paddle/fluid/operators/amp/update_loss_scaling_op.cc @@ -54,8 +54,7 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "PrevLossScaling"), - ctx.device_context()); + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } }; @@ -107,6 +106,9 @@ class UpdateLossScalingOpMaker : public framework::OpProtoAndCheckerMaker { "the received is %f", decr_ratio)); }); + AddAttr("stop_update", + "Stop updating loss scaling, and just zero inputs.") + .SetDefault(false); AddComment(R"DOC( Update loss scaling according to overall gradients. If all gradients is finite after incr_every_n_steps, loss scaling will increase by incr_ratio. diff --git a/paddle/fluid/operators/amp/update_loss_scaling_op.cu b/paddle/fluid/operators/amp/update_loss_scaling_op.cu index ee6186e1f9..b48b0e7889 100644 --- a/paddle/fluid/operators/amp/update_loss_scaling_op.cu +++ b/paddle/fluid/operators/amp/update_loss_scaling_op.cu @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/amp/update_loss_scaling_op.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -83,8 +84,10 @@ class LazyZeros { } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; using GPU = paddle::platform::CUDADeviceContext; REGISTER_OP_CUDA_KERNEL(update_loss_scaling, ops::UpdateLossScalingKernel, - ops::UpdateLossScalingKernel); + ops::UpdateLossScalingKernel, + ops::UpdateLossScalingKernel); diff --git a/paddle/fluid/operators/amp/update_loss_scaling_op.h b/paddle/fluid/operators/amp/update_loss_scaling_op.h index 89de9c645f..db768f3f87 100644 --- a/paddle/fluid/operators/amp/update_loss_scaling_op.h +++ b/paddle/fluid/operators/amp/update_loss_scaling_op.h @@ -17,6 +17,7 @@ #include #include #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/errors.h" @@ -79,30 +80,38 @@ class LazyZeros { template class UpdateLossScalingKernel : public framework::OpKernel { + using MPDType = typename details::MPTypeTrait::Type; + public: void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = ctx.template device_context(); + const auto xs = ctx.MultiInput("X"); + auto outs = ctx.MultiOutput("Out"); const auto* found_inf = ctx.Input("FoundInfinite"); + PADDLE_ENFORCE_EQ(found_inf->numel(), 1, + platform::errors::InvalidArgument( + "FoundInfinite must has only one element.")); + const bool* found_inf_data = found_inf->data(); + + LazyZeros{}(dev_ctx, found_inf_data, xs, outs); + const bool stop_update = ctx.Attr("stop_update"); + if (stop_update) { + return; + } + const auto* pre_loss_scaling = ctx.Input("PrevLossScaling"); const auto* good_in = ctx.Input("InGoodSteps"); const auto* bad_in = ctx.Input("InBadSteps"); - auto outs = ctx.MultiOutput("Out"); auto* updated_loss_scaling = ctx.Output("LossScaling"); auto* good_out = ctx.Output("OutGoodSteps"); auto* bad_out = ctx.Output("OutBadSteps"); - - PADDLE_ENFORCE_EQ(found_inf->numel(), 1, - platform::errors::InvalidArgument( - "FoundInfinite must has only one element.")); - - const bool* found_inf_data = found_inf->data(); - const T* pre_loss_scaling_data = pre_loss_scaling->data(); + const MPDType* pre_loss_scaling_data = pre_loss_scaling->data(); const int* good_in_data = good_in->data(); const int* bad_in_data = bad_in->data(); - auto& dev_ctx = ctx.template device_context(); - T* updated_loss_scaling_data = - updated_loss_scaling->mutable_data(dev_ctx.GetPlace()); + MPDType* updated_loss_scaling_data = + updated_loss_scaling->mutable_data(dev_ctx.GetPlace()); int* good_out_data = good_out->mutable_data(dev_ctx.GetPlace()); int* bad_out_data = bad_out->mutable_data(dev_ctx.GetPlace()); @@ -111,11 +120,10 @@ class UpdateLossScalingKernel : public framework::OpKernel { ctx.Attr("decr_every_n_nan_or_inf"); const float incr_ratio = ctx.Attr("incr_ratio"); const float decr_ratio = ctx.Attr("decr_ratio"); - UpdateLossScalingFunctor{}( + UpdateLossScalingFunctor{}( dev_ctx, found_inf_data, pre_loss_scaling_data, good_in_data, bad_in_data, incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, decr_ratio, updated_loss_scaling_data, good_out_data, bad_out_data); - LazyZeros{}(dev_ctx, found_inf_data, xs, outs); } }; diff --git a/paddle/fluid/operators/optimizers/adam_op.cc b/paddle/fluid/operators/optimizers/adam_op.cc index 8e4cce68ac..621920731f 100644 --- a/paddle/fluid/operators/optimizers/adam_op.cc +++ b/paddle/fluid/operators/optimizers/adam_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/optimizers/adam_op.h" +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace operators { @@ -150,12 +151,17 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker { "as beta2, this has a higher priority than attr(beta2), the " "shape of this tensor MUST BE [1].") .AsDispensable(); + AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable(); AddOutput("ParamOut", "(Tensor) Output parameter"); AddOutput("Moment1Out", "(Tensor) Output first moment"); AddOutput("Moment2Out", "(Tensor) Output second moment"); AddOutput("Beta1PowOut", "(Tensor) Output beta1 power accumulator"); AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator"); + AddOutput("MasterParamOut", + "The updated FP32 master weight for AMP. " + "It shared memory with Input(MasterParam).") + .AsDispensable(); AddAttr("beta1", "(float, default 0.9) " @@ -183,6 +189,10 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker { "inner_op_parallelism is larger then 0, sparse update " "will run in multithread mode") .SetDefault(1000); + AddAttr("multi_precision", + "(bool, default false) " + "Whether to use multi-precision during weight updating.") + .SetDefault(false); AddComment(R"DOC( Adam Optimizer. @@ -213,3 +223,13 @@ REGISTER_OP_WITHOUT_GRADIENT(adam, ops::AdamOp, ops::AdamOpMaker); REGISTER_OP_CPU_KERNEL( adam, ops::AdamOpKernel, ops::AdamOpKernel); + +REGISTER_OP_VERSION(adam) + .AddCheckpoint( + R"ROC( + Upgrade adam add 1 attribute [multi_precision]. + )ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "multi_precision", + "(bool) Whether to use multi-precision during weight updating.", + false)); diff --git a/paddle/fluid/operators/optimizers/adam_op.cu b/paddle/fluid/operators/optimizers/adam_op.cu index 0713237561..54aea67f4e 100644 --- a/paddle/fluid/operators/optimizers/adam_op.cu +++ b/paddle/fluid/operators/optimizers/adam_op.cu @@ -11,70 +11,81 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/optimizers/adam_op.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { -template -__global__ void AdamKernelREG(T beta1, T beta2, T epsilon, T beta1_pow_, - T beta2_pow_, const T* moment1, T* moment1_out, - const T* moment2, T* moment2_out, const T* lr_, +template +__global__ void AdamKernelREG(MT beta1, MT beta2, MT epsilon, MT beta1_pow_, + MT beta2_pow_, const MT* moment1, MT* moment1_out, + const MT* moment2, MT* moment2_out, const MT* lr_, const T* grad, const T* param, T* param_out, + const MT* master_param, MT* master_param_out, int ndim) { - T lr = *lr_; - T beta1_pow = beta1_pow_; - T beta2_pow = beta2_pow_; + MT lr = *lr_; + MT beta1_pow = beta1_pow_; + MT beta2_pow = beta2_pow_; - lr *= - sqrt(static_cast(1.0) - beta2_pow) / (static_cast(1.0) - beta1_pow); + lr *= sqrt(static_cast(1.0) - beta2_pow) / + (static_cast(1.0) - beta1_pow); int id = blockIdx.x * blockDim.x + threadIdx.x; for (; id < ndim; id += gridDim.x * blockDim.x) { - T p = param[id]; - T g = grad[id]; - T mom1 = moment1[id]; - T mom2 = moment2[id]; - mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; - mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; + MT p = master_param ? master_param[id] : static_cast(param[id]); + MT g = static_cast(grad[id]); + MT mom1 = moment1[id]; + MT mom2 = moment2[id]; + mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; + mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; p -= lr * (mom1 / - (sqrt(mom2) + epsilon * sqrt(static_cast(1.0) - beta2_pow))); + (sqrt(mom2) + epsilon * sqrt(static_cast(1.0) - beta2_pow))); moment1_out[id] = mom1; moment2_out[id] = mom2; - param_out[id] = p; + param_out[id] = static_cast(p); + if (master_param_out) { + master_param_out[id] = p; + } } } -template -__global__ void AdamKernelMEM(T beta1, T beta2, T epsilon, const T* beta1_pow_, - const T* beta2_pow_, const T* moment1, - T* moment1_out, const T* moment2, T* moment2_out, - const T* lr_, const T* grad, const T* param, - T* param_out, int ndim) { - T lr = *lr_; - T beta1_pow = *beta1_pow_; - T beta2_pow = *beta2_pow_; - - lr *= - sqrt(static_cast(1.0) - beta2_pow) / (static_cast(1.0) - beta1_pow); +template +__global__ void AdamKernelMEM(MT beta1, MT beta2, MT epsilon, + const MT* beta1_pow_, const MT* beta2_pow_, + const MT* moment1, MT* moment1_out, + const MT* moment2, MT* moment2_out, const MT* lr_, + const T* grad, const T* param, T* param_out, + const MT* master_param, MT* master_param_out, + int ndim) { + MT lr = *lr_; + MT beta1_pow = *beta1_pow_; + MT beta2_pow = *beta2_pow_; + + lr *= sqrt(static_cast(1.0) - beta2_pow) / + (static_cast(1.0) - beta1_pow); int id = blockIdx.x * blockDim.x + threadIdx.x; for (; id < ndim; id += gridDim.x * blockDim.x) { - T p = param[id]; - T g = grad[id]; - T mom1 = moment1[id]; - T mom2 = moment2[id]; - mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; - mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; + MT p = master_param ? master_param[id] : static_cast(param[id]); + MT g = static_cast(grad[id]); + MT mom1 = static_cast(moment1[id]); + MT mom2 = static_cast(moment2[id]); + mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; + mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; p -= lr * (mom1 / - (sqrt(mom2) + epsilon * sqrt(static_cast(1.0) - beta2_pow))); + (sqrt(mom2) + epsilon * sqrt(static_cast(1.0) - beta2_pow))); moment1_out[id] = mom1; moment2_out[id] = mom2; - param_out[id] = p; + param_out[id] = static_cast(p); + if (master_param_out) { + master_param_out[id] = p; + } } } template @@ -85,15 +96,17 @@ __global__ void UpdateBetaPow(T beta1, T beta2, const T* beta1_pow_, *beta2_pow_out = beta2 * beta2_pow_[0]; } -template +template __global__ void SparseAdamCUDAKernelREG( - T beta1, T beta2, T epsilon, const T beta1_pow, const T beta2_pow, - const T* mom1_, T* mom1_out_, const T* mom2_, T* mom2_out_, const T* lr_, - const T* grad_, const T* param_, T* param_out_, const int64_t* rows_, + MT beta1, MT beta2, MT epsilon, const MT beta1_pow, const MT beta2_pow, + const MT* mom1_, MT* mom1_out_, const MT* mom2_, MT* mom2_out_, + const MT* lr_, const T* grad_, const T* param_, T* param_out_, + const MT* master_param, MT* master_param_out, const int64_t* rows_, int64_t row_numel, int64_t row_count, bool lazy_mode, int ndim) { int id = blockIdx.x * blockDim.x + threadIdx.x; - T lr = *lr_; - lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow); + MT lr = *lr_; + lr *= sqrt(static_cast(1.0) - beta2_pow) / + (static_cast(1.0) - beta1_pow); for (; id < ndim; id += blockDim.x * gridDim.x) { auto row_idx = @@ -101,19 +114,24 @@ __global__ void SparseAdamCUDAKernelREG( if (lazy_mode && row_idx < 0) { return; } else { - T mom1 = mom1_[id]; - T mom2 = mom2_[id]; - T p = param_[id]; - T g = row_idx >= 0 ? grad_[row_idx * row_numel + id % row_numel] : 0; - mom1 = beta1 * mom1 + (1 - beta1) * g; - mom2 = beta2 * mom2 + (1 - beta2) * g * g; + MT mom1 = mom1_[id]; + MT mom2 = mom2_[id]; + MT p = master_param ? master_param[id] : static_cast(param_[id]); + MT g = row_idx >= 0 + ? static_cast(grad_[row_idx * row_numel + id % row_numel]) + : static_cast(0); + mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; + mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; p -= lr * (mom1 / (sqrt(mom2) + - epsilon * sqrt(static_cast(1.0) - beta2_pow))); + epsilon * sqrt(static_cast(1.0) - beta2_pow))); // Write back to global memory mom1_out_[id] = mom1; mom2_out_[id] = mom2; - param_out_[id] = p; + param_out_[id] = static_cast(p); + if (master_param_out) { + master_param_out[id] = p; + } } } } @@ -131,11 +149,12 @@ class AdamOpCUDAKernel : public framework::OpKernel { framework::ToTypeName(param_var->Type()))); using paddle::framework::LoDTensor; + using MPDType = typename details::MPTypeTrait::Type; int64_t min_row_size_to_use_multithread = ctx.Attr("min_row_size_to_use_multithread"); bool lazy_mode = ctx.Attr("lazy_mode"); - T epsilon = static_cast(ctx.Attr("epsilon")); + MPDType epsilon = static_cast(ctx.Attr("epsilon")); auto* param = ctx.Input("Param"); auto* grad_var = ctx.InputVar("Grad"); auto* mom1 = ctx.Input("Moment1"); @@ -151,23 +170,23 @@ class AdamOpCUDAKernel : public framework::OpKernel { auto* beta1_pow_out = ctx.Output("Beta1PowOut"); auto* beta2_pow_out = ctx.Output("Beta2PowOut"); - T beta1 = static_cast(ctx.Attr("beta1")); + MPDType beta1 = static_cast(ctx.Attr("beta1")); if (ctx.HasInput("Beta1Tensor")) { auto* beta1_tensor = ctx.Input("Beta1Tensor"); PADDLE_ENFORCE_EQ(beta1_tensor->numel(), 1, platform::errors::InvalidArgument( "Input(Beta1Tensor) size must be 1, but get %d", beta1_tensor->numel())); - beta1 = static_cast(GetAttrFromTensor(beta1_tensor)); + beta1 = static_cast(GetAttrFromTensor(beta1_tensor)); } - T beta2 = static_cast(ctx.Attr("beta2")); + MPDType beta2 = static_cast(ctx.Attr("beta2")); if (ctx.HasInput("Beta2Tensor")) { auto* beta2_tensor = ctx.Input("Beta2Tensor"); PADDLE_ENFORCE_EQ(beta2_tensor->numel(), 1, platform::errors::InvalidArgument( "Input(Beta2Tensor) size must be 1, but get %d", beta2_tensor->numel())); - beta2 = static_cast(GetAttrFromTensor(beta2_tensor)); + beta2 = static_cast(GetAttrFromTensor(beta2_tensor)); } VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel() << "beta2_pow.numel() : " << beta2_pow->numel(); @@ -183,6 +202,28 @@ class AdamOpCUDAKernel : public framework::OpKernel { "beta2 pow output size should be 1, but received " "value is:%d.", beta2_pow_out->numel())); + + const bool multi_precision = ctx.Attr("multi_precision"); + const LoDTensor* master_param = nullptr; + LoDTensor* master_param_out = nullptr; + if (multi_precision) { + bool has_master = + ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut"); + PADDLE_ENFORCE_EQ(has_master, true, + platform::errors::InvalidArgument( + "The Input(MasterParam) and Output(MasterParamOut) " + "should not be null when " + "the attr `multi_precision` is true")); + master_param = ctx.Input("MasterParam"); + master_param_out = ctx.Output("MasterParamOut"); + } + const MPDType* master_in_data = + multi_precision ? master_param->data() : nullptr; + MPDType* master_out_data = + multi_precision + ? master_param_out->mutable_data(ctx.GetPlace()) + : nullptr; + auto& dev_ctx = ctx.template device_context(); if (grad_var->IsType()) { @@ -195,29 +236,36 @@ class AdamOpCUDAKernel : public framework::OpKernel { if (beta1_pow->place() == platform::CPUPlace() && beta2_pow->place() == platform::CPUPlace()) { // Compute with betapow in REG - AdamKernelREG<<>>( - beta1, beta2, epsilon, *beta1_pow->data(), *beta2_pow->data(), - mom1->data(), mom1_out->mutable_data(ctx.GetPlace()), - mom2->data(), mom2_out->mutable_data(ctx.GetPlace()), - lr->data(), grad->data(), param->data(), - param_out->mutable_data(ctx.GetPlace()), param->numel()); + AdamKernelREG<<>>( + beta1, beta2, epsilon, *beta1_pow->data(), + *beta2_pow->data(), mom1->data(), + mom1_out->mutable_data(ctx.GetPlace()), + mom2->data(), + mom2_out->mutable_data(ctx.GetPlace()), + lr->data(), grad->data(), param->data(), + param_out->mutable_data(ctx.GetPlace()), master_in_data, + master_out_data, param->numel()); // Cpu update - beta1_pow_out->mutable_data(platform::CPUPlace())[0] = - beta1 * beta1_pow->data()[0]; - beta2_pow_out->mutable_data(platform::CPUPlace())[0] = - beta2 * beta2_pow->data()[0]; + beta1_pow_out->mutable_data(platform::CPUPlace())[0] = + beta1 * beta1_pow->data()[0]; + beta2_pow_out->mutable_data(platform::CPUPlace())[0] = + beta2 * beta2_pow->data()[0]; } else { - AdamKernelMEM<<>>( - beta1, beta2, epsilon, beta1_pow->data(), beta2_pow->data(), - mom1->data(), mom1_out->mutable_data(ctx.GetPlace()), - mom2->data(), mom2_out->mutable_data(ctx.GetPlace()), - lr->data(), grad->data(), param->data(), - param_out->mutable_data(ctx.GetPlace()), param->numel()); + AdamKernelMEM<<>>( + beta1, beta2, epsilon, beta1_pow->data(), + beta2_pow->data(), mom1->data(), + mom1_out->mutable_data(ctx.GetPlace()), + mom2->data(), + mom2_out->mutable_data(ctx.GetPlace()), + lr->data(), grad->data(), param->data(), + param_out->mutable_data(ctx.GetPlace()), master_in_data, + master_out_data, param->numel()); // Update with gpu - UpdateBetaPow<<<1, 32, 0, dev_ctx.stream()>>>( - beta1, beta2, beta1_pow->data(), beta2_pow->data(), - beta1_pow_out->mutable_data(ctx.GetPlace()), - beta2_pow_out->mutable_data(ctx.GetPlace())); + UpdateBetaPow<<<1, 32, 0, dev_ctx.stream()>>>( + beta1, beta2, beta1_pow->data(), + beta2_pow->data(), + beta1_pow_out->mutable_data(ctx.GetPlace()), + beta2_pow_out->mutable_data(ctx.GetPlace())); } } else if (grad_var->IsType()) { @@ -260,26 +308,33 @@ class AdamOpCUDAKernel : public framework::OpKernel { int ndim = param->numel(); int blocks = (ndim + threads - 1) / threads; - SparseAdamCUDAKernelREG<<>>( - beta1, beta2, epsilon, *beta1_pow->data(), *beta2_pow->data(), - mom1->data(), mom1_out->mutable_data(ctx.GetPlace()), - mom2->data(), mom2_out->mutable_data(ctx.GetPlace()), - lr->data(), grad_data, param->data(), - param_out->mutable_data(ctx.GetPlace()), rows, row_numel, - grad_merge.rows().size(), lazy_mode, ndim); + SparseAdamCUDAKernelREG< + T, MPDType><<>>( + beta1, beta2, epsilon, *beta1_pow->data(), + *beta2_pow->data(), mom1->data(), + mom1_out->mutable_data(ctx.GetPlace()), + mom2->data(), + mom2_out->mutable_data(ctx.GetPlace()), + lr->data(), grad_data, param->data(), + param_out->mutable_data(ctx.GetPlace()), master_in_data, + master_out_data, rows, row_numel, grad_merge.rows().size(), + lazy_mode, ndim); // Update with cpu - beta1_pow_out->mutable_data(platform::CPUPlace())[0] = - beta1 * beta1_pow->data()[0]; - beta2_pow_out->mutable_data(platform::CPUPlace())[0] = - beta2 * beta2_pow->data()[0]; + beta1_pow_out->mutable_data(platform::CPUPlace())[0] = + beta1 * beta1_pow->data()[0]; + beta2_pow_out->mutable_data(platform::CPUPlace())[0] = + beta2 * beta2_pow->data()[0]; } else { - SparseAdamFunctor functor( - beta1, beta2, epsilon, beta1_pow->data(), beta2_pow->data(), - mom1->data(), mom1_out->mutable_data(ctx.GetPlace()), - mom2->data(), mom2_out->mutable_data(ctx.GetPlace()), - lr->data(), grad_data, param->data(), - param_out->mutable_data(ctx.GetPlace()), rows, row_numel, - grad_merge.rows().size(), lazy_mode); + SparseAdamFunctor functor( + beta1, beta2, epsilon, beta1_pow->data(), + beta2_pow->data(), mom1->data(), + mom1_out->mutable_data(ctx.GetPlace()), + mom2->data(), + mom2_out->mutable_data(ctx.GetPlace()), + lr->data(), grad_data, param->data(), + param_out->mutable_data(ctx.GetPlace()), master_in_data, + master_out_data, rows, row_numel, grad_merge.rows().size(), + lazy_mode); // FIXME(minqiyang): remove BinarySearch in GPU later platform::ForRange for_range( @@ -288,10 +343,11 @@ class AdamOpCUDAKernel : public framework::OpKernel { param->numel()); for_range(functor); // update beta1 and beta2 - UpdateBetaPow<<<1, 32, 0, dev_ctx.stream()>>>( - beta1, beta2, beta1_pow->data(), beta2_pow->data(), - beta1_pow_out->mutable_data(ctx.GetPlace()), - beta2_pow_out->mutable_data(ctx.GetPlace())); + UpdateBetaPow<<<1, 32, 0, dev_ctx.stream()>>>( + beta1, beta2, beta1_pow->data(), + beta2_pow->data(), + beta1_pow_out->mutable_data(ctx.GetPlace()), + beta2_pow_out->mutable_data(ctx.GetPlace())); } } else { PADDLE_THROW(platform::errors::InvalidArgument( @@ -304,5 +360,8 @@ class AdamOpCUDAKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; + REGISTER_OP_CUDA_KERNEL(adam, ops::AdamOpCUDAKernel, - ops::AdamOpCUDAKernel); + ops::AdamOpCUDAKernel, + ops::AdamOpCUDAKernel); diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index c8b28aed24..6356911f06 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -191,26 +191,28 @@ class AdamFunctor { } }; -template +template class SparseAdamFunctor; -template -class SparseAdamFunctor { +template +class SparseAdamFunctor { private: - T beta1_; - T beta2_; - T epsilon_; - - const T* beta1_pow_; - const T* beta2_pow_; - const T* moment1_; - T* moment1_out_; - const T* moment2_; - T* moment2_out_; - const T* lr_; + MT beta1_; + MT beta2_; + MT epsilon_; + + const MT* beta1_pow_; + const MT* beta2_pow_; + const MT* moment1_; + MT* moment1_out_; + const MT* moment2_; + MT* moment2_out_; + const MT* lr_; const T* grad_; const T* param_; T* param_out_; + const MT* master_param_; + MT* master_param_out_; const int64_t* rows_; int64_t row_numel_; @@ -218,10 +220,11 @@ class SparseAdamFunctor { bool lazy_mode_; public: - SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow, - const T* beta2_pow, const T* mom1, T* mom1_out, - const T* mom2, T* mom2_out, const T* lr, const T* grad, - const T* param, T* param_out, const int64_t* rows, + SparseAdamFunctor(MT beta1, MT beta2, MT epsilon, const MT* beta1_pow, + const MT* beta2_pow, const MT* mom1, MT* mom1_out, + const MT* mom2, MT* mom2_out, const MT* lr, const T* grad, + const T* param, T* param_out, const MT* master_param, + MT* master_param_out, const int64_t* rows, int64_t row_numel, int64_t row_count, bool lazy_mode) : beta1_(beta1), beta2_(beta2), @@ -236,31 +239,38 @@ class SparseAdamFunctor { grad_(grad), param_(param), param_out_(param_out), + master_param_(master_param), + master_param_out_(master_param_out), rows_(rows), row_numel_(row_numel), row_count_(row_count), lazy_mode_(lazy_mode) {} - inline HOSTDEVICE void adam_update(size_t i, T g) const { + inline HOSTDEVICE void adam_update(size_t i, MT g) const { // The following code is the same as dense - T mom1 = moment1_[i]; - T mom2 = moment2_[i]; - T lr = *lr_; - T beta1_pow = *beta1_pow_; - T beta2_pow = *beta2_pow_; - T p = param_[i]; + MT mom1 = moment1_[i]; + MT mom2 = moment2_[i]; + MT lr = *lr_; + MT beta1_pow = *beta1_pow_; + MT beta2_pow = *beta2_pow_; + MT p = master_param_ ? master_param_[i] : static_cast(param_[i]); // Calculation - lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow); + lr *= sqrt(static_cast(1.0) - beta2_pow) / + (static_cast(1.0) - beta1_pow); - mom1 = beta1_ * mom1 + (1 - beta1_) * g; - mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; - p -= lr * (mom1 / (sqrt(mom2) + epsilon_ * sqrt(1 - beta2_pow))); + mom1 = beta1_ * mom1 + (static_cast(1.0) - beta1_) * g; + mom2 = beta2_ * mom2 + (static_cast(1.0) - beta2_) * g * g; + p -= lr * (mom1 / (sqrt(mom2) + + epsilon_ * sqrt(static_cast(1.0) - beta2_pow))); // Write back to global memory moment1_out_[i] = mom1; moment2_out_[i] = mom2; - param_out_[i] = p; + param_out_[i] = static_cast(p); + if (master_param_out_) { + master_param_out_[i] = p; + } } inline HOSTDEVICE void operator()(size_t i) const { @@ -269,14 +279,16 @@ class SparseAdamFunctor { if (lazy_mode_ && row_idx < 0) { return; } else { - T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0; + MT g = row_idx >= 0 + ? static_cast(grad_[row_idx * row_numel_ + i % row_numel_]) + : static_cast(0); adam_update(i, g); } } }; template -class SparseAdamFunctor { +class SparseAdamFunctor { private: T beta1_; T beta2_; diff --git a/paddle/fluid/operators/optimizers/momentum_op.cc b/paddle/fluid/operators/optimizers/momentum_op.cc index b9a74c1bf7..bf30d8512a 100644 --- a/paddle/fluid/operators/optimizers/momentum_op.cc +++ b/paddle/fluid/operators/optimizers/momentum_op.cc @@ -115,7 +115,8 @@ REGISTER_OP_CPU_KERNEL( REGISTER_OP_VERSION(momentum) .AddCheckpoint( R"ROC( - Upgrade momentum add 2 attributes [regularization_method, regularization_coeff]. + Upgrade momentum add 4 attributes [regularization_method, regularization_coeff, + multi_precision, rescale_grad]. )ROC", paddle::framework::compatible::OpVersionDesc() .NewInput("MasterParam", "FP32 master weight for AMP.") diff --git a/paddle/fluid/operators/optimizers/momentum_op.h b/paddle/fluid/operators/optimizers/momentum_op.h index 64acdfe890..cbb0704fa8 100644 --- a/paddle/fluid/operators/optimizers/momentum_op.h +++ b/paddle/fluid/operators/optimizers/momentum_op.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/math/algorithm.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/platform/float16.h" @@ -32,17 +33,6 @@ struct UseNesterov; namespace details { -template -class MPTypeTrait { - public: - using Type = T; -}; -template <> -class MPTypeTrait { - public: - using Type = float; -}; - template struct CPUDenseUpdater { template diff --git a/python/paddle/fluid/contrib/mixed_precision/amp_nn.py b/python/paddle/fluid/contrib/mixed_precision/amp_nn.py index d4dc968ca0..3bfc078971 100644 --- a/python/paddle/fluid/contrib/mixed_precision/amp_nn.py +++ b/python/paddle/fluid/contrib/mixed_precision/amp_nn.py @@ -15,6 +15,7 @@ from paddle.fluid.data_feeder import check_variable_and_dtype, check_type from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.framework import Variable +from paddle.fluid import core __all__ = ['check_finite_and_unscale', 'update_loss_scaling'] @@ -35,7 +36,7 @@ def check_finite_and_unscale(x, scale, name=None): """ check_type(x, 'x', (tuple, list), 'check_finite_and_unscale') for e in x: - check_variable_and_dtype(e, "x", ['float32', 'float64'], + check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'], 'check_finite_and_unscale') helper = LayerHelper("check_finite_and_unscale", **locals()) @@ -58,6 +59,7 @@ def update_loss_scaling(x, decr_every_n_nan_or_inf, incr_ratio, decr_ratio, + stop_update=False, name=None): """ Update loss scaling according to overall gradients. If all gradients is @@ -90,9 +92,13 @@ def update_loss_scaling(x, ['float32', 'float64'], "update_loss_scaling") check_type(x, 'x', (tuple, list), 'update_loss_scaling') for e in x: - check_variable_and_dtype(e, "x", ['float32', 'float64'], + check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'], 'update_loss_scaling') - assert prev_loss_scaling.dtype == e.dtype, "The dtype of prev_loss_scaling should be equal to the dtype of x." + if e.dtype == core.VarDesc.VarType.FP16: + assert prev_loss_scaling.dtype == core.VarDesc.VarType.FP32, \ + "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16." + else: + assert prev_loss_scaling.dtype == e.dtype, "The dtype of prev_loss_scaling should be equal to the dtype of x." helper = LayerHelper("update_loss_scaling", **locals()) @@ -116,6 +122,7 @@ def update_loss_scaling(x, 'decr_every_n_nan_or_inf': decr_every_n_nan_or_inf, 'incr_ratio': incr_ratio, 'decr_ratio': decr_ratio, + 'stop_update': stop_update } helper.append_op( diff --git a/python/paddle/fluid/contrib/mixed_precision/decorator.py b/python/paddle/fluid/contrib/mixed_precision/decorator.py index 2215d11aa0..bee73a9803 100644 --- a/python/paddle/fluid/contrib/mixed_precision/decorator.py +++ b/python/paddle/fluid/contrib/mixed_precision/decorator.py @@ -12,17 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +from ... import core from ... import default_main_program from ... import default_startup_program +from ... import framework from ... import layers -from ... import unique_name from ... import program_guard +from ... import unique_name from . import fp16_utils from .fp16_utils import rewrite_program +from .fp16_utils import cast_model_to_fp16 +from .fp16_utils import cast_parameters_to_fp16 from .fp16_utils import update_role_var_grad from .fp16_lists import AutoMixedPrecisionLists from .amp_nn import check_finite_and_unscale from .amp_nn import update_loss_scaling +import types +import warnings __all__ = ["decorate"] @@ -50,12 +56,16 @@ class OptimizerWithMixedPrecision(object): scaling. decr_ratio(float): The less-than-one-multiplier to use when decreasing the loss scaling. + use_pure_fp16(bool): Whether to use the pure fp16 training. Default False. + use_fp16_guard(bool): Whether to use `fp16_guard` when constructing the program. + Default None, which means that its value is equal to `use_pure_fp16`. """ def __init__(self, optimizer, amp_lists, init_loss_scaling, use_dynamic_loss_scaling, incr_every_n_steps, - decr_every_n_nan_or_inf, incr_ratio, decr_ratio): + decr_every_n_nan_or_inf, incr_ratio, decr_ratio, use_pure_fp16, + use_fp16_guard): self._optimizer = optimizer self._amp_lists = amp_lists self._param_grads = None @@ -68,6 +78,9 @@ class OptimizerWithMixedPrecision(object): self._use_dynamic_loss_scaling = use_dynamic_loss_scaling self._learning_rate = optimizer._learning_rate self._learning_rate_map = optimizer._learning_rate_map + self._use_pure_fp16 = use_pure_fp16 + self._use_fp16_guard = use_fp16_guard + self._to_fp16_var_names = None if self._use_dynamic_loss_scaling: self._incr_every_n_steps = incr_every_n_steps self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf @@ -151,20 +164,61 @@ class OptimizerWithMixedPrecision(object): train_program = loss.block.program self._train_program = train_program - with program_guard(train_program, startup_program): + with program_guard(self._train_program, startup_program): self._init_amp_var() - rewrite_program(train_program, self._amp_lists) - self._scaled_loss = loss * self._loss_scaling + if self._use_pure_fp16: + self._to_fp16_var_names = cast_model_to_fp16( + self._train_program, self._amp_lists, self._use_fp16_guard) + else: + rewrite_program(self._train_program, self._amp_lists) + + if loss.dtype != core.VarDesc.VarType.FP32: + loss = loss.astype('float32') + # When not using dynamic loss scaling and the init loss scaling value is equal to 1.0, + # the model can be optimized. + if self._use_dynamic_loss_scaling or self._init_loss_scaling != 1.0: + self._scaled_loss = loss * self._loss_scaling + else: + self._scaled_loss = loss + params_grads = self._optimizer.backward( self._scaled_loss, startup_program, parameter_list, no_grad_set, callbacks) return params_grads + def amp_init(self, + place, + scope=None, + test_program=None, + use_fp16_test=False): + """ + Init the amp training, such as cast fp32 parameters to fp16 type. + + Args: + place(CPUPlace|CUDAPlace): place is used to initialize + fp16 parameters with fp32 values. + scope(Scope): The scope is used to find fp32 parameters. + test_program(Program): The program is used for testing. + use_fp16_test(bool): Whether to use fp16 testing. + + """ + assert self._train_program is not None, \ + "Please call the minimize method first." + if self._use_pure_fp16: + cast_parameters_to_fp16(place, self._train_program, scope, + self._to_fp16_var_names) + if test_program is not None: + if self._use_pure_fp16: + cast_model_to_fp16(test_program, self._amp_lists, + self._use_fp16_guard) + elif use_fp16_test: + rewrite_program(test_program, self._amp_lists) + def apply_gradients(self, params_grads): """ Check scaled gradients to determine whether to update loss scaling and update - parameters by their scaled gradients, + parameters by their scaled gradients. Args: params_grads (list): A list of params and scaled grads. @@ -177,39 +231,95 @@ class OptimizerWithMixedPrecision(object): # transferred across GPUs can be FP16. update_role_var_grad(self._train_program, params_grads) + # When not using dynamic loss scaling and the init loss scaling value is equal to 1.0, + # the model can be optimized. + if not self._use_dynamic_loss_scaling and self._init_loss_scaling == 1.0: + return self._optimizer.apply_gradients(params_grads) + grads = [g for _, g in params_grads] - if not self._is_distributed: - with self._train_program._optimized_guard(grads): - grads, found_inf = check_finite_and_unscale( - grads, self._loss_scaling, name="find_infinite_scale") - else: + fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32] + fp16_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP16] + assert len(fp32_grads) + len(fp16_grads) == len(grads), \ + "Data types of all grads must be either fp16 or fp32." + + found_infs = [] + if self._is_distributed: # if distributed, split check_finite_and_unscale to overlap # unscale with communication - found_infs = [] for p, g in params_grads: with self._train_program._optimized_guard([p, g]): _, found_inf = check_finite_and_unscale( [g, ], self._loss_scaling, name="find_infinite_scale") found_infs.append(found_inf) + elif self._use_pure_fp16: + if fp32_grads: + with self._train_program._optimized_guard(fp32_grads): + _, fp32_found_inf = check_finite_and_unscale( + fp32_grads, + self._loss_scaling, + name="find_infinite_scale_fp32") + found_infs.append(fp32_found_inf) + if fp16_grads: + with self._train_program._optimized_guard(fp16_grads): + _, fp16_found_inf = check_finite_and_unscale( + fp16_grads, + self._loss_scaling, + name="find_infinite_scale_fp16") + found_infs.append(fp16_found_inf) + else: + with self._train_program._optimized_guard(grads): + _, found_inf = check_finite_and_unscale( + grads, self._loss_scaling, name="find_infinite_scale") if self._use_dynamic_loss_scaling: - if self._is_distributed: + if self._is_distributed or self._use_pure_fp16: with self._train_program._optimized_guard([]): all_infs = layers.concat(found_infs) found_inf = layers.reduce_any(all_infs) - with self._train_program._optimized_guard([]): - update_loss_scaling( - grads, - found_inf, - self._loss_scaling, - self._num_good_steps, - self._num_bad_steps, - self._incr_every_n_steps, - self._decr_every_n_nan_or_inf, - self._incr_ratio, - self._decr_ratio, - name="update_loss_scaling") + if self._use_pure_fp16: + stop_update = False + with self._train_program._optimized_guard([]): + if fp32_grads: + update_loss_scaling( + fp32_grads, + found_inf, + self._loss_scaling, + self._num_good_steps, + self._num_bad_steps, + self._incr_every_n_steps, + self._decr_every_n_nan_or_inf, + self._incr_ratio, + self._decr_ratio, + stop_update=stop_update, + name="update_loss_scaling_fp32") + stop_update = True + if fp16_grads: + update_loss_scaling( + fp16_grads, + found_inf, + self._loss_scaling, + self._num_good_steps, + self._num_bad_steps, + self._incr_every_n_steps, + self._decr_every_n_nan_or_inf, + self._incr_ratio, + self._decr_ratio, + stop_update=stop_update, + name="update_loss_scaling_fp16") + else: + with self._train_program._optimized_guard([]): + update_loss_scaling( + grads, + found_inf, + self._loss_scaling, + self._num_good_steps, + self._num_bad_steps, + self._incr_every_n_steps, + self._decr_every_n_nan_or_inf, + self._incr_ratio, + self._decr_ratio, + name="update_loss_scaling") optimize_ops = self._optimizer.apply_gradients(params_grads) return optimize_ops @@ -239,6 +349,13 @@ class OptimizerWithMixedPrecision(object): The scaled loss by scaling factor, the list of optimize ops, and a list of scaled parameters and gradients. """ + opt_dict = self._optimizer.__class__.__dict__ + if 'minimize' in opt_dict and isinstance(opt_dict['minimize'], + types.FunctionType): + warnings.warn( + "The decorated optimizer has its own `minimize` method, but it will not be executed." + ) + scaled_params_grads = self.backward( loss, startup_program=startup_program, @@ -258,7 +375,9 @@ def decorate(optimizer, decr_every_n_nan_or_inf=2, incr_ratio=2.0, decr_ratio=0.8, - use_dynamic_loss_scaling=True): + use_dynamic_loss_scaling=True, + use_pure_fp16=False, + use_fp16_guard=None): """ Decorate the given optimizer to adapt to the mixed-precision training. @@ -276,6 +395,9 @@ def decorate(optimizer, decr_ratio(float): The less-than-one-multiplier to use when decreasing the loss scaling. use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling. + use_pure_fp16(bool): Whether to use the pure fp16 training. Default False. + use_fp16_guard(bool): Whether to use `fp16_guard` when constructing the program. + Default None, which means that its value equals to `use_pure_fp16`. Returns: An optimizer acting like a normal one but with mixed-precision training @@ -295,8 +417,13 @@ def decorate(optimizer, """ if amp_lists is None: amp_lists = AutoMixedPrecisionLists() + + if use_fp16_guard is None: + use_fp16_guard = use_pure_fp16 + mp_optimizer = OptimizerWithMixedPrecision( optimizer, amp_lists, init_loss_scaling, use_dynamic_loss_scaling, - incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, decr_ratio) + incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, decr_ratio, + use_pure_fp16, use_fp16_guard) return mp_optimizer diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py index a92d8f17db..a409595d3e 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py @@ -38,6 +38,7 @@ class AutoMixedPrecisionLists(object): self.white_list = copy.copy(white_list) self.black_list = copy.copy(black_list) self.gray_list = copy.copy(gray_list) + self.unsupported_list = copy.copy(unsupported_fp16_list) self.black_varnames = copy.copy(custom_black_varnames) self._update_list() @@ -64,6 +65,7 @@ class AutoMixedPrecisionLists(object): elif op_name in self.gray_list: self.gray_list.remove(op_name) self.black_list.add(op_name) + self.unsupported_list.add(op_name) # The three sets listed below are changed dynamiclly. They don't contain all @@ -141,10 +143,10 @@ gray_list = { 'cast', 'fused_bn_add_activation', } -''' + # The set of ops that don't support fp16 calculation unsupported_fp16_list = { - # from python/paddle/fluid/layers/io.py + # from python/paddle/fluid/layers/io.py 'send', 'send_barrier', 'recv', @@ -153,8 +155,8 @@ unsupported_fp16_list = { 'create_double_buffer_reader', 'read', 'load', - - # from python/paddle/fluid/control_flow.py + + # from python/paddle/fluid/control_flow.py 'increment', 'less_than', 'less_equal', @@ -174,7 +176,6 @@ unsupported_fp16_list = { 'while', 'ifelse', 'is_empty', - 'lstm', 'cudnn_lstm', 'lstmp', @@ -275,7 +276,6 @@ unsupported_fp16_list = { 'pixel_shuffle', 'fsp', 'cvm', - 'affine_channel', 'roi_pool', 'roi_align', @@ -283,6 +283,4 @@ unsupported_fp16_list = { 'generate_proposals', 'generate_proposal_labels', 'generate_mask_labels', - } -''' diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index c9a070a03a..e02671e219 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -15,17 +15,28 @@ from __future__ import print_function from ... import core +from ... import framework from ... import layers from ... import global_scope from ...log_helper import get_logger +from ...wrapped_decorator import signature_safe_contextmanager +from .fp16_lists import AutoMixedPrecisionLists +import collections import logging import numpy as np -__all__ = ["cast_model_to_fp16", "cast_parameters_to_fp16"] +__all__ = ["fp16_guard", "cast_model_to_fp16", "cast_parameters_to_fp16"] _logger = get_logger( __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') +_valid_types = [ + core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.SELECTED_ROWS, + core.VarDesc.VarType.LOD_TENSOR_ARRAY +] + +_fp16_guard_pattern = "__use_fp16__" + def _rename_arg(op, old_name, new_name): """ @@ -44,6 +55,18 @@ def _rename_arg(op, old_name, new_name): op_desc._rename_output(old_name, new_name) +def _rename_op_input(program, op_var_rename_map, origin_ops, keep_fp32_ops): + for block in program.blocks: + ops = block.ops + block_id = block.idx + for op in ops: + if op not in origin_ops or op in keep_fp32_ops: + continue + for name in op.input_arg_names: + if name in op_var_rename_map[block_id]: + op._rename_input(name, op_var_rename_map[block_id][name]) + + def _dtype_to_str(dtype): """ Convert specific variable type to its corresponding string. @@ -72,10 +95,6 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): num_cast_op (int): The number of cast ops that have been inserted. """ num_cast_ops = 0 - valid_types = [ - core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.SELECTED_ROWS, - core.VarDesc.VarType.LOD_TENSOR_ARRAY - ] for in_name in op.input_names: if src_dtype == core.VarDesc.VarType.FP32 and op.type in [ @@ -85,7 +104,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): continue for in_var_name in op.input(in_name): in_var = block.var(in_var_name) - if in_var.type not in valid_types or in_var.dtype == dest_dtype: + if in_var.type not in _valid_types or in_var.dtype == dest_dtype: continue if in_var.dtype == src_dtype: cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype) @@ -119,7 +138,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): continue for out_var_name in op.output(out_name): out_var = block.var(out_var_name) - if out_var.type not in valid_types: + if out_var.type not in _valid_types: continue if out_var.dtype == core.VarDesc.VarType.FP32: out_var.desc.set_dtype(core.VarDesc.VarType.FP16) @@ -128,6 +147,38 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): return num_cast_ops +def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name, + op_var_rename_map): + num_cast_ops = 0 + + target_var = block.var(target_name) + if target_var.type not in _valid_types or target_var.dtype == dest_dtype: + return num_cast_ops + + assert target_var.dtype == src_dtype, \ + "The real dtype({}) is not equal to the src dtype({})".format(_dtype_to_str(target_var.dtype), _dtype_to_str(src_dtype)) + + cast_name = target_var.name + '.cast_' + _dtype_to_str(dest_dtype) + cast_var = block.vars.get(cast_name) + if cast_var is None or cast_var.dtype != dest_dtype: + cast_var = block.create_var( + name=cast_name, + dtype=dest_dtype, + persistable=False, + stop_gradient=target_var.stop_gradient) + block._insert_op( + idx, + type="cast", + inputs={"X": target_var}, + outputs={"Out": cast_var}, + attrs={"in_dtype": target_var.dtype, + "out_dtype": cast_var.dtype}) + num_cast_ops += 1 + op_var_rename_map[block.idx][target_var.name] = cast_var.name + + return num_cast_ops + + def find_true_prev_op(ops, cur_op, var_name): """ Find the true prev op that outputs var_name variable. @@ -174,9 +225,8 @@ def find_true_post_op(ops, cur_op, var_name): for in_var_name in op.input(in_name): if in_var_name == var_name: post_op.append(op) - if post_op != []: - return post_op - return None + + return post_op def find_op_index(block_desc, cur_op_desc): @@ -200,26 +250,73 @@ def _is_in_black_varnames(op, amp_lists): return False -def cast_model_to_fp16(main_program): +def _need_keep_fp32(op, unsupported_op_list, use_fp16_guard): + if op.type in unsupported_op_list: + # the highest priority condition: If ops don't have fp16 computing kernels, + # they must be executed in fp32 calculation pattern. + return True + + # process ops about learning rate + in_out_arg_names = [] + in_out_arg_names.extend(list(op.input_arg_names)) + in_out_arg_names.extend(list(op.output_arg_names)) + for name in in_out_arg_names: + if "learning_rate" in name: + return True + + if use_fp16_guard: + if op.has_attr("op_namescope") and \ + (_fp16_guard_pattern in op.attr("op_namescope")): + # op in fp16 guard + return False + else: + # op not in fp16 guard + return True + else: + return False + + +@signature_safe_contextmanager +def fp16_guard(): + """ + As for the pure fp16 training, if users set `use_fp16_guard` to True, + only those ops created in the context manager `fp16_guard` will be + transformed as float16 type. + """ + with framework.name_scope(prefix=_fp16_guard_pattern): + yield + + +def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): """ Traverse all ops in the whole model and set their inputs and outputs to the fp16 data type. This function will do some special process for the batch normalization, which keeps the computational process of batchnorms in FP32. Args: - main_program (Program): The main program for training. + program (Program): The used program. + amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object. + use_fp16_guard(bool): Determine whether to use `fp16_guard` when + constructing the program. Default True. """ - valid_types = [ - core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.SELECTED_ROWS, - core.VarDesc.VarType.LOD_TENSOR_ARRAY - ] - global_block = main_program.global_block() - for block in main_program.blocks: + if amp_lists is None: + amp_lists = AutoMixedPrecisionLists() + global_block = program.global_block() + keep_fp32_ops = set() + to_fp16_var_names = set() + origin_ops = [] + for block in program.blocks: + origin_ops.extend(block.ops) + + for block in program.blocks: ops = block.ops for op in ops: if op.type == 'create_py_reader' or op.type == 'read': continue + if _need_keep_fp32(op, amp_lists.unsupported_list, use_fp16_guard): + keep_fp32_ops.add(op) + continue # processed below for in_name in op.input_names: if op.type in { 'batch_norm', 'fused_bn_add_activation', 'layer_norm' @@ -231,19 +328,20 @@ def cast_model_to_fp16(main_program): in_var = block.var(in_var_name) except ValueError as e: _logger.debug( - "-- {}, try to get it in the global block. --". + "-- {}, try to get it in the global block --". format(e)) in_var = global_block.var(in_var_name) if in_var is not None: _logger.debug( - "-- var {} is got in the global block. --". + "-- var {} is got in the global block --". format(in_var_name)) - if in_var is None or in_var.type not in valid_types: + if in_var is None or in_var.type not in _valid_types: continue if in_var.dtype == core.VarDesc.VarType.FP32: in_var.desc.set_dtype(core.VarDesc.VarType.FP16) + to_fp16_var_names.add(in_var_name) _logger.debug( "-- op type: {}, in var name: {}, in var dtype: {} --". @@ -260,15 +358,15 @@ def cast_model_to_fp16(main_program): out_var = block.var(out_var_name) except ValueError as e: _logger.debug( - "-- {}, try to get it in the global block. --". + "-- {}, try to get it in the global block --". format(e)) out_var = global_block.var(out_var_name) if out_var is not None: _logger.debug( - "-- var {} is got in the global block. --". + "-- var {} is got in the global block --". format(out_var_name)) - if out_var is None or out_var.type not in valid_types: + if out_var is None or out_var.type not in _valid_types: continue if out_var.dtype == core.VarDesc.VarType.FP32: @@ -287,35 +385,65 @@ def cast_model_to_fp16(main_program): 'dtype') == core.VarDesc.VarType.FP32: op._set_attr('dtype', core.VarDesc.VarType.FP16) + # process ops in keep_fp32_ops + op_var_rename_map = [ + collections.OrderedDict() for _ in range(len(program.blocks)) + ] + for block in program.blocks: + ops = block.ops + idx = 0 + while idx < len(ops): + op = ops[idx] + num_cast_ops = 0 + if op in keep_fp32_ops: + pre_cast_num = _insert_cast_op(block, op, idx, + core.VarDesc.VarType.FP16, + core.VarDesc.VarType.FP32) + num_cast_ops += pre_cast_num + for out_var_name in op.output_arg_names: + out_var = block.vars.get(out_var_name) + if out_var is None or out_var.type not in _valid_types: + continue + if out_var.dtype == core.VarDesc.VarType.FP16: + out_var.desc.set_dtype(core.VarDesc.VarType.FP32) + post_ops = find_true_post_op(ops, op, out_var_name) + for post_op in post_ops: + if post_op in keep_fp32_ops: + continue + post_cast_num = _insert_cast_post_op( + block, op, idx + pre_cast_num + 1, + core.VarDesc.VarType.FP32, + core.VarDesc.VarType.FP16, out_var_name, + op_var_rename_map) + num_cast_ops += post_cast_num + idx += num_cast_ops + 1 + + _rename_op_input(program, op_var_rename_map, origin_ops, keep_fp32_ops) + return to_fp16_var_names -def cast_parameters_to_fp16(place, main_program, scope=None): + +def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None): """ - Traverse all parameters in the whole model and set them to the fp16 data type. + Traverse all parameters in the whole model and set them to the FP16 data type. Whereas, this function will keep parameters of batchnorms in FP32. Args: - place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the weight tensors. - main_program (Program): The main program for training. - scope(fluid.Scope, optional): scope is used to get the weight tensor values. - Default is None. + place(fluid.CPUPlace|fluid.CUDAPlace): `place` is used to restore the FP16 weight tensors. + program (Program): The used program. + scope(fluid.Scope, optional): `scope` is used to get the FP32 weight tensor values. + Default is None. + to_fp16_var_names(set|list, optional): The data types of vars in `to_fp16_var_names` + will be set to FP16. Usually, it is the returned + value of `cast_model_to_fp16` API. """ - all_ops = [] - for block in main_program.blocks: - all_ops.extend(block.ops) - bn_params = set() - for op in all_ops: - if op.type not in { - 'batch_norm', 'fused_bn_add_activation', 'layer_norm' - }: - continue - for in_name in op.input_names: - if in_name not in {'X', 'Z'}: - for in_var_name in op.input(in_name): - bn_params.add(in_var_name) - global_block = main_program.global_block() - all_parameters = global_block.all_parameters() - var_scope = scope if scope is not None else global_scope() + all_parameters = [] + for block in program.blocks: + all_parameters.extend(block.all_parameters()) + + fp16_var_names = to_fp16_var_names if to_fp16_var_names else set() + var_scope = scope if scope else global_scope() for param in all_parameters: - if param.name not in bn_params: + if param.name in fp16_var_names: + _logger.debug("---- cast {} to fp16 dtype ----".format(param.name)) param_t = var_scope.find_var(param.name).get_tensor() data = np.array(param_t) param_t.set(np.float16(data), place) @@ -458,7 +586,7 @@ def update_role_var_grad(main_prog, params_grads): if op == block.ops[-1]: continue post_ops = find_true_post_op(block.ops, op, g.name) - if post_ops is not None: + if post_ops: raise ValueError("The cast op {0}'s output should not be" "used by a non-optimize op, however, it" "is used by {1}".format(op, post_ops[0])) diff --git a/python/paddle/fluid/contrib/tests/test_multi_precision_fp16_train.py b/python/paddle/fluid/contrib/tests/test_multi_precision_fp16_train.py index 3526a3d761..15373ee7bb 100644 --- a/python/paddle/fluid/contrib/tests/test_multi_precision_fp16_train.py +++ b/python/paddle/fluid/contrib/tests/test_multi_precision_fp16_train.py @@ -19,8 +19,7 @@ import paddle.fluid as fluid import contextlib import unittest import numpy as np -from paddle.static.amp import cast_model_to_fp16 -from paddle.static.amp import cast_parameters_to_fp16 +from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_model_to_fp16 paddle.enable_static() @@ -65,38 +64,19 @@ def resnet_cifar10(input, depth=32): n = (depth - 2) // 6 conv1 = conv_bn_layer( input=input, ch_out=16, filter_size=3, stride=1, padding=1) - res1 = layer_warp(basicblock, conv1, 16, 16, n, 1) - res2 = layer_warp(basicblock, res1, 16, 32, n, 2) - res3 = layer_warp(basicblock, res2, 32, 64, n, 2) + with paddle.static.amp.fp16_guard(): + res1 = layer_warp(basicblock, conv1, 16, 16, n, 1) + res2 = layer_warp(basicblock, res1, 16, 32, n, 2) + res3 = layer_warp(basicblock, res2, 32, 64, n, 2) pool = fluid.layers.pool2d( input=res3, pool_size=8, pool_type='avg', pool_stride=1) return pool -def compile(program, loss_name=None): - build_strategy = paddle.static.BuildStrategy() - exec_strategy = paddle.static.ExecutionStrategy() - - exec_strategy.num_threads = 1 - exec_strategy.num_iteration_per_drop_scope = 10000 - - build_strategy.fuse_bn_act_ops = True - build_strategy.fuse_elewise_add_act_ops = True - build_strategy.fuse_bn_add_act_ops = True - - compiled_program = paddle.static.CompiledProgram( - program).with_data_parallel( - loss_name=loss_name, - build_strategy=build_strategy, - exec_strategy=exec_strategy) - - return compiled_program - - -def train(use_pure_fp16=True, use_nesterov=False): +def train(use_pure_fp16=True, use_nesterov=False, use_adam=False): classdim = 10 data_shape = [3, 32, 32] - BATCH_SIZE = 128 + BATCH_SIZE = 32 PASS_NUM = 1 train_program = fluid.Program() @@ -107,28 +87,35 @@ def train(use_pure_fp16=True, use_nesterov=False): images = fluid.layers.data( name='pixel', shape=data_shape, dtype='float32') label = fluid.layers.data(name='label', shape=[1], dtype='int64') - net = resnet_cifar10(images, 32) - + net = resnet_cifar10(images) logits = fluid.layers.fc(input=net, size=classdim, act="softmax") - if use_pure_fp16: - cast_model_to_fp16(fluid.default_main_program()) - logits_fp32 = fluid.layers.cast(x=logits, dtype="float32") - else: - logits_fp32 = logits cost = fluid.layers.softmax_with_cross_entropy( - logits_fp32, label, return_softmax=False) + logits, label, return_softmax=False) sum_cost = fluid.layers.reduce_sum(cost) # Test program test_program = train_program.clone(for_test=True) - optimizer = paddle.optimizer.Momentum( - learning_rate=0.001, - momentum=0.9, - use_nesterov=use_nesterov, - weight_decay=fluid.regularizer.L2Decay(1e-4), - multi_precision=use_pure_fp16, - rescale_grad=1.0 / BATCH_SIZE) + if use_adam: + optimizer = paddle.optimizer.Adam( + learning_rate=0.001, + epsilon=1e-8, + weight_decay=0.0, + multi_precision=True) + else: + optimizer = paddle.optimizer.Momentum( + learning_rate=0.001, + momentum=0.9, + use_nesterov=use_nesterov, + weight_decay=fluid.regularizer.L2Decay(1e-4), + multi_precision=use_pure_fp16) + + if use_pure_fp16: + optimizer = paddle.static.amp.decorate( + optimizer, + init_loss_scaling=128.0, + use_dynamic_loss_scaling=True, + use_pure_fp16=True) optimizer.minimize(sum_cost) @@ -146,13 +133,13 @@ def train(use_pure_fp16=True, use_nesterov=False): def train_loop(main_program): exe.run(startup_prog) if use_pure_fp16: - cast_parameters_to_fp16(place, train_program, fluid.global_scope()) - compiled_program = compile(train_program, sum_cost.name) + optimizer.amp_init( + place, test_program=test_program, use_fp16_test=True) loss = 0.0 for pass_id in range(PASS_NUM): train_loss_list = [] for batch_id, data in enumerate(train_reader()): - loss, = exe.run(compiled_program, + loss, = exe.run(train_program, feed=feeder.feed(data), fetch_list=[sum_cost]) loss_v = loss[0] if isinstance(loss, np.ndarray) else loss @@ -182,18 +169,25 @@ class TestImageMultiPrecision(unittest.TestCase): if not fluid.core.is_compiled_with_cuda(): return - def do_test(use_nesterov=False): - suffix = "with Nesterov" if use_nesterov else "without Nesterov" + def do_test(use_nesterov=False, use_adam=False): + if use_adam: + suffix = "use Adam" + else: + suffix = "with Nesterov" if use_nesterov else "without Nesterov" with self.scope_prog_guard(): print("-----------------FP16 Train {}-----------------".format( suffix)) train_loss_fp16, test_loss_fp16 = train( - use_pure_fp16=True, use_nesterov=use_nesterov) + use_pure_fp16=True, + use_nesterov=use_nesterov, + use_adam=use_adam) with self.scope_prog_guard(): print("-----------------FP32 Train {}-----------------".format( suffix)) train_loss_fp32, test_loss_fp32 = train( - use_pure_fp16=False, use_nesterov=use_nesterov) + use_pure_fp16=False, + use_nesterov=use_nesterov, + use_adam=use_adam) self.assertTrue( np.allclose( @@ -214,6 +208,7 @@ class TestImageMultiPrecision(unittest.TestCase): do_test(use_nesterov=False) do_test(use_nesterov=True) + do_test(use_adam=True) @contextlib.contextmanager def scope_prog_guard(self): @@ -260,7 +255,7 @@ class TestAmpWithNonIterableDataLoader(unittest.TestCase): op._set_attr('out_dtype', fluid.core.VarDesc.VarType.FP32) op._set_attr('dtype', fluid.core.VarDesc.VarType.FP32) - cast_model_to_fp16(main_prog) + cast_model_to_fp16(main_prog, use_fp16_guard=False) def test_non_iterable_dataloader(self): self.decorate_with_data_loader() diff --git a/python/paddle/optimizer/adam.py b/python/paddle/optimizer/adam.py index 2354a3b66a..cd6156d105 100644 --- a/python/paddle/optimizer/adam.py +++ b/python/paddle/optimizer/adam.py @@ -16,6 +16,10 @@ from .optimizer import Optimizer from ..fluid import core from ..fluid import framework from ..fluid.framework import Variable +from ..fluid import layers +from ..fluid import unique_name +from ..fluid.layer_helper import LayerHelper +import warnings from ..fluid.dygraph import base as imperative_base import paddle @@ -79,6 +83,7 @@ class Adam(Optimizer): gradient in current mini-batch, so it will be much more faster. But this mode has different semantics with the original Adam algorithm and may lead to different result. The default value is False. + multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false. name (str, optional): Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. The default value is None. @@ -135,6 +140,7 @@ class Adam(Optimizer): weight_decay=None, grad_clip=None, lazy_mode=False, + multi_precision=False, name=None): assert learning_rate is not None assert beta1 is not None @@ -157,28 +163,90 @@ class Adam(Optimizer): self._beta2 = beta2 self._epsilon = epsilon self._lazy_mode = lazy_mode + self._multi_precision = multi_precision + self._master_weights = {} + + def _create_master_weight(self, param): + assert isinstance(self.helper, LayerHelper) + + var_name = param.name + "_fp32_master" + var_name = unique_name.generate(var_name) + var = layers.create_global_var( + name=var_name, + shape=param.shape, + value=0, + dtype='float32', + persistable=True) + block = self.helper.startup_program.global_block() + block.append_op( + type="cast", + inputs={"X": [param]}, + outputs={"Out": [var]}, + attrs={ + "in_dtype": param.dtype, + "out_dtype": core.VarDesc.VarType.FP32 + }) + self._master_weights[param.name] = var + return var + + def _get_accumulator(self, name, param): + """Utility function to fetch an accumulator for a parameter + Args: + name: name of the accumulator + param: parameter variable for which accumulator is to be fetched + Returns: + accumulator variable for the parameter + """ + if self._name is not None: + name = self._name + "_" + name + find_master = self._multi_precision and param.dtype == core.VarDesc.VarType.FP16 + target_param = self._master_weights[ + param.name] if find_master else param + target_name = target_param.name + if (name not in self._accumulators or + target_name not in self._accumulators[name]): + raise Exception("Accumulator {} does not exist for parameter {}". + format(name, target_name)) + return self._accumulators[name][target_name] + + def _add_moments_pows(self, p): + acc_dtype = p.dtype + if acc_dtype == core.VarDesc.VarType.FP16: + acc_dtype = core.VarDesc.VarType.FP32 + self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype) + self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype) + self._add_accumulator( + name=self._beta1_pow_acc_str, + param=p, + dtype=acc_dtype, + fill_value=0.9 if isinstance(self._beta1, Variable) \ + else self._beta1, + shape=[1], + type=core.VarDesc.VarType.LOD_TENSOR, device='cpu') + self._add_accumulator( + name=self._beta2_pow_acc_str, + param=p, + dtype=acc_dtype, + fill_value=0.999 if isinstance(self._beta2, Variable) \ + else self._beta2, + shape=[1], + type=core.VarDesc.VarType.LOD_TENSOR, device='cpu') def _create_accumulators(self, block, parameters): assert isinstance(block, framework.Block) # Create accumulator tensors for first and second moments for p in parameters: - self._add_accumulator(self._moment1_acc_str, p) - self._add_accumulator(self._moment2_acc_str, p) - self._add_accumulator( - name=self._beta1_pow_acc_str, - param=p, - fill_value=0.9 if isinstance(self._beta1, Variable) \ - else self._beta1, - shape=[1], - type=core.VarDesc.VarType.LOD_TENSOR, device='cpu') - self._add_accumulator( - name=self._beta2_pow_acc_str, - param=p, - fill_value=0.999 if isinstance(self._beta2, Variable) \ - else self._beta2, - shape=[1], - type=core.VarDesc.VarType.LOD_TENSOR, device='cpu') + if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: + master_p = self._create_master_weight(p) + self._add_moments_pows(master_p) + continue + if p.dtype == core.VarDesc.VarType.FP16 and not self._multi_precision: + warnings.warn( + "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence." + "Consider using multi_precision=True option of the Momentum optimizer." + ) + self._add_moments_pows(p) def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) @@ -191,6 +259,10 @@ class Adam(Optimizer): param_and_grad[0]) beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str, param_and_grad[0]) + find_master = self._multi_precision and param_and_grad[ + 0].dtype == core.VarDesc.VarType.FP16 + master_weight = (self._master_weights[param_and_grad[0].name] + if find_master else None) lr = self._create_param_lr(param_and_grad) # create the adam optimize op @@ -227,7 +299,8 @@ class Adam(Optimizer): attrs = { "epsilon": self._epsilon, "lazy_mode": self._lazy_mode, - "min_row_size_to_use_multithread": 1000 + "min_row_size_to_use_multithread": 1000, + "multi_precision": find_master } if isinstance(self._beta1, Variable): @@ -239,6 +312,10 @@ class Adam(Optimizer): else: attrs['beta2'] = self._beta2 + if find_master: + inputs["MasterParam"] = master_weight + outputs["MasterParamOut"] = master_weight + adam_op = block.append_op( type=self.type, inputs=inputs, diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index 050ac2f031..ff560e8134 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -71,6 +71,7 @@ class AdamW(Adam): gradient in current mini-batch, so it will be much more faster. But this mode has different semantics with the original Adam algorithm and may lead to different result. The default value is False. + multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false. name (str, optional): Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. The default value is None. @@ -111,6 +112,7 @@ class AdamW(Adam): apply_decay_param_fun=None, grad_clip=None, lazy_mode=False, + multi_precision=False, name=None): assert learning_rate is not None assert beta1 is not None @@ -138,7 +140,8 @@ class AdamW(Adam): epsilon=epsilon, grad_clip=grad_clip, name=name, - lazy_mode=lazy_mode) + lazy_mode=lazy_mode, + multi_precision=multi_precision) def _append_decoupled_weight_decay(self, block, param_and_grad): """ diff --git a/python/paddle/optimizer/momentum.py b/python/paddle/optimizer/momentum.py index bfcd2bc038..5fc5506ec3 100644 --- a/python/paddle/optimizer/momentum.py +++ b/python/paddle/optimizer/momentum.py @@ -128,21 +128,6 @@ class Momentum(Optimizer): self.helper = LayerHelper(self.__class__.__name__) for p in parameters: self._add_accumulator(self._velocity_acc_str, p) - else: - all_parameters = fluid.default_main_program().global_block( - ).all_parameters() - self.helper = LayerHelper(self.__class__.__name__) - for p in all_parameters: - if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: - master_p = self._create_master_weight(p) - self._add_accumulator(self._velocity_acc_str, master_p) - continue - if p.dtype == core.VarDesc.VarType.FP16 and not self._multi_precision: - warnings.warn( - "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence." - "Consider using multi_precision=True option of the Momentum optimizer." - ) - self._add_accumulator(self._velocity_acc_str, p) def _create_master_weight(self, param): assert isinstance(self.helper, LayerHelper) @@ -190,8 +175,21 @@ class Momentum(Optimizer): return self._accumulators[name][target_name] def _create_accumulators(self, block, parameters): + if framework.in_dygraph_mode(): + return + assert isinstance(block, framework.Block) - # create accumulator in init func, so no implementation here + for p in parameters: + if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: + master_p = self._create_master_weight(p) + self._add_accumulator(self._velocity_acc_str, master_p) + continue + if p.dtype == core.VarDesc.VarType.FP16 and not self._multi_precision: + warnings.warn( + "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence." + "Consider using multi_precision=True option of the Momentum optimizer." + ) + self._add_accumulator(self._velocity_acc_str, p) def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) -- GitLab