From 59e425cd2d8f2fdc331cc79e6c33726dfeec3249 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Wed, 13 Oct 2021 14:39:42 +0800 Subject: [PATCH] [Amp] refine code of amp level (#36362) * refine amp level * fix typo * update tracer._amp_level --- paddle/fluid/imperative/amp_auto_cast.cc | 13 +++++++++- paddle/fluid/imperative/amp_auto_cast.h | 24 +++++++++---------- paddle/fluid/imperative/tracer.cc | 4 ++-- paddle/fluid/imperative/tracer.h | 9 ++++--- paddle/fluid/pybind/imperative.cc | 11 +++++++-- .../fleet/meta_parallel/pp_utils/utils.py | 2 +- .../distributed/fleet/utils/recompute.py | 2 +- python/paddle/fluid/dygraph/amp/auto_cast.py | 10 ++++---- 8 files changed, 49 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/imperative/amp_auto_cast.cc b/paddle/fluid/imperative/amp_auto_cast.cc index 48e5e430b1..b0d86f6db9 100644 --- a/paddle/fluid/imperative/amp_auto_cast.cc +++ b/paddle/fluid/imperative/amp_auto_cast.cc @@ -24,6 +24,17 @@ namespace imperative { class VarBase; +AutoCastGuard::AutoCastGuard(std::shared_ptr tracer, AmpLevel level) + : tracer_(tracer) { + pre_amp_level_ = tracer_->GetAmpLevel(); + + if (pre_amp_level_ != level) { + tracer_->SetAmpLevel(level); + } +} + +AutoCastGuard::~AutoCastGuard() { tracer_->SetAmpLevel(pre_amp_level_); } + AmpOperators::AmpOperators() : allow_ops_(new std::unordered_set()), block_ops_(new std::unordered_set()), @@ -117,7 +128,7 @@ static inline std::shared_ptr CastToType( imperative::NameVarBaseMap outs = {{"Out", {out}}}; { - AutoCastGuard guard(tracer, 0); + AutoCastGuard guard(tracer, AmpLevel::O0); tracer->TraceOp("cast", ins, outs, std::move(attrs)); } diff --git a/paddle/fluid/imperative/amp_auto_cast.h b/paddle/fluid/imperative/amp_auto_cast.h index 79bc83a777..903e265288 100644 --- a/paddle/fluid/imperative/amp_auto_cast.h +++ b/paddle/fluid/imperative/amp_auto_cast.h @@ -19,15 +19,22 @@ #include #include -#include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/type_defs.h" namespace paddle { namespace imperative { -// Singleton implementation with C++ 11 +// NOTE(zhiqiu): only O1 and O2 are valid now +enum class AmpLevel { + O0 = 0, // fp32 + O1, // amp, mixed fp32-fp16 + O2, // almost fp16 + O3, // fp16 +}; + class Tracer; +// Singleton implementation with C++ 11 class AmpOperators { public: ~AmpOperators(); @@ -63,16 +70,9 @@ std::ostream& operator<<(std::ostream& os, AmpOperators& ops); // NOTE(zhiqiu): AutoCastGuard is used for RAII. class AutoCastGuard { public: - AutoCastGuard(std::shared_ptr tracer, int guard_level) - : tracer_(tracer) { - pre_amp_level_ = tracer_->AMPLevel(); - - if (pre_amp_level_ != guard_level) { - tracer_->SetAMPLevel(guard_level); - } - } + AutoCastGuard(std::shared_ptr tracer, AmpLevel guard_level); - ~AutoCastGuard() { tracer_->SetAMPLevel(pre_amp_level_); } + ~AutoCastGuard(); // forbid copy and operator= AutoCastGuard(const AutoCastGuard& guard) = delete; @@ -80,7 +80,7 @@ class AutoCastGuard { private: std::shared_ptr tracer_; - int pre_amp_level_; + AmpLevel pre_amp_level_; }; NameVarBaseMap AutoCastInputs(const std::string& op_type, diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 49e079c58c..0f363d0ea1 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -176,10 +176,10 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, : attr_checker->GetDefaultAttrMap(); NameVarBaseMap new_ins = ins; - if (amp_level_ == 1) { + if (amp_level_ == AmpLevel::O1) { VLOG(5) << "Auto mixed precision run operator: " << type; new_ins = AutoCastInputs(type, ins); - } else if (amp_level_ == 2) { + } else if (amp_level_ == AmpLevel::O2) { VLOG(5) << "Pure fp16 run operator: " << type; new_ins = CastPureFp16Inputs(type, ins); } diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index e77623d7a4..418b2069b5 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -23,6 +23,7 @@ #include #include "ThreadPool.h" #include "paddle/fluid/framework/garbage_collector.h" +#include "paddle/fluid/imperative/amp_auto_cast.h" #include "paddle/fluid/imperative/basic_engine.h" #include "paddle/fluid/imperative/jit/program_desc_tracer.h" #include "paddle/fluid/imperative/layer.h" @@ -31,6 +32,8 @@ namespace paddle { namespace imperative { +enum class AmpLevel; + using GarbageCollectorMap = std::map>; @@ -105,9 +108,9 @@ class Tracer { void SetHasGrad(bool has_grad) { has_grad_ = has_grad; } - void SetAMPLevel(int level) { amp_level_ = level; } + void SetAmpLevel(AmpLevel level) { amp_level_ = level; } - int AMPLevel() const { return amp_level_; } + AmpLevel GetAmpLevel() const { return amp_level_; } paddle::framework::GarbageCollector* MutableGarbageCollectorIfNotExists( const platform::Place& place); @@ -120,7 +123,7 @@ class Tracer { platform::Place expected_place_; GarbageCollectorMap gcs_; static thread_local bool has_grad_; - int amp_level_{0}; + AmpLevel amp_level_{AmpLevel::O0}; }; // To access static variable current_tracer diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 5aae05db8c..2e22ee9013 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -1940,6 +1940,13 @@ void BindImperative(py::module *m_ptr) { &imperative::jit::ProgramDescTracer::CreateProgramDesc) .def("reset", &imperative::jit::ProgramDescTracer::Reset); + py::enum_(m, "AmpLevel", py::arithmetic()) + .value("O0", paddle::imperative::AmpLevel::O0) + .value("O1", paddle::imperative::AmpLevel::O1) + .value("O2", paddle::imperative::AmpLevel::O2) + .value("O3", paddle::imperative::AmpLevel::O3) + .export_values(); + py::class_>( m, "Tracer", R"DOC()DOC") .def("__init__", @@ -1947,8 +1954,8 @@ void BindImperative(py::module *m_ptr) { .def_property("_enable_program_desc_tracing", &imperative::Tracer::IsProgramDescTracingEnabled, &imperative::Tracer::SetEnableProgramDescTracing) - .def_property("_amp_level", &imperative::Tracer::AMPLevel, - &imperative::Tracer::SetAMPLevel) + .def_property("_amp_level", &imperative::Tracer::GetAmpLevel, + &imperative::Tracer::SetAmpLevel) .def_property("_has_grad", &imperative::Tracer::HasGrad, &imperative::Tracer::SetHasGrad) .def_property( diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index b29b0b3e27..0826609654 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -198,7 +198,7 @@ class _HPRecomputeFunction(PyLayer): # TODO support AMP tracer = framework._dygraph_tracer() - if tracer._amp_level == 0: + if tracer._amp_level == core.AmpLevel.O0: ctx.is_fw_autocast = False else: ctx.is_fw_autocast = True diff --git a/python/paddle/distributed/fleet/utils/recompute.py b/python/paddle/distributed/fleet/utils/recompute.py index 302877e51f..56a64049b1 100755 --- a/python/paddle/distributed/fleet/utils/recompute.py +++ b/python/paddle/distributed/fleet/utils/recompute.py @@ -98,7 +98,7 @@ class RecomputeFunction(PyLayer): # TODO support AMP tracer = framework._dygraph_tracer() - if tracer._amp_level == 0: + if tracer._amp_level == core.AmpLevel.O0: ctx.is_fw_autocast = False else: ctx.is_fw_autocast = True diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index 0d02a383c1..d218e6b749 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -24,6 +24,8 @@ import paddle import operator import types +AMP_LEVEL = core.AmpLevel + __all__ = ['amp_guard', 'amp_decorate'] # The set of ops that support fp16 calculation and are considered numerically- @@ -108,7 +110,7 @@ def _in_amp_guard(): """ tracer = _dygraph_tracer() if tracer: - if tracer._amp_level == 1: + if tracer._amp_level == core.AmpLevel.O1: return True else: return False @@ -251,11 +253,11 @@ def amp_guard(enable=True, enable = False if level == 'O1': - amp_level = 1 + amp_level = AMP_LEVEL.O1 _white_list = WHITE_LIST _black_list = BLACK_LIST else: - amp_level = 2 + amp_level = AMP_LEVEL.O2 _white_list = PURE_FP16_WHITE_LIST _black_list = PURE_FP16_BLACK_LIST @@ -264,7 +266,7 @@ def amp_guard(enable=True, custom_black_list, level) if not enable: - amp_level = 0 + amp_level = AMP_LEVEL.O0 if tracer: # enable auto_cast -- GitLab