未验证 提交 7d6d3848 编写于 作者: Z zhangbo9674 提交者: GitHub

[AMP] support GPU BF16 amp for dygraph (#39029)

* support dtype param for auto_cast

* add amp_dtype for tracer

* add unsupported bf16 list

* support bf16 amp for O2

* refine python interface for bfloat16

* refine code

* refine code

* refine unittest

* refine code

* refine code

* add bf16 o1

* refine code by comment

* add gradient accumulator

* add recompute
上级 7e4ed848
......@@ -113,26 +113,40 @@ AutoCastGuard::~AutoCastGuard() { tracer_->SetAmpLevel(pre_amp_level_); }
AmpOperators::AmpOperators()
: allow_ops_(new std::unordered_set<std::string>()),
block_ops_(new std::unordered_set<std::string>()),
unsupported_fp16_ops_(new std::unordered_set<std::string>()) {
unsupported_fp16_ops_(new std::unordered_set<std::string>()),
unsupported_bf16_ops_(new std::unordered_set<std::string>()) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto unsupported_ops_gpu = std::get<2>(
auto unsupported_ops_gpu_fp16 = std::get<2>(
OpSupportedInfos("GPU", paddle::framework::proto::VarType::FP16));
unsupported_fp16_ops_->insert(unsupported_ops_gpu.begin(),
unsupported_ops_gpu.end());
unsupported_fp16_ops_->insert(unsupported_ops_gpu_fp16.begin(),
unsupported_ops_gpu_fp16.end());
auto unsupported_ops_gpu_bf16 = std::get<2>(
OpSupportedInfos("GPU", paddle::framework::proto::VarType::BF16));
unsupported_bf16_ops_->insert(unsupported_ops_gpu_bf16.begin(),
unsupported_ops_gpu_bf16.end());
// NOTE: GPU/NPU/XPU is compiled seperatly.
#elif defined(PADDLE_WITH_ASCEND_CL)
auto unsupported_ops_npu = std::get<2>(
auto unsupported_ops_npu_fp16 = std::get<2>(
OpSupportedInfos("NPU", paddle::framework::proto::VarType::FP16));
unsupported_fp16_ops_->insert(unsupported_ops_npu.begin(),
unsupported_ops_npu.end());
unsupported_fp16_ops_->insert(unsupported_ops_npu_fp16.begin(),
unsupported_ops_npu_fp16.end());
auto unsupported_ops_npu_bf16 = std::get<2>(
OpSupportedInfos("NPU", paddle::framework::proto::VarType::BF16));
unsupported_bf16_ops_->insert(unsupported_ops_npu_bf16.begin(),
unsupported_ops_npu_bf16.end());
#elif defined(PADDLE_WITH_XPU)
auto unsupported_ops_xpu = std::get<2>(
auto unsupported_ops_xpu_fp16 = std::get<2>(
OpSupportedInfos("XPU", paddle::framework::proto::VarType::FP16));
unsupported_fp16_ops_->insert(unsupported_ops_xpu.begin(),
unsupported_ops_xpu.end());
unsupported_fp16_ops_->insert(unsupported_ops_xpu_fp16.begin(),
unsupported_ops_xpu_fp16.end());
auto unsupported_ops_xpu_bf16 = std::get<2>(
OpSupportedInfos("XPU", paddle::framework::proto::VarType::BF16));
unsupported_bf16_ops_->insert(unsupported_ops_xpu_bf16.begin(),
unsupported_ops_xpu_bf16.end());
#endif
VLOG(4) << allow_ops_->size() << " " << block_ops_->size() << " "
<< unsupported_fp16_ops_->size();
<< unsupported_fp16_ops_->size() << " "
<< unsupported_bf16_ops_->size();
}
AmpOperators::~AmpOperators() {}
......@@ -157,6 +171,11 @@ AmpOperators::GetMutableUnsupportedFp16Ops() {
return unsupported_fp16_ops_;
}
std::shared_ptr<std::unordered_set<std::string>>
AmpOperators::GetMutableUnsupportedBf16Ops() {
return unsupported_bf16_ops_;
}
std::ostream& operator<<(std::ostream& os, AmpOperators& ops) {
os << "allow ops: ";
auto allow_ops = ops.GetMutableAllowOps();
......@@ -172,6 +191,11 @@ std::ostream& operator<<(std::ostream& os, AmpOperators& ops) {
auto unsupported_fp16_ops = ops.GetMutableUnsupportedFp16Ops();
std::copy((*unsupported_fp16_ops).begin(), (*unsupported_fp16_ops).end(),
std::ostream_iterator<std::string>(os, " "));
os << "\n";
os << "unsupported bf16 ops: ";
auto unsupported_bf16_ops = ops.GetMutableUnsupportedBf16Ops();
std::copy((*unsupported_bf16_ops).begin(), (*unsupported_bf16_ops).end(),
std::ostream_iterator<std::string>(os, " "));
return os;
}
......@@ -188,7 +212,8 @@ inline bool NeedCast(const std::shared_ptr<VarType>& var) {
paddle::platform::is_xpu_place(place)) {
// CudaPinndePlace is added for varbase created by dataloader
if (data_type == paddle::framework::proto::VarType::FP32 ||
data_type == paddle::framework::proto::VarType::FP16) {
data_type == paddle::framework::proto::VarType::FP16 ||
data_type == paddle::framework::proto::VarType::BF16) {
return true;
}
}
......@@ -236,6 +261,16 @@ static inline std::shared_ptr<VarType> CastToFP32(
return var;
}
template <typename VarType>
static inline std::shared_ptr<VarType> CastToBF16(
const std::shared_ptr<VarType>& var) {
auto dst_type = framework::proto::VarType::BF16;
if (NeedCast(var) && (GetDataType<VarType>(var) != dst_type)) {
return CastToType(var, dst_type);
}
return var;
}
template <typename VarType>
static inline framework::proto::VarType::Type GetPromoteType(
const std::string& op_type, const NameVarMap<VarType>& ins) {
......@@ -386,5 +421,62 @@ template NameVarMap<VarBase> CastPureFp16Inputs<VarBase>(
const std::string& op_type, const NameVarMap<VarBase>& ins);
template NameVarMap<egr::EagerVariable> CastPureFp16Inputs<egr::EagerVariable>(
const std::string& op_type, const NameVarMap<egr::EagerVariable>& ins);
template <typename VarType>
NameVarMap<VarType> AutoCastBF16Inputs(const std::string& op_type,
const NameVarMap<VarType>& ins) {
NameVarMap<VarType> new_ins(ins);
if (AmpOperators::Instance().GetMutableAllowOps()->count(op_type)) {
for (auto& pair : new_ins) {
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to bfloat16";
for (auto& var : pair.second) {
var = CastToBF16<VarType>(var);
}
}
return new_ins;
} else {
for (auto& pair : new_ins) {
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to float";
for (auto& var : pair.second) {
var = CastToFP32<VarType>(var);
}
}
return new_ins;
}
return new_ins;
}
template NameVarMap<VarBase> AutoCastBF16Inputs<VarBase>(
const std::string& op_type, const NameVarMap<VarBase>& ins);
template NameVarMap<egr::EagerVariable> AutoCastBF16Inputs<egr::EagerVariable>(
const std::string& op_type, const NameVarMap<egr::EagerVariable>& ins);
template <typename VarType>
NameVarMap<VarType> CastPureBf16Inputs(const std::string& op_type,
const NameVarMap<VarType>& ins) {
NameVarMap<VarType> new_ins(ins);
auto dst_type = framework::proto::VarType::BF16;
if (AmpOperators::Instance().GetMutableUnsupportedBf16Ops()->count(op_type) ||
AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) {
dst_type = framework::proto::VarType::FP32;
}
for (auto& pair : new_ins) {
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to "
<< framework::DataTypeToString(dst_type);
for (auto& var : pair.second) {
var = (dst_type == framework::proto::VarType::FP32
? CastToFP32<VarType>(var)
: CastToBF16<VarType>(var));
}
}
return new_ins;
}
template NameVarMap<VarBase> CastPureBf16Inputs<VarBase>(
const std::string& op_type, const NameVarMap<VarBase>& ins);
template NameVarMap<egr::EagerVariable> CastPureBf16Inputs<egr::EagerVariable>(
const std::string& op_type, const NameVarMap<egr::EagerVariable>& ins);
} // namespace imperative
} // namespace paddle
......@@ -56,6 +56,9 @@ class AmpOperators {
std::shared_ptr<std::unordered_set<std::string>>
GetMutableUnsupportedFp16Ops();
std::shared_ptr<std::unordered_set<std::string>>
GetMutableUnsupportedBf16Ops();
private:
AmpOperators(); // forbid calling default constructor
......@@ -69,6 +72,9 @@ class AmpOperators {
// The set of ops that has no fp16 CUDA kennel.
std::shared_ptr<std::unordered_set<std::string>> unsupported_fp16_ops_;
// The set of ops that has no bf16 CUDA kennel.
std::shared_ptr<std::unordered_set<std::string>> unsupported_bf16_ops_;
};
std::ostream& operator<<(std::ostream& os, AmpOperators& ops);
......@@ -95,6 +101,12 @@ NameVarMap<VarType> AutoCastInputs(const std::string& op_type,
template <typename VarType>
NameVarMap<VarType> CastPureFp16Inputs(const std::string& op_type,
const NameVarMap<VarType>& ins);
template <typename VarType>
NameVarMap<VarType> AutoCastBF16Inputs(const std::string& op_type,
const NameVarMap<VarType>& ins);
template <typename VarType>
NameVarMap<VarType> CastPureBf16Inputs(const std::string& op_type,
const NameVarMap<VarType>& ins);
} // namespace imperative
} // namespace paddle
......@@ -24,6 +24,7 @@
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
......@@ -420,6 +421,22 @@ void TensorAdd(const VarType& src, VarType* dst) {
src_tensor, dst_tensor, place);
}
}
if (data_type == framework::proto::VarType::BF16) {
if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA)
return TensorAddImpl<platform::CUDADeviceContext, platform::bfloat16>(
src_tensor, dst_tensor, place);
#else
PADDLE_THROW(platform::errors::Unimplemented(
"Gradient accumulation of data type (%s) on place (%s) is not "
"supported in imperative mode",
framework::DataTypeToString(data_type), place));
#endif
} else if (platform::is_cpu_place(place)) {
return TensorAddImpl<platform::CPUDeviceContext, platform::bfloat16>(
src_tensor, dst_tensor, place);
}
}
PADDLE_THROW(platform::errors::Unimplemented(
"Gradient accumulation of data type (%s) on place (%s) is not "
"supported in imperative mode",
......
......@@ -35,6 +35,8 @@ thread_local bool Tracer::has_grad_ = true;
thread_local AmpLevel Tracer::amp_level_ = AmpLevel::O0;
thread_local pten::DataType Tracer::amp_dtype_ = pten::DataType::FLOAT32;
static std::shared_ptr<Tracer> g_current_tracer(nullptr);
const std::shared_ptr<Tracer>& GetCurrentTracer() { return g_current_tracer; }
......@@ -200,10 +202,18 @@ void Tracer::TraceOp(const std::string& type, const NameVarMap<VarType>& ins,
NameVarMap<VarType> new_ins = ins;
if (amp_level_ == AmpLevel::O1) {
VLOG(5) << "Auto mixed precision run operator: " << type;
new_ins = AutoCastInputs<VarType>(type, ins);
if (amp_dtype_ == pten::DataType::FLOAT16) {
new_ins = AutoCastInputs<VarType>(type, ins);
} else if (amp_dtype_ == pten::DataType::BFLOAT16) {
new_ins = AutoCastBF16Inputs<VarType>(type, ins);
}
} else if (amp_level_ == AmpLevel::O2) {
VLOG(5) << "Pure fp16 run operator: " << type;
new_ins = CastPureFp16Inputs<VarType>(type, ins);
if (amp_dtype_ == pten::DataType::FLOAT16) {
new_ins = CastPureFp16Inputs<VarType>(type, ins);
} else if (amp_dtype_ == pten::DataType::BFLOAT16) {
new_ins = CastPureBf16Inputs<VarType>(type, ins);
}
}
try {
......
......@@ -34,6 +34,8 @@ namespace imperative {
enum class AmpLevel;
enum class AmpDtype;
using GarbageCollectorMap =
std::map<platform::Place,
std::unique_ptr<paddle::framework::GarbageCollector>>;
......@@ -131,6 +133,27 @@ class Tracer {
AmpLevel GetAmpLevel() const { return amp_level_; }
void SetAmpDtype(std::string amp_dtype) {
VLOG(4) << "set amp_dtype to " << amp_dtype;
if (amp_dtype == "float16") {
amp_dtype_ = pten::DataType::FLOAT16;
} else if (amp_dtype == "bfloat16") {
amp_dtype_ = pten::DataType::BFLOAT16;
} else {
amp_dtype_ = pten::DataType::FLOAT32;
}
}
std::string GetAmpDtype() const {
if (amp_dtype_ == pten::DataType::FLOAT16) {
return std::string("float16");
} else if (amp_dtype_ == pten::DataType::BFLOAT16) {
return std::string("bfloat16");
} else {
return std::string("float32");
}
}
paddle::framework::GarbageCollector* MutableGarbageCollectorIfNotExists(
const platform::Place& place);
......@@ -143,6 +166,7 @@ class Tracer {
GarbageCollectorMap gcs_;
static thread_local bool has_grad_;
static thread_local AmpLevel amp_level_;
static thread_local pten::DataType amp_dtype_;
};
// To access static variable current_tracer
......
......@@ -2230,6 +2230,8 @@ void BindImperative(py::module *m_ptr) {
&imperative::Tracer::SetEnableProgramDescTracing)
.def_property("_amp_level", &imperative::Tracer::GetAmpLevel,
&imperative::Tracer::SetAmpLevel)
.def_property("_amp_dtype", &imperative::Tracer::GetAmpDtype,
&imperative::Tracer::SetAmpDtype)
.def_property("_has_grad", &imperative::Tracer::HasGrad,
&imperative::Tracer::SetHasGrad)
.def_property(
......
......@@ -359,6 +359,8 @@ struct ElementwiseAddTo<paddle::platform::CPUDeviceContext, T> {
template struct ElementwiseAddTo<paddle::platform::CPUDeviceContext,
pten::dtype::float16>;
template struct ElementwiseAddTo<paddle::platform::CPUDeviceContext,
pten::dtype::bfloat16>;
} // namespace funcs
} // namespace pten
......@@ -381,6 +381,8 @@ struct ElementwiseAddTo<paddle::platform::CUDADeviceContext, T> {
template struct ElementwiseAddTo<paddle::platform::CUDADeviceContext,
pten::dtype::float16>;
template struct ElementwiseAddTo<paddle::platform::CUDADeviceContext,
pten::dtype::bfloat16>;
} // namespace funcs
} // namespace pten
......@@ -21,7 +21,8 @@ __all__ = []
def auto_cast(enable=True,
custom_white_list=None,
custom_black_list=None,
level='O1'):
level='O1',
dtype='float16'):
"""
Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode.
If enabled, the input data type (float32 or float16) of each operator is decided
......@@ -40,7 +41,8 @@ def auto_cast(enable=True,
observed in downstream ops. These ops will not be converted to fp16.
level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list;
O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don't support fp16 kernel and batchnorm. Default is O1(amp)
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
Examples:
.. code-block:: python
......@@ -73,7 +75,7 @@ def auto_cast(enable=True,
print(d.dtype) # FP16
"""
return amp_guard(enable, custom_white_list, custom_black_list, level)
return amp_guard(enable, custom_white_list, custom_black_list, level, dtype)
def decorate(models,
......
......@@ -107,6 +107,15 @@ class RecomputeFunction(PyLayer):
else:
raise ValueError("unsupported amp level: {}".format(
tracer._amp_level))
if tracer._amp_dtype == 'float16':
ctx.amp_dtype = 'float16'
elif tracer._amp_dtype in ('bfloat16', 'float32'):
ctx.amp_dtype = 'bfloat16'
else:
raise ValueError("unsupported amp dtype: {}".format(
tracer._amp_dtype))
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
with paddle.no_grad():
......@@ -137,7 +146,8 @@ class RecomputeFunction(PyLayer):
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_level):
level=ctx.amp_level,
dtype=ctx.amp_dtype):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)
else:
......@@ -145,7 +155,8 @@ class RecomputeFunction(PyLayer):
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_level):
level=ctx.amp_level,
dtype=ctx.amp_dtype):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)
......
......@@ -67,6 +67,9 @@ def convert_dtype(dtype):
# however, jointly supporting python2 and python3, (as well as python4 maybe)
# may still be a long-lasting problem.
return str(dtype)
# NOTE(zhangbo): Now numpy does not support bfloat, and paddle use uint16 to represent bfloat16, and there binaries are consistent.
if dtype in ['bfloat16']:
return 'uint16'
raise TypeError(
"dtype must be any of [bool, float16, uint16, float32, float64, int8, int16, "
......
......@@ -75,19 +75,29 @@ PURE_FP16_BLACK_LIST = {
'lookup_table', 'lookup_table_v2', 'scatter', 'scatter_grad'
}
BF16_WHITE_LIST = {'conv2d'}
BF16_BLACK_LIST = {' '}
#NOTE(zhiqiu): similar as paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionLists._update_list
# The reason why not use AutoMixedPrecisionLists is that custom_black_varnames is not suitable for imperative mode.
def _update_list(custom_white_list, custom_black_list, level='O1'):
def _update_list(custom_white_list,
custom_black_list,
level='O1',
dtype='float16'):
"""
Update black and white list according to users' custom list.
"""
if level == 'O1':
_white_list = copy.copy(WHITE_LIST)
_black_list = copy.copy(BLACK_LIST)
if dtype == 'float16':
if level == 'O1':
_white_list = copy.copy(WHITE_LIST)
_black_list = copy.copy(BLACK_LIST)
else:
_white_list = copy.copy(PURE_FP16_WHITE_LIST)
_black_list = copy.copy(PURE_FP16_BLACK_LIST)
else:
_white_list = copy.copy(PURE_FP16_WHITE_LIST)
_black_list = copy.copy(PURE_FP16_BLACK_LIST)
_white_list = copy.copy(BF16_WHITE_LIST)
_black_list = copy.copy(BF16_BLACK_LIST)
if custom_white_list and custom_black_list:
for op_name in custom_white_list:
if op_name in custom_black_list:
......@@ -125,6 +135,27 @@ def _in_pure_fp16_guard():
return tracer and tracer._amp_level == core.AmpLevel.O2
def _is_gpu_float16_supported():
"""
Judge whether current gpu support float16 amp.
"""
prop = paddle.device.cuda.get_device_capability()
return prop[0] >= 7
def _is_gpu_bfloat16_supported():
"""
Judge whether current gpu support bfloat16 amp.
"""
prop = paddle.device.cuda.get_device_capability()
cuda_version = paddle.version.cuda()
if cuda_version is not None:
cuda_version_check = int(cuda_version.split('.')[0]) >= 11
else:
cuda_version_check = False
return prop[0] >= 8 and cuda_version_check
@dygraph_only
def pure_fp16_initialize(models):
for idx in range(len(models)):
......@@ -165,7 +196,8 @@ def check_optimizers(optimizers):
def amp_guard(enable=True,
custom_white_list=None,
custom_black_list=None,
level='O1'):
level='O1',
dtype='float16'):
"""
:api_attr: imperative
......@@ -186,6 +218,7 @@ def amp_guard(enable=True,
observed in downstream ops. These ops will not be converted to fp16.
level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list;
O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don't support fp16 kernel and batchnorm. Default is O1(amp)
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
Examples:
......@@ -207,49 +240,88 @@ def amp_guard(enable=True,
print(conv.dtype) # FP32
"""
# check amp_level: O0-O2
level = level.upper()
if not (level in ['O0', 'O1', 'O2']):
raise ValueError(
"level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16 train mode."
"level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode."
)
# check amp_dtype: float16 or bfloat16
dtype = dtype.lower()
if not (dtype in ['float16', 'bfloat16']):
raise ValueError("dtype should be 'float16' or 'bfloat16'.")
# check tracer
tracer = _dygraph_tracer()
if not tracer:
raise ValueError(
"current_tracer is None, maybe it is not in imperative mode.")
# check device_type:
# NOTE: Now, amp only support gpu for float16 and bfloat16, xpu for float16.
# Maybe we will support cpu for bfloat16.
if enable and not (tracer._expected_place.is_gpu_place() or
tracer._expected_place.is_xpu_place()):
warnings.warn(
'amp_guard can only be enabled on CUDAPlace and XPUPlace, current place is %s, so it makes no effect.'
% tracer._expected_place)
enable = False
# For xpu:
if tracer._expected_place.is_xpu_place() and (dtype == 'bfloat16'):
warnings.warn('XPUPlace only support float16 amp.')
enable = False
# For gpu float16: Compute Capability should >= 7.
# For gpu bfloat16: Compute Capability should >= 8 & CUDA Version should >= 11.
if tracer._expected_place.is_gpu_place():
prop = paddle.device.cuda.get_device_capability()
if prop[0] < 7:
if (dtype == 'float16') and not _is_gpu_float16_supported():
prop = paddle.device.cuda.get_device_capability()
warnings.warn(
"AMP only support NVIDIA GPU with Compute Capability 7.0 or higher, current GPU is: %s, with Compute Capability: %d.%d."
"For float16, amp only support NVIDIA GPU with Compute Capability 7.0 or higher, current GPU is: %s, with Compute Capability: %d.%d."
% (paddle.device.cuda.get_device_name(), prop[0], prop[1]))
elif (dtype == 'bfloat16') and not _is_gpu_bfloat16_supported():
prop = paddle.device.cuda.get_device_capability()
cuda_version = paddle.version.cuda()
warnings.warn(
"For bfloat16, amp only support NVIDIA GPU with Compute Capability 8.0 or higher and CUDA Version 11.0 or higher, current GPU is: %s, with Compute Capability: %d.%d, current CUDA Version is: %s."
% (paddle.device.cuda.get_device_name(), prop[0], prop[1],
cuda_version))
amp_dtype = dtype
if level == 'O1':
amp_level = AMP_LEVEL.O1
_white_list = WHITE_LIST
_black_list = BLACK_LIST
if dtype == 'float16':
_white_list = WHITE_LIST
_black_list = BLACK_LIST
elif dtype == 'bfloat16':
_white_list = BF16_WHITE_LIST
_black_list = BF16_BLACK_LIST
elif level == 'O2':
amp_level = AMP_LEVEL.O2
_white_list = PURE_FP16_WHITE_LIST
_black_list = PURE_FP16_BLACK_LIST
if dtype == 'float16':
_white_list = PURE_FP16_WHITE_LIST
_black_list = PURE_FP16_BLACK_LIST
elif dtype == 'bfloat16':
_white_list = BF16_WHITE_LIST
_black_list = BF16_BLACK_LIST
elif level == 'O0':
amp_level = AMP_LEVEL.O0
_white_list = WHITE_LIST
_black_list = BLACK_LIST
if dtype == 'float16':
_white_list = WHITE_LIST
_black_list = BLACK_LIST
elif dtype == 'bfloat16':
_white_list = BF16_WHITE_LIST
_black_list = BF16_BLACK_LIST
if custom_white_list or custom_black_list:
_white_list, _black_list = _update_list(custom_white_list,
custom_black_list, level)
custom_black_list, level, dtype)
if not enable:
amp_level = AMP_LEVEL.O0
amp_dtype = "float32"
if tracer:
# enable auto_cast
......@@ -268,6 +340,10 @@ def amp_guard(enable=True,
# original_flags = get_flags(AMP_RELATED_FLAGS)
# set_flags(AMP_RELATED_FLAGS_SETTING)
# set amp dtype
original_amp_dtype = tracer._amp_dtype
tracer._amp_dtype = amp_dtype
# restore status
try:
yield
......@@ -276,6 +352,7 @@ def amp_guard(enable=True,
tracer._amp_level = original_amp_level
tracer._set_amp_op_list(original_white_list, original_black_list)
# set_flags(original_flags)
tracer._amp_dtype = original_amp_dtype
class StateDictHook(object):
......
......@@ -860,7 +860,6 @@ class TestPureFp16InferenceSaveLoad(unittest.TestCase):
results = exe.run(inference_program,
feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets)
self.assertTrue(np.allclose(pred.numpy(), results, atol=1.e-5))
......@@ -1126,5 +1125,26 @@ class TestLayerNormFp16(unittest.TestCase):
self.assertTrue(out.dtype == fluid.core.VarDesc.VarType.FP16)
class TestBf16(unittest.TestCase):
'''
test amp for BF16
'''
def train(self, enable_amp=True):
paddle.seed(100)
input = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.)
conv = paddle.nn.Conv2D(4, 6, (3, 3))
with paddle.amp.auto_cast(
enable=enable_amp, level='O2', dtype='bfloat16'):
output = conv(input)
output = output.cast('float32')
return output.numpy()
def test_bf16(self):
out_fp32 = self.train(enable_amp=False)
out_bf16 = self.train(enable_amp=True)
self.assertTrue(np.allclose(out_fp32, out_bf16, rtol=1.e-3, atol=1.e-2))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册