未验证 提交 6f8ab1fa 编写于 作者: Z Zhang Ting 提交者: GitHub

[AMP] use promote dtype when amp_level=O2 (#51063)

上级 5cdd9f2c
......@@ -122,24 +122,16 @@ inline phi::DataType GetAmpDestDtype(
auto amp_setting_dtype =
egr::Controller::Instance().GetCurrentTracer()->GetAmpPhiDtype();
auto dst_type = amp_setting_dtype;
if (amp_level == paddle::imperative::AmpLevel::O1) {
if (paddle::imperative::AmpOperators::Instance()
.GetMutableAllowOps()
->count(op_name)) {
dst_type = amp_setting_dtype;
} else if (paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps()
->count(op_name)) {
dst_type = phi::DataType::FLOAT32;
} else {
dst_type = GetPromoteType(op_name, amp_tensors_vector, amp_setting_dtype);
}
} else if (amp_level == paddle::imperative::AmpLevel::O2) {
if (paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps()
->count(op_name)) {
dst_type = phi::DataType::FLOAT32;
}
if (paddle::imperative::AmpOperators::Instance().GetMutableAllowOps()->count(
op_name)) {
dst_type = amp_setting_dtype;
} else if (paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps()
->count(op_name)) {
dst_type = phi::DataType::FLOAT32;
} else {
dst_type = GetPromoteType(op_name, amp_tensors_vector, amp_setting_dtype);
}
if (dst_type == amp_setting_dtype &&
......
......@@ -17,8 +17,8 @@ from .auto_cast import decorate # noqa: F401
from .auto_cast import amp_guard # noqa: F401
from .auto_cast import amp_decorate # noqa: F401
from .auto_cast import low_precision_op_list # noqa: F401
from .auto_cast import WHITE_LIST # noqa: F401
from .auto_cast import BLACK_LIST # noqa: F401
from .auto_cast import FP16_WHITE_LIST # noqa: F401
from .auto_cast import FP16_BLACK_LIST # noqa: F401
from .auto_cast import PURE_FP16_WHITE_LIST # noqa: F401
from .auto_cast import PURE_FP16_BLACK_LIST # noqa: F401
......
......@@ -25,7 +25,7 @@ AMP_LEVEL = core.AmpLevel
# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
WHITE_LIST = {
FP16_WHITE_LIST = {
'conv2d',
'matmul',
'matmul_v2',
......@@ -37,7 +37,7 @@ WHITE_LIST = {
# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
BLACK_LIST = {
FP16_BLACK_LIST = {
'exp',
'square',
'log',
......@@ -73,7 +73,8 @@ AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
}
PURE_FP16_WHITE_LIST = set()
PURE_FP16_WHITE_LIST = copy.copy(FP16_WHITE_LIST)
PURE_FP16_BLACK_LIST = {
'lookup_table',
'lookup_table_v2',
......@@ -90,7 +91,7 @@ PURE_FP16_BLACK_LIST = {
BF16_WHITE_LIST = {'conv2d', 'matmul_v2'}
BF16_BLACK_LIST = set()
PURE_BF16_WHITE_LIST = set()
PURE_BF16_WHITE_LIST = copy.copy(BF16_WHITE_LIST)
PURE_BF16_BLACK_LIST = set()
_g_amp_state_ = None
......@@ -139,8 +140,8 @@ def _update_list(
"""
if dtype == 'float16':
if level == 'O1':
_white_list = copy.copy(WHITE_LIST)
_black_list = copy.copy(BLACK_LIST)
_white_list = copy.copy(FP16_WHITE_LIST)
_black_list = copy.copy(FP16_BLACK_LIST)
else:
_white_list = copy.copy(PURE_FP16_WHITE_LIST)
_black_list = copy.copy(PURE_FP16_BLACK_LIST)
......@@ -424,8 +425,8 @@ def amp_guard(
if level == 'O1':
amp_level = AMP_LEVEL.O1
if dtype == 'float16':
_white_list = WHITE_LIST
_black_list = BLACK_LIST
_white_list = FP16_WHITE_LIST
_black_list = FP16_BLACK_LIST
elif dtype == 'bfloat16':
_white_list = BF16_WHITE_LIST
_black_list = BF16_BLACK_LIST
......@@ -441,8 +442,8 @@ def amp_guard(
elif level == 'O0':
amp_level = AMP_LEVEL.O0
if dtype == 'float16':
_white_list = WHITE_LIST
_black_list = BLACK_LIST
_white_list = FP16_WHITE_LIST
_black_list = FP16_BLACK_LIST
elif dtype == 'bfloat16':
_white_list = BF16_WHITE_LIST
_black_list = BF16_BLACK_LIST
......
......@@ -88,8 +88,8 @@ class TestAutoCast(unittest.TestCase):
def custom_op_list(self):
with fluid.dygraph.guard():
tracer = fluid.framework._dygraph_tracer()
base_white_list = paddle.amp.WHITE_LIST
base_black_list = paddle.amp.BLACK_LIST
base_white_list = paddle.amp.FP16_WHITE_LIST
base_black_list = paddle.amp.FP16_BLACK_LIST
with paddle.amp.amp_guard(
custom_white_list=["log"], custom_black_list=["conv2d"]
):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册