diff --git a/paddle/fluid/imperative/amp_auto_cast.cc b/paddle/fluid/imperative/amp_auto_cast.cc index 547fa02326bec36858717c8f66a268551423dbaa..e07393c47f9562bdd404ba13c7d123b5a0640e8d 100644 --- a/paddle/fluid/imperative/amp_auto_cast.cc +++ b/paddle/fluid/imperative/amp_auto_cast.cc @@ -113,26 +113,40 @@ AutoCastGuard::~AutoCastGuard() { tracer_->SetAmpLevel(pre_amp_level_); } AmpOperators::AmpOperators() : allow_ops_(new std::unordered_set()), block_ops_(new std::unordered_set()), - unsupported_fp16_ops_(new std::unordered_set()) { + unsupported_fp16_ops_(new std::unordered_set()), + unsupported_bf16_ops_(new std::unordered_set()) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - auto unsupported_ops_gpu = std::get<2>( + auto unsupported_ops_gpu_fp16 = std::get<2>( OpSupportedInfos("GPU", paddle::framework::proto::VarType::FP16)); - unsupported_fp16_ops_->insert(unsupported_ops_gpu.begin(), - unsupported_ops_gpu.end()); + unsupported_fp16_ops_->insert(unsupported_ops_gpu_fp16.begin(), + unsupported_ops_gpu_fp16.end()); + auto unsupported_ops_gpu_bf16 = std::get<2>( + OpSupportedInfos("GPU", paddle::framework::proto::VarType::BF16)); + unsupported_bf16_ops_->insert(unsupported_ops_gpu_bf16.begin(), + unsupported_ops_gpu_bf16.end()); // NOTE: GPU/NPU/XPU is compiled seperatly. #elif defined(PADDLE_WITH_ASCEND_CL) - auto unsupported_ops_npu = std::get<2>( + auto unsupported_ops_npu_fp16 = std::get<2>( OpSupportedInfos("NPU", paddle::framework::proto::VarType::FP16)); - unsupported_fp16_ops_->insert(unsupported_ops_npu.begin(), - unsupported_ops_npu.end()); + unsupported_fp16_ops_->insert(unsupported_ops_npu_fp16.begin(), + unsupported_ops_npu_fp16.end()); + auto unsupported_ops_npu_bf16 = std::get<2>( + OpSupportedInfos("NPU", paddle::framework::proto::VarType::BF16)); + unsupported_bf16_ops_->insert(unsupported_ops_npu_bf16.begin(), + unsupported_ops_npu_bf16.end()); #elif defined(PADDLE_WITH_XPU) - auto unsupported_ops_xpu = std::get<2>( + auto unsupported_ops_xpu_fp16 = std::get<2>( OpSupportedInfos("XPU", paddle::framework::proto::VarType::FP16)); - unsupported_fp16_ops_->insert(unsupported_ops_xpu.begin(), - unsupported_ops_xpu.end()); + unsupported_fp16_ops_->insert(unsupported_ops_xpu_fp16.begin(), + unsupported_ops_xpu_fp16.end()); + auto unsupported_ops_xpu_bf16 = std::get<2>( + OpSupportedInfos("XPU", paddle::framework::proto::VarType::BF16)); + unsupported_bf16_ops_->insert(unsupported_ops_xpu_bf16.begin(), + unsupported_ops_xpu_bf16.end()); #endif VLOG(4) << allow_ops_->size() << " " << block_ops_->size() << " " - << unsupported_fp16_ops_->size(); + << unsupported_fp16_ops_->size() << " " + << unsupported_bf16_ops_->size(); } AmpOperators::~AmpOperators() {} @@ -157,6 +171,11 @@ AmpOperators::GetMutableUnsupportedFp16Ops() { return unsupported_fp16_ops_; } +std::shared_ptr> +AmpOperators::GetMutableUnsupportedBf16Ops() { + return unsupported_bf16_ops_; +} + std::ostream& operator<<(std::ostream& os, AmpOperators& ops) { os << "allow ops: "; auto allow_ops = ops.GetMutableAllowOps(); @@ -172,6 +191,11 @@ std::ostream& operator<<(std::ostream& os, AmpOperators& ops) { auto unsupported_fp16_ops = ops.GetMutableUnsupportedFp16Ops(); std::copy((*unsupported_fp16_ops).begin(), (*unsupported_fp16_ops).end(), std::ostream_iterator(os, " ")); + os << "\n"; + os << "unsupported bf16 ops: "; + auto unsupported_bf16_ops = ops.GetMutableUnsupportedBf16Ops(); + std::copy((*unsupported_bf16_ops).begin(), (*unsupported_bf16_ops).end(), + std::ostream_iterator(os, " ")); return os; } @@ -188,7 +212,8 @@ inline bool NeedCast(const std::shared_ptr& var) { paddle::platform::is_xpu_place(place)) { // CudaPinndePlace is added for varbase created by dataloader if (data_type == paddle::framework::proto::VarType::FP32 || - data_type == paddle::framework::proto::VarType::FP16) { + data_type == paddle::framework::proto::VarType::FP16 || + data_type == paddle::framework::proto::VarType::BF16) { return true; } } @@ -236,6 +261,16 @@ static inline std::shared_ptr CastToFP32( return var; } +template +static inline std::shared_ptr CastToBF16( + const std::shared_ptr& var) { + auto dst_type = framework::proto::VarType::BF16; + if (NeedCast(var) && (GetDataType(var) != dst_type)) { + return CastToType(var, dst_type); + } + return var; +} + template static inline framework::proto::VarType::Type GetPromoteType( const std::string& op_type, const NameVarMap& ins) { @@ -386,5 +421,62 @@ template NameVarMap CastPureFp16Inputs( const std::string& op_type, const NameVarMap& ins); template NameVarMap CastPureFp16Inputs( const std::string& op_type, const NameVarMap& ins); + +template +NameVarMap AutoCastBF16Inputs(const std::string& op_type, + const NameVarMap& ins) { + NameVarMap new_ins(ins); + if (AmpOperators::Instance().GetMutableAllowOps()->count(op_type)) { + for (auto& pair : new_ins) { + VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " + << GetDtypeStr(*pair.second.cbegin()) << " to bfloat16"; + for (auto& var : pair.second) { + var = CastToBF16(var); + } + } + return new_ins; + } else { + for (auto& pair : new_ins) { + VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " + << GetDtypeStr(*pair.second.cbegin()) << " to float"; + for (auto& var : pair.second) { + var = CastToFP32(var); + } + } + return new_ins; + } + return new_ins; +} +template NameVarMap AutoCastBF16Inputs( + const std::string& op_type, const NameVarMap& ins); +template NameVarMap AutoCastBF16Inputs( + const std::string& op_type, const NameVarMap& ins); + +template +NameVarMap CastPureBf16Inputs(const std::string& op_type, + const NameVarMap& ins) { + NameVarMap new_ins(ins); + auto dst_type = framework::proto::VarType::BF16; + if (AmpOperators::Instance().GetMutableUnsupportedBf16Ops()->count(op_type) || + AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) { + dst_type = framework::proto::VarType::FP32; + } + for (auto& pair : new_ins) { + VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " + << GetDtypeStr(*pair.second.cbegin()) << " to " + << framework::DataTypeToString(dst_type); + for (auto& var : pair.second) { + var = (dst_type == framework::proto::VarType::FP32 + ? CastToFP32(var) + : CastToBF16(var)); + } + } + return new_ins; +} +template NameVarMap CastPureBf16Inputs( + const std::string& op_type, const NameVarMap& ins); +template NameVarMap CastPureBf16Inputs( + const std::string& op_type, const NameVarMap& ins); + } // namespace imperative } // namespace paddle diff --git a/paddle/fluid/imperative/amp_auto_cast.h b/paddle/fluid/imperative/amp_auto_cast.h index 775f9f973ae12f0a810b7c1e66d35b7b3f0bee90..2eef6591a440ea575ca217424923464471d0dea2 100644 --- a/paddle/fluid/imperative/amp_auto_cast.h +++ b/paddle/fluid/imperative/amp_auto_cast.h @@ -56,6 +56,9 @@ class AmpOperators { std::shared_ptr> GetMutableUnsupportedFp16Ops(); + std::shared_ptr> + GetMutableUnsupportedBf16Ops(); + private: AmpOperators(); // forbid calling default constructor @@ -69,6 +72,9 @@ class AmpOperators { // The set of ops that has no fp16 CUDA kennel. std::shared_ptr> unsupported_fp16_ops_; + + // The set of ops that has no bf16 CUDA kennel. + std::shared_ptr> unsupported_bf16_ops_; }; std::ostream& operator<<(std::ostream& os, AmpOperators& ops); @@ -95,6 +101,12 @@ NameVarMap AutoCastInputs(const std::string& op_type, template NameVarMap CastPureFp16Inputs(const std::string& op_type, const NameVarMap& ins); +template +NameVarMap AutoCastBF16Inputs(const std::string& op_type, + const NameVarMap& ins); +template +NameVarMap CastPureBf16Inputs(const std::string& op_type, + const NameVarMap& ins); } // namespace imperative } // namespace paddle diff --git a/paddle/fluid/imperative/gradient_accumulator.cc b/paddle/fluid/imperative/gradient_accumulator.cc index 7e61d3dab1622e0400febd161e9764de118d1674..382d1b0591cbe48725ad4d027d9cb582dae85054 100644 --- a/paddle/fluid/imperative/gradient_accumulator.cc +++ b/paddle/fluid/imperative/gradient_accumulator.cc @@ -24,6 +24,7 @@ #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/float16.h" @@ -420,6 +421,22 @@ void TensorAdd(const VarType& src, VarType* dst) { src_tensor, dst_tensor, place); } } + if (data_type == framework::proto::VarType::BF16) { + if (platform::is_gpu_place(place)) { +#if defined(PADDLE_WITH_CUDA) + return TensorAddImpl( + src_tensor, dst_tensor, place); +#else + PADDLE_THROW(platform::errors::Unimplemented( + "Gradient accumulation of data type (%s) on place (%s) is not " + "supported in imperative mode", + framework::DataTypeToString(data_type), place)); +#endif + } else if (platform::is_cpu_place(place)) { + return TensorAddImpl( + src_tensor, dst_tensor, place); + } + } PADDLE_THROW(platform::errors::Unimplemented( "Gradient accumulation of data type (%s) on place (%s) is not " "supported in imperative mode", diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index a600720ef78edb5175bb7d17821f5d8e229d1a93..c913afebd8d4fae38d7a2722d59ec54f81038a8d 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -35,6 +35,8 @@ thread_local bool Tracer::has_grad_ = true; thread_local AmpLevel Tracer::amp_level_ = AmpLevel::O0; +thread_local pten::DataType Tracer::amp_dtype_ = pten::DataType::FLOAT32; + static std::shared_ptr g_current_tracer(nullptr); const std::shared_ptr& GetCurrentTracer() { return g_current_tracer; } @@ -200,10 +202,18 @@ void Tracer::TraceOp(const std::string& type, const NameVarMap& ins, NameVarMap new_ins = ins; if (amp_level_ == AmpLevel::O1) { VLOG(5) << "Auto mixed precision run operator: " << type; - new_ins = AutoCastInputs(type, ins); + if (amp_dtype_ == pten::DataType::FLOAT16) { + new_ins = AutoCastInputs(type, ins); + } else if (amp_dtype_ == pten::DataType::BFLOAT16) { + new_ins = AutoCastBF16Inputs(type, ins); + } } else if (amp_level_ == AmpLevel::O2) { VLOG(5) << "Pure fp16 run operator: " << type; - new_ins = CastPureFp16Inputs(type, ins); + if (amp_dtype_ == pten::DataType::FLOAT16) { + new_ins = CastPureFp16Inputs(type, ins); + } else if (amp_dtype_ == pten::DataType::BFLOAT16) { + new_ins = CastPureBf16Inputs(type, ins); + } } try { diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index 3a9a1b630ce9cbc89f57b746e6e1e1445f6bd318..b7b22721560545034554b816c9a2c7cae37adfa4 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -34,6 +34,8 @@ namespace imperative { enum class AmpLevel; +enum class AmpDtype; + using GarbageCollectorMap = std::map>; @@ -131,6 +133,27 @@ class Tracer { AmpLevel GetAmpLevel() const { return amp_level_; } + void SetAmpDtype(std::string amp_dtype) { + VLOG(4) << "set amp_dtype to " << amp_dtype; + if (amp_dtype == "float16") { + amp_dtype_ = pten::DataType::FLOAT16; + } else if (amp_dtype == "bfloat16") { + amp_dtype_ = pten::DataType::BFLOAT16; + } else { + amp_dtype_ = pten::DataType::FLOAT32; + } + } + + std::string GetAmpDtype() const { + if (amp_dtype_ == pten::DataType::FLOAT16) { + return std::string("float16"); + } else if (amp_dtype_ == pten::DataType::BFLOAT16) { + return std::string("bfloat16"); + } else { + return std::string("float32"); + } + } + paddle::framework::GarbageCollector* MutableGarbageCollectorIfNotExists( const platform::Place& place); @@ -143,6 +166,7 @@ class Tracer { GarbageCollectorMap gcs_; static thread_local bool has_grad_; static thread_local AmpLevel amp_level_; + static thread_local pten::DataType amp_dtype_; }; // To access static variable current_tracer diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index f4ed1ee3424f229d77c293d19edca911aea31f69..5dc163bb8b187036fb3b717d192388400a94dda4 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -2230,6 +2230,8 @@ void BindImperative(py::module *m_ptr) { &imperative::Tracer::SetEnableProgramDescTracing) .def_property("_amp_level", &imperative::Tracer::GetAmpLevel, &imperative::Tracer::SetAmpLevel) + .def_property("_amp_dtype", &imperative::Tracer::GetAmpDtype, + &imperative::Tracer::SetAmpDtype) .def_property("_has_grad", &imperative::Tracer::HasGrad, &imperative::Tracer::SetHasGrad) .def_property( diff --git a/paddle/pten/kernels/funcs/math_function.cc b/paddle/pten/kernels/funcs/math_function.cc index facb26a552019df6e485c2cdbfb5ddda77dc6be5..4a09157fac0f674e8aaf737d6b7697c5fd7d860f 100644 --- a/paddle/pten/kernels/funcs/math_function.cc +++ b/paddle/pten/kernels/funcs/math_function.cc @@ -359,6 +359,8 @@ struct ElementwiseAddTo { template struct ElementwiseAddTo; +template struct ElementwiseAddTo; } // namespace funcs } // namespace pten diff --git a/paddle/pten/kernels/funcs/math_function.cu b/paddle/pten/kernels/funcs/math_function.cu index d019a382d77173185e9ce0a7e76d7c6ae5fcf773..d202e46da8bd95d84037d90a3c20277a20f0938f 100644 --- a/paddle/pten/kernels/funcs/math_function.cu +++ b/paddle/pten/kernels/funcs/math_function.cu @@ -381,6 +381,8 @@ struct ElementwiseAddTo { template struct ElementwiseAddTo; +template struct ElementwiseAddTo; } // namespace funcs } // namespace pten diff --git a/python/paddle/amp/auto_cast.py b/python/paddle/amp/auto_cast.py index 297baa1762386d802e839b105e41437c0c999305..9ca29d509f60e5c3eb435e816a32cc14aa92e921 100644 --- a/python/paddle/amp/auto_cast.py +++ b/python/paddle/amp/auto_cast.py @@ -21,7 +21,8 @@ __all__ = [] def auto_cast(enable=True, custom_white_list=None, custom_black_list=None, - level='O1'): + level='O1', + dtype='float16'): """ Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode. If enabled, the input data type (float32 or float16) of each operator is decided @@ -40,7 +41,8 @@ def auto_cast(enable=True, observed in downstream ops. These ops will not be converted to fp16. level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list; O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don't support fp16 kernel and batchnorm. Default is O1(amp) - + dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'. + Examples: .. code-block:: python @@ -73,7 +75,7 @@ def auto_cast(enable=True, print(d.dtype) # FP16 """ - return amp_guard(enable, custom_white_list, custom_black_list, level) + return amp_guard(enable, custom_white_list, custom_black_list, level, dtype) def decorate(models, diff --git a/python/paddle/distributed/fleet/utils/recompute.py b/python/paddle/distributed/fleet/utils/recompute.py index c83c350b9e4f781cb1a17179b01d5ec5dea70397..dccd7f6205302663117c0957c19138270bf32feb 100755 --- a/python/paddle/distributed/fleet/utils/recompute.py +++ b/python/paddle/distributed/fleet/utils/recompute.py @@ -107,6 +107,15 @@ class RecomputeFunction(PyLayer): else: raise ValueError("unsupported amp level: {}".format( tracer._amp_level)) + + if tracer._amp_dtype == 'float16': + ctx.amp_dtype = 'float16' + elif tracer._amp_dtype in ('bfloat16', 'float32'): + ctx.amp_dtype = 'bfloat16' + else: + raise ValueError("unsupported amp dtype: {}".format( + tracer._amp_dtype)) + ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() with paddle.no_grad(): @@ -137,7 +146,8 @@ class RecomputeFunction(PyLayer): enable=ctx.is_fw_autocast, custom_white_list=ctx.amp_white_list, custom_black_list=ctx.amp_black_list, - level=ctx.amp_level): + level=ctx.amp_level, + dtype=ctx.amp_dtype): detached_inputs = detach_variable(tuple(inputs)) outputs = ctx.run_function(*detached_inputs) else: @@ -145,7 +155,8 @@ class RecomputeFunction(PyLayer): enable=ctx.is_fw_autocast, custom_white_list=ctx.amp_white_list, custom_black_list=ctx.amp_black_list, - level=ctx.amp_level): + level=ctx.amp_level, + dtype=ctx.amp_dtype): detached_inputs = detach_variable(tuple(inputs)) outputs = ctx.run_function(*detached_inputs) diff --git a/python/paddle/fluid/data_feeder.py b/python/paddle/fluid/data_feeder.py index 3bcefc41d2e781aa904f7ab581af3d72bc97b0d9..c11ebf7f8eae6021829d8b541eaa6917f0e657cb 100644 --- a/python/paddle/fluid/data_feeder.py +++ b/python/paddle/fluid/data_feeder.py @@ -67,6 +67,9 @@ def convert_dtype(dtype): # however, jointly supporting python2 and python3, (as well as python4 maybe) # may still be a long-lasting problem. return str(dtype) + # NOTE(zhangbo): Now numpy does not support bfloat, and paddle use uint16 to represent bfloat16, and there binaries are consistent. + if dtype in ['bfloat16']: + return 'uint16' raise TypeError( "dtype must be any of [bool, float16, uint16, float32, float64, int8, int16, " diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index f09e210c3c161540d842073cc878f8ba131b598f..01d64550321d5e96d3dddeebb2509e3d96f3237b 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -75,19 +75,29 @@ PURE_FP16_BLACK_LIST = { 'lookup_table', 'lookup_table_v2', 'scatter', 'scatter_grad' } +BF16_WHITE_LIST = {'conv2d'} +BF16_BLACK_LIST = {' '} + #NOTE(zhiqiu): similar as paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionLists._update_list # The reason why not use AutoMixedPrecisionLists is that custom_black_varnames is not suitable for imperative mode. -def _update_list(custom_white_list, custom_black_list, level='O1'): +def _update_list(custom_white_list, + custom_black_list, + level='O1', + dtype='float16'): """ Update black and white list according to users' custom list. """ - if level == 'O1': - _white_list = copy.copy(WHITE_LIST) - _black_list = copy.copy(BLACK_LIST) + if dtype == 'float16': + if level == 'O1': + _white_list = copy.copy(WHITE_LIST) + _black_list = copy.copy(BLACK_LIST) + else: + _white_list = copy.copy(PURE_FP16_WHITE_LIST) + _black_list = copy.copy(PURE_FP16_BLACK_LIST) else: - _white_list = copy.copy(PURE_FP16_WHITE_LIST) - _black_list = copy.copy(PURE_FP16_BLACK_LIST) + _white_list = copy.copy(BF16_WHITE_LIST) + _black_list = copy.copy(BF16_BLACK_LIST) if custom_white_list and custom_black_list: for op_name in custom_white_list: if op_name in custom_black_list: @@ -125,6 +135,27 @@ def _in_pure_fp16_guard(): return tracer and tracer._amp_level == core.AmpLevel.O2 +def _is_gpu_float16_supported(): + """ + Judge whether current gpu support float16 amp. + """ + prop = paddle.device.cuda.get_device_capability() + return prop[0] >= 7 + + +def _is_gpu_bfloat16_supported(): + """ + Judge whether current gpu support bfloat16 amp. + """ + prop = paddle.device.cuda.get_device_capability() + cuda_version = paddle.version.cuda() + if cuda_version is not None: + cuda_version_check = int(cuda_version.split('.')[0]) >= 11 + else: + cuda_version_check = False + return prop[0] >= 8 and cuda_version_check + + @dygraph_only def pure_fp16_initialize(models): for idx in range(len(models)): @@ -165,7 +196,8 @@ def check_optimizers(optimizers): def amp_guard(enable=True, custom_white_list=None, custom_black_list=None, - level='O1'): + level='O1', + dtype='float16'): """ :api_attr: imperative @@ -186,6 +218,7 @@ def amp_guard(enable=True, observed in downstream ops. These ops will not be converted to fp16. level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list; O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don't support fp16 kernel and batchnorm. Default is O1(amp) + dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'. Examples: @@ -207,49 +240,88 @@ def amp_guard(enable=True, print(conv.dtype) # FP32 """ + # check amp_level: O0-O2 + level = level.upper() if not (level in ['O0', 'O1', 'O2']): raise ValueError( - "level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16 train mode." + "level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode." ) + # check amp_dtype: float16 or bfloat16 + dtype = dtype.lower() + if not (dtype in ['float16', 'bfloat16']): + raise ValueError("dtype should be 'float16' or 'bfloat16'.") + + # check tracer tracer = _dygraph_tracer() if not tracer: raise ValueError( "current_tracer is None, maybe it is not in imperative mode.") + # check device_type: + # NOTE: Now, amp only support gpu for float16 and bfloat16, xpu for float16. + # Maybe we will support cpu for bfloat16. if enable and not (tracer._expected_place.is_gpu_place() or tracer._expected_place.is_xpu_place()): warnings.warn( 'amp_guard can only be enabled on CUDAPlace and XPUPlace, current place is %s, so it makes no effect.' % tracer._expected_place) enable = False - + # For xpu: + if tracer._expected_place.is_xpu_place() and (dtype == 'bfloat16'): + warnings.warn('XPUPlace only support float16 amp.') + enable = False + # For gpu float16: Compute Capability should >= 7. + # For gpu bfloat16: Compute Capability should >= 8 & CUDA Version should >= 11. if tracer._expected_place.is_gpu_place(): - prop = paddle.device.cuda.get_device_capability() - if prop[0] < 7: + if (dtype == 'float16') and not _is_gpu_float16_supported(): + prop = paddle.device.cuda.get_device_capability() warnings.warn( - "AMP only support NVIDIA GPU with Compute Capability 7.0 or higher, current GPU is: %s, with Compute Capability: %d.%d." + "For float16, amp only support NVIDIA GPU with Compute Capability 7.0 or higher, current GPU is: %s, with Compute Capability: %d.%d." % (paddle.device.cuda.get_device_name(), prop[0], prop[1])) + elif (dtype == 'bfloat16') and not _is_gpu_bfloat16_supported(): + prop = paddle.device.cuda.get_device_capability() + cuda_version = paddle.version.cuda() + warnings.warn( + "For bfloat16, amp only support NVIDIA GPU with Compute Capability 8.0 or higher and CUDA Version 11.0 or higher, current GPU is: %s, with Compute Capability: %d.%d, current CUDA Version is: %s." + % (paddle.device.cuda.get_device_name(), prop[0], prop[1], + cuda_version)) + + amp_dtype = dtype if level == 'O1': amp_level = AMP_LEVEL.O1 - _white_list = WHITE_LIST - _black_list = BLACK_LIST + if dtype == 'float16': + _white_list = WHITE_LIST + _black_list = BLACK_LIST + elif dtype == 'bfloat16': + _white_list = BF16_WHITE_LIST + _black_list = BF16_BLACK_LIST + elif level == 'O2': amp_level = AMP_LEVEL.O2 - _white_list = PURE_FP16_WHITE_LIST - _black_list = PURE_FP16_BLACK_LIST + if dtype == 'float16': + _white_list = PURE_FP16_WHITE_LIST + _black_list = PURE_FP16_BLACK_LIST + elif dtype == 'bfloat16': + _white_list = BF16_WHITE_LIST + _black_list = BF16_BLACK_LIST elif level == 'O0': amp_level = AMP_LEVEL.O0 - _white_list = WHITE_LIST - _black_list = BLACK_LIST + if dtype == 'float16': + _white_list = WHITE_LIST + _black_list = BLACK_LIST + elif dtype == 'bfloat16': + _white_list = BF16_WHITE_LIST + _black_list = BF16_BLACK_LIST if custom_white_list or custom_black_list: _white_list, _black_list = _update_list(custom_white_list, - custom_black_list, level) + custom_black_list, level, dtype) if not enable: amp_level = AMP_LEVEL.O0 + amp_dtype = "float32" if tracer: # enable auto_cast @@ -268,6 +340,10 @@ def amp_guard(enable=True, # original_flags = get_flags(AMP_RELATED_FLAGS) # set_flags(AMP_RELATED_FLAGS_SETTING) + # set amp dtype + original_amp_dtype = tracer._amp_dtype + tracer._amp_dtype = amp_dtype + # restore status try: yield @@ -276,6 +352,7 @@ def amp_guard(enable=True, tracer._amp_level = original_amp_level tracer._set_amp_op_list(original_white_list, original_black_list) # set_flags(original_flags) + tracer._amp_dtype = original_amp_dtype class StateDictHook(object): diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index 62b40f88571d4119e0deaa1121e6d718baca7b7e..306c6b4707e8a3d7386bd8af3e32e55d09d563c4 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -860,7 +860,6 @@ class TestPureFp16InferenceSaveLoad(unittest.TestCase): results = exe.run(inference_program, feed={feed_target_names[0]: tensor_img}, fetch_list=fetch_targets) - self.assertTrue(np.allclose(pred.numpy(), results, atol=1.e-5)) @@ -1126,5 +1125,26 @@ class TestLayerNormFp16(unittest.TestCase): self.assertTrue(out.dtype == fluid.core.VarDesc.VarType.FP16) +class TestBf16(unittest.TestCase): + ''' + test amp for BF16 + ''' + + def train(self, enable_amp=True): + paddle.seed(100) + input = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.) + conv = paddle.nn.Conv2D(4, 6, (3, 3)) + with paddle.amp.auto_cast( + enable=enable_amp, level='O2', dtype='bfloat16'): + output = conv(input) + output = output.cast('float32') + return output.numpy() + + def test_bf16(self): + out_fp32 = self.train(enable_amp=False) + out_bf16 = self.train(enable_amp=True) + self.assertTrue(np.allclose(out_fp32, out_bf16, rtol=1.e-3, atol=1.e-2)) + + if __name__ == '__main__': unittest.main()