未验证 提交 59e425cd 编写于 作者: L Leo Chen 提交者: GitHub

[Amp] refine code of amp level (#36362)

* refine amp level

* fix typo

* update tracer._amp_level
上级 e051bba0
......@@ -24,6 +24,17 @@ namespace imperative {
class VarBase;
AutoCastGuard::AutoCastGuard(std::shared_ptr<Tracer> tracer, AmpLevel level)
: tracer_(tracer) {
pre_amp_level_ = tracer_->GetAmpLevel();
if (pre_amp_level_ != level) {
tracer_->SetAmpLevel(level);
}
}
AutoCastGuard::~AutoCastGuard() { tracer_->SetAmpLevel(pre_amp_level_); }
AmpOperators::AmpOperators()
: allow_ops_(new std::unordered_set<std::string>()),
block_ops_(new std::unordered_set<std::string>()),
......@@ -117,7 +128,7 @@ static inline std::shared_ptr<imperative::VarBase> CastToType(
imperative::NameVarBaseMap outs = {{"Out", {out}}};
{
AutoCastGuard guard(tracer, 0);
AutoCastGuard guard(tracer, AmpLevel::O0);
tracer->TraceOp("cast", ins, outs, std::move(attrs));
}
......
......@@ -19,15 +19,22 @@
#include <tuple>
#include <unordered_set>
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/type_defs.h"
namespace paddle {
namespace imperative {
// Singleton implementation with C++ 11
// NOTE(zhiqiu): only O1 and O2 are valid now
enum class AmpLevel {
O0 = 0, // fp32
O1, // amp, mixed fp32-fp16
O2, // almost fp16
O3, // fp16
};
class Tracer;
// Singleton implementation with C++ 11
class AmpOperators {
public:
~AmpOperators();
......@@ -63,16 +70,9 @@ std::ostream& operator<<(std::ostream& os, AmpOperators& ops);
// NOTE(zhiqiu): AutoCastGuard is used for RAII.
class AutoCastGuard {
public:
AutoCastGuard(std::shared_ptr<Tracer> tracer, int guard_level)
: tracer_(tracer) {
pre_amp_level_ = tracer_->AMPLevel();
if (pre_amp_level_ != guard_level) {
tracer_->SetAMPLevel(guard_level);
}
}
AutoCastGuard(std::shared_ptr<Tracer> tracer, AmpLevel guard_level);
~AutoCastGuard() { tracer_->SetAMPLevel(pre_amp_level_); }
~AutoCastGuard();
// forbid copy and operator=
AutoCastGuard(const AutoCastGuard& guard) = delete;
......@@ -80,7 +80,7 @@ class AutoCastGuard {
private:
std::shared_ptr<Tracer> tracer_;
int pre_amp_level_;
AmpLevel pre_amp_level_;
};
NameVarBaseMap AutoCastInputs(const std::string& op_type,
......
......@@ -176,10 +176,10 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
: attr_checker->GetDefaultAttrMap();
NameVarBaseMap new_ins = ins;
if (amp_level_ == 1) {
if (amp_level_ == AmpLevel::O1) {
VLOG(5) << "Auto mixed precision run operator: " << type;
new_ins = AutoCastInputs(type, ins);
} else if (amp_level_ == 2) {
} else if (amp_level_ == AmpLevel::O2) {
VLOG(5) << "Pure fp16 run operator: " << type;
new_ins = CastPureFp16Inputs(type, ins);
}
......
......@@ -23,6 +23,7 @@
#include <vector>
#include "ThreadPool.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/imperative/amp_auto_cast.h"
#include "paddle/fluid/imperative/basic_engine.h"
#include "paddle/fluid/imperative/jit/program_desc_tracer.h"
#include "paddle/fluid/imperative/layer.h"
......@@ -31,6 +32,8 @@
namespace paddle {
namespace imperative {
enum class AmpLevel;
using GarbageCollectorMap =
std::map<platform::Place,
std::unique_ptr<paddle::framework::GarbageCollector>>;
......@@ -105,9 +108,9 @@ class Tracer {
void SetHasGrad(bool has_grad) { has_grad_ = has_grad; }
void SetAMPLevel(int level) { amp_level_ = level; }
void SetAmpLevel(AmpLevel level) { amp_level_ = level; }
int AMPLevel() const { return amp_level_; }
AmpLevel GetAmpLevel() const { return amp_level_; }
paddle::framework::GarbageCollector* MutableGarbageCollectorIfNotExists(
const platform::Place& place);
......@@ -120,7 +123,7 @@ class Tracer {
platform::Place expected_place_;
GarbageCollectorMap gcs_;
static thread_local bool has_grad_;
int amp_level_{0};
AmpLevel amp_level_{AmpLevel::O0};
};
// To access static variable current_tracer
......
......@@ -1940,6 +1940,13 @@ void BindImperative(py::module *m_ptr) {
&imperative::jit::ProgramDescTracer::CreateProgramDesc)
.def("reset", &imperative::jit::ProgramDescTracer::Reset);
py::enum_<paddle::imperative::AmpLevel>(m, "AmpLevel", py::arithmetic())
.value("O0", paddle::imperative::AmpLevel::O0)
.value("O1", paddle::imperative::AmpLevel::O1)
.value("O2", paddle::imperative::AmpLevel::O2)
.value("O3", paddle::imperative::AmpLevel::O3)
.export_values();
py::class_<imperative::Tracer, std::shared_ptr<imperative::Tracer>>(
m, "Tracer", R"DOC()DOC")
.def("__init__",
......@@ -1947,8 +1954,8 @@ void BindImperative(py::module *m_ptr) {
.def_property("_enable_program_desc_tracing",
&imperative::Tracer::IsProgramDescTracingEnabled,
&imperative::Tracer::SetEnableProgramDescTracing)
.def_property("_amp_level", &imperative::Tracer::AMPLevel,
&imperative::Tracer::SetAMPLevel)
.def_property("_amp_level", &imperative::Tracer::GetAmpLevel,
&imperative::Tracer::SetAmpLevel)
.def_property("_has_grad", &imperative::Tracer::HasGrad,
&imperative::Tracer::SetHasGrad)
.def_property(
......
......@@ -198,7 +198,7 @@ class _HPRecomputeFunction(PyLayer):
# TODO support AMP
tracer = framework._dygraph_tracer()
if tracer._amp_level == 0:
if tracer._amp_level == core.AmpLevel.O0:
ctx.is_fw_autocast = False
else:
ctx.is_fw_autocast = True
......
......@@ -98,7 +98,7 @@ class RecomputeFunction(PyLayer):
# TODO support AMP
tracer = framework._dygraph_tracer()
if tracer._amp_level == 0:
if tracer._amp_level == core.AmpLevel.O0:
ctx.is_fw_autocast = False
else:
ctx.is_fw_autocast = True
......
......@@ -24,6 +24,8 @@ import paddle
import operator
import types
AMP_LEVEL = core.AmpLevel
__all__ = ['amp_guard', 'amp_decorate']
# The set of ops that support fp16 calculation and are considered numerically-
......@@ -108,7 +110,7 @@ def _in_amp_guard():
"""
tracer = _dygraph_tracer()
if tracer:
if tracer._amp_level == 1:
if tracer._amp_level == core.AmpLevel.O1:
return True
else:
return False
......@@ -251,11 +253,11 @@ def amp_guard(enable=True,
enable = False
if level == 'O1':
amp_level = 1
amp_level = AMP_LEVEL.O1
_white_list = WHITE_LIST
_black_list = BLACK_LIST
else:
amp_level = 2
amp_level = AMP_LEVEL.O2
_white_list = PURE_FP16_WHITE_LIST
_black_list = PURE_FP16_BLACK_LIST
......@@ -264,7 +266,7 @@ def amp_guard(enable=True,
custom_black_list, level)
if not enable:
amp_level = 0
amp_level = AMP_LEVEL.O0
if tracer:
# enable auto_cast
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册