未验证 提交 7d6d3848 编写于 作者: Z zhangbo9674 提交者: GitHub

[AMP] support GPU BF16 amp for dygraph (#39029)

* support dtype param for auto_cast

* add amp_dtype for tracer

* add unsupported bf16 list

* support bf16 amp for O2

* refine python interface for bfloat16

* refine code

* refine code

* refine unittest

* refine code

* refine code

* add bf16 o1

* refine code by comment

* add gradient accumulator

* add recompute
上级 7e4ed848
...@@ -113,26 +113,40 @@ AutoCastGuard::~AutoCastGuard() { tracer_->SetAmpLevel(pre_amp_level_); } ...@@ -113,26 +113,40 @@ AutoCastGuard::~AutoCastGuard() { tracer_->SetAmpLevel(pre_amp_level_); }
AmpOperators::AmpOperators() AmpOperators::AmpOperators()
: allow_ops_(new std::unordered_set<std::string>()), : allow_ops_(new std::unordered_set<std::string>()),
block_ops_(new std::unordered_set<std::string>()), block_ops_(new std::unordered_set<std::string>()),
unsupported_fp16_ops_(new std::unordered_set<std::string>()) { unsupported_fp16_ops_(new std::unordered_set<std::string>()),
unsupported_bf16_ops_(new std::unordered_set<std::string>()) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto unsupported_ops_gpu = std::get<2>( auto unsupported_ops_gpu_fp16 = std::get<2>(
OpSupportedInfos("GPU", paddle::framework::proto::VarType::FP16)); OpSupportedInfos("GPU", paddle::framework::proto::VarType::FP16));
unsupported_fp16_ops_->insert(unsupported_ops_gpu.begin(), unsupported_fp16_ops_->insert(unsupported_ops_gpu_fp16.begin(),
unsupported_ops_gpu.end()); unsupported_ops_gpu_fp16.end());
auto unsupported_ops_gpu_bf16 = std::get<2>(
OpSupportedInfos("GPU", paddle::framework::proto::VarType::BF16));
unsupported_bf16_ops_->insert(unsupported_ops_gpu_bf16.begin(),
unsupported_ops_gpu_bf16.end());
// NOTE: GPU/NPU/XPU is compiled seperatly. // NOTE: GPU/NPU/XPU is compiled seperatly.
#elif defined(PADDLE_WITH_ASCEND_CL) #elif defined(PADDLE_WITH_ASCEND_CL)
auto unsupported_ops_npu = std::get<2>( auto unsupported_ops_npu_fp16 = std::get<2>(
OpSupportedInfos("NPU", paddle::framework::proto::VarType::FP16)); OpSupportedInfos("NPU", paddle::framework::proto::VarType::FP16));
unsupported_fp16_ops_->insert(unsupported_ops_npu.begin(), unsupported_fp16_ops_->insert(unsupported_ops_npu_fp16.begin(),
unsupported_ops_npu.end()); unsupported_ops_npu_fp16.end());
auto unsupported_ops_npu_bf16 = std::get<2>(
OpSupportedInfos("NPU", paddle::framework::proto::VarType::BF16));
unsupported_bf16_ops_->insert(unsupported_ops_npu_bf16.begin(),
unsupported_ops_npu_bf16.end());
#elif defined(PADDLE_WITH_XPU) #elif defined(PADDLE_WITH_XPU)
auto unsupported_ops_xpu = std::get<2>( auto unsupported_ops_xpu_fp16 = std::get<2>(
OpSupportedInfos("XPU", paddle::framework::proto::VarType::FP16)); OpSupportedInfos("XPU", paddle::framework::proto::VarType::FP16));
unsupported_fp16_ops_->insert(unsupported_ops_xpu.begin(), unsupported_fp16_ops_->insert(unsupported_ops_xpu_fp16.begin(),
unsupported_ops_xpu.end()); unsupported_ops_xpu_fp16.end());
auto unsupported_ops_xpu_bf16 = std::get<2>(
OpSupportedInfos("XPU", paddle::framework::proto::VarType::BF16));
unsupported_bf16_ops_->insert(unsupported_ops_xpu_bf16.begin(),
unsupported_ops_xpu_bf16.end());
#endif #endif
VLOG(4) << allow_ops_->size() << " " << block_ops_->size() << " " VLOG(4) << allow_ops_->size() << " " << block_ops_->size() << " "
<< unsupported_fp16_ops_->size(); << unsupported_fp16_ops_->size() << " "
<< unsupported_bf16_ops_->size();
} }
AmpOperators::~AmpOperators() {} AmpOperators::~AmpOperators() {}
...@@ -157,6 +171,11 @@ AmpOperators::GetMutableUnsupportedFp16Ops() { ...@@ -157,6 +171,11 @@ AmpOperators::GetMutableUnsupportedFp16Ops() {
return unsupported_fp16_ops_; return unsupported_fp16_ops_;
} }
std::shared_ptr<std::unordered_set<std::string>>
AmpOperators::GetMutableUnsupportedBf16Ops() {
return unsupported_bf16_ops_;
}
std::ostream& operator<<(std::ostream& os, AmpOperators& ops) { std::ostream& operator<<(std::ostream& os, AmpOperators& ops) {
os << "allow ops: "; os << "allow ops: ";
auto allow_ops = ops.GetMutableAllowOps(); auto allow_ops = ops.GetMutableAllowOps();
...@@ -172,6 +191,11 @@ std::ostream& operator<<(std::ostream& os, AmpOperators& ops) { ...@@ -172,6 +191,11 @@ std::ostream& operator<<(std::ostream& os, AmpOperators& ops) {
auto unsupported_fp16_ops = ops.GetMutableUnsupportedFp16Ops(); auto unsupported_fp16_ops = ops.GetMutableUnsupportedFp16Ops();
std::copy((*unsupported_fp16_ops).begin(), (*unsupported_fp16_ops).end(), std::copy((*unsupported_fp16_ops).begin(), (*unsupported_fp16_ops).end(),
std::ostream_iterator<std::string>(os, " ")); std::ostream_iterator<std::string>(os, " "));
os << "\n";
os << "unsupported bf16 ops: ";
auto unsupported_bf16_ops = ops.GetMutableUnsupportedBf16Ops();
std::copy((*unsupported_bf16_ops).begin(), (*unsupported_bf16_ops).end(),
std::ostream_iterator<std::string>(os, " "));
return os; return os;
} }
...@@ -188,7 +212,8 @@ inline bool NeedCast(const std::shared_ptr<VarType>& var) { ...@@ -188,7 +212,8 @@ inline bool NeedCast(const std::shared_ptr<VarType>& var) {
paddle::platform::is_xpu_place(place)) { paddle::platform::is_xpu_place(place)) {
// CudaPinndePlace is added for varbase created by dataloader // CudaPinndePlace is added for varbase created by dataloader
if (data_type == paddle::framework::proto::VarType::FP32 || if (data_type == paddle::framework::proto::VarType::FP32 ||
data_type == paddle::framework::proto::VarType::FP16) { data_type == paddle::framework::proto::VarType::FP16 ||
data_type == paddle::framework::proto::VarType::BF16) {
return true; return true;
} }
} }
...@@ -236,6 +261,16 @@ static inline std::shared_ptr<VarType> CastToFP32( ...@@ -236,6 +261,16 @@ static inline std::shared_ptr<VarType> CastToFP32(
return var; return var;
} }
template <typename VarType>
static inline std::shared_ptr<VarType> CastToBF16(
const std::shared_ptr<VarType>& var) {
auto dst_type = framework::proto::VarType::BF16;
if (NeedCast(var) && (GetDataType<VarType>(var) != dst_type)) {
return CastToType(var, dst_type);
}
return var;
}
template <typename VarType> template <typename VarType>
static inline framework::proto::VarType::Type GetPromoteType( static inline framework::proto::VarType::Type GetPromoteType(
const std::string& op_type, const NameVarMap<VarType>& ins) { const std::string& op_type, const NameVarMap<VarType>& ins) {
...@@ -386,5 +421,62 @@ template NameVarMap<VarBase> CastPureFp16Inputs<VarBase>( ...@@ -386,5 +421,62 @@ template NameVarMap<VarBase> CastPureFp16Inputs<VarBase>(
const std::string& op_type, const NameVarMap<VarBase>& ins); const std::string& op_type, const NameVarMap<VarBase>& ins);
template NameVarMap<egr::EagerVariable> CastPureFp16Inputs<egr::EagerVariable>( template NameVarMap<egr::EagerVariable> CastPureFp16Inputs<egr::EagerVariable>(
const std::string& op_type, const NameVarMap<egr::EagerVariable>& ins); const std::string& op_type, const NameVarMap<egr::EagerVariable>& ins);
template <typename VarType>
NameVarMap<VarType> AutoCastBF16Inputs(const std::string& op_type,
const NameVarMap<VarType>& ins) {
NameVarMap<VarType> new_ins(ins);
if (AmpOperators::Instance().GetMutableAllowOps()->count(op_type)) {
for (auto& pair : new_ins) {
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to bfloat16";
for (auto& var : pair.second) {
var = CastToBF16<VarType>(var);
}
}
return new_ins;
} else {
for (auto& pair : new_ins) {
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to float";
for (auto& var : pair.second) {
var = CastToFP32<VarType>(var);
}
}
return new_ins;
}
return new_ins;
}
template NameVarMap<VarBase> AutoCastBF16Inputs<VarBase>(
const std::string& op_type, const NameVarMap<VarBase>& ins);
template NameVarMap<egr::EagerVariable> AutoCastBF16Inputs<egr::EagerVariable>(
const std::string& op_type, const NameVarMap<egr::EagerVariable>& ins);
template <typename VarType>
NameVarMap<VarType> CastPureBf16Inputs(const std::string& op_type,
const NameVarMap<VarType>& ins) {
NameVarMap<VarType> new_ins(ins);
auto dst_type = framework::proto::VarType::BF16;
if (AmpOperators::Instance().GetMutableUnsupportedBf16Ops()->count(op_type) ||
AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) {
dst_type = framework::proto::VarType::FP32;
}
for (auto& pair : new_ins) {
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to "
<< framework::DataTypeToString(dst_type);
for (auto& var : pair.second) {
var = (dst_type == framework::proto::VarType::FP32
? CastToFP32<VarType>(var)
: CastToBF16<VarType>(var));
}
}
return new_ins;
}
template NameVarMap<VarBase> CastPureBf16Inputs<VarBase>(
const std::string& op_type, const NameVarMap<VarBase>& ins);
template NameVarMap<egr::EagerVariable> CastPureBf16Inputs<egr::EagerVariable>(
const std::string& op_type, const NameVarMap<egr::EagerVariable>& ins);
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
...@@ -56,6 +56,9 @@ class AmpOperators { ...@@ -56,6 +56,9 @@ class AmpOperators {
std::shared_ptr<std::unordered_set<std::string>> std::shared_ptr<std::unordered_set<std::string>>
GetMutableUnsupportedFp16Ops(); GetMutableUnsupportedFp16Ops();
std::shared_ptr<std::unordered_set<std::string>>
GetMutableUnsupportedBf16Ops();
private: private:
AmpOperators(); // forbid calling default constructor AmpOperators(); // forbid calling default constructor
...@@ -69,6 +72,9 @@ class AmpOperators { ...@@ -69,6 +72,9 @@ class AmpOperators {
// The set of ops that has no fp16 CUDA kennel. // The set of ops that has no fp16 CUDA kennel.
std::shared_ptr<std::unordered_set<std::string>> unsupported_fp16_ops_; std::shared_ptr<std::unordered_set<std::string>> unsupported_fp16_ops_;
// The set of ops that has no bf16 CUDA kennel.
std::shared_ptr<std::unordered_set<std::string>> unsupported_bf16_ops_;
}; };
std::ostream& operator<<(std::ostream& os, AmpOperators& ops); std::ostream& operator<<(std::ostream& os, AmpOperators& ops);
...@@ -95,6 +101,12 @@ NameVarMap<VarType> AutoCastInputs(const std::string& op_type, ...@@ -95,6 +101,12 @@ NameVarMap<VarType> AutoCastInputs(const std::string& op_type,
template <typename VarType> template <typename VarType>
NameVarMap<VarType> CastPureFp16Inputs(const std::string& op_type, NameVarMap<VarType> CastPureFp16Inputs(const std::string& op_type,
const NameVarMap<VarType>& ins); const NameVarMap<VarType>& ins);
template <typename VarType>
NameVarMap<VarType> AutoCastBF16Inputs(const std::string& op_type,
const NameVarMap<VarType>& ins);
template <typename VarType>
NameVarMap<VarType> CastPureBf16Inputs(const std::string& op_type,
const NameVarMap<VarType>& ins);
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -420,6 +421,22 @@ void TensorAdd(const VarType& src, VarType* dst) { ...@@ -420,6 +421,22 @@ void TensorAdd(const VarType& src, VarType* dst) {
src_tensor, dst_tensor, place); src_tensor, dst_tensor, place);
} }
} }
if (data_type == framework::proto::VarType::BF16) {
if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA)
return TensorAddImpl<platform::CUDADeviceContext, platform::bfloat16>(
src_tensor, dst_tensor, place);
#else
PADDLE_THROW(platform::errors::Unimplemented(
"Gradient accumulation of data type (%s) on place (%s) is not "
"supported in imperative mode",
framework::DataTypeToString(data_type), place));
#endif
} else if (platform::is_cpu_place(place)) {
return TensorAddImpl<platform::CPUDeviceContext, platform::bfloat16>(
src_tensor, dst_tensor, place);
}
}
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Gradient accumulation of data type (%s) on place (%s) is not " "Gradient accumulation of data type (%s) on place (%s) is not "
"supported in imperative mode", "supported in imperative mode",
......
...@@ -35,6 +35,8 @@ thread_local bool Tracer::has_grad_ = true; ...@@ -35,6 +35,8 @@ thread_local bool Tracer::has_grad_ = true;
thread_local AmpLevel Tracer::amp_level_ = AmpLevel::O0; thread_local AmpLevel Tracer::amp_level_ = AmpLevel::O0;
thread_local pten::DataType Tracer::amp_dtype_ = pten::DataType::FLOAT32;
static std::shared_ptr<Tracer> g_current_tracer(nullptr); static std::shared_ptr<Tracer> g_current_tracer(nullptr);
const std::shared_ptr<Tracer>& GetCurrentTracer() { return g_current_tracer; } const std::shared_ptr<Tracer>& GetCurrentTracer() { return g_current_tracer; }
...@@ -200,10 +202,18 @@ void Tracer::TraceOp(const std::string& type, const NameVarMap<VarType>& ins, ...@@ -200,10 +202,18 @@ void Tracer::TraceOp(const std::string& type, const NameVarMap<VarType>& ins,
NameVarMap<VarType> new_ins = ins; NameVarMap<VarType> new_ins = ins;
if (amp_level_ == AmpLevel::O1) { if (amp_level_ == AmpLevel::O1) {
VLOG(5) << "Auto mixed precision run operator: " << type; VLOG(5) << "Auto mixed precision run operator: " << type;
new_ins = AutoCastInputs<VarType>(type, ins); if (amp_dtype_ == pten::DataType::FLOAT16) {
new_ins = AutoCastInputs<VarType>(type, ins);
} else if (amp_dtype_ == pten::DataType::BFLOAT16) {
new_ins = AutoCastBF16Inputs<VarType>(type, ins);
}
} else if (amp_level_ == AmpLevel::O2) { } else if (amp_level_ == AmpLevel::O2) {
VLOG(5) << "Pure fp16 run operator: " << type; VLOG(5) << "Pure fp16 run operator: " << type;
new_ins = CastPureFp16Inputs<VarType>(type, ins); if (amp_dtype_ == pten::DataType::FLOAT16) {
new_ins = CastPureFp16Inputs<VarType>(type, ins);
} else if (amp_dtype_ == pten::DataType::BFLOAT16) {
new_ins = CastPureBf16Inputs<VarType>(type, ins);
}
} }
try { try {
......
...@@ -34,6 +34,8 @@ namespace imperative { ...@@ -34,6 +34,8 @@ namespace imperative {
enum class AmpLevel; enum class AmpLevel;
enum class AmpDtype;
using GarbageCollectorMap = using GarbageCollectorMap =
std::map<platform::Place, std::map<platform::Place,
std::unique_ptr<paddle::framework::GarbageCollector>>; std::unique_ptr<paddle::framework::GarbageCollector>>;
...@@ -131,6 +133,27 @@ class Tracer { ...@@ -131,6 +133,27 @@ class Tracer {
AmpLevel GetAmpLevel() const { return amp_level_; } AmpLevel GetAmpLevel() const { return amp_level_; }
void SetAmpDtype(std::string amp_dtype) {
VLOG(4) << "set amp_dtype to " << amp_dtype;
if (amp_dtype == "float16") {
amp_dtype_ = pten::DataType::FLOAT16;
} else if (amp_dtype == "bfloat16") {
amp_dtype_ = pten::DataType::BFLOAT16;
} else {
amp_dtype_ = pten::DataType::FLOAT32;
}
}
std::string GetAmpDtype() const {
if (amp_dtype_ == pten::DataType::FLOAT16) {
return std::string("float16");
} else if (amp_dtype_ == pten::DataType::BFLOAT16) {
return std::string("bfloat16");
} else {
return std::string("float32");
}
}
paddle::framework::GarbageCollector* MutableGarbageCollectorIfNotExists( paddle::framework::GarbageCollector* MutableGarbageCollectorIfNotExists(
const platform::Place& place); const platform::Place& place);
...@@ -143,6 +166,7 @@ class Tracer { ...@@ -143,6 +166,7 @@ class Tracer {
GarbageCollectorMap gcs_; GarbageCollectorMap gcs_;
static thread_local bool has_grad_; static thread_local bool has_grad_;
static thread_local AmpLevel amp_level_; static thread_local AmpLevel amp_level_;
static thread_local pten::DataType amp_dtype_;
}; };
// To access static variable current_tracer // To access static variable current_tracer
......
...@@ -2230,6 +2230,8 @@ void BindImperative(py::module *m_ptr) { ...@@ -2230,6 +2230,8 @@ void BindImperative(py::module *m_ptr) {
&imperative::Tracer::SetEnableProgramDescTracing) &imperative::Tracer::SetEnableProgramDescTracing)
.def_property("_amp_level", &imperative::Tracer::GetAmpLevel, .def_property("_amp_level", &imperative::Tracer::GetAmpLevel,
&imperative::Tracer::SetAmpLevel) &imperative::Tracer::SetAmpLevel)
.def_property("_amp_dtype", &imperative::Tracer::GetAmpDtype,
&imperative::Tracer::SetAmpDtype)
.def_property("_has_grad", &imperative::Tracer::HasGrad, .def_property("_has_grad", &imperative::Tracer::HasGrad,
&imperative::Tracer::SetHasGrad) &imperative::Tracer::SetHasGrad)
.def_property( .def_property(
......
...@@ -359,6 +359,8 @@ struct ElementwiseAddTo<paddle::platform::CPUDeviceContext, T> { ...@@ -359,6 +359,8 @@ struct ElementwiseAddTo<paddle::platform::CPUDeviceContext, T> {
template struct ElementwiseAddTo<paddle::platform::CPUDeviceContext, template struct ElementwiseAddTo<paddle::platform::CPUDeviceContext,
pten::dtype::float16>; pten::dtype::float16>;
template struct ElementwiseAddTo<paddle::platform::CPUDeviceContext,
pten::dtype::bfloat16>;
} // namespace funcs } // namespace funcs
} // namespace pten } // namespace pten
...@@ -381,6 +381,8 @@ struct ElementwiseAddTo<paddle::platform::CUDADeviceContext, T> { ...@@ -381,6 +381,8 @@ struct ElementwiseAddTo<paddle::platform::CUDADeviceContext, T> {
template struct ElementwiseAddTo<paddle::platform::CUDADeviceContext, template struct ElementwiseAddTo<paddle::platform::CUDADeviceContext,
pten::dtype::float16>; pten::dtype::float16>;
template struct ElementwiseAddTo<paddle::platform::CUDADeviceContext,
pten::dtype::bfloat16>;
} // namespace funcs } // namespace funcs
} // namespace pten } // namespace pten
...@@ -21,7 +21,8 @@ __all__ = [] ...@@ -21,7 +21,8 @@ __all__ = []
def auto_cast(enable=True, def auto_cast(enable=True,
custom_white_list=None, custom_white_list=None,
custom_black_list=None, custom_black_list=None,
level='O1'): level='O1',
dtype='float16'):
""" """
Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode. Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode.
If enabled, the input data type (float32 or float16) of each operator is decided If enabled, the input data type (float32 or float16) of each operator is decided
...@@ -40,7 +41,8 @@ def auto_cast(enable=True, ...@@ -40,7 +41,8 @@ def auto_cast(enable=True,
observed in downstream ops. These ops will not be converted to fp16. observed in downstream ops. These ops will not be converted to fp16.
level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list; level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list;
O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don't support fp16 kernel and batchnorm. Default is O1(amp) O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don't support fp16 kernel and batchnorm. Default is O1(amp)
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -73,7 +75,7 @@ def auto_cast(enable=True, ...@@ -73,7 +75,7 @@ def auto_cast(enable=True,
print(d.dtype) # FP16 print(d.dtype) # FP16
""" """
return amp_guard(enable, custom_white_list, custom_black_list, level) return amp_guard(enable, custom_white_list, custom_black_list, level, dtype)
def decorate(models, def decorate(models,
......
...@@ -107,6 +107,15 @@ class RecomputeFunction(PyLayer): ...@@ -107,6 +107,15 @@ class RecomputeFunction(PyLayer):
else: else:
raise ValueError("unsupported amp level: {}".format( raise ValueError("unsupported amp level: {}".format(
tracer._amp_level)) tracer._amp_level))
if tracer._amp_dtype == 'float16':
ctx.amp_dtype = 'float16'
elif tracer._amp_dtype in ('bfloat16', 'float32'):
ctx.amp_dtype = 'bfloat16'
else:
raise ValueError("unsupported amp dtype: {}".format(
tracer._amp_dtype))
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
with paddle.no_grad(): with paddle.no_grad():
...@@ -137,7 +146,8 @@ class RecomputeFunction(PyLayer): ...@@ -137,7 +146,8 @@ class RecomputeFunction(PyLayer):
enable=ctx.is_fw_autocast, enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list, custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list, custom_black_list=ctx.amp_black_list,
level=ctx.amp_level): level=ctx.amp_level,
dtype=ctx.amp_dtype):
detached_inputs = detach_variable(tuple(inputs)) detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs) outputs = ctx.run_function(*detached_inputs)
else: else:
...@@ -145,7 +155,8 @@ class RecomputeFunction(PyLayer): ...@@ -145,7 +155,8 @@ class RecomputeFunction(PyLayer):
enable=ctx.is_fw_autocast, enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list, custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list, custom_black_list=ctx.amp_black_list,
level=ctx.amp_level): level=ctx.amp_level,
dtype=ctx.amp_dtype):
detached_inputs = detach_variable(tuple(inputs)) detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs) outputs = ctx.run_function(*detached_inputs)
......
...@@ -67,6 +67,9 @@ def convert_dtype(dtype): ...@@ -67,6 +67,9 @@ def convert_dtype(dtype):
# however, jointly supporting python2 and python3, (as well as python4 maybe) # however, jointly supporting python2 and python3, (as well as python4 maybe)
# may still be a long-lasting problem. # may still be a long-lasting problem.
return str(dtype) return str(dtype)
# NOTE(zhangbo): Now numpy does not support bfloat, and paddle use uint16 to represent bfloat16, and there binaries are consistent.
if dtype in ['bfloat16']:
return 'uint16'
raise TypeError( raise TypeError(
"dtype must be any of [bool, float16, uint16, float32, float64, int8, int16, " "dtype must be any of [bool, float16, uint16, float32, float64, int8, int16, "
......
...@@ -75,19 +75,29 @@ PURE_FP16_BLACK_LIST = { ...@@ -75,19 +75,29 @@ PURE_FP16_BLACK_LIST = {
'lookup_table', 'lookup_table_v2', 'scatter', 'scatter_grad' 'lookup_table', 'lookup_table_v2', 'scatter', 'scatter_grad'
} }
BF16_WHITE_LIST = {'conv2d'}
BF16_BLACK_LIST = {' '}
#NOTE(zhiqiu): similar as paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionLists._update_list #NOTE(zhiqiu): similar as paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionLists._update_list
# The reason why not use AutoMixedPrecisionLists is that custom_black_varnames is not suitable for imperative mode. # The reason why not use AutoMixedPrecisionLists is that custom_black_varnames is not suitable for imperative mode.
def _update_list(custom_white_list, custom_black_list, level='O1'): def _update_list(custom_white_list,
custom_black_list,
level='O1',
dtype='float16'):
""" """
Update black and white list according to users' custom list. Update black and white list according to users' custom list.
""" """
if level == 'O1': if dtype == 'float16':
_white_list = copy.copy(WHITE_LIST) if level == 'O1':
_black_list = copy.copy(BLACK_LIST) _white_list = copy.copy(WHITE_LIST)
_black_list = copy.copy(BLACK_LIST)
else:
_white_list = copy.copy(PURE_FP16_WHITE_LIST)
_black_list = copy.copy(PURE_FP16_BLACK_LIST)
else: else:
_white_list = copy.copy(PURE_FP16_WHITE_LIST) _white_list = copy.copy(BF16_WHITE_LIST)
_black_list = copy.copy(PURE_FP16_BLACK_LIST) _black_list = copy.copy(BF16_BLACK_LIST)
if custom_white_list and custom_black_list: if custom_white_list and custom_black_list:
for op_name in custom_white_list: for op_name in custom_white_list:
if op_name in custom_black_list: if op_name in custom_black_list:
...@@ -125,6 +135,27 @@ def _in_pure_fp16_guard(): ...@@ -125,6 +135,27 @@ def _in_pure_fp16_guard():
return tracer and tracer._amp_level == core.AmpLevel.O2 return tracer and tracer._amp_level == core.AmpLevel.O2
def _is_gpu_float16_supported():
"""
Judge whether current gpu support float16 amp.
"""
prop = paddle.device.cuda.get_device_capability()
return prop[0] >= 7
def _is_gpu_bfloat16_supported():
"""
Judge whether current gpu support bfloat16 amp.
"""
prop = paddle.device.cuda.get_device_capability()
cuda_version = paddle.version.cuda()
if cuda_version is not None:
cuda_version_check = int(cuda_version.split('.')[0]) >= 11
else:
cuda_version_check = False
return prop[0] >= 8 and cuda_version_check
@dygraph_only @dygraph_only
def pure_fp16_initialize(models): def pure_fp16_initialize(models):
for idx in range(len(models)): for idx in range(len(models)):
...@@ -165,7 +196,8 @@ def check_optimizers(optimizers): ...@@ -165,7 +196,8 @@ def check_optimizers(optimizers):
def amp_guard(enable=True, def amp_guard(enable=True,
custom_white_list=None, custom_white_list=None,
custom_black_list=None, custom_black_list=None,
level='O1'): level='O1',
dtype='float16'):
""" """
:api_attr: imperative :api_attr: imperative
...@@ -186,6 +218,7 @@ def amp_guard(enable=True, ...@@ -186,6 +218,7 @@ def amp_guard(enable=True,
observed in downstream ops. These ops will not be converted to fp16. observed in downstream ops. These ops will not be converted to fp16.
level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list; level(str, optional): Auto mixed precision level. Accepted values are "O1" and "O2": O1 represent mixed precision, the input data type of each operator will be casted by white_list and black_list;
O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don't support fp16 kernel and batchnorm. Default is O1(amp) O2 represent Pure fp16, all operators parameters and input data will be casted to fp16, except operators in black_list, don't support fp16 kernel and batchnorm. Default is O1(amp)
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
Examples: Examples:
...@@ -207,49 +240,88 @@ def amp_guard(enable=True, ...@@ -207,49 +240,88 @@ def amp_guard(enable=True,
print(conv.dtype) # FP32 print(conv.dtype) # FP32
""" """
# check amp_level: O0-O2
level = level.upper()
if not (level in ['O0', 'O1', 'O2']): if not (level in ['O0', 'O1', 'O2']):
raise ValueError( raise ValueError(
"level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16 train mode." "level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode."
) )
# check amp_dtype: float16 or bfloat16
dtype = dtype.lower()
if not (dtype in ['float16', 'bfloat16']):
raise ValueError("dtype should be 'float16' or 'bfloat16'.")
# check tracer
tracer = _dygraph_tracer() tracer = _dygraph_tracer()
if not tracer: if not tracer:
raise ValueError( raise ValueError(
"current_tracer is None, maybe it is not in imperative mode.") "current_tracer is None, maybe it is not in imperative mode.")
# check device_type:
# NOTE: Now, amp only support gpu for float16 and bfloat16, xpu for float16.
# Maybe we will support cpu for bfloat16.
if enable and not (tracer._expected_place.is_gpu_place() or if enable and not (tracer._expected_place.is_gpu_place() or
tracer._expected_place.is_xpu_place()): tracer._expected_place.is_xpu_place()):
warnings.warn( warnings.warn(
'amp_guard can only be enabled on CUDAPlace and XPUPlace, current place is %s, so it makes no effect.' 'amp_guard can only be enabled on CUDAPlace and XPUPlace, current place is %s, so it makes no effect.'
% tracer._expected_place) % tracer._expected_place)
enable = False enable = False
# For xpu:
if tracer._expected_place.is_xpu_place() and (dtype == 'bfloat16'):
warnings.warn('XPUPlace only support float16 amp.')
enable = False
# For gpu float16: Compute Capability should >= 7.
# For gpu bfloat16: Compute Capability should >= 8 & CUDA Version should >= 11.
if tracer._expected_place.is_gpu_place(): if tracer._expected_place.is_gpu_place():
prop = paddle.device.cuda.get_device_capability() if (dtype == 'float16') and not _is_gpu_float16_supported():
if prop[0] < 7: prop = paddle.device.cuda.get_device_capability()
warnings.warn( warnings.warn(
"AMP only support NVIDIA GPU with Compute Capability 7.0 or higher, current GPU is: %s, with Compute Capability: %d.%d." "For float16, amp only support NVIDIA GPU with Compute Capability 7.0 or higher, current GPU is: %s, with Compute Capability: %d.%d."
% (paddle.device.cuda.get_device_name(), prop[0], prop[1])) % (paddle.device.cuda.get_device_name(), prop[0], prop[1]))
elif (dtype == 'bfloat16') and not _is_gpu_bfloat16_supported():
prop = paddle.device.cuda.get_device_capability()
cuda_version = paddle.version.cuda()
warnings.warn(
"For bfloat16, amp only support NVIDIA GPU with Compute Capability 8.0 or higher and CUDA Version 11.0 or higher, current GPU is: %s, with Compute Capability: %d.%d, current CUDA Version is: %s."
% (paddle.device.cuda.get_device_name(), prop[0], prop[1],
cuda_version))
amp_dtype = dtype
if level == 'O1': if level == 'O1':
amp_level = AMP_LEVEL.O1 amp_level = AMP_LEVEL.O1
_white_list = WHITE_LIST if dtype == 'float16':
_black_list = BLACK_LIST _white_list = WHITE_LIST
_black_list = BLACK_LIST
elif dtype == 'bfloat16':
_white_list = BF16_WHITE_LIST
_black_list = BF16_BLACK_LIST
elif level == 'O2': elif level == 'O2':
amp_level = AMP_LEVEL.O2 amp_level = AMP_LEVEL.O2
_white_list = PURE_FP16_WHITE_LIST if dtype == 'float16':
_black_list = PURE_FP16_BLACK_LIST _white_list = PURE_FP16_WHITE_LIST
_black_list = PURE_FP16_BLACK_LIST
elif dtype == 'bfloat16':
_white_list = BF16_WHITE_LIST
_black_list = BF16_BLACK_LIST
elif level == 'O0': elif level == 'O0':
amp_level = AMP_LEVEL.O0 amp_level = AMP_LEVEL.O0
_white_list = WHITE_LIST if dtype == 'float16':
_black_list = BLACK_LIST _white_list = WHITE_LIST
_black_list = BLACK_LIST
elif dtype == 'bfloat16':
_white_list = BF16_WHITE_LIST
_black_list = BF16_BLACK_LIST
if custom_white_list or custom_black_list: if custom_white_list or custom_black_list:
_white_list, _black_list = _update_list(custom_white_list, _white_list, _black_list = _update_list(custom_white_list,
custom_black_list, level) custom_black_list, level, dtype)
if not enable: if not enable:
amp_level = AMP_LEVEL.O0 amp_level = AMP_LEVEL.O0
amp_dtype = "float32"
if tracer: if tracer:
# enable auto_cast # enable auto_cast
...@@ -268,6 +340,10 @@ def amp_guard(enable=True, ...@@ -268,6 +340,10 @@ def amp_guard(enable=True,
# original_flags = get_flags(AMP_RELATED_FLAGS) # original_flags = get_flags(AMP_RELATED_FLAGS)
# set_flags(AMP_RELATED_FLAGS_SETTING) # set_flags(AMP_RELATED_FLAGS_SETTING)
# set amp dtype
original_amp_dtype = tracer._amp_dtype
tracer._amp_dtype = amp_dtype
# restore status # restore status
try: try:
yield yield
...@@ -276,6 +352,7 @@ def amp_guard(enable=True, ...@@ -276,6 +352,7 @@ def amp_guard(enable=True,
tracer._amp_level = original_amp_level tracer._amp_level = original_amp_level
tracer._set_amp_op_list(original_white_list, original_black_list) tracer._set_amp_op_list(original_white_list, original_black_list)
# set_flags(original_flags) # set_flags(original_flags)
tracer._amp_dtype = original_amp_dtype
class StateDictHook(object): class StateDictHook(object):
......
...@@ -860,7 +860,6 @@ class TestPureFp16InferenceSaveLoad(unittest.TestCase): ...@@ -860,7 +860,6 @@ class TestPureFp16InferenceSaveLoad(unittest.TestCase):
results = exe.run(inference_program, results = exe.run(inference_program,
feed={feed_target_names[0]: tensor_img}, feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets) fetch_list=fetch_targets)
self.assertTrue(np.allclose(pred.numpy(), results, atol=1.e-5)) self.assertTrue(np.allclose(pred.numpy(), results, atol=1.e-5))
...@@ -1126,5 +1125,26 @@ class TestLayerNormFp16(unittest.TestCase): ...@@ -1126,5 +1125,26 @@ class TestLayerNormFp16(unittest.TestCase):
self.assertTrue(out.dtype == fluid.core.VarDesc.VarType.FP16) self.assertTrue(out.dtype == fluid.core.VarDesc.VarType.FP16)
class TestBf16(unittest.TestCase):
'''
test amp for BF16
'''
def train(self, enable_amp=True):
paddle.seed(100)
input = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.)
conv = paddle.nn.Conv2D(4, 6, (3, 3))
with paddle.amp.auto_cast(
enable=enable_amp, level='O2', dtype='bfloat16'):
output = conv(input)
output = output.cast('float32')
return output.numpy()
def test_bf16(self):
out_fp32 = self.train(enable_amp=False)
out_bf16 = self.train(enable_amp=True)
self.assertTrue(np.allclose(out_fp32, out_bf16, rtol=1.e-3, atol=1.e-2))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册