Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6f3c9643
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6f3c9643
编写于
4月 14, 2023
作者:
J
JZ-LIANG
提交者:
GitHub
4月 14, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Eb118 BF16 Adoption (#52827)
* pr1 * pr2 * pr3 * fixed unitest * adopt for scale
上级
8cbc75ca
变更
11
展开全部
显示空白变更内容
内联
并排
Showing
11 changed file
with
1878 addition
and
970 deletion
+1878
-970
python/paddle/distributed/auto_parallel/constants.py
python/paddle/distributed/auto_parallel/constants.py
+4
-2
python/paddle/distributed/auto_parallel/operators/dist_embedding.py
...dle/distributed/auto_parallel/operators/dist_embedding.py
+7
-4
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
...paddle/distributed/auto_parallel/operators/dist_matmul.py
+1049
-631
python/paddle/distributed/auto_parallel/parallelizer_v2.py
python/paddle/distributed/auto_parallel/parallelizer_v2.py
+14
-5
python/paddle/distributed/passes/auto_parallel_amp.py
python/paddle/distributed/passes/auto_parallel_amp.py
+347
-166
python/paddle/distributed/passes/auto_parallel_fp16.py
python/paddle/distributed/passes/auto_parallel_fp16.py
+252
-156
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
...paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
+3
-0
python/paddle/fluid/tests/unittests/auto_parallel/amp_o2_pass.py
...paddle/fluid/tests/unittests/auto_parallel/amp_o2_pass.py
+142
-0
python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py
.../fluid/tests/unittests/auto_parallel/amp_pass_unittest.py
+1
-1
python/paddle/fluid/tests/unittests/auto_parallel/test_amp_o2_pass.py
...e/fluid/tests/unittests/auto_parallel/test_amp_o2_pass.py
+55
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py
...ddle/fluid/tests/unittests/auto_parallel/test_strategy.py
+4
-5
未找到文件。
python/paddle/distributed/auto_parallel/constants.py
浏览文件 @
6f3c9643
...
@@ -62,6 +62,8 @@ set_field_default_config(RECOMPUTE, "enable_tuning", False)
...
@@ -62,6 +62,8 @@ set_field_default_config(RECOMPUTE, "enable_tuning", False)
#########################################
#########################################
AMP
=
"amp"
AMP
=
"amp"
set_field_default_config
(
AMP
,
"enable"
,
False
)
set_field_default_config
(
AMP
,
"enable"
,
False
)
set_field_default_config
(
AMP
,
"dtype"
,
"float16"
)
set_field_default_config
(
AMP
,
"level"
,
"o1"
)
set_field_default_config
(
AMP
,
"init_loss_scaling"
,
32768.0
)
set_field_default_config
(
AMP
,
"init_loss_scaling"
,
32768.0
)
set_field_default_config
(
AMP
,
"incr_every_n_steps"
,
1000
)
set_field_default_config
(
AMP
,
"incr_every_n_steps"
,
1000
)
set_field_default_config
(
AMP
,
"decr_every_n_nan_or_inf"
,
2
)
set_field_default_config
(
AMP
,
"decr_every_n_nan_or_inf"
,
2
)
...
@@ -71,8 +73,8 @@ set_field_default_config(AMP, "use_dynamic_loss_scaling", True)
...
@@ -71,8 +73,8 @@ set_field_default_config(AMP, "use_dynamic_loss_scaling", True)
set_field_default_config
(
AMP
,
"custom_white_list"
,
[])
set_field_default_config
(
AMP
,
"custom_white_list"
,
[])
set_field_default_config
(
AMP
,
"custom_black_list"
,
[])
set_field_default_config
(
AMP
,
"custom_black_list"
,
[])
set_field_default_config
(
AMP
,
"custom_black_varnames"
,
[])
set_field_default_config
(
AMP
,
"custom_black_varnames"
,
[])
set_field_default_config
(
AMP
,
"use_
pure_fp16
"
,
False
)
set_field_default_config
(
AMP
,
"use_
fp16_guard
"
,
False
)
set_field_default_config
(
AMP
,
"use_
fp16_guard"
,
Tru
e
)
set_field_default_config
(
AMP
,
"use_
bf16_guard"
,
Fals
e
)
set_field_default_config
(
AMP
,
"use_optimizer_fp16"
,
False
)
set_field_default_config
(
AMP
,
"use_optimizer_fp16"
,
False
)
#########################################
#########################################
...
...
python/paddle/distributed/auto_parallel/operators/dist_embedding.py
浏览文件 @
6f3c9643
...
@@ -459,7 +459,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
...
@@ -459,7 +459,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
check_variable_and_dtype
(
check_variable_and_dtype
(
Out_var
,
Out_var
,
'tensor'
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint16'
],
'c_allreduce_sum'
,
'c_allreduce_sum'
,
)
)
...
@@ -649,7 +649,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
...
@@ -649,7 +649,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
check_variable_and_dtype
(
check_variable_and_dtype
(
Out_grad
,
Out_grad
,
'tensor'
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint16'
],
'_c_identity'
,
'_c_identity'
,
)
)
...
@@ -691,12 +691,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
...
@@ -691,12 +691,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
},
},
)
)
check_variable_and_dtype
(
check_variable_and_dtype
(
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
)
check_dtype
(
check_dtype
(
intermediate_var_0
.
dtype
,
intermediate_var_0
.
dtype
,
'dtype'
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
'linear'
,
)
)
...
...
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
浏览文件 @
6f3c9643
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/parallelizer_v2.py
浏览文件 @
6f3c9643
...
@@ -254,17 +254,26 @@ class Parallelizer:
...
@@ -254,17 +254,26 @@ class Parallelizer:
self
.
_dist_context
.
serial_feed_vars
[
"inputs"
]
self
.
_dist_context
.
serial_feed_vars
[
"inputs"
]
+
self
.
_dist_context
.
serial_feed_vars
[
"labels"
]
+
self
.
_dist_context
.
serial_feed_vars
[
"labels"
]
)
)
if
config
[
"use_pure_fp16"
]:
self
.
_logger
.
info
(
"Applying AMP-{}-{} ..."
.
format
(
config
[
"dtype"
],
config
[
'level'
]
),
)
if
config
[
'level'
]
==
"o1"
:
auto_parallel_amp_pass
=
new_pass
(
"auto_parallel_amp"
,
config
)
auto_parallel_amp_pass
.
apply
(
[
main_program
],
[
startup_program
],
self
.
_pass_context
)
loss
=
auto_parallel_amp_pass
.
get_loss
()
elif
config
[
'level'
]
in
[
'o2'
,
'o3'
]:
config
[
"base_opt"
]
=
optimizer
config
[
"base_opt"
]
=
optimizer
auto_parallel_fp16_pass
=
new_pass
(
"auto_parallel_fp16"
,
config
)
auto_parallel_fp16_pass
=
new_pass
(
"auto_parallel_fp16"
,
config
)
auto_parallel_fp16_pass
.
apply
(
auto_parallel_fp16_pass
.
apply
(
[
main_program
],
[
startup_program
],
self
.
_pass_context
[
main_program
],
[
startup_program
],
self
.
_pass_context
)
)
loss
=
auto_parallel_fp16_pass
.
get_loss
()
else
:
else
:
auto_parallel_amp_pass
=
new_pass
(
"auto_parallel_amp"
,
config
)
raise
ValueError
(
"AMP level should be one of o1, o2, o3"
)
auto_parallel_amp_pass
.
apply
(
[
main_program
],
[
startup_program
],
self
.
_pass_context
)
# apply recompute pass
# apply recompute pass
# recompute is then train-only optimization
# recompute is then train-only optimization
...
...
python/paddle/distributed/passes/auto_parallel_amp.py
浏览文件 @
6f3c9643
此差异已折叠。
点击以展开。
python/paddle/distributed/passes/auto_parallel_fp16.py
浏览文件 @
6f3c9643
...
@@ -27,14 +27,13 @@ from paddle.distributed.auto_parallel.utils import (
...
@@ -27,14 +27,13 @@ from paddle.distributed.auto_parallel.utils import (
from
paddle.distributed.auto_parallel.process_group
import
(
from
paddle.distributed.auto_parallel.process_group
import
(
get_world_process_group
,
get_world_process_group
,
)
)
from
paddle.fluid.contrib.mixed_precision.fp16_
util
s
import
(
from
paddle.fluid.contrib.mixed_precision.fp16_
list
s
import
(
AutoMixedPrecisionLists
,
AutoMixedPrecisionLists
,
)
)
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
(
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
(
_keep_layer_norm_scale_bias_to_fp32
,
_keep_layer_norm_scale_bias_to_fp32
,
_need_keep_fp32
,
_need_keep_fp32
,
_valid_types
,
_valid_types
,
_dtype_to_str
,
)
)
from
paddle.distributed.auto_parallel.dist_attribute
import
(
from
paddle.distributed.auto_parallel.dist_attribute
import
(
OperatorDistributedAttribute
,
OperatorDistributedAttribute
,
...
@@ -55,6 +54,23 @@ __amp_skip_ops__ = [
...
@@ -55,6 +54,23 @@ __amp_skip_ops__ = [
'while'
,
'while'
,
'cast'
,
'cast'
,
]
]
__target_dtype__
=
None
def
_dtype_to_str
(
dtype
):
"""
Convert specific variable type to its corresponding string.
Args:
dtype (VarType): Variable type.
"""
if
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
# TODO(Xreki): change the returned str to "bf16" for BF16 data type.
# Currently too many codes use "cast_fp16" as key.
return
'fp16'
elif
dtype
==
core
.
VarDesc
.
VarType
.
BF16
:
return
'bf16'
else
:
return
'fp32'
def
set_op_dtype_to_fp16
(
op
):
def
set_op_dtype_to_fp16
(
op
):
...
@@ -62,14 +78,20 @@ def set_op_dtype_to_fp16(op):
...
@@ -62,14 +78,20 @@ def set_op_dtype_to_fp16(op):
op
.
has_attr
(
'in_dtype'
)
op
.
has_attr
(
'in_dtype'
)
and
op
.
attr
(
'in_dtype'
)
==
core
.
VarDesc
.
VarType
.
FP32
and
op
.
attr
(
'in_dtype'
)
==
core
.
VarDesc
.
VarType
.
FP32
):
):
op
.
_set_attr
(
'in_dtype'
,
core
.
VarDesc
.
VarType
.
FP16
)
op
.
_set_attr
(
'in_dtype'
,
__target_dtype__
)
if
(
if
(
op
.
has_attr
(
'out_dtype'
)
op
.
has_attr
(
'out_dtype'
)
and
op
.
attr
(
'out_dtype'
)
==
core
.
VarDesc
.
VarType
.
FP32
and
op
.
attr
(
'out_dtype'
)
==
core
.
VarDesc
.
VarType
.
FP32
):
):
op
.
_set_attr
(
'out_dtype'
,
core
.
VarDesc
.
VarType
.
FP16
)
op
.
_set_attr
(
'out_dtype'
,
__target_dtype__
)
if
op
.
has_attr
(
'dtype'
)
and
op
.
attr
(
'dtype'
)
==
core
.
VarDesc
.
VarType
.
FP32
:
if
op
.
has_attr
(
'dtype'
)
and
op
.
attr
(
'dtype'
)
==
core
.
VarDesc
.
VarType
.
FP32
:
op
.
_set_attr
(
'dtype'
,
core
.
VarDesc
.
VarType
.
FP16
)
op
.
_set_attr
(
'dtype'
,
__target_dtype__
)
if
__target_dtype__
==
core
.
VarDesc
.
VarType
.
BF16
:
if
op
.
has_attr
(
'use_mkldnn'
):
op
.
_set_attr
(
'use_mkldnn'
,
True
)
if
op
.
has_attr
(
'mkldnn_data_type'
):
op
.
_set_attr
(
'mkldnn_data_type'
,
'bfloat16'
)
# adapot for backward op
# adapot for backward op
...
@@ -156,6 +178,7 @@ class FP16State(object):
...
@@ -156,6 +178,7 @@ class FP16State(object):
list
list
)
# {forward_op_id: [(output_name, input_name, out_dtype, in_dtype, slot_name), ]}
)
# {forward_op_id: [(output_name, input_name, out_dtype, in_dtype, slot_name), ]}
self
.
is_train
=
False
self
.
is_train
=
False
self
.
out_var_op_deps
=
{}
def
_is_fp16_op
(
self
,
op_id
):
def
_is_fp16_op
(
self
,
op_id
):
return
self
.
_op_fp16_dict
.
get
(
op_id
,
None
)
return
self
.
_op_fp16_dict
.
get
(
op_id
,
None
)
...
@@ -169,6 +192,14 @@ class FP16State(object):
...
@@ -169,6 +192,14 @@ class FP16State(object):
# assume all backward block are behind forward blocks
# assume all backward block are behind forward blocks
for
block
in
self
.
program
.
blocks
:
for
block
in
self
.
program
.
blocks
:
for
op
in
block
.
ops
:
for
op
in
block
.
ops
:
for
name
in
op
.
output_arg_names
:
if
name
not
in
self
.
out_var_op_deps
:
self
.
out_var_op_deps
[
name
]
=
[
op
.
desc
.
original_id
()]
else
:
self
.
out_var_op_deps
[
name
].
extend
(
[
op
.
desc
.
original_id
()]
)
self
.
_mark_op
(
op
)
self
.
_mark_op
(
op
)
# set forward tensor dtype
# set forward tensor dtype
...
@@ -192,6 +223,18 @@ class FP16State(object):
...
@@ -192,6 +223,18 @@ class FP16State(object):
if
op
.
type
==
"assign"
and
"array_"
in
op
.
input_arg_names
[
0
]:
if
op
.
type
==
"assign"
and
"array_"
in
op
.
input_arg_names
[
0
]:
self
.
_op_fp16_dict
[
op
.
desc
.
original_id
()]
=
False
self
.
_op_fp16_dict
[
op
.
desc
.
original_id
()]
=
False
return
return
# If assign op is inplace-operation, assign op exec mode should be same with the created op of output_var.
if
op
.
type
==
"assign"
:
out_name
=
op
.
output_arg_names
[
0
]
if
len
(
self
.
out_var_op_deps
[
out_name
])
>
1
:
if
not
self
.
_op_fp16_dict
[
self
.
out_var_op_deps
[
out_name
][
0
]
]:
self
.
_op_fp16_dict
[
op
.
desc
.
original_id
()]
=
False
else
:
self
.
_op_fp16_dict
[
op
.
desc
.
original_id
()]
=
True
return
if
_need_keep_fp32
(
if
_need_keep_fp32
(
op
,
self
.
amp_list
.
unsupported_list
,
self
.
use_fp16_guard
op
,
self
.
amp_list
.
unsupported_list
,
self
.
use_fp16_guard
):
):
...
@@ -228,7 +271,7 @@ class FP16State(object):
...
@@ -228,7 +271,7 @@ class FP16State(object):
return
return
if
var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
if
var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
var
.
desc
.
set_dtype
(
core
.
VarDesc
.
VarType
.
FP16
)
var
.
desc
.
set_dtype
(
__target_dtype__
)
def
resolute_tensor_dtype
(
self
,
block
):
def
resolute_tensor_dtype
(
self
,
block
):
...
@@ -260,7 +303,7 @@ class FP16State(object):
...
@@ -260,7 +303,7 @@ class FP16State(object):
out_var
=
block
.
vars
.
get
(
out_var_name
)
out_var
=
block
.
vars
.
get
(
out_var_name
)
if
out_var
is
None
or
out_var
.
type
not
in
_valid_types
:
if
out_var
is
None
or
out_var
.
type
not
in
_valid_types
:
continue
continue
if
out_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
if
out_var
.
dtype
==
__target_dtype__
:
out_var
.
desc
.
set_dtype
(
core
.
VarDesc
.
VarType
.
FP32
)
out_var
.
desc
.
set_dtype
(
core
.
VarDesc
.
VarType
.
FP32
)
elif
is_backward_op
(
op
):
elif
is_backward_op
(
op
):
if
self
.
_is_fp16_op
(
op
.
desc
.
original_id
())
==
True
:
if
self
.
_is_fp16_op
(
op
.
desc
.
original_id
())
==
True
:
...
@@ -276,7 +319,7 @@ class FP16State(object):
...
@@ -276,7 +319,7 @@ class FP16State(object):
out_var
=
block
.
vars
.
get
(
out_var_name
)
out_var
=
block
.
vars
.
get
(
out_var_name
)
if
out_var
is
None
or
out_var
.
type
not
in
_valid_types
:
if
out_var
is
None
or
out_var
.
type
not
in
_valid_types
:
continue
continue
if
out_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
if
out_var
.
dtype
==
__target_dtype__
:
out_var
.
desc
.
set_dtype
(
core
.
VarDesc
.
VarType
.
FP32
)
out_var
.
desc
.
set_dtype
(
core
.
VarDesc
.
VarType
.
FP32
)
def
cast_block
(
self
,
block
):
def
cast_block
(
self
,
block
):
...
@@ -295,7 +338,7 @@ class FP16State(object):
...
@@ -295,7 +338,7 @@ class FP16State(object):
op
,
op
,
idx
,
idx
,
block
,
block
,
core
.
VarDesc
.
VarType
.
FP16
,
__target_dtype__
,
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
FP32
,
self
.
dist_context
,
self
.
dist_context
,
)
)
...
@@ -305,7 +348,7 @@ class FP16State(object):
...
@@ -305,7 +348,7 @@ class FP16State(object):
idx
,
idx
,
block
,
block
,
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
FP16
,
__target_dtype__
,
self
.
dist_context
,
self
.
dist_context
,
)
)
elif
is_backward_op
(
op
):
elif
is_backward_op
(
op
):
...
@@ -315,7 +358,7 @@ class FP16State(object):
...
@@ -315,7 +358,7 @@ class FP16State(object):
op
,
op
,
idx
,
idx
,
block
,
block
,
core
.
VarDesc
.
VarType
.
FP16
,
__target_dtype__
,
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
FP32
,
self
.
dist_context
,
self
.
dist_context
,
)
)
...
@@ -325,7 +368,7 @@ class FP16State(object):
...
@@ -325,7 +368,7 @@ class FP16State(object):
idx
,
idx
,
block
,
block
,
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
FP16
,
__target_dtype__
,
self
.
dist_context
,
self
.
dist_context
,
)
)
elif
op
.
type
==
"sum"
:
elif
op
.
type
==
"sum"
:
...
@@ -399,6 +442,9 @@ class FP16State(object):
...
@@ -399,6 +442,9 @@ class FP16State(object):
dist_context
,
cast_var
,
ref_mapping
,
ref_mesh
dist_context
,
cast_var
,
ref_mapping
,
ref_mesh
)
)
op_namescope
=
"/"
if
op
.
has_attr
(
'op_namescope'
):
op_namescope
=
op
.
attr
(
'op_namescope'
)
cast_op
=
block
.
_insert_op_without_sync
(
cast_op
=
block
.
_insert_op_without_sync
(
idx
,
idx
,
type
=
"cast"
,
type
=
"cast"
,
...
@@ -410,6 +456,9 @@ class FP16State(object):
...
@@ -410,6 +456,9 @@ class FP16State(object):
OP_ROLE_KEY
:
OpRole
.
Forward
,
OP_ROLE_KEY
:
OpRole
.
Forward
,
},
},
)
)
cast_op
.
_set_attr
(
'op_namescope'
,
op_namescope
)
# for recompute
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
cast_op
,
ref_mesh
,
ref_mapping
,
dist_context
cast_op
,
ref_mesh
,
ref_mapping
,
dist_context
)
)
...
@@ -455,22 +504,36 @@ class FP16State(object):
...
@@ -455,22 +504,36 @@ class FP16State(object):
)
in
self
.
forward_input_cast_ops
[
forward_op_id
]:
)
in
self
.
forward_input_cast_ops
[
forward_op_id
]:
# rename input
# rename input
# some forward output is not need by backward computation, e.g. logit in softmax_with_cross_entropy
if
op
.
type
!=
"scale"
and
slot_name
in
op
.
input_names
:
assert
src_name
in
op
.
input
(
assert
src_name
in
op
.
input
(
slot_name
slot_name
),
"var: {} not in op's {}. {}"
.
format
(
src_name
,
slot_name
,
str
(
op
))
),
"var: {} not in op's {}. {}"
.
format
(
src_name
,
slot_name
,
str
(
op
)
)
src_var_dist_attr
=
grad_op_attr
.
get_input_dist_attr
(
src_name
)
src_var_dist_attr
=
grad_op_attr
.
get_input_dist_attr
(
src_name
)
assert
src_var_dist_attr
is
not
None
assert
src_var_dist_attr
is
not
None
op
.
_rename_input
(
src_name
,
cast_name
)
op
.
_rename_input
(
src_name
,
cast_name
)
grad_op_attr
.
set_input_dist_attr
(
cast_name
,
src_var_dist_attr
)
grad_op_attr
.
set_input_dist_attr
(
cast_name
,
src_var_dist_attr
)
# NOTE Special for scale op, scale op's grad op is scale,
# so slot name map rule could not apply to grad scale op
# cast_name: mean_0.tmp_0.cast_bf16, src_name: mean_0.tmp_0, dst_dtype: paddle.bfloat16, src_dtype: paddle.float32, slot_name: X.
if
op
.
type
==
"scale"
:
grad_slot_name
=
"X"
# create cast grad
# create cast grad
else
:
grad_slot_name
=
slot_name
+
"@GRAD"
grad_slot_name
=
slot_name
+
"@GRAD"
assert
grad_slot_name
in
op
.
output_names
if
grad_slot_name
in
op
.
output_names
:
# some forward input maybe stop_gradient=True, e.g. input_mask
if
len
(
op
.
output
(
grad_slot_name
))
==
0
:
if
len
(
op
.
output
(
grad_slot_name
))
==
0
:
var
=
block
.
var
(
src_name
)
assert
var
.
stop_gradient
is
True
continue
continue
assert
len
(
op
.
output
(
grad_slot_name
))
==
1
assert
(
len
(
op
.
output
(
grad_slot_name
))
==
1
),
"[{}], Current Op: {}"
.
format
(
grad_slot_name
,
str
(
op
))
grad_name
=
op
.
output
(
grad_slot_name
)[
0
]
grad_name
=
op
.
output
(
grad_slot_name
)[
0
]
grad
=
block
.
var
(
grad_name
)
grad
=
block
.
var
(
grad_name
)
grad_dist_attr
=
grad_op_attr
.
get_output_dist_attr
(
grad_name
)
grad_dist_attr
=
grad_op_attr
.
get_output_dist_attr
(
grad_name
)
...
@@ -492,7 +555,9 @@ class FP16State(object):
...
@@ -492,7 +555,9 @@ class FP16State(object):
cast_grad
,
grad_dist_attr
cast_grad
,
grad_dist_attr
)
)
op
.
_rename_output
(
grad_name
,
cast_grad
.
name
)
op
.
_rename_output
(
grad_name
,
cast_grad
.
name
)
grad_op_attr
.
set_output_dist_attr
(
cast_grad
.
name
,
grad_dist_attr
)
grad_op_attr
.
set_output_dist_attr
(
cast_grad
.
name
,
grad_dist_attr
)
# add cast
# add cast
cast_op
=
block
.
_insert_op_without_sync
(
cast_op
=
block
.
_insert_op_without_sync
(
...
@@ -573,7 +638,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context):
...
@@ -573,7 +638,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context):
def
_split_grads
(
params_grads
):
def
_split_grads
(
params_grads
):
grads
=
[
g
for
_
,
g
in
params_grads
]
grads
=
[
g
for
_
,
g
in
params_grads
]
fp32_grads
=
[
g
for
g
in
grads
if
g
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
]
fp32_grads
=
[
g
for
g
in
grads
if
g
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
]
fp16_grads
=
[
g
for
g
in
grads
if
g
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
]
fp16_grads
=
[
g
for
g
in
grads
if
g
.
dtype
==
__target_dtype__
]
assert
len
(
fp32_grads
)
+
len
(
fp16_grads
)
==
len
(
assert
len
(
fp32_grads
)
+
len
(
fp16_grads
)
==
len
(
grads
grads
),
"Data types of all grads must be either fp16 or fp32."
),
"Data types of all grads must be either fp16 or fp32."
...
@@ -633,17 +698,15 @@ def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"):
...
@@ -633,17 +698,15 @@ def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"):
# TODO to support CUDAPinned/NPU/XPU Places
# TODO to support CUDAPinned/NPU/XPU Places
if
direction
==
"D2H"
:
if
direction
==
"D2H"
:
dst_place_type
=
0
dst_place_type
=
0
elif
direction
==
"D2H"
:
dst_place_type
=
1
else
:
else
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"direction [{}] is not supported yet."
.
format
(
direction
)
f
"direction [
{
direction
}
] is not supported yet."
)
)
attrs
=
{
'dst_place_type'
:
dst_place_type
}
attrs
=
{
'dst_place_type'
:
dst_place_type
}
new_op
=
block
.
_insert_op_without_sync
(
new_op
=
block
.
_insert_op_without_sync
(
index
=
idx
,
index
=
idx
,
type
=
'memcpy'
,
type
=
'memcpy
_d2h
'
,
inputs
=
{
'X'
:
[
src_var
]},
inputs
=
{
'X'
:
[
src_var
]},
outputs
=
{
'Out'
:
[
output_var
]},
outputs
=
{
'Out'
:
[
output_var
]},
attrs
=
attrs
,
attrs
=
attrs
,
...
@@ -678,17 +741,17 @@ def cast_startup_program():
...
@@ -678,17 +741,17 @@ def cast_startup_program():
for
op
in
startup_program
.
global_block
().
ops
:
for
op
in
startup_program
.
global_block
().
ops
:
if
is_initialization_op
(
op
):
if
is_initialization_op
(
op
):
output_name
=
op
.
output_arg_names
[
0
]
output_name
=
op
.
output_arg_names
[
0
]
if
(
if
param_to_dtype
.
get
(
output_name
,
None
)
==
__target_dtype__
:
param_to_dtype
.
get
(
output_name
,
None
)
==
core
.
VarDesc
.
VarType
.
FP16
):
assert
op
.
has_attr
(
assert
op
.
has_attr
(
'dtype'
'dtype'
),
"initialization op is supported to has dtype attribute but got {}."
.
format
(
),
"initialization op is supported to has dtype attribute but got {}."
.
format
(
str
(
op
)
str
(
op
)
)
)
out_var
=
startup_program
.
global_block
().
var
(
output_name
)
if
out_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
out_var
.
desc
.
set_dtype
(
__target_dtype__
)
if
op
.
attr
(
'dtype'
)
==
core
.
VarDesc
.
VarType
.
FP32
:
if
op
.
attr
(
'dtype'
)
==
core
.
VarDesc
.
VarType
.
FP32
:
op
.
_set_attr
(
'dtype'
,
core
.
VarDesc
.
VarType
.
FP16
)
op
.
_set_attr
(
'dtype'
,
__target_dtype__
)
@
register_pass
(
"auto_parallel_fp16"
)
@
register_pass
(
"auto_parallel_fp16"
)
...
@@ -701,14 +764,44 @@ class FP16Pass(AMPPass):
...
@@ -701,14 +764,44 @@ class FP16Pass(AMPPass):
# in distributed scenario, all ranks should have the same modification.
# in distributed scenario, all ranks should have the same modification.
def
_apply_single_impl
(
self
,
main_program
,
startup_program
,
context
):
def
_apply_single_impl
(
self
,
main_program
,
startup_program
,
context
):
self
.
dist_context
=
self
.
get_attr
(
"dist_context"
)
self
.
dist_context
=
self
.
get_attr
(
"dist_context"
)
self
.
target_dtype
=
self
.
get_attr
(
"dtype"
)
params_grads
=
self
.
get_attr
(
"params_grads"
)
params_grads
=
self
.
get_attr
(
"params_grads"
)
self
.
use_optimizer_fp16
=
self
.
get_attr
(
"use_optimizer_fp16"
,
None
)
if
self
.
use_optimizer_fp16
is
None
:
self
.
use_optimizer_fp16
=
self
.
get_attr
(
"level"
,
None
)
==
"o3"
# swith enviroment for fp16 / bf16.
if
self
.
target_dtype
==
"float16"
:
__target_dtype
=
core
.
VarDesc
.
VarType
.
FP16
elif
self
.
target_dtype
==
"bfloat16"
:
__target_dtype
=
core
.
VarDesc
.
VarType
.
BF16
else
:
raise
NotImplementedError
(
"target dtype [{}] is for amp o2 not supported yet."
.
format
(
self
.
target_dtype
)
)
global
__target_dtype__
__target_dtype__
=
__target_dtype
amp_list
=
AutoMixedPrecisionLists
(
amp_list
=
AutoMixedPrecisionLists
(
set
(
self
.
get_attr
(
"custom_white_list"
)),
set
(
self
.
get_attr
(
"custom_white_list"
)),
set
(
self
.
get_attr
(
"custom_black_list"
)),
set
(
self
.
get_attr
(
"custom_black_list"
)),
None
,
dtype
=
self
.
target_dtype
,
)
)
amp_list
.
unsupported_list
-=
{
"conditional_block_grad"
,
"conditional_block"
,
"conditional_block_infer"
,
"select_input"
,
"while"
,
"while_grad"
,
"cast"
,
"tensor_array_to_tensor"
,
"lod_array_length"
,
"write_to_array"
,
}
# NOTE don't not change input data dtype, since it is controled by dataloader
# NOTE don't not change input data dtype, since it is controled by dataloader
# and which is out of control of FP16 Pass
# and which is out of control of FP16 Pass
input_data_var_names
=
[
var
.
name
for
var
in
self
.
get_attr
(
"input_data"
)]
input_data_var_names
=
[
var
.
name
for
var
in
self
.
get_attr
(
"input_data"
)]
...
@@ -726,6 +819,7 @@ class FP16Pass(AMPPass):
...
@@ -726,6 +819,7 @@ class FP16Pass(AMPPass):
cast_startup_program
()
cast_startup_program
()
if
is_train
:
if
is_train
:
if
self
.
target_dtype
==
"fp16"
:
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
# TODO (JZ-LIANG)support cast forward program only when inference
# TODO (JZ-LIANG)support cast forward program only when inference
self
.
_init_amp_var
()
self
.
_init_amp_var
()
...
@@ -801,10 +895,12 @@ class FP16Pass(AMPPass):
...
@@ -801,10 +895,12 @@ class FP16Pass(AMPPass):
# modify optimizer
# modify optimizer
base_opt
=
self
.
get_attr
(
"base_opt"
)
base_opt
=
self
.
get_attr
(
"base_opt"
)
base_opt
.
_multi_precision
=
True
base_opt
.
_multi_precision
=
True
if
self
.
get_attr
(
"use_optimizer_fp16"
)
:
if
self
.
use_optimizer_fp16
:
base_opt
.
_multi_precision
=
False
base_opt
.
_multi_precision
=
False
if
self
.
target_dtype
==
"fp16"
:
if
isinstance
(
if
isinstance
(
base_opt
,
(
paddle
.
fluid
.
optimizer
.
Adam
,
paddle
.
optimizer
.
AdamW
)
base_opt
,
(
paddle
.
fluid
.
optimizer
.
Adam
,
paddle
.
optimizer
.
AdamW
),
):
):
with
main_program
.
_optimized_guard
([]):
with
main_program
.
_optimized_guard
([]):
# found_inf = paddle.tensor.creation._memcpy(
# found_inf = paddle.tensor.creation._memcpy(
...
...
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
浏览文件 @
6f3c9643
...
@@ -40,6 +40,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
...
@@ -40,6 +40,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules
(
test_random_ctrl MODULES test_random_ctrl ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_random_ctrl MODULES test_random_ctrl ENVS
${
dist_ENVS
}
)
set_tests_properties
(
test_random_ctrl PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
set_tests_properties
(
test_random_ctrl PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
TIMEOUT 50
)
TIMEOUT 50
)
py_test_modules
(
test_amp_o2_pass MODULES test_amp_o2_pass ENVS
${
dist_ENVS
}
)
set_tests_properties
(
test_amp_o2_pass PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
TIMEOUT 50
)
py_test_modules
(
test_iterable_dataset MODULES test_iterable_dataset ENVS
py_test_modules
(
test_iterable_dataset MODULES test_iterable_dataset ENVS
${
dist_ENVS
}
)
${
dist_ENVS
}
)
set_tests_properties
(
test_iterable_dataset
set_tests_properties
(
test_iterable_dataset
...
...
python/paddle/fluid/tests/unittests/auto_parallel/amp_o2_pass.py
0 → 100644
浏览文件 @
6f3c9643
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
random
import
re
import
unittest
import
numpy
as
np
from
get_gpt_model
import
FakeDataset
,
generate_model
import
paddle
from
paddle.distributed.fleet
import
auto
from
paddle.fluid.framework
import
core
paddle
.
enable_static
()
def
get_cuda_version
():
result
=
os
.
popen
(
"nvcc --version"
).
read
()
regex
=
r
'release (\S+),'
match
=
re
.
search
(
regex
,
result
)
if
match
:
num
=
str
(
match
.
group
(
1
))
integer
,
decimal
=
num
.
split
(
'.'
)
return
int
(
integer
)
*
1000
+
int
(
float
(
decimal
)
*
10
)
else
:
return
-
1
def
apply_pass
(
use_amp
=
False
,
amp_dtype
=
"bfloat16"
):
strategy
=
auto
.
Strategy
()
strategy
.
auto_mode
=
"semi"
strategy
.
reinit
=
True
if
use_amp
:
amp
=
strategy
.
amp
amp
.
enable
=
True
amp
.
dtype
=
amp_dtype
amp
.
level
=
"o2"
amp
.
custom_black_list
=
[
'c_softmax_with_cross_entropy'
,
'elementwise_div'
,
'reduce_sum'
,
]
return
strategy
def
reset_prog
():
paddle
.
fluid
.
framework
.
switch_main_program
(
paddle
.
static
.
Program
())
paddle
.
fluid
.
framework
.
switch_startup_program
(
paddle
.
static
.
Program
())
class
TestShardingStage2WithNewEXE
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
batch_size
=
2
self
.
batch_num
=
10
self
.
clip_norm
=
0.2
self
.
dataset
=
FakeDataset
(
self
.
batch_size
*
self
.
batch_num
)
def
init
(
self
,
engine
):
paddle
.
seed
(
2022
)
np
.
random
.
seed
(
2022
)
random
.
seed
(
2022
)
place
=
paddle
.
fluid
.
CUDAPlace
(
paddle
.
distributed
.
ParallelEnv
().
dev_id
)
engine
.
_executor
=
paddle
.
static
.
Executor
(
place
)
def
get_engine
(
self
,
use_amp
=
False
,
amp_dtype
=
"bfloat16"
):
reset_prog
()
strategy
=
apply_pass
(
use_amp
,
amp_dtype
)
# clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
clip
=
None
opt
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
0.00001
,
grad_clip
=
clip
)
model
,
loss
=
generate_model
(
"mp"
)
engine
=
auto
.
Engine
(
model
,
loss
,
opt
,
strategy
=
strategy
)
self
.
init
(
engine
)
return
engine
def
check_bf16
(
self
,
program
):
num_bf16
=
0
num_fp16
=
0
num_fp32
=
0
for
p
in
program
.
all_parameters
():
if
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
num_fp32
+=
1
if
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
num_fp16
+=
1
if
p
.
dtype
==
core
.
VarDesc
.
VarType
.
BF16
:
num_bf16
+=
1
self
.
assertEqual
(
num_bf16
,
25
)
self
.
assertEqual
(
num_fp16
,
0
)
self
.
assertEqual
(
num_fp32
,
11
)
def
test_param_grad_fuse_overlap
(
self
):
# std
mp_engine
=
self
.
get_engine
(
use_amp
=
False
)
mp_history
=
mp_engine
.
fit
(
self
.
dataset
,
3
,
epochs
=
1
,
steps_per_epoch
=
self
.
batch_num
,
log_freq
=
1
,
batch_size
=
self
.
batch_size
,
)
loss0
=
mp_history
.
history
[
'loss'
][
0
]
# bf16
mp_bf16_engine
=
self
.
get_engine
(
use_amp
=
True
)
if
not
paddle
.
is_compiled_with_cuda
()
or
get_cuda_version
()
<
11000
:
return
mp_bf16_history
=
mp_bf16_engine
.
fit
(
self
.
dataset
,
3
,
epochs
=
1
,
steps_per_epoch
=
self
.
batch_num
,
log_freq
=
1
,
batch_size
=
self
.
batch_size
,
)
loss1
=
mp_bf16_history
.
history
[
'loss'
][
0
]
np
.
testing
.
assert_allclose
(
loss0
,
loss1
,
atol
=
1e-3
,
rtol
=
1e-2
)
self
.
check_bf16
(
mp_bf16_engine
.
main_program
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py
浏览文件 @
6f3c9643
...
@@ -38,7 +38,7 @@ def apply_pass(use_amp=False, level=None):
...
@@ -38,7 +38,7 @@ def apply_pass(use_amp=False, level=None):
]
]
amp
.
init_loss_scaling
=
32768
amp
.
init_loss_scaling
=
32768
amp
.
use_fp16_guard
=
False
amp
.
use_fp16_guard
=
False
amp
.
use_pure_fp16
=
level
in
[
"o2"
,
"o3"
]
amp
.
level
=
level
amp
.
use_optimizer_fp16
=
level
==
"o3"
amp
.
use_optimizer_fp16
=
level
==
"o3"
print
(
"amp level: "
,
level
)
print
(
"amp level: "
,
level
)
return
strategy
return
strategy
...
...
python/paddle/fluid/tests/unittests/auto_parallel/test_amp_o2_pass.py
0 → 100644
浏览文件 @
6f3c9643
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
subprocess
import
sys
import
tempfile
import
unittest
class
TestAMPO2
(
unittest
.
TestCase
):
def
test_bf16
(
self
):
file_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
launch_model_path
=
os
.
path
.
join
(
file_dir
,
"amp_o2_pass.py"
)
if
os
.
environ
.
get
(
"WITH_COVERAGE"
,
"OFF"
)
==
"ON"
:
coverage_args
=
[
"-m"
,
"coverage"
,
"run"
,
"--branch"
,
"-p"
]
else
:
coverage_args
=
[]
tmp_dir
=
tempfile
.
TemporaryDirectory
()
cmd
=
(
[
sys
.
executable
,
"-u"
]
+
coverage_args
+
[
"-m"
,
"paddle.distributed.launch"
,
"--devices"
,
"0,1"
,
"--log_dir"
,
tmp_dir
.
name
,
launch_model_path
,
]
)
process
=
subprocess
.
Popen
(
cmd
)
process
.
wait
()
self
.
assertEqual
(
process
.
returncode
,
0
)
tmp_dir
.
cleanup
()
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py
浏览文件 @
6f3c9643
...
@@ -13,13 +13,13 @@
...
@@ -13,13 +13,13 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
# import yaml
# import yaml
import
unittest
import
unittest
from
paddle.distributed.fleet
import
auto
from
paddle.distributed.fleet
import
auto
class
TestStrategy
(
unittest
.
TestCase
):
class
TestStrategy
(
unittest
.
TestCase
):
def
test_default_config
(
self
):
def
test_default_config
(
self
):
strategy
=
auto
.
Strategy
()
strategy
=
auto
.
Strategy
()
...
@@ -29,6 +29,8 @@ class TestStrategy(unittest.TestCase):
...
@@ -29,6 +29,8 @@ class TestStrategy(unittest.TestCase):
amp
=
strategy
.
amp
amp
=
strategy
.
amp
self
.
assertEqual
(
amp
.
enable
,
False
)
self
.
assertEqual
(
amp
.
enable
,
False
)
self
.
assertAlmostEqual
(
amp
.
dtype
,
"float16"
)
self
.
assertAlmostEqual
(
amp
.
level
,
"o1"
)
self
.
assertAlmostEqual
(
amp
.
init_loss_scaling
,
32768.0
)
self
.
assertAlmostEqual
(
amp
.
init_loss_scaling
,
32768.0
)
self
.
assertEqual
(
amp
.
incr_every_n_steps
,
1000
)
self
.
assertEqual
(
amp
.
incr_every_n_steps
,
1000
)
self
.
assertEqual
(
amp
.
decr_every_n_nan_or_inf
,
2
)
self
.
assertEqual
(
amp
.
decr_every_n_nan_or_inf
,
2
)
...
@@ -38,8 +40,7 @@ class TestStrategy(unittest.TestCase):
...
@@ -38,8 +40,7 @@ class TestStrategy(unittest.TestCase):
self
.
assertEqual
(
amp
.
custom_black_list
,
[])
self
.
assertEqual
(
amp
.
custom_black_list
,
[])
self
.
assertEqual
(
amp
.
custom_white_list
,
[])
self
.
assertEqual
(
amp
.
custom_white_list
,
[])
self
.
assertEqual
(
amp
.
custom_black_varnames
,
[])
self
.
assertEqual
(
amp
.
custom_black_varnames
,
[])
self
.
assertEqual
(
amp
.
use_pure_fp16
,
False
)
self
.
assertEqual
(
amp
.
use_fp16_guard
,
False
)
self
.
assertEqual
(
amp
.
use_fp16_guard
,
True
)
self
.
assertEqual
(
amp
.
use_optimizer_fp16
,
False
)
self
.
assertEqual
(
amp
.
use_optimizer_fp16
,
False
)
sharding
=
strategy
.
sharding
sharding
=
strategy
.
sharding
...
@@ -92,7 +93,6 @@ class TestStrategy(unittest.TestCase):
...
@@ -92,7 +93,6 @@ class TestStrategy(unittest.TestCase):
amp
.
custom_white_list
=
[
"x"
]
amp
.
custom_white_list
=
[
"x"
]
amp
.
custom_black_list
=
[
"y"
]
amp
.
custom_black_list
=
[
"y"
]
amp
.
custom_black_varnames
=
[
"z"
]
amp
.
custom_black_varnames
=
[
"z"
]
amp
.
use_pure_fp16
=
True
amp
.
use_fp16_guard
=
False
amp
.
use_fp16_guard
=
False
amp
.
use_optimizer_fp16
=
True
amp
.
use_optimizer_fp16
=
True
self
.
assertEqual
(
amp
.
enable
,
True
)
self
.
assertEqual
(
amp
.
enable
,
True
)
...
@@ -105,7 +105,6 @@ class TestStrategy(unittest.TestCase):
...
@@ -105,7 +105,6 @@ class TestStrategy(unittest.TestCase):
self
.
assertEqual
(
amp
.
custom_white_list
,
[
"x"
])
self
.
assertEqual
(
amp
.
custom_white_list
,
[
"x"
])
self
.
assertEqual
(
amp
.
custom_black_list
,
[
"y"
])
self
.
assertEqual
(
amp
.
custom_black_list
,
[
"y"
])
self
.
assertEqual
(
amp
.
custom_black_varnames
,
[
"z"
])
self
.
assertEqual
(
amp
.
custom_black_varnames
,
[
"z"
])
self
.
assertEqual
(
amp
.
use_pure_fp16
,
True
)
self
.
assertEqual
(
amp
.
use_fp16_guard
,
False
)
self
.
assertEqual
(
amp
.
use_fp16_guard
,
False
)
self
.
assertEqual
(
amp
.
use_optimizer_fp16
,
True
)
self
.
assertEqual
(
amp
.
use_optimizer_fp16
,
True
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录