未验证 提交 d8dfef54 编写于 作者: Z Zhen Wang 提交者: GitHub

[Cherry-Pick] Support pure fp16 training for AMP API. (#29544) (#30241)

* Support pure fp16 training for AMP API. (#29544)

* add cast ops before and after unsupported fp16 ops.

* Keep partial net in FP32 pattern.

* Support check_finite_and_unscale and update_loss_scaling for FP16 calculation mode.

* Add fp16 support for adam op.

* add multi precision attr for adam.

* Fix the bug of test_multi_precision_fp16_train UT.

* Code format for CI.

* Fix the redefine error about MPTypeTrait on windows.

* fix bugs of the _create_accumulators func in Momentum.

* fix bug when inserting post cast op.

* Add the update_loss_scaling op in allow_set of UnusedVarCheck.

* Update for ci coverage.

* Add some doc for OptimizerWithMixedPrecision.

* Fix the code style.

* Imporve the doc of `amp_init`.

* Change for fp16 testing if users have the infer program defined in separate way.

* Remove tensor copy in the update_loss_scaling op. (#29426)

* remove tensor copy in the update_loss_scaling op

* not use thrust.

* fix some cuda memory access error.
上级 b4931ab1
......@@ -73,7 +73,8 @@ static const std::unordered_set<std::string> &GetOpWithUnusedVarAllowSet() {
"fused_batch_norm_act", // 2
"fused_batch_norm_act_grad", // 2
"data_norm", // 0
"data_norm_grad", // 0);
"data_norm_grad", // 0
"update_loss_scaling", // 0
});
return *allow_set;
}
......
......@@ -15,6 +15,8 @@ limitations under the License. */
#include <cuda.h>
#include "paddle/fluid/operators/amp/check_finite_and_unscale_op.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
......@@ -25,15 +27,16 @@ __global__ void InverseAndMemset(const T* s, T* o, bool* found_inf) {
*found_inf = false;
}
template <typename T>
__global__ void CheckFiniteAndUnscale(const T* in, const T* scale, int num,
template <typename T, typename MT>
__global__ void CheckFiniteAndUnscale(const T* in, const MT* scale, int num,
bool* found_inf, T* out) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < num) {
T val = in[idx] * (*scale);
out[idx] = val;
if (!isfinite(val)) {
MT val = static_cast<MT>(in[idx]) * (*scale);
T narrow_val = static_cast<T>(val);
out[idx] = narrow_val;
if (!isfinite(narrow_val)) {
*found_inf = true;
}
}
......@@ -41,6 +44,8 @@ __global__ void CheckFiniteAndUnscale(const T* in, const T* scale, int num,
template <typename T>
class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
using MPDType = typename details::MPTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
......@@ -49,14 +54,15 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
auto* found_inf = ctx.Output<framework::Tensor>("FoundInfinite");
const T* scale_data = scale->data<T>();
const MPDType* scale_data = scale->data<MPDType>();
bool* found_inf_data = found_inf->mutable_data<bool>(dev_ctx.GetPlace());
framework::Tensor inverse_scale =
ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({1}, dev_ctx);
T* inverse_scale_v = inverse_scale.template data<T>();
ctx.AllocateTmpTensor<MPDType, platform::CUDADeviceContext>({1},
dev_ctx);
MPDType* inverse_scale_v = inverse_scale.template data<MPDType>();
InverseAndMemset<T><<<1, 1, 0, dev_ctx.stream()>>>(
InverseAndMemset<MPDType><<<1, 1, 0, dev_ctx.stream()>>>(
scale_data, inverse_scale_v, found_inf_data);
for (size_t i = 0; i < xs.size(); ++i) {
......@@ -69,7 +75,7 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
int block = 1024;
int grid = (num + block - 1) / block;
VLOG(3) << "launch kernel";
CheckFiniteAndUnscale<T><<<grid, block, 0, dev_ctx.stream()>>>(
CheckFiniteAndUnscale<T, MPDType><<<grid, block, 0, dev_ctx.stream()>>>(
x_data, inverse_scale_v, num, found_inf_data, out_data);
VLOG(3) << "finish kernel";
}
......@@ -79,6 +85,8 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(check_finite_and_unscale,
ops::CheckFiniteAndUnscaleGpuKernel<float>,
ops::CheckFiniteAndUnscaleGpuKernel<double>);
ops::CheckFiniteAndUnscaleGpuKernel<double>,
ops::CheckFiniteAndUnscaleGpuKernel<plat::float16>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
namespace details {
template <typename T>
class MPTypeTrait {
public:
using Type = T;
};
template <>
class MPTypeTrait<platform::float16> {
public:
using Type = float;
};
} // namespace details
} // namespace operators
} // namespace paddle
......@@ -54,8 +54,7 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "PrevLossScaling"),
ctx.device_context());
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
......@@ -107,6 +106,9 @@ class UpdateLossScalingOpMaker : public framework::OpProtoAndCheckerMaker {
"the received is %f",
decr_ratio));
});
AddAttr<bool>("stop_update",
"Stop updating loss scaling, and just zero inputs.")
.SetDefault(false);
AddComment(R"DOC(
Update loss scaling according to overall gradients. If all gradients is
finite after incr_every_n_steps, loss scaling will increase by incr_ratio.
......@@ -135,18 +137,18 @@ class UpdateLossScalingFunctor<platform::CPUDeviceContext, T> {
};
template <typename T>
class LazyZeroInputs<platform::CPUDeviceContext, T> {
class LazyZeros<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& dev_ctx,
const bool* found_inf_data,
const std::vector<const framework::Tensor*>& xs,
const std::vector<framework::Tensor*>& outs) const {
if (*found_inf_data) {
VLOG(1) << "-- UpdateLossScaling: Infinite values are found in grads. --";
for (size_t i = 0; i < xs.size(); ++i) {
auto* out = outs[i];
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
int num = out->numel();
for (size_t i = 0; i < xs.size(); ++i) {
auto* out = outs[i];
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
int num = out->numel();
if (*found_inf_data) {
VLOG(1) << "-- UpdateLossScaling: Find infinite grads. --";
std::memset(out_data, 0, num * sizeof(T));
}
}
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/update_loss_scaling_op.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
......@@ -32,6 +33,17 @@ __global__ void GpuUpdateLossScaling(
updated_loss_scaling_data, good_out_data, bad_out_data);
}
template <typename T>
__global__ void FillIf(T* data, const int64_t num, const T value,
const bool* has_inf) {
if (*has_inf) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < num; i += blockDim.x * gridDim.x) {
data[i] = value;
}
}
}
template <typename T>
class UpdateLossScalingFunctor<platform::CUDADeviceContext, T> {
public:
......@@ -50,26 +62,20 @@ class UpdateLossScalingFunctor<platform::CUDADeviceContext, T> {
};
template <typename T>
class LazyZeroInputs<platform::CUDADeviceContext, T> {
class LazyZeros<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& dev_ctx,
const bool* found_inf_data,
const std::vector<const framework::Tensor*>& xs,
const std::vector<framework::Tensor*>& outs) const {
const auto gpu_place =
BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace());
bool has_inf{false};
memory::Copy(platform::CPUPlace(), &has_inf, gpu_place, found_inf_data,
sizeof(bool), dev_ctx.stream());
dev_ctx.Wait(); // wait async copy
if (has_inf) {
VLOG(1) << "-- UpdateLossScaling: Infinite values are found in grads. --";
for (size_t i = 0; i < xs.size(); ++i) {
auto* out = outs[i];
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
int num = out->numel();
cudaMemsetAsync(out_data, 0, num * sizeof(T), dev_ctx.stream());
}
for (size_t i = 0; i < xs.size(); ++i) {
auto* out = outs[i];
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
int64_t num = out->numel();
int block = 1024;
int grid = (block - 1 + num) / block;
FillIf<<<grid, block, 0, dev_ctx.stream()>>>(
out_data, num, static_cast<T>(0), found_inf_data);
}
}
};
......@@ -78,8 +84,10 @@ class LazyZeroInputs<platform::CUDADeviceContext, T> {
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
using GPU = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(update_loss_scaling,
ops::UpdateLossScalingKernel<GPU, float>,
ops::UpdateLossScalingKernel<GPU, double>);
ops::UpdateLossScalingKernel<GPU, double>,
ops::UpdateLossScalingKernel<GPU, plat::float16>);
......@@ -17,6 +17,7 @@
#include <cmath>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
......@@ -70,7 +71,7 @@ class UpdateLossScalingFunctor {
};
template <typename DeviceContext, typename T>
class LazyZeroInputs {
class LazyZeros {
public:
void operator()(const DeviceContext& dev_ctx, const bool* found_inf_data,
const std::vector<const framework::Tensor*>& xs,
......@@ -79,30 +80,38 @@ class LazyZeroInputs {
template <typename DeviceContext, typename T>
class UpdateLossScalingKernel : public framework::OpKernel<T> {
using MPDType = typename details::MPTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<DeviceContext>();
const auto xs = ctx.MultiInput<framework::Tensor>("X");
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
const auto* found_inf = ctx.Input<Tensor>("FoundInfinite");
PADDLE_ENFORCE_EQ(found_inf->numel(), 1,
platform::errors::InvalidArgument(
"FoundInfinite must has only one element."));
const bool* found_inf_data = found_inf->data<bool>();
LazyZeros<DeviceContext, T>{}(dev_ctx, found_inf_data, xs, outs);
const bool stop_update = ctx.Attr<bool>("stop_update");
if (stop_update) {
return;
}
const auto* pre_loss_scaling = ctx.Input<Tensor>("PrevLossScaling");
const auto* good_in = ctx.Input<Tensor>("InGoodSteps");
const auto* bad_in = ctx.Input<Tensor>("InBadSteps");
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
auto* updated_loss_scaling = ctx.Output<Tensor>("LossScaling");
auto* good_out = ctx.Output<Tensor>("OutGoodSteps");
auto* bad_out = ctx.Output<Tensor>("OutBadSteps");
PADDLE_ENFORCE_EQ(found_inf->numel(), 1,
platform::errors::InvalidArgument(
"FoundInfinite must has only one element."));
const bool* found_inf_data = found_inf->data<bool>();
const T* pre_loss_scaling_data = pre_loss_scaling->data<T>();
const MPDType* pre_loss_scaling_data = pre_loss_scaling->data<MPDType>();
const int* good_in_data = good_in->data<int>();
const int* bad_in_data = bad_in->data<int>();
auto& dev_ctx = ctx.template device_context<DeviceContext>();
T* updated_loss_scaling_data =
updated_loss_scaling->mutable_data<T>(dev_ctx.GetPlace());
MPDType* updated_loss_scaling_data =
updated_loss_scaling->mutable_data<MPDType>(dev_ctx.GetPlace());
int* good_out_data = good_out->mutable_data<int>(dev_ctx.GetPlace());
int* bad_out_data = bad_out->mutable_data<int>(dev_ctx.GetPlace());
......@@ -111,11 +120,10 @@ class UpdateLossScalingKernel : public framework::OpKernel<T> {
ctx.Attr<int>("decr_every_n_nan_or_inf");
const float incr_ratio = ctx.Attr<float>("incr_ratio");
const float decr_ratio = ctx.Attr<float>("decr_ratio");
UpdateLossScalingFunctor<DeviceContext, T>{}(
UpdateLossScalingFunctor<DeviceContext, MPDType>{}(
dev_ctx, found_inf_data, pre_loss_scaling_data, good_in_data,
bad_in_data, incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio,
decr_ratio, updated_loss_scaling_data, good_out_data, bad_out_data);
LazyZeroInputs<DeviceContext, T>{}(dev_ctx, found_inf_data, xs, outs);
}
};
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/optimizers/adam_op.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace operators {
......@@ -150,12 +151,17 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
"as beta2, this has a higher priority than attr(beta2), the "
"shape of this tensor MUST BE [1].")
.AsDispensable();
AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable();
AddOutput("ParamOut", "(Tensor) Output parameter");
AddOutput("Moment1Out", "(Tensor) Output first moment");
AddOutput("Moment2Out", "(Tensor) Output second moment");
AddOutput("Beta1PowOut", "(Tensor) Output beta1 power accumulator");
AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator");
AddOutput("MasterParamOut",
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam).")
.AsDispensable();
AddAttr<float>("beta1",
"(float, default 0.9) "
......@@ -183,6 +189,10 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
"inner_op_parallelism is larger then 0, sparse update "
"will run in multithread mode")
.SetDefault(1000);
AddAttr<bool>("multi_precision",
"(bool, default false) "
"Whether to use multi-precision during weight updating.")
.SetDefault(false);
AddComment(R"DOC(
Adam Optimizer.
......@@ -213,3 +223,13 @@ REGISTER_OP_WITHOUT_GRADIENT(adam, ops::AdamOp, ops::AdamOpMaker);
REGISTER_OP_CPU_KERNEL(
adam, ops::AdamOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::AdamOpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_VERSION(adam)
.AddCheckpoint(
R"ROC(
Upgrade adam add 1 attribute [multi_precision].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"multi_precision",
"(bool) Whether to use multi-precision during weight updating.",
false));
......@@ -11,70 +11,81 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/optimizers/adam_op.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
template <typename T>
__global__ void AdamKernelREG(T beta1, T beta2, T epsilon, T beta1_pow_,
T beta2_pow_, const T* moment1, T* moment1_out,
const T* moment2, T* moment2_out, const T* lr_,
template <typename T, typename MT>
__global__ void AdamKernelREG(MT beta1, MT beta2, MT epsilon, MT beta1_pow_,
MT beta2_pow_, const MT* moment1, MT* moment1_out,
const MT* moment2, MT* moment2_out, const MT* lr_,
const T* grad, const T* param, T* param_out,
const MT* master_param, MT* master_param_out,
int ndim) {
T lr = *lr_;
T beta1_pow = beta1_pow_;
T beta2_pow = beta2_pow_;
MT lr = *lr_;
MT beta1_pow = beta1_pow_;
MT beta2_pow = beta2_pow_;
lr *=
sqrt(static_cast<T>(1.0) - beta2_pow) / (static_cast<T>(1.0) - beta1_pow);
lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow);
int id = blockIdx.x * blockDim.x + threadIdx.x;
for (; id < ndim; id += gridDim.x * blockDim.x) {
T p = param[id];
T g = grad[id];
T mom1 = moment1[id];
T mom2 = moment2[id];
mom1 = beta1 * mom1 + (static_cast<T>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<T>(1.0) - beta2) * g * g;
MT p = master_param ? master_param[id] : static_cast<MT>(param[id]);
MT g = static_cast<MT>(grad[id]);
MT mom1 = moment1[id];
MT mom2 = moment2[id];
mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
p -= lr * (mom1 /
(sqrt(mom2) + epsilon * sqrt(static_cast<T>(1.0) - beta2_pow)));
(sqrt(mom2) + epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));
moment1_out[id] = mom1;
moment2_out[id] = mom2;
param_out[id] = p;
param_out[id] = static_cast<T>(p);
if (master_param_out) {
master_param_out[id] = p;
}
}
}
template <typename T>
__global__ void AdamKernelMEM(T beta1, T beta2, T epsilon, const T* beta1_pow_,
const T* beta2_pow_, const T* moment1,
T* moment1_out, const T* moment2, T* moment2_out,
const T* lr_, const T* grad, const T* param,
T* param_out, int ndim) {
T lr = *lr_;
T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_;
lr *=
sqrt(static_cast<T>(1.0) - beta2_pow) / (static_cast<T>(1.0) - beta1_pow);
template <typename T, typename MT>
__global__ void AdamKernelMEM(MT beta1, MT beta2, MT epsilon,
const MT* beta1_pow_, const MT* beta2_pow_,
const MT* moment1, MT* moment1_out,
const MT* moment2, MT* moment2_out, const MT* lr_,
const T* grad, const T* param, T* param_out,
const MT* master_param, MT* master_param_out,
int ndim) {
MT lr = *lr_;
MT beta1_pow = *beta1_pow_;
MT beta2_pow = *beta2_pow_;
lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow);
int id = blockIdx.x * blockDim.x + threadIdx.x;
for (; id < ndim; id += gridDim.x * blockDim.x) {
T p = param[id];
T g = grad[id];
T mom1 = moment1[id];
T mom2 = moment2[id];
mom1 = beta1 * mom1 + (static_cast<T>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<T>(1.0) - beta2) * g * g;
MT p = master_param ? master_param[id] : static_cast<MT>(param[id]);
MT g = static_cast<MT>(grad[id]);
MT mom1 = static_cast<MT>(moment1[id]);
MT mom2 = static_cast<MT>(moment2[id]);
mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
p -= lr * (mom1 /
(sqrt(mom2) + epsilon * sqrt(static_cast<T>(1.0) - beta2_pow)));
(sqrt(mom2) + epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));
moment1_out[id] = mom1;
moment2_out[id] = mom2;
param_out[id] = p;
param_out[id] = static_cast<T>(p);
if (master_param_out) {
master_param_out[id] = p;
}
}
}
template <typename T>
......@@ -85,15 +96,17 @@ __global__ void UpdateBetaPow(T beta1, T beta2, const T* beta1_pow_,
*beta2_pow_out = beta2 * beta2_pow_[0];
}
template <typename T>
template <typename T, typename MT>
__global__ void SparseAdamCUDAKernelREG(
T beta1, T beta2, T epsilon, const T beta1_pow, const T beta2_pow,
const T* mom1_, T* mom1_out_, const T* mom2_, T* mom2_out_, const T* lr_,
const T* grad_, const T* param_, T* param_out_, const int64_t* rows_,
MT beta1, MT beta2, MT epsilon, const MT beta1_pow, const MT beta2_pow,
const MT* mom1_, MT* mom1_out_, const MT* mom2_, MT* mom2_out_,
const MT* lr_, const T* grad_, const T* param_, T* param_out_,
const MT* master_param, MT* master_param_out, const int64_t* rows_,
int64_t row_numel, int64_t row_count, bool lazy_mode, int ndim) {
int id = blockIdx.x * blockDim.x + threadIdx.x;
T lr = *lr_;
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
MT lr = *lr_;
lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow);
for (; id < ndim; id += blockDim.x * gridDim.x) {
auto row_idx =
......@@ -101,19 +114,24 @@ __global__ void SparseAdamCUDAKernelREG(
if (lazy_mode && row_idx < 0) {
return;
} else {
T mom1 = mom1_[id];
T mom2 = mom2_[id];
T p = param_[id];
T g = row_idx >= 0 ? grad_[row_idx * row_numel + id % row_numel] : 0;
mom1 = beta1 * mom1 + (1 - beta1) * g;
mom2 = beta2 * mom2 + (1 - beta2) * g * g;
MT mom1 = mom1_[id];
MT mom2 = mom2_[id];
MT p = master_param ? master_param[id] : static_cast<MT>(param_[id]);
MT g = row_idx >= 0
? static_cast<MT>(grad_[row_idx * row_numel + id % row_numel])
: static_cast<MT>(0);
mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
p -= lr * (mom1 / (sqrt(mom2) +
epsilon * sqrt(static_cast<T>(1.0) - beta2_pow)));
epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));
// Write back to global memory
mom1_out_[id] = mom1;
mom2_out_[id] = mom2;
param_out_[id] = p;
param_out_[id] = static_cast<T>(p);
if (master_param_out) {
master_param_out[id] = p;
}
}
}
}
......@@ -131,11 +149,12 @@ class AdamOpCUDAKernel : public framework::OpKernel<T> {
framework::ToTypeName(param_var->Type())));
using paddle::framework::LoDTensor;
using MPDType = typename details::MPTypeTrait<T>::Type;
int64_t min_row_size_to_use_multithread =
ctx.Attr<int64_t>("min_row_size_to_use_multithread");
bool lazy_mode = ctx.Attr<bool>("lazy_mode");
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
MPDType epsilon = static_cast<MPDType>(ctx.Attr<float>("epsilon"));
auto* param = ctx.Input<LoDTensor>("Param");
auto* grad_var = ctx.InputVar("Grad");
auto* mom1 = ctx.Input<LoDTensor>("Moment1");
......@@ -151,23 +170,23 @@ class AdamOpCUDAKernel : public framework::OpKernel<T> {
auto* beta1_pow_out = ctx.Output<LoDTensor>("Beta1PowOut");
auto* beta2_pow_out = ctx.Output<LoDTensor>("Beta2PowOut");
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
MPDType beta1 = static_cast<MPDType>(ctx.Attr<float>("beta1"));
if (ctx.HasInput("Beta1Tensor")) {
auto* beta1_tensor = ctx.Input<framework::Tensor>("Beta1Tensor");
PADDLE_ENFORCE_EQ(beta1_tensor->numel(), 1,
platform::errors::InvalidArgument(
"Input(Beta1Tensor) size must be 1, but get %d",
beta1_tensor->numel()));
beta1 = static_cast<T>(GetAttrFromTensor(beta1_tensor));
beta1 = static_cast<MPDType>(GetAttrFromTensor(beta1_tensor));
}
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
MPDType beta2 = static_cast<MPDType>(ctx.Attr<float>("beta2"));
if (ctx.HasInput("Beta2Tensor")) {
auto* beta2_tensor = ctx.Input<framework::Tensor>("Beta2Tensor");
PADDLE_ENFORCE_EQ(beta2_tensor->numel(), 1,
platform::errors::InvalidArgument(
"Input(Beta2Tensor) size must be 1, but get %d",
beta2_tensor->numel()));
beta2 = static_cast<T>(GetAttrFromTensor(beta2_tensor));
beta2 = static_cast<MPDType>(GetAttrFromTensor(beta2_tensor));
}
VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel()
<< "beta2_pow.numel() : " << beta2_pow->numel();
......@@ -183,6 +202,28 @@ class AdamOpCUDAKernel : public framework::OpKernel<T> {
"beta2 pow output size should be 1, but received "
"value is:%d.",
beta2_pow_out->numel()));
const bool multi_precision = ctx.Attr<bool>("multi_precision");
const LoDTensor* master_param = nullptr;
LoDTensor* master_param_out = nullptr;
if (multi_precision) {
bool has_master =
ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
PADDLE_ENFORCE_EQ(has_master, true,
platform::errors::InvalidArgument(
"The Input(MasterParam) and Output(MasterParamOut) "
"should not be null when "
"the attr `multi_precision` is true"));
master_param = ctx.Input<LoDTensor>("MasterParam");
master_param_out = ctx.Output<LoDTensor>("MasterParamOut");
}
const MPDType* master_in_data =
multi_precision ? master_param->data<MPDType>() : nullptr;
MPDType* master_out_data =
multi_precision
? master_param_out->mutable_data<MPDType>(ctx.GetPlace())
: nullptr;
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
if (grad_var->IsType<framework::LoDTensor>()) {
......@@ -195,29 +236,36 @@ class AdamOpCUDAKernel : public framework::OpKernel<T> {
if (beta1_pow->place() == platform::CPUPlace() &&
beta2_pow->place() == platform::CPUPlace()) {
// Compute with betapow in REG
AdamKernelREG<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1, beta2, epsilon, *beta1_pow->data<T>(), *beta2_pow->data<T>(),
mom1->data<T>(), mom1_out->mutable_data<T>(ctx.GetPlace()),
mom2->data<T>(), mom2_out->mutable_data<T>(ctx.GetPlace()),
lr->data<T>(), grad->data<T>(), param->data<T>(),
param_out->mutable_data<T>(ctx.GetPlace()), param->numel());
AdamKernelREG<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1, beta2, epsilon, *beta1_pow->data<MPDType>(),
*beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
mom2->data<MPDType>(),
mom2_out->mutable_data<MPDType>(ctx.GetPlace()),
lr->data<MPDType>(), grad->data<T>(), param->data<T>(),
param_out->mutable_data<T>(ctx.GetPlace()), master_in_data,
master_out_data, param->numel());
// Cpu update
beta1_pow_out->mutable_data<T>(platform::CPUPlace())[0] =
beta1 * beta1_pow->data<T>()[0];
beta2_pow_out->mutable_data<T>(platform::CPUPlace())[0] =
beta2 * beta2_pow->data<T>()[0];
beta1_pow_out->mutable_data<MPDType>(platform::CPUPlace())[0] =
beta1 * beta1_pow->data<MPDType>()[0];
beta2_pow_out->mutable_data<MPDType>(platform::CPUPlace())[0] =
beta2 * beta2_pow->data<MPDType>()[0];
} else {
AdamKernelMEM<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1, beta2, epsilon, beta1_pow->data<T>(), beta2_pow->data<T>(),
mom1->data<T>(), mom1_out->mutable_data<T>(ctx.GetPlace()),
mom2->data<T>(), mom2_out->mutable_data<T>(ctx.GetPlace()),
lr->data<T>(), grad->data<T>(), param->data<T>(),
param_out->mutable_data<T>(ctx.GetPlace()), param->numel());
AdamKernelMEM<T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1, beta2, epsilon, beta1_pow->data<MPDType>(),
beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
mom2->data<MPDType>(),
mom2_out->mutable_data<MPDType>(ctx.GetPlace()),
lr->data<MPDType>(), grad->data<T>(), param->data<T>(),
param_out->mutable_data<T>(ctx.GetPlace()), master_in_data,
master_out_data, param->numel());
// Update with gpu
UpdateBetaPow<T><<<1, 32, 0, dev_ctx.stream()>>>(
beta1, beta2, beta1_pow->data<T>(), beta2_pow->data<T>(),
beta1_pow_out->mutable_data<T>(ctx.GetPlace()),
beta2_pow_out->mutable_data<T>(ctx.GetPlace()));
UpdateBetaPow<MPDType><<<1, 32, 0, dev_ctx.stream()>>>(
beta1, beta2, beta1_pow->data<MPDType>(),
beta2_pow->data<MPDType>(),
beta1_pow_out->mutable_data<MPDType>(ctx.GetPlace()),
beta2_pow_out->mutable_data<MPDType>(ctx.GetPlace()));
}
} else if (grad_var->IsType<framework::SelectedRows>()) {
......@@ -260,26 +308,33 @@ class AdamOpCUDAKernel : public framework::OpKernel<T> {
int ndim = param->numel();
int blocks = (ndim + threads - 1) / threads;
SparseAdamCUDAKernelREG<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1, beta2, epsilon, *beta1_pow->data<T>(), *beta2_pow->data<T>(),
mom1->data<T>(), mom1_out->mutable_data<T>(ctx.GetPlace()),
mom2->data<T>(), mom2_out->mutable_data<T>(ctx.GetPlace()),
lr->data<T>(), grad_data, param->data<T>(),
param_out->mutable_data<T>(ctx.GetPlace()), rows, row_numel,
grad_merge.rows().size(), lazy_mode, ndim);
SparseAdamCUDAKernelREG<
T, MPDType><<<blocks, threads, 0, dev_ctx.stream()>>>(
beta1, beta2, epsilon, *beta1_pow->data<MPDType>(),
*beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
mom2->data<MPDType>(),
mom2_out->mutable_data<MPDType>(ctx.GetPlace()),
lr->data<MPDType>(), grad_data, param->data<T>(),
param_out->mutable_data<T>(ctx.GetPlace()), master_in_data,
master_out_data, rows, row_numel, grad_merge.rows().size(),
lazy_mode, ndim);
// Update with cpu
beta1_pow_out->mutable_data<T>(platform::CPUPlace())[0] =
beta1 * beta1_pow->data<T>()[0];
beta2_pow_out->mutable_data<T>(platform::CPUPlace())[0] =
beta2 * beta2_pow->data<T>()[0];
beta1_pow_out->mutable_data<MPDType>(platform::CPUPlace())[0] =
beta1 * beta1_pow->data<MPDType>()[0];
beta2_pow_out->mutable_data<MPDType>(platform::CPUPlace())[0] =
beta2 * beta2_pow->data<MPDType>()[0];
} else {
SparseAdamFunctor<T, GPUAdam> functor(
beta1, beta2, epsilon, beta1_pow->data<T>(), beta2_pow->data<T>(),
mom1->data<T>(), mom1_out->mutable_data<T>(ctx.GetPlace()),
mom2->data<T>(), mom2_out->mutable_data<T>(ctx.GetPlace()),
lr->data<T>(), grad_data, param->data<T>(),
param_out->mutable_data<T>(ctx.GetPlace()), rows, row_numel,
grad_merge.rows().size(), lazy_mode);
SparseAdamFunctor<T, GPUAdam, MPDType> functor(
beta1, beta2, epsilon, beta1_pow->data<MPDType>(),
beta2_pow->data<MPDType>(), mom1->data<MPDType>(),
mom1_out->mutable_data<MPDType>(ctx.GetPlace()),
mom2->data<MPDType>(),
mom2_out->mutable_data<MPDType>(ctx.GetPlace()),
lr->data<MPDType>(), grad_data, param->data<T>(),
param_out->mutable_data<T>(ctx.GetPlace()), master_in_data,
master_out_data, rows, row_numel, grad_merge.rows().size(),
lazy_mode);
// FIXME(minqiyang): remove BinarySearch in GPU later
platform::ForRange<platform::CUDADeviceContext> for_range(
......@@ -288,10 +343,11 @@ class AdamOpCUDAKernel : public framework::OpKernel<T> {
param->numel());
for_range(functor);
// update beta1 and beta2
UpdateBetaPow<T><<<1, 32, 0, dev_ctx.stream()>>>(
beta1, beta2, beta1_pow->data<T>(), beta2_pow->data<T>(),
beta1_pow_out->mutable_data<T>(ctx.GetPlace()),
beta2_pow_out->mutable_data<T>(ctx.GetPlace()));
UpdateBetaPow<MPDType><<<1, 32, 0, dev_ctx.stream()>>>(
beta1, beta2, beta1_pow->data<MPDType>(),
beta2_pow->data<MPDType>(),
beta1_pow_out->mutable_data<MPDType>(ctx.GetPlace()),
beta2_pow_out->mutable_data<MPDType>(ctx.GetPlace()));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
......@@ -304,5 +360,8 @@ class AdamOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(adam, ops::AdamOpCUDAKernel<float>,
ops::AdamOpCUDAKernel<double>);
ops::AdamOpCUDAKernel<double>,
ops::AdamOpCUDAKernel<plat::float16>);
......@@ -191,26 +191,28 @@ class AdamFunctor<T, CPUAdam> {
}
};
template <typename T, typename Flavour>
template <typename T, typename Flavour, typename MT = T>
class SparseAdamFunctor;
template <typename T>
class SparseAdamFunctor<T, GPUAdam> {
template <typename T, typename MT>
class SparseAdamFunctor<T, GPUAdam, MT> {
private:
T beta1_;
T beta2_;
T epsilon_;
const T* beta1_pow_;
const T* beta2_pow_;
const T* moment1_;
T* moment1_out_;
const T* moment2_;
T* moment2_out_;
const T* lr_;
MT beta1_;
MT beta2_;
MT epsilon_;
const MT* beta1_pow_;
const MT* beta2_pow_;
const MT* moment1_;
MT* moment1_out_;
const MT* moment2_;
MT* moment2_out_;
const MT* lr_;
const T* grad_;
const T* param_;
T* param_out_;
const MT* master_param_;
MT* master_param_out_;
const int64_t* rows_;
int64_t row_numel_;
......@@ -218,10 +220,11 @@ class SparseAdamFunctor<T, GPUAdam> {
bool lazy_mode_;
public:
SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
const T* beta2_pow, const T* mom1, T* mom1_out,
const T* mom2, T* mom2_out, const T* lr, const T* grad,
const T* param, T* param_out, const int64_t* rows,
SparseAdamFunctor(MT beta1, MT beta2, MT epsilon, const MT* beta1_pow,
const MT* beta2_pow, const MT* mom1, MT* mom1_out,
const MT* mom2, MT* mom2_out, const MT* lr, const T* grad,
const T* param, T* param_out, const MT* master_param,
MT* master_param_out, const int64_t* rows,
int64_t row_numel, int64_t row_count, bool lazy_mode)
: beta1_(beta1),
beta2_(beta2),
......@@ -236,31 +239,38 @@ class SparseAdamFunctor<T, GPUAdam> {
grad_(grad),
param_(param),
param_out_(param_out),
master_param_(master_param),
master_param_out_(master_param_out),
rows_(rows),
row_numel_(row_numel),
row_count_(row_count),
lazy_mode_(lazy_mode) {}
inline HOSTDEVICE void adam_update(size_t i, T g) const {
inline HOSTDEVICE void adam_update(size_t i, MT g) const {
// The following code is the same as dense
T mom1 = moment1_[i];
T mom2 = moment2_[i];
T lr = *lr_;
T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_;
T p = param_[i];
MT mom1 = moment1_[i];
MT mom2 = moment2_[i];
MT lr = *lr_;
MT beta1_pow = *beta1_pow_;
MT beta2_pow = *beta2_pow_;
MT p = master_param_ ? master_param_[i] : static_cast<MT>(param_[i]);
// Calculation
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow);
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
p -= lr * (mom1 / (sqrt(mom2) + epsilon_ * sqrt(1 - beta2_pow)));
mom1 = beta1_ * mom1 + (static_cast<MT>(1.0) - beta1_) * g;
mom2 = beta2_ * mom2 + (static_cast<MT>(1.0) - beta2_) * g * g;
p -= lr * (mom1 / (sqrt(mom2) +
epsilon_ * sqrt(static_cast<MT>(1.0) - beta2_pow)));
// Write back to global memory
moment1_out_[i] = mom1;
moment2_out_[i] = mom2;
param_out_[i] = p;
param_out_[i] = static_cast<T>(p);
if (master_param_out_) {
master_param_out_[i] = p;
}
}
inline HOSTDEVICE void operator()(size_t i) const {
......@@ -269,14 +279,16 @@ class SparseAdamFunctor<T, GPUAdam> {
if (lazy_mode_ && row_idx < 0) {
return;
} else {
T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0;
MT g = row_idx >= 0
? static_cast<MT>(grad_[row_idx * row_numel_ + i % row_numel_])
: static_cast<MT>(0);
adam_update(i, g);
}
}
};
template <typename T>
class SparseAdamFunctor<T, CPUAdam> {
class SparseAdamFunctor<T, CPUAdam, T> {
private:
T beta1_;
T beta2_;
......
......@@ -115,7 +115,8 @@ REGISTER_OP_CPU_KERNEL(
REGISTER_OP_VERSION(momentum)
.AddCheckpoint(
R"ROC(
Upgrade momentum add 2 attributes [regularization_method, regularization_coeff].
Upgrade momentum add 4 attributes [regularization_method, regularization_coeff,
multi_precision, rescale_grad].
)ROC",
paddle::framework::compatible::OpVersionDesc()
.NewAttr("regularization_method",
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/float16.h"
......@@ -32,17 +33,6 @@ struct UseNesterov;
namespace details {
template <typename T>
class MPTypeTrait {
public:
using Type = T;
};
template <>
class MPTypeTrait<platform::float16> {
public:
using Type = float;
};
template <typename T>
struct CPUDenseUpdater {
template <typename G>
......
......@@ -15,6 +15,7 @@
from paddle.fluid.data_feeder import check_variable_and_dtype, check_type
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import Variable
from paddle.fluid import core
__all__ = ['check_finite_and_unscale', 'update_loss_scaling']
......@@ -35,7 +36,7 @@ def check_finite_and_unscale(x, scale, name=None):
"""
check_type(x, 'x', (tuple, list), 'check_finite_and_unscale')
for e in x:
check_variable_and_dtype(e, "x", ['float32', 'float64'],
check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'],
'check_finite_and_unscale')
helper = LayerHelper("check_finite_and_unscale", **locals())
......@@ -58,6 +59,7 @@ def update_loss_scaling(x,
decr_every_n_nan_or_inf,
incr_ratio,
decr_ratio,
stop_update=False,
name=None):
"""
Update loss scaling according to overall gradients. If all gradients is
......@@ -90,9 +92,13 @@ def update_loss_scaling(x,
['float32', 'float64'], "update_loss_scaling")
check_type(x, 'x', (tuple, list), 'update_loss_scaling')
for e in x:
check_variable_and_dtype(e, "x", ['float32', 'float64'],
check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'],
'update_loss_scaling')
assert prev_loss_scaling.dtype == e.dtype, "The dtype of prev_loss_scaling should be equal to the dtype of x."
if e.dtype == core.VarDesc.VarType.FP16:
assert prev_loss_scaling.dtype == core.VarDesc.VarType.FP32, \
"The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
else:
assert prev_loss_scaling.dtype == e.dtype, "The dtype of prev_loss_scaling should be equal to the dtype of x."
helper = LayerHelper("update_loss_scaling", **locals())
......@@ -116,6 +122,7 @@ def update_loss_scaling(x,
'decr_every_n_nan_or_inf': decr_every_n_nan_or_inf,
'incr_ratio': incr_ratio,
'decr_ratio': decr_ratio,
'stop_update': stop_update
}
helper.append_op(
......
......@@ -12,17 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ... import core
from ... import default_main_program
from ... import default_startup_program
from ... import framework
from ... import layers
from ... import unique_name
from ... import program_guard
from ... import unique_name
from . import fp16_utils
from .fp16_utils import rewrite_program
from .fp16_utils import cast_model_to_fp16
from .fp16_utils import cast_parameters_to_fp16
from .fp16_utils import update_role_var_grad
from .fp16_lists import AutoMixedPrecisionLists
from .amp_nn import check_finite_and_unscale
from .amp_nn import update_loss_scaling
import types
import warnings
__all__ = ["decorate"]
......@@ -50,12 +56,16 @@ class OptimizerWithMixedPrecision(object):
scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing
the loss scaling.
use_pure_fp16(bool): Whether to use the pure fp16 training. Default False.
use_fp16_guard(bool): Whether to use `fp16_guard` when constructing the program.
Default None, which means that its value is equal to `use_pure_fp16`.
"""
def __init__(self, optimizer, amp_lists, init_loss_scaling,
use_dynamic_loss_scaling, incr_every_n_steps,
decr_every_n_nan_or_inf, incr_ratio, decr_ratio):
decr_every_n_nan_or_inf, incr_ratio, decr_ratio, use_pure_fp16,
use_fp16_guard):
self._optimizer = optimizer
self._amp_lists = amp_lists
self._param_grads = None
......@@ -68,6 +78,9 @@ class OptimizerWithMixedPrecision(object):
self._use_dynamic_loss_scaling = use_dynamic_loss_scaling
self._learning_rate = optimizer._learning_rate
self._learning_rate_map = optimizer._learning_rate_map
self._use_pure_fp16 = use_pure_fp16
self._use_fp16_guard = use_fp16_guard
self._to_fp16_var_names = None
if self._use_dynamic_loss_scaling:
self._incr_every_n_steps = incr_every_n_steps
self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf
......@@ -151,20 +164,61 @@ class OptimizerWithMixedPrecision(object):
train_program = loss.block.program
self._train_program = train_program
with program_guard(train_program, startup_program):
with program_guard(self._train_program, startup_program):
self._init_amp_var()
rewrite_program(train_program, self._amp_lists)
self._scaled_loss = loss * self._loss_scaling
if self._use_pure_fp16:
self._to_fp16_var_names = cast_model_to_fp16(
self._train_program, self._amp_lists, self._use_fp16_guard)
else:
rewrite_program(self._train_program, self._amp_lists)
if loss.dtype != core.VarDesc.VarType.FP32:
loss = loss.astype('float32')
# When not using dynamic loss scaling and the init loss scaling value is equal to 1.0,
# the model can be optimized.
if self._use_dynamic_loss_scaling or self._init_loss_scaling != 1.0:
self._scaled_loss = loss * self._loss_scaling
else:
self._scaled_loss = loss
params_grads = self._optimizer.backward(
self._scaled_loss, startup_program, parameter_list, no_grad_set,
callbacks)
return params_grads
def amp_init(self,
place,
scope=None,
test_program=None,
use_fp16_test=False):
"""
Init the amp training, such as cast fp32 parameters to fp16 type.
Args:
place(CPUPlace|CUDAPlace): place is used to initialize
fp16 parameters with fp32 values.
scope(Scope): The scope is used to find fp32 parameters.
test_program(Program): The program is used for testing.
use_fp16_test(bool): Whether to use fp16 testing.
"""
assert self._train_program is not None, \
"Please call the minimize method first."
if self._use_pure_fp16:
cast_parameters_to_fp16(place, self._train_program, scope,
self._to_fp16_var_names)
if test_program is not None:
if self._use_pure_fp16:
cast_model_to_fp16(test_program, self._amp_lists,
self._use_fp16_guard)
elif use_fp16_test:
rewrite_program(test_program, self._amp_lists)
def apply_gradients(self, params_grads):
"""
Check scaled gradients to determine whether to update loss scaling and update
parameters by their scaled gradients,
parameters by their scaled gradients.
Args:
params_grads (list): A list of params and scaled grads.
......@@ -177,39 +231,95 @@ class OptimizerWithMixedPrecision(object):
# transferred across GPUs can be FP16.
update_role_var_grad(self._train_program, params_grads)
# When not using dynamic loss scaling and the init loss scaling value is equal to 1.0,
# the model can be optimized.
if not self._use_dynamic_loss_scaling and self._init_loss_scaling == 1.0:
return self._optimizer.apply_gradients(params_grads)
grads = [g for _, g in params_grads]
if not self._is_distributed:
with self._train_program._optimized_guard(grads):
grads, found_inf = check_finite_and_unscale(
grads, self._loss_scaling, name="find_infinite_scale")
else:
fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32]
fp16_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP16]
assert len(fp32_grads) + len(fp16_grads) == len(grads), \
"Data types of all grads must be either fp16 or fp32."
found_infs = []
if self._is_distributed:
# if distributed, split check_finite_and_unscale to overlap
# unscale with communication
found_infs = []
for p, g in params_grads:
with self._train_program._optimized_guard([p, g]):
_, found_inf = check_finite_and_unscale(
[g, ], self._loss_scaling, name="find_infinite_scale")
found_infs.append(found_inf)
elif self._use_pure_fp16:
if fp32_grads:
with self._train_program._optimized_guard(fp32_grads):
_, fp32_found_inf = check_finite_and_unscale(
fp32_grads,
self._loss_scaling,
name="find_infinite_scale_fp32")
found_infs.append(fp32_found_inf)
if fp16_grads:
with self._train_program._optimized_guard(fp16_grads):
_, fp16_found_inf = check_finite_and_unscale(
fp16_grads,
self._loss_scaling,
name="find_infinite_scale_fp16")
found_infs.append(fp16_found_inf)
else:
with self._train_program._optimized_guard(grads):
_, found_inf = check_finite_and_unscale(
grads, self._loss_scaling, name="find_infinite_scale")
if self._use_dynamic_loss_scaling:
if self._is_distributed:
if self._is_distributed or self._use_pure_fp16:
with self._train_program._optimized_guard([]):
all_infs = layers.concat(found_infs)
found_inf = layers.reduce_any(all_infs)
with self._train_program._optimized_guard([]):
update_loss_scaling(
grads,
found_inf,
self._loss_scaling,
self._num_good_steps,
self._num_bad_steps,
self._incr_every_n_steps,
self._decr_every_n_nan_or_inf,
self._incr_ratio,
self._decr_ratio,
name="update_loss_scaling")
if self._use_pure_fp16:
stop_update = False
with self._train_program._optimized_guard([]):
if fp32_grads:
update_loss_scaling(
fp32_grads,
found_inf,
self._loss_scaling,
self._num_good_steps,
self._num_bad_steps,
self._incr_every_n_steps,
self._decr_every_n_nan_or_inf,
self._incr_ratio,
self._decr_ratio,
stop_update=stop_update,
name="update_loss_scaling_fp32")
stop_update = True
if fp16_grads:
update_loss_scaling(
fp16_grads,
found_inf,
self._loss_scaling,
self._num_good_steps,
self._num_bad_steps,
self._incr_every_n_steps,
self._decr_every_n_nan_or_inf,
self._incr_ratio,
self._decr_ratio,
stop_update=stop_update,
name="update_loss_scaling_fp16")
else:
with self._train_program._optimized_guard([]):
update_loss_scaling(
grads,
found_inf,
self._loss_scaling,
self._num_good_steps,
self._num_bad_steps,
self._incr_every_n_steps,
self._decr_every_n_nan_or_inf,
self._incr_ratio,
self._decr_ratio,
name="update_loss_scaling")
optimize_ops = self._optimizer.apply_gradients(params_grads)
return optimize_ops
......@@ -239,6 +349,13 @@ class OptimizerWithMixedPrecision(object):
The scaled loss by scaling factor, the list of optimize ops, and a
list of scaled parameters and gradients.
"""
opt_dict = self._optimizer.__class__.__dict__
if 'minimize' in opt_dict and isinstance(opt_dict['minimize'],
types.FunctionType):
warnings.warn(
"The decorated optimizer has its own `minimize` method, but it will not be executed."
)
scaled_params_grads = self.backward(
loss,
startup_program=startup_program,
......@@ -258,7 +375,9 @@ def decorate(optimizer,
decr_every_n_nan_or_inf=2,
incr_ratio=2.0,
decr_ratio=0.8,
use_dynamic_loss_scaling=True):
use_dynamic_loss_scaling=True,
use_pure_fp16=False,
use_fp16_guard=None):
"""
Decorate the given optimizer to adapt to the mixed-precision training.
......@@ -276,6 +395,9 @@ def decorate(optimizer,
decr_ratio(float): The less-than-one-multiplier to use when decreasing
the loss scaling.
use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling.
use_pure_fp16(bool): Whether to use the pure fp16 training. Default False.
use_fp16_guard(bool): Whether to use `fp16_guard` when constructing the program.
Default None, which means that its value equals to `use_pure_fp16`.
Returns:
An optimizer acting like a normal one but with mixed-precision training
......@@ -295,8 +417,13 @@ def decorate(optimizer,
"""
if amp_lists is None:
amp_lists = AutoMixedPrecisionLists()
if use_fp16_guard is None:
use_fp16_guard = use_pure_fp16
mp_optimizer = OptimizerWithMixedPrecision(
optimizer, amp_lists, init_loss_scaling, use_dynamic_loss_scaling,
incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, decr_ratio)
incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, decr_ratio,
use_pure_fp16, use_fp16_guard)
return mp_optimizer
......@@ -38,6 +38,7 @@ class AutoMixedPrecisionLists(object):
self.white_list = copy.copy(white_list)
self.black_list = copy.copy(black_list)
self.gray_list = copy.copy(gray_list)
self.unsupported_list = copy.copy(unsupported_fp16_list)
self.black_varnames = copy.copy(custom_black_varnames)
self._update_list()
......@@ -64,6 +65,7 @@ class AutoMixedPrecisionLists(object):
elif op_name in self.gray_list:
self.gray_list.remove(op_name)
self.black_list.add(op_name)
self.unsupported_list.add(op_name)
# The three sets listed below are changed dynamiclly. They don't contain all
......@@ -141,10 +143,10 @@ gray_list = {
'cast',
'fused_bn_add_activation',
}
'''
# The set of ops that don't support fp16 calculation
unsupported_fp16_list = {
# from python/paddle/fluid/layers/io.py
# from python/paddle/fluid/layers/io.py
'send',
'send_barrier',
'recv',
......@@ -153,8 +155,8 @@ unsupported_fp16_list = {
'create_double_buffer_reader',
'read',
'load',
# from python/paddle/fluid/control_flow.py
# from python/paddle/fluid/control_flow.py
'increment',
'less_than',
'less_equal',
......@@ -174,7 +176,6 @@ unsupported_fp16_list = {
'while',
'ifelse',
'is_empty',
'lstm',
'cudnn_lstm',
'lstmp',
......@@ -275,7 +276,6 @@ unsupported_fp16_list = {
'pixel_shuffle',
'fsp',
'cvm',
'affine_channel',
'roi_pool',
'roi_align',
......@@ -283,6 +283,4 @@ unsupported_fp16_list = {
'generate_proposals',
'generate_proposal_labels',
'generate_mask_labels',
}
'''
......@@ -15,17 +15,28 @@
from __future__ import print_function
from ... import core
from ... import framework
from ... import layers
from ... import global_scope
from ...log_helper import get_logger
from ...wrapped_decorator import signature_safe_contextmanager
from .fp16_lists import AutoMixedPrecisionLists
import collections
import logging
import numpy as np
__all__ = ["cast_model_to_fp16", "cast_parameters_to_fp16"]
__all__ = ["fp16_guard", "cast_model_to_fp16", "cast_parameters_to_fp16"]
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
_valid_types = [
core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.SELECTED_ROWS,
core.VarDesc.VarType.LOD_TENSOR_ARRAY
]
_fp16_guard_pattern = "__use_fp16__"
def _rename_arg(op, old_name, new_name):
"""
......@@ -44,6 +55,18 @@ def _rename_arg(op, old_name, new_name):
op_desc._rename_output(old_name, new_name)
def _rename_op_input(program, op_var_rename_map, origin_ops, keep_fp32_ops):
for block in program.blocks:
ops = block.ops
block_id = block.idx
for op in ops:
if op not in origin_ops or op in keep_fp32_ops:
continue
for name in op.input_arg_names:
if name in op_var_rename_map[block_id]:
op._rename_input(name, op_var_rename_map[block_id][name])
def _dtype_to_str(dtype):
"""
Convert specific variable type to its corresponding string.
......@@ -72,10 +95,6 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
num_cast_op (int): The number of cast ops that have been inserted.
"""
num_cast_ops = 0
valid_types = [
core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.SELECTED_ROWS,
core.VarDesc.VarType.LOD_TENSOR_ARRAY
]
for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and op.type in [
......@@ -85,7 +104,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
continue
for in_var_name in op.input(in_name):
in_var = block.var(in_var_name)
if in_var.type not in valid_types or in_var.dtype == dest_dtype:
if in_var.type not in _valid_types or in_var.dtype == dest_dtype:
continue
if in_var.dtype == src_dtype:
cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype)
......@@ -119,7 +138,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
continue
for out_var_name in op.output(out_name):
out_var = block.var(out_var_name)
if out_var.type not in valid_types:
if out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(core.VarDesc.VarType.FP16)
......@@ -128,6 +147,38 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
return num_cast_ops
def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name,
op_var_rename_map):
num_cast_ops = 0
target_var = block.var(target_name)
if target_var.type not in _valid_types or target_var.dtype == dest_dtype:
return num_cast_ops
assert target_var.dtype == src_dtype, \
"The real dtype({}) is not equal to the src dtype({})".format(_dtype_to_str(target_var.dtype), _dtype_to_str(src_dtype))
cast_name = target_var.name + '.cast_' + _dtype_to_str(dest_dtype)
cast_var = block.vars.get(cast_name)
if cast_var is None or cast_var.dtype != dest_dtype:
cast_var = block.create_var(
name=cast_name,
dtype=dest_dtype,
persistable=False,
stop_gradient=target_var.stop_gradient)
block._insert_op(
idx,
type="cast",
inputs={"X": target_var},
outputs={"Out": cast_var},
attrs={"in_dtype": target_var.dtype,
"out_dtype": cast_var.dtype})
num_cast_ops += 1
op_var_rename_map[block.idx][target_var.name] = cast_var.name
return num_cast_ops
def find_true_prev_op(ops, cur_op, var_name):
"""
Find the true prev op that outputs var_name variable.
......@@ -174,9 +225,8 @@ def find_true_post_op(ops, cur_op, var_name):
for in_var_name in op.input(in_name):
if in_var_name == var_name:
post_op.append(op)
if post_op != []:
return post_op
return None
return post_op
def find_op_index(block_desc, cur_op_desc):
......@@ -200,26 +250,73 @@ def _is_in_black_varnames(op, amp_lists):
return False
def cast_model_to_fp16(main_program):
def _need_keep_fp32(op, unsupported_op_list, use_fp16_guard):
if op.type in unsupported_op_list:
# the highest priority condition: If ops don't have fp16 computing kernels,
# they must be executed in fp32 calculation pattern.
return True
# process ops about learning rate
in_out_arg_names = []
in_out_arg_names.extend(list(op.input_arg_names))
in_out_arg_names.extend(list(op.output_arg_names))
for name in in_out_arg_names:
if "learning_rate" in name:
return True
if use_fp16_guard:
if op.has_attr("op_namescope") and \
(_fp16_guard_pattern in op.attr("op_namescope")):
# op in fp16 guard
return False
else:
# op not in fp16 guard
return True
else:
return False
@signature_safe_contextmanager
def fp16_guard():
"""
As for the pure fp16 training, if users set `use_fp16_guard` to True,
only those ops created in the context manager `fp16_guard` will be
transformed as float16 type.
"""
with framework.name_scope(prefix=_fp16_guard_pattern):
yield
def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
"""
Traverse all ops in the whole model and set their inputs and outputs
to the fp16 data type. This function will do some special process for
the batch normalization, which keeps the computational process of
batchnorms in FP32.
Args:
main_program (Program): The main program for training.
program (Program): The used program.
amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object.
use_fp16_guard(bool): Determine whether to use `fp16_guard` when
constructing the program. Default True.
"""
valid_types = [
core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.SELECTED_ROWS,
core.VarDesc.VarType.LOD_TENSOR_ARRAY
]
global_block = main_program.global_block()
for block in main_program.blocks:
if amp_lists is None:
amp_lists = AutoMixedPrecisionLists()
global_block = program.global_block()
keep_fp32_ops = set()
to_fp16_var_names = set()
origin_ops = []
for block in program.blocks:
origin_ops.extend(block.ops)
for block in program.blocks:
ops = block.ops
for op in ops:
if op.type == 'create_py_reader' or op.type == 'read':
continue
if _need_keep_fp32(op, amp_lists.unsupported_list, use_fp16_guard):
keep_fp32_ops.add(op)
continue # processed below
for in_name in op.input_names:
if op.type in {
'batch_norm', 'fused_bn_add_activation', 'layer_norm'
......@@ -231,19 +328,20 @@ def cast_model_to_fp16(main_program):
in_var = block.var(in_var_name)
except ValueError as e:
_logger.debug(
"-- {}, try to get it in the global block. --".
"-- {}, try to get it in the global block --".
format(e))
in_var = global_block.var(in_var_name)
if in_var is not None:
_logger.debug(
"-- var {} is got in the global block. --".
"-- var {} is got in the global block --".
format(in_var_name))
if in_var is None or in_var.type not in valid_types:
if in_var is None or in_var.type not in _valid_types:
continue
if in_var.dtype == core.VarDesc.VarType.FP32:
in_var.desc.set_dtype(core.VarDesc.VarType.FP16)
to_fp16_var_names.add(in_var_name)
_logger.debug(
"-- op type: {}, in var name: {}, in var dtype: {} --".
......@@ -260,15 +358,15 @@ def cast_model_to_fp16(main_program):
out_var = block.var(out_var_name)
except ValueError as e:
_logger.debug(
"-- {}, try to get it in the global block. --".
"-- {}, try to get it in the global block --".
format(e))
out_var = global_block.var(out_var_name)
if out_var is not None:
_logger.debug(
"-- var {} is got in the global block. --".
"-- var {} is got in the global block --".
format(out_var_name))
if out_var is None or out_var.type not in valid_types:
if out_var is None or out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP32:
......@@ -287,35 +385,65 @@ def cast_model_to_fp16(main_program):
'dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.FP16)
# process ops in keep_fp32_ops
op_var_rename_map = [
collections.OrderedDict() for _ in range(len(program.blocks))
]
for block in program.blocks:
ops = block.ops
idx = 0
while idx < len(ops):
op = ops[idx]
num_cast_ops = 0
if op in keep_fp32_ops:
pre_cast_num = _insert_cast_op(block, op, idx,
core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32)
num_cast_ops += pre_cast_num
for out_var_name in op.output_arg_names:
out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP16:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
post_ops = find_true_post_op(ops, op, out_var_name)
for post_op in post_ops:
if post_op in keep_fp32_ops:
continue
post_cast_num = _insert_cast_post_op(
block, op, idx + pre_cast_num + 1,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, out_var_name,
op_var_rename_map)
num_cast_ops += post_cast_num
idx += num_cast_ops + 1
_rename_op_input(program, op_var_rename_map, origin_ops, keep_fp32_ops)
return to_fp16_var_names
def cast_parameters_to_fp16(place, main_program, scope=None):
def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None):
"""
Traverse all parameters in the whole model and set them to the fp16 data type.
Traverse all parameters in the whole model and set them to the FP16 data type.
Whereas, this function will keep parameters of batchnorms in FP32.
Args:
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the weight tensors.
main_program (Program): The main program for training.
scope(fluid.Scope, optional): scope is used to get the weight tensor values.
Default is None.
place(fluid.CPUPlace|fluid.CUDAPlace): `place` is used to restore the FP16 weight tensors.
program (Program): The used program.
scope(fluid.Scope, optional): `scope` is used to get the FP32 weight tensor values.
Default is None.
to_fp16_var_names(set|list, optional): The data types of vars in `to_fp16_var_names`
will be set to FP16. Usually, it is the returned
value of `cast_model_to_fp16` API.
"""
all_ops = []
for block in main_program.blocks:
all_ops.extend(block.ops)
bn_params = set()
for op in all_ops:
if op.type not in {
'batch_norm', 'fused_bn_add_activation', 'layer_norm'
}:
continue
for in_name in op.input_names:
if in_name not in {'X', 'Z'}:
for in_var_name in op.input(in_name):
bn_params.add(in_var_name)
global_block = main_program.global_block()
all_parameters = global_block.all_parameters()
var_scope = scope if scope is not None else global_scope()
all_parameters = []
for block in program.blocks:
all_parameters.extend(block.all_parameters())
fp16_var_names = to_fp16_var_names if to_fp16_var_names else set()
var_scope = scope if scope else global_scope()
for param in all_parameters:
if param.name not in bn_params:
if param.name in fp16_var_names:
_logger.debug("---- cast {} to fp16 dtype ----".format(param.name))
param_t = var_scope.find_var(param.name).get_tensor()
data = np.array(param_t)
param_t.set(np.float16(data), place)
......@@ -458,7 +586,7 @@ def update_role_var_grad(main_prog, params_grads):
if op == block.ops[-1]:
continue
post_ops = find_true_post_op(block.ops, op, g.name)
if post_ops is not None:
if post_ops:
raise ValueError("The cast op {0}'s output should not be"
"used by a non-optimize op, however, it"
"is used by {1}".format(op, post_ops[0]))
......
......@@ -19,8 +19,7 @@ import paddle.fluid as fluid
import contextlib
import unittest
import numpy as np
from paddle.static.amp import cast_model_to_fp16
from paddle.static.amp import cast_parameters_to_fp16
from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_model_to_fp16
paddle.enable_static()
......@@ -65,38 +64,19 @@ def resnet_cifar10(input, depth=32):
n = (depth - 2) // 6
conv1 = conv_bn_layer(
input=input, ch_out=16, filter_size=3, stride=1, padding=1)
res1 = layer_warp(basicblock, conv1, 16, 16, n, 1)
res2 = layer_warp(basicblock, res1, 16, 32, n, 2)
res3 = layer_warp(basicblock, res2, 32, 64, n, 2)
with paddle.static.amp.fp16_guard():
res1 = layer_warp(basicblock, conv1, 16, 16, n, 1)
res2 = layer_warp(basicblock, res1, 16, 32, n, 2)
res3 = layer_warp(basicblock, res2, 32, 64, n, 2)
pool = fluid.layers.pool2d(
input=res3, pool_size=8, pool_type='avg', pool_stride=1)
return pool
def compile(program, loss_name=None):
build_strategy = paddle.static.BuildStrategy()
exec_strategy = paddle.static.ExecutionStrategy()
exec_strategy.num_threads = 1
exec_strategy.num_iteration_per_drop_scope = 10000
build_strategy.fuse_bn_act_ops = True
build_strategy.fuse_elewise_add_act_ops = True
build_strategy.fuse_bn_add_act_ops = True
compiled_program = paddle.static.CompiledProgram(
program).with_data_parallel(
loss_name=loss_name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
return compiled_program
def train(use_pure_fp16=True, use_nesterov=False):
def train(use_pure_fp16=True, use_nesterov=False, use_adam=False):
classdim = 10
data_shape = [3, 32, 32]
BATCH_SIZE = 128
BATCH_SIZE = 32
PASS_NUM = 1
train_program = fluid.Program()
......@@ -107,28 +87,35 @@ def train(use_pure_fp16=True, use_nesterov=False):
images = fluid.layers.data(
name='pixel', shape=data_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
net = resnet_cifar10(images, 32)
net = resnet_cifar10(images)
logits = fluid.layers.fc(input=net, size=classdim, act="softmax")
if use_pure_fp16:
cast_model_to_fp16(fluid.default_main_program())
logits_fp32 = fluid.layers.cast(x=logits, dtype="float32")
else:
logits_fp32 = logits
cost = fluid.layers.softmax_with_cross_entropy(
logits_fp32, label, return_softmax=False)
logits, label, return_softmax=False)
sum_cost = fluid.layers.reduce_sum(cost)
# Test program
test_program = train_program.clone(for_test=True)
optimizer = paddle.optimizer.Momentum(
learning_rate=0.001,
momentum=0.9,
use_nesterov=use_nesterov,
weight_decay=fluid.regularizer.L2Decay(1e-4),
multi_precision=use_pure_fp16,
rescale_grad=1.0 / BATCH_SIZE)
if use_adam:
optimizer = paddle.optimizer.Adam(
learning_rate=0.001,
epsilon=1e-8,
weight_decay=0.0,
multi_precision=True)
else:
optimizer = paddle.optimizer.Momentum(
learning_rate=0.001,
momentum=0.9,
use_nesterov=use_nesterov,
weight_decay=fluid.regularizer.L2Decay(1e-4),
multi_precision=use_pure_fp16)
if use_pure_fp16:
optimizer = paddle.static.amp.decorate(
optimizer,
init_loss_scaling=128.0,
use_dynamic_loss_scaling=True,
use_pure_fp16=True)
optimizer.minimize(sum_cost)
......@@ -146,13 +133,13 @@ def train(use_pure_fp16=True, use_nesterov=False):
def train_loop(main_program):
exe.run(startup_prog)
if use_pure_fp16:
cast_parameters_to_fp16(place, train_program, fluid.global_scope())
compiled_program = compile(train_program, sum_cost.name)
optimizer.amp_init(
place, test_program=test_program, use_fp16_test=True)
loss = 0.0
for pass_id in range(PASS_NUM):
train_loss_list = []
for batch_id, data in enumerate(train_reader()):
loss, = exe.run(compiled_program,
loss, = exe.run(train_program,
feed=feeder.feed(data),
fetch_list=[sum_cost])
loss_v = loss[0] if isinstance(loss, np.ndarray) else loss
......@@ -182,18 +169,25 @@ class TestImageMultiPrecision(unittest.TestCase):
if not fluid.core.is_compiled_with_cuda():
return
def do_test(use_nesterov=False):
suffix = "with Nesterov" if use_nesterov else "without Nesterov"
def do_test(use_nesterov=False, use_adam=False):
if use_adam:
suffix = "use Adam"
else:
suffix = "with Nesterov" if use_nesterov else "without Nesterov"
with self.scope_prog_guard():
print("-----------------FP16 Train {}-----------------".format(
suffix))
train_loss_fp16, test_loss_fp16 = train(
use_pure_fp16=True, use_nesterov=use_nesterov)
use_pure_fp16=True,
use_nesterov=use_nesterov,
use_adam=use_adam)
with self.scope_prog_guard():
print("-----------------FP32 Train {}-----------------".format(
suffix))
train_loss_fp32, test_loss_fp32 = train(
use_pure_fp16=False, use_nesterov=use_nesterov)
use_pure_fp16=False,
use_nesterov=use_nesterov,
use_adam=use_adam)
self.assertTrue(
np.allclose(
......@@ -214,6 +208,7 @@ class TestImageMultiPrecision(unittest.TestCase):
do_test(use_nesterov=False)
do_test(use_nesterov=True)
do_test(use_adam=True)
@contextlib.contextmanager
def scope_prog_guard(self):
......@@ -260,7 +255,7 @@ class TestAmpWithNonIterableDataLoader(unittest.TestCase):
op._set_attr('out_dtype', fluid.core.VarDesc.VarType.FP32)
op._set_attr('dtype', fluid.core.VarDesc.VarType.FP32)
cast_model_to_fp16(main_prog)
cast_model_to_fp16(main_prog, use_fp16_guard=False)
def test_non_iterable_dataloader(self):
self.decorate_with_data_loader()
......
......@@ -35,7 +35,7 @@ class TestUpdateLossScalingOp(OpTest):
}
self.outputs = {
'Out': [('out0', np.zeros_like(x))],
'Out': [('out0', x)],
'LossScaling': self.prev_loss_scaling * self.incr_ratio,
'OutGoodSteps': self.zero_steps,
'OutBadSteps': self.zero_steps
......
......@@ -16,6 +16,10 @@ from .optimizer import Optimizer
from ..fluid import core
from ..fluid import framework
from ..fluid.framework import Variable
from ..fluid import layers
from ..fluid import unique_name
from ..fluid.layer_helper import LayerHelper
import warnings
from ..fluid.dygraph import base as imperative_base
import paddle
......@@ -79,6 +83,7 @@ class Adam(Optimizer):
gradient in current mini-batch, so it will be much more faster. But this mode has
different semantics with the original Adam algorithm and may lead to different result.
The default value is False.
multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false.
name (str, optional): Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`.
The default value is None.
......@@ -135,6 +140,7 @@ class Adam(Optimizer):
weight_decay=None,
grad_clip=None,
lazy_mode=False,
multi_precision=False,
name=None):
assert learning_rate is not None
assert beta1 is not None
......@@ -157,28 +163,90 @@ class Adam(Optimizer):
self._beta2 = beta2
self._epsilon = epsilon
self._lazy_mode = lazy_mode
self._multi_precision = multi_precision
self._master_weights = {}
def _create_master_weight(self, param):
assert isinstance(self.helper, LayerHelper)
var_name = param.name + "_fp32_master"
var_name = unique_name.generate(var_name)
var = layers.create_global_var(
name=var_name,
shape=param.shape,
value=0,
dtype='float32',
persistable=True)
block = self.helper.startup_program.global_block()
block.append_op(
type="cast",
inputs={"X": [param]},
outputs={"Out": [var]},
attrs={
"in_dtype": param.dtype,
"out_dtype": core.VarDesc.VarType.FP32
})
self._master_weights[param.name] = var
return var
def _get_accumulator(self, name, param):
"""Utility function to fetch an accumulator for a parameter
Args:
name: name of the accumulator
param: parameter variable for which accumulator is to be fetched
Returns:
accumulator variable for the parameter
"""
if self._name is not None:
name = self._name + "_" + name
find_master = self._multi_precision and param.dtype == core.VarDesc.VarType.FP16
target_param = self._master_weights[
param.name] if find_master else param
target_name = target_param.name
if (name not in self._accumulators or
target_name not in self._accumulators[name]):
raise Exception("Accumulator {} does not exist for parameter {}".
format(name, target_name))
return self._accumulators[name][target_name]
def _add_moments_pows(self, p):
acc_dtype = p.dtype
if acc_dtype == core.VarDesc.VarType.FP16:
acc_dtype = core.VarDesc.VarType.FP32
self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype)
self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype)
self._add_accumulator(
name=self._beta1_pow_acc_str,
param=p,
dtype=acc_dtype,
fill_value=0.9 if isinstance(self._beta1, Variable) \
else self._beta1,
shape=[1],
type=core.VarDesc.VarType.LOD_TENSOR, device='cpu')
self._add_accumulator(
name=self._beta2_pow_acc_str,
param=p,
dtype=acc_dtype,
fill_value=0.999 if isinstance(self._beta2, Variable) \
else self._beta2,
shape=[1],
type=core.VarDesc.VarType.LOD_TENSOR, device='cpu')
def _create_accumulators(self, block, parameters):
assert isinstance(block, framework.Block)
# Create accumulator tensors for first and second moments
for p in parameters:
self._add_accumulator(self._moment1_acc_str, p)
self._add_accumulator(self._moment2_acc_str, p)
self._add_accumulator(
name=self._beta1_pow_acc_str,
param=p,
fill_value=0.9 if isinstance(self._beta1, Variable) \
else self._beta1,
shape=[1],
type=core.VarDesc.VarType.LOD_TENSOR, device='cpu')
self._add_accumulator(
name=self._beta2_pow_acc_str,
param=p,
fill_value=0.999 if isinstance(self._beta2, Variable) \
else self._beta2,
shape=[1],
type=core.VarDesc.VarType.LOD_TENSOR, device='cpu')
if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
master_p = self._create_master_weight(p)
self._add_moments_pows(master_p)
continue
if p.dtype == core.VarDesc.VarType.FP16 and not self._multi_precision:
warnings.warn(
"Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
"Consider using multi_precision=True option of the Momentum optimizer."
)
self._add_moments_pows(p)
def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block)
......@@ -191,6 +259,10 @@ class Adam(Optimizer):
param_and_grad[0])
beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
param_and_grad[0])
find_master = self._multi_precision and param_and_grad[
0].dtype == core.VarDesc.VarType.FP16
master_weight = (self._master_weights[param_and_grad[0].name]
if find_master else None)
lr = self._create_param_lr(param_and_grad)
# create the adam optimize op
......@@ -227,7 +299,8 @@ class Adam(Optimizer):
attrs = {
"epsilon": self._epsilon,
"lazy_mode": self._lazy_mode,
"min_row_size_to_use_multithread": 1000
"min_row_size_to_use_multithread": 1000,
"multi_precision": find_master
}
if isinstance(self._beta1, Variable):
......@@ -239,6 +312,10 @@ class Adam(Optimizer):
else:
attrs['beta2'] = self._beta2
if find_master:
inputs["MasterParam"] = master_weight
outputs["MasterParamOut"] = master_weight
adam_op = block.append_op(
type=self.type,
inputs=inputs,
......
......@@ -71,6 +71,7 @@ class AdamW(Adam):
gradient in current mini-batch, so it will be much more faster. But this mode has
different semantics with the original Adam algorithm and may lead to different result.
The default value is False.
multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false.
name (str, optional): Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`.
The default value is None.
......@@ -111,6 +112,7 @@ class AdamW(Adam):
apply_decay_param_fun=None,
grad_clip=None,
lazy_mode=False,
multi_precision=False,
name=None):
assert learning_rate is not None
assert beta1 is not None
......@@ -138,7 +140,8 @@ class AdamW(Adam):
epsilon=epsilon,
grad_clip=grad_clip,
name=name,
lazy_mode=lazy_mode)
lazy_mode=lazy_mode,
multi_precision=multi_precision)
def _append_decoupled_weight_decay(self, block, param_and_grad):
"""
......
......@@ -129,21 +129,6 @@ class Momentum(Optimizer):
self.helper = LayerHelper(self.__class__.__name__)
for p in parameters:
self._add_accumulator(self._velocity_acc_str, p)
else:
all_parameters = fluid.default_main_program().global_block(
).all_parameters()
self.helper = LayerHelper(self.__class__.__name__)
for p in all_parameters:
if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
master_p = self._create_master_weight(p)
self._add_accumulator(self._velocity_acc_str, master_p)
continue
if p.dtype == core.VarDesc.VarType.FP16 and not self._multi_precision:
warnings.warn(
"Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
"Consider using multi_precision=True option of the Momentum optimizer."
)
self._add_accumulator(self._velocity_acc_str, p)
def _create_master_weight(self, param):
assert isinstance(self.helper, LayerHelper)
......@@ -191,8 +176,21 @@ class Momentum(Optimizer):
return self._accumulators[name][target_name]
def _create_accumulators(self, block, parameters):
if framework.in_dygraph_mode():
return
assert isinstance(block, framework.Block)
# create accumulator in init func, so no implementation here
for p in parameters:
if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
master_p = self._create_master_weight(p)
self._add_accumulator(self._velocity_acc_str, master_p)
continue
if p.dtype == core.VarDesc.VarType.FP16 and not self._multi_precision:
warnings.warn(
"Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
"Consider using multi_precision=True option of the Momentum optimizer."
)
self._add_accumulator(self._velocity_acc_str, p)
def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册