未验证 提交 59e425cd 编写于 作者: L Leo Chen 提交者: GitHub

[Amp] refine code of amp level (#36362)

* refine amp level

* fix typo

* update tracer._amp_level
上级 e051bba0
...@@ -24,6 +24,17 @@ namespace imperative { ...@@ -24,6 +24,17 @@ namespace imperative {
class VarBase; class VarBase;
AutoCastGuard::AutoCastGuard(std::shared_ptr<Tracer> tracer, AmpLevel level)
: tracer_(tracer) {
pre_amp_level_ = tracer_->GetAmpLevel();
if (pre_amp_level_ != level) {
tracer_->SetAmpLevel(level);
}
}
AutoCastGuard::~AutoCastGuard() { tracer_->SetAmpLevel(pre_amp_level_); }
AmpOperators::AmpOperators() AmpOperators::AmpOperators()
: allow_ops_(new std::unordered_set<std::string>()), : allow_ops_(new std::unordered_set<std::string>()),
block_ops_(new std::unordered_set<std::string>()), block_ops_(new std::unordered_set<std::string>()),
...@@ -117,7 +128,7 @@ static inline std::shared_ptr<imperative::VarBase> CastToType( ...@@ -117,7 +128,7 @@ static inline std::shared_ptr<imperative::VarBase> CastToType(
imperative::NameVarBaseMap outs = {{"Out", {out}}}; imperative::NameVarBaseMap outs = {{"Out", {out}}};
{ {
AutoCastGuard guard(tracer, 0); AutoCastGuard guard(tracer, AmpLevel::O0);
tracer->TraceOp("cast", ins, outs, std::move(attrs)); tracer->TraceOp("cast", ins, outs, std::move(attrs));
} }
......
...@@ -19,15 +19,22 @@ ...@@ -19,15 +19,22 @@
#include <tuple> #include <tuple>
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/type_defs.h"
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
// Singleton implementation with C++ 11 // NOTE(zhiqiu): only O1 and O2 are valid now
enum class AmpLevel {
O0 = 0, // fp32
O1, // amp, mixed fp32-fp16
O2, // almost fp16
O3, // fp16
};
class Tracer; class Tracer;
// Singleton implementation with C++ 11
class AmpOperators { class AmpOperators {
public: public:
~AmpOperators(); ~AmpOperators();
...@@ -63,16 +70,9 @@ std::ostream& operator<<(std::ostream& os, AmpOperators& ops); ...@@ -63,16 +70,9 @@ std::ostream& operator<<(std::ostream& os, AmpOperators& ops);
// NOTE(zhiqiu): AutoCastGuard is used for RAII. // NOTE(zhiqiu): AutoCastGuard is used for RAII.
class AutoCastGuard { class AutoCastGuard {
public: public:
AutoCastGuard(std::shared_ptr<Tracer> tracer, int guard_level) AutoCastGuard(std::shared_ptr<Tracer> tracer, AmpLevel guard_level);
: tracer_(tracer) {
pre_amp_level_ = tracer_->AMPLevel();
if (pre_amp_level_ != guard_level) {
tracer_->SetAMPLevel(guard_level);
}
}
~AutoCastGuard() { tracer_->SetAMPLevel(pre_amp_level_); } ~AutoCastGuard();
// forbid copy and operator= // forbid copy and operator=
AutoCastGuard(const AutoCastGuard& guard) = delete; AutoCastGuard(const AutoCastGuard& guard) = delete;
...@@ -80,7 +80,7 @@ class AutoCastGuard { ...@@ -80,7 +80,7 @@ class AutoCastGuard {
private: private:
std::shared_ptr<Tracer> tracer_; std::shared_ptr<Tracer> tracer_;
int pre_amp_level_; AmpLevel pre_amp_level_;
}; };
NameVarBaseMap AutoCastInputs(const std::string& op_type, NameVarBaseMap AutoCastInputs(const std::string& op_type,
......
...@@ -176,10 +176,10 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, ...@@ -176,10 +176,10 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
: attr_checker->GetDefaultAttrMap(); : attr_checker->GetDefaultAttrMap();
NameVarBaseMap new_ins = ins; NameVarBaseMap new_ins = ins;
if (amp_level_ == 1) { if (amp_level_ == AmpLevel::O1) {
VLOG(5) << "Auto mixed precision run operator: " << type; VLOG(5) << "Auto mixed precision run operator: " << type;
new_ins = AutoCastInputs(type, ins); new_ins = AutoCastInputs(type, ins);
} else if (amp_level_ == 2) { } else if (amp_level_ == AmpLevel::O2) {
VLOG(5) << "Pure fp16 run operator: " << type; VLOG(5) << "Pure fp16 run operator: " << type;
new_ins = CastPureFp16Inputs(type, ins); new_ins = CastPureFp16Inputs(type, ins);
} }
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include <vector> #include <vector>
#include "ThreadPool.h" #include "ThreadPool.h"
#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/imperative/amp_auto_cast.h"
#include "paddle/fluid/imperative/basic_engine.h" #include "paddle/fluid/imperative/basic_engine.h"
#include "paddle/fluid/imperative/jit/program_desc_tracer.h" #include "paddle/fluid/imperative/jit/program_desc_tracer.h"
#include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/layer.h"
...@@ -31,6 +32,8 @@ ...@@ -31,6 +32,8 @@
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
enum class AmpLevel;
using GarbageCollectorMap = using GarbageCollectorMap =
std::map<platform::Place, std::map<platform::Place,
std::unique_ptr<paddle::framework::GarbageCollector>>; std::unique_ptr<paddle::framework::GarbageCollector>>;
...@@ -105,9 +108,9 @@ class Tracer { ...@@ -105,9 +108,9 @@ class Tracer {
void SetHasGrad(bool has_grad) { has_grad_ = has_grad; } void SetHasGrad(bool has_grad) { has_grad_ = has_grad; }
void SetAMPLevel(int level) { amp_level_ = level; } void SetAmpLevel(AmpLevel level) { amp_level_ = level; }
int AMPLevel() const { return amp_level_; } AmpLevel GetAmpLevel() const { return amp_level_; }
paddle::framework::GarbageCollector* MutableGarbageCollectorIfNotExists( paddle::framework::GarbageCollector* MutableGarbageCollectorIfNotExists(
const platform::Place& place); const platform::Place& place);
...@@ -120,7 +123,7 @@ class Tracer { ...@@ -120,7 +123,7 @@ class Tracer {
platform::Place expected_place_; platform::Place expected_place_;
GarbageCollectorMap gcs_; GarbageCollectorMap gcs_;
static thread_local bool has_grad_; static thread_local bool has_grad_;
int amp_level_{0}; AmpLevel amp_level_{AmpLevel::O0};
}; };
// To access static variable current_tracer // To access static variable current_tracer
......
...@@ -1940,6 +1940,13 @@ void BindImperative(py::module *m_ptr) { ...@@ -1940,6 +1940,13 @@ void BindImperative(py::module *m_ptr) {
&imperative::jit::ProgramDescTracer::CreateProgramDesc) &imperative::jit::ProgramDescTracer::CreateProgramDesc)
.def("reset", &imperative::jit::ProgramDescTracer::Reset); .def("reset", &imperative::jit::ProgramDescTracer::Reset);
py::enum_<paddle::imperative::AmpLevel>(m, "AmpLevel", py::arithmetic())
.value("O0", paddle::imperative::AmpLevel::O0)
.value("O1", paddle::imperative::AmpLevel::O1)
.value("O2", paddle::imperative::AmpLevel::O2)
.value("O3", paddle::imperative::AmpLevel::O3)
.export_values();
py::class_<imperative::Tracer, std::shared_ptr<imperative::Tracer>>( py::class_<imperative::Tracer, std::shared_ptr<imperative::Tracer>>(
m, "Tracer", R"DOC()DOC") m, "Tracer", R"DOC()DOC")
.def("__init__", .def("__init__",
...@@ -1947,8 +1954,8 @@ void BindImperative(py::module *m_ptr) { ...@@ -1947,8 +1954,8 @@ void BindImperative(py::module *m_ptr) {
.def_property("_enable_program_desc_tracing", .def_property("_enable_program_desc_tracing",
&imperative::Tracer::IsProgramDescTracingEnabled, &imperative::Tracer::IsProgramDescTracingEnabled,
&imperative::Tracer::SetEnableProgramDescTracing) &imperative::Tracer::SetEnableProgramDescTracing)
.def_property("_amp_level", &imperative::Tracer::AMPLevel, .def_property("_amp_level", &imperative::Tracer::GetAmpLevel,
&imperative::Tracer::SetAMPLevel) &imperative::Tracer::SetAmpLevel)
.def_property("_has_grad", &imperative::Tracer::HasGrad, .def_property("_has_grad", &imperative::Tracer::HasGrad,
&imperative::Tracer::SetHasGrad) &imperative::Tracer::SetHasGrad)
.def_property( .def_property(
......
...@@ -198,7 +198,7 @@ class _HPRecomputeFunction(PyLayer): ...@@ -198,7 +198,7 @@ class _HPRecomputeFunction(PyLayer):
# TODO support AMP # TODO support AMP
tracer = framework._dygraph_tracer() tracer = framework._dygraph_tracer()
if tracer._amp_level == 0: if tracer._amp_level == core.AmpLevel.O0:
ctx.is_fw_autocast = False ctx.is_fw_autocast = False
else: else:
ctx.is_fw_autocast = True ctx.is_fw_autocast = True
......
...@@ -98,7 +98,7 @@ class RecomputeFunction(PyLayer): ...@@ -98,7 +98,7 @@ class RecomputeFunction(PyLayer):
# TODO support AMP # TODO support AMP
tracer = framework._dygraph_tracer() tracer = framework._dygraph_tracer()
if tracer._amp_level == 0: if tracer._amp_level == core.AmpLevel.O0:
ctx.is_fw_autocast = False ctx.is_fw_autocast = False
else: else:
ctx.is_fw_autocast = True ctx.is_fw_autocast = True
......
...@@ -24,6 +24,8 @@ import paddle ...@@ -24,6 +24,8 @@ import paddle
import operator import operator
import types import types
AMP_LEVEL = core.AmpLevel
__all__ = ['amp_guard', 'amp_decorate'] __all__ = ['amp_guard', 'amp_decorate']
# The set of ops that support fp16 calculation and are considered numerically- # The set of ops that support fp16 calculation and are considered numerically-
...@@ -108,7 +110,7 @@ def _in_amp_guard(): ...@@ -108,7 +110,7 @@ def _in_amp_guard():
""" """
tracer = _dygraph_tracer() tracer = _dygraph_tracer()
if tracer: if tracer:
if tracer._amp_level == 1: if tracer._amp_level == core.AmpLevel.O1:
return True return True
else: else:
return False return False
...@@ -251,11 +253,11 @@ def amp_guard(enable=True, ...@@ -251,11 +253,11 @@ def amp_guard(enable=True,
enable = False enable = False
if level == 'O1': if level == 'O1':
amp_level = 1 amp_level = AMP_LEVEL.O1
_white_list = WHITE_LIST _white_list = WHITE_LIST
_black_list = BLACK_LIST _black_list = BLACK_LIST
else: else:
amp_level = 2 amp_level = AMP_LEVEL.O2
_white_list = PURE_FP16_WHITE_LIST _white_list = PURE_FP16_WHITE_LIST
_black_list = PURE_FP16_BLACK_LIST _black_list = PURE_FP16_BLACK_LIST
...@@ -264,7 +266,7 @@ def amp_guard(enable=True, ...@@ -264,7 +266,7 @@ def amp_guard(enable=True,
custom_black_list, level) custom_black_list, level)
if not enable: if not enable:
amp_level = 0 amp_level = AMP_LEVEL.O0
if tracer: if tracer:
# enable auto_cast # enable auto_cast
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册