Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1ee4fc32
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1ee4fc32
编写于
10月 26, 2021
作者:
L
Leo Chen
提交者:
GitHub
10月 26, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Amp] refine code of amp level (#36362) (#36726)
* refine amp level * fix typo * update tracer._amp_level
上级
53480c9c
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
49 addition
and
26 deletion
+49
-26
paddle/fluid/imperative/amp_auto_cast.cc
paddle/fluid/imperative/amp_auto_cast.cc
+12
-1
paddle/fluid/imperative/amp_auto_cast.h
paddle/fluid/imperative/amp_auto_cast.h
+12
-12
paddle/fluid/imperative/tracer.cc
paddle/fluid/imperative/tracer.cc
+2
-2
paddle/fluid/imperative/tracer.h
paddle/fluid/imperative/tracer.h
+6
-3
paddle/fluid/pybind/imperative.cc
paddle/fluid/pybind/imperative.cc
+9
-2
python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py
.../paddle/distributed/fleet/meta_parallel/pp_utils/utils.py
+1
-1
python/paddle/distributed/fleet/utils/recompute.py
python/paddle/distributed/fleet/utils/recompute.py
+1
-1
python/paddle/fluid/dygraph/amp/auto_cast.py
python/paddle/fluid/dygraph/amp/auto_cast.py
+6
-4
未找到文件。
paddle/fluid/imperative/amp_auto_cast.cc
浏览文件 @
1ee4fc32
...
...
@@ -24,6 +24,17 @@ namespace imperative {
class
VarBase
;
AutoCastGuard
::
AutoCastGuard
(
std
::
shared_ptr
<
Tracer
>
tracer
,
AmpLevel
level
)
:
tracer_
(
tracer
)
{
pre_amp_level_
=
tracer_
->
GetAmpLevel
();
if
(
pre_amp_level_
!=
level
)
{
tracer_
->
SetAmpLevel
(
level
);
}
}
AutoCastGuard
::~
AutoCastGuard
()
{
tracer_
->
SetAmpLevel
(
pre_amp_level_
);
}
AmpOperators
::
AmpOperators
()
:
allow_ops_
(
new
std
::
unordered_set
<
std
::
string
>
()),
block_ops_
(
new
std
::
unordered_set
<
std
::
string
>
()),
...
...
@@ -117,7 +128,7 @@ static inline std::shared_ptr<imperative::VarBase> CastToType(
imperative
::
NameVarBaseMap
outs
=
{{
"Out"
,
{
out
}}};
{
AutoCastGuard
guard
(
tracer
,
0
);
AutoCastGuard
guard
(
tracer
,
AmpLevel
::
O
0
);
tracer
->
TraceOp
(
"cast"
,
ins
,
outs
,
std
::
move
(
attrs
));
}
...
...
paddle/fluid/imperative/amp_auto_cast.h
浏览文件 @
1ee4fc32
...
...
@@ -19,15 +19,22 @@
#include <tuple>
#include <unordered_set>
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/type_defs.h"
namespace
paddle
{
namespace
imperative
{
// Singleton implementation with C++ 11
// NOTE(zhiqiu): only O1 and O2 are valid now
enum
class
AmpLevel
{
O0
=
0
,
// fp32
O1
,
// amp, mixed fp32-fp16
O2
,
// almost fp16
O3
,
// fp16
};
class
Tracer
;
// Singleton implementation with C++ 11
class
AmpOperators
{
public:
~
AmpOperators
();
...
...
@@ -63,16 +70,9 @@ std::ostream& operator<<(std::ostream& os, AmpOperators& ops);
// NOTE(zhiqiu): AutoCastGuard is used for RAII.
class
AutoCastGuard
{
public:
AutoCastGuard
(
std
::
shared_ptr
<
Tracer
>
tracer
,
int
guard_level
)
:
tracer_
(
tracer
)
{
pre_amp_level_
=
tracer_
->
AMPLevel
();
if
(
pre_amp_level_
!=
guard_level
)
{
tracer_
->
SetAMPLevel
(
guard_level
);
}
}
AutoCastGuard
(
std
::
shared_ptr
<
Tracer
>
tracer
,
AmpLevel
guard_level
);
~
AutoCastGuard
()
{
tracer_
->
SetAMPLevel
(
pre_amp_level_
);
}
~
AutoCastGuard
()
;
// forbid copy and operator=
AutoCastGuard
(
const
AutoCastGuard
&
guard
)
=
delete
;
...
...
@@ -80,7 +80,7 @@ class AutoCastGuard {
private:
std
::
shared_ptr
<
Tracer
>
tracer_
;
int
pre_amp_level_
;
AmpLevel
pre_amp_level_
;
};
NameVarBaseMap
AutoCastInputs
(
const
std
::
string
&
op_type
,
...
...
paddle/fluid/imperative/tracer.cc
浏览文件 @
1ee4fc32
...
...
@@ -176,10 +176,10 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
:
attr_checker
->
GetDefaultAttrMap
();
NameVarBaseMap
new_ins
=
ins
;
if
(
amp_level_
==
1
)
{
if
(
amp_level_
==
AmpLevel
::
O
1
)
{
VLOG
(
5
)
<<
"Auto mixed precision run operator: "
<<
type
;
new_ins
=
AutoCastInputs
(
type
,
ins
);
}
else
if
(
amp_level_
==
2
)
{
}
else
if
(
amp_level_
==
AmpLevel
::
O
2
)
{
VLOG
(
5
)
<<
"Pure fp16 run operator: "
<<
type
;
new_ins
=
CastPureFp16Inputs
(
type
,
ins
);
}
...
...
paddle/fluid/imperative/tracer.h
浏览文件 @
1ee4fc32
...
...
@@ -23,6 +23,7 @@
#include <vector>
#include "ThreadPool.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/imperative/amp_auto_cast.h"
#include "paddle/fluid/imperative/basic_engine.h"
#include "paddle/fluid/imperative/jit/program_desc_tracer.h"
#include "paddle/fluid/imperative/layer.h"
...
...
@@ -31,6 +32,8 @@
namespace
paddle
{
namespace
imperative
{
enum
class
AmpLevel
;
using
GarbageCollectorMap
=
std
::
map
<
platform
::
Place
,
std
::
unique_ptr
<
paddle
::
framework
::
GarbageCollector
>>
;
...
...
@@ -105,9 +108,9 @@ class Tracer {
void
SetHasGrad
(
bool
has_grad
)
{
has_grad_
=
has_grad
;
}
void
SetA
MPLevel
(
int
level
)
{
amp_level_
=
level
;
}
void
SetA
mpLevel
(
AmpLevel
level
)
{
amp_level_
=
level
;
}
int
AMP
Level
()
const
{
return
amp_level_
;
}
AmpLevel
GetAmp
Level
()
const
{
return
amp_level_
;
}
paddle
::
framework
::
GarbageCollector
*
MutableGarbageCollectorIfNotExists
(
const
platform
::
Place
&
place
);
...
...
@@ -120,7 +123,7 @@ class Tracer {
platform
::
Place
expected_place_
;
GarbageCollectorMap
gcs_
;
static
thread_local
bool
has_grad_
;
int
amp_level_
{
0
};
AmpLevel
amp_level_
{
AmpLevel
::
O
0
};
};
// To access static variable current_tracer
...
...
paddle/fluid/pybind/imperative.cc
浏览文件 @
1ee4fc32
...
...
@@ -1940,6 +1940,13 @@ void BindImperative(py::module *m_ptr) {
&
imperative
::
jit
::
ProgramDescTracer
::
CreateProgramDesc
)
.
def
(
"reset"
,
&
imperative
::
jit
::
ProgramDescTracer
::
Reset
);
py
::
enum_
<
paddle
::
imperative
::
AmpLevel
>
(
m
,
"AmpLevel"
,
py
::
arithmetic
())
.
value
(
"O0"
,
paddle
::
imperative
::
AmpLevel
::
O0
)
.
value
(
"O1"
,
paddle
::
imperative
::
AmpLevel
::
O1
)
.
value
(
"O2"
,
paddle
::
imperative
::
AmpLevel
::
O2
)
.
value
(
"O3"
,
paddle
::
imperative
::
AmpLevel
::
O3
)
.
export_values
();
py
::
class_
<
imperative
::
Tracer
,
std
::
shared_ptr
<
imperative
::
Tracer
>>
(
m
,
"Tracer"
,
R"DOC()DOC"
)
.
def
(
"__init__"
,
...
...
@@ -1947,8 +1954,8 @@ void BindImperative(py::module *m_ptr) {
.
def_property
(
"_enable_program_desc_tracing"
,
&
imperative
::
Tracer
::
IsProgramDescTracingEnabled
,
&
imperative
::
Tracer
::
SetEnableProgramDescTracing
)
.
def_property
(
"_amp_level"
,
&
imperative
::
Tracer
::
AMP
Level
,
&
imperative
::
Tracer
::
SetA
MP
Level
)
.
def_property
(
"_amp_level"
,
&
imperative
::
Tracer
::
GetAmp
Level
,
&
imperative
::
Tracer
::
SetA
mp
Level
)
.
def_property
(
"_has_grad"
,
&
imperative
::
Tracer
::
HasGrad
,
&
imperative
::
Tracer
::
SetHasGrad
)
.
def_property
(
...
...
python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py
浏览文件 @
1ee4fc32
...
...
@@ -198,7 +198,7 @@ class _HPRecomputeFunction(PyLayer):
# TODO support AMP
tracer
=
framework
.
_dygraph_tracer
()
if
tracer
.
_amp_level
==
0
:
if
tracer
.
_amp_level
==
core
.
AmpLevel
.
O
0
:
ctx
.
is_fw_autocast
=
False
else
:
ctx
.
is_fw_autocast
=
True
...
...
python/paddle/distributed/fleet/utils/recompute.py
浏览文件 @
1ee4fc32
...
...
@@ -98,7 +98,7 @@ class RecomputeFunction(PyLayer):
# TODO support AMP
tracer
=
framework
.
_dygraph_tracer
()
if
tracer
.
_amp_level
==
0
:
if
tracer
.
_amp_level
==
core
.
AmpLevel
.
O
0
:
ctx
.
is_fw_autocast
=
False
else
:
ctx
.
is_fw_autocast
=
True
...
...
python/paddle/fluid/dygraph/amp/auto_cast.py
浏览文件 @
1ee4fc32
...
...
@@ -24,6 +24,8 @@ import paddle
import
operator
import
types
AMP_LEVEL
=
core
.
AmpLevel
__all__
=
[
'amp_guard'
,
'amp_decorate'
]
# The set of ops that support fp16 calculation and are considered numerically-
...
...
@@ -108,7 +110,7 @@ def _in_amp_guard():
"""
tracer
=
_dygraph_tracer
()
if
tracer
:
if
tracer
.
_amp_level
==
1
:
if
tracer
.
_amp_level
==
core
.
AmpLevel
.
O
1
:
return
True
else
:
return
False
...
...
@@ -251,11 +253,11 @@ def amp_guard(enable=True,
enable
=
False
if
level
==
'O1'
:
amp_level
=
1
amp_level
=
AMP_LEVEL
.
O
1
_white_list
=
WHITE_LIST
_black_list
=
BLACK_LIST
else
:
amp_level
=
2
amp_level
=
AMP_LEVEL
.
O
2
_white_list
=
PURE_FP16_WHITE_LIST
_black_list
=
PURE_FP16_BLACK_LIST
...
...
@@ -264,7 +266,7 @@ def amp_guard(enable=True,
custom_black_list
,
level
)
if
not
enable
:
amp_level
=
0
amp_level
=
AMP_LEVEL
.
O
0
if
tracer
:
# enable auto_cast
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录