未验证 提交 6f8ab1fa 编写于 作者: Z Zhang Ting 提交者: GitHub

[AMP] use promote dtype when amp_level=O2 (#51063)

上级 5cdd9f2c
...@@ -122,10 +122,9 @@ inline phi::DataType GetAmpDestDtype( ...@@ -122,10 +122,9 @@ inline phi::DataType GetAmpDestDtype(
auto amp_setting_dtype = auto amp_setting_dtype =
egr::Controller::Instance().GetCurrentTracer()->GetAmpPhiDtype(); egr::Controller::Instance().GetCurrentTracer()->GetAmpPhiDtype();
auto dst_type = amp_setting_dtype; auto dst_type = amp_setting_dtype;
if (amp_level == paddle::imperative::AmpLevel::O1) {
if (paddle::imperative::AmpOperators::Instance() if (paddle::imperative::AmpOperators::Instance().GetMutableAllowOps()->count(
.GetMutableAllowOps() op_name)) {
->count(op_name)) {
dst_type = amp_setting_dtype; dst_type = amp_setting_dtype;
} else if (paddle::imperative::AmpOperators::Instance() } else if (paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps() .GetMutableBlockOps()
...@@ -134,13 +133,6 @@ inline phi::DataType GetAmpDestDtype( ...@@ -134,13 +133,6 @@ inline phi::DataType GetAmpDestDtype(
} else { } else {
dst_type = GetPromoteType(op_name, amp_tensors_vector, amp_setting_dtype); dst_type = GetPromoteType(op_name, amp_tensors_vector, amp_setting_dtype);
} }
} else if (amp_level == paddle::imperative::AmpLevel::O2) {
if (paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps()
->count(op_name)) {
dst_type = phi::DataType::FLOAT32;
}
}
if (dst_type == amp_setting_dtype && if (dst_type == amp_setting_dtype &&
(paddle::imperative::AmpOperators::Instance() (paddle::imperative::AmpOperators::Instance()
......
...@@ -17,8 +17,8 @@ from .auto_cast import decorate # noqa: F401 ...@@ -17,8 +17,8 @@ from .auto_cast import decorate # noqa: F401
from .auto_cast import amp_guard # noqa: F401 from .auto_cast import amp_guard # noqa: F401
from .auto_cast import amp_decorate # noqa: F401 from .auto_cast import amp_decorate # noqa: F401
from .auto_cast import low_precision_op_list # noqa: F401 from .auto_cast import low_precision_op_list # noqa: F401
from .auto_cast import WHITE_LIST # noqa: F401 from .auto_cast import FP16_WHITE_LIST # noqa: F401
from .auto_cast import BLACK_LIST # noqa: F401 from .auto_cast import FP16_BLACK_LIST # noqa: F401
from .auto_cast import PURE_FP16_WHITE_LIST # noqa: F401 from .auto_cast import PURE_FP16_WHITE_LIST # noqa: F401
from .auto_cast import PURE_FP16_BLACK_LIST # noqa: F401 from .auto_cast import PURE_FP16_BLACK_LIST # noqa: F401
......
...@@ -25,7 +25,7 @@ AMP_LEVEL = core.AmpLevel ...@@ -25,7 +25,7 @@ AMP_LEVEL = core.AmpLevel
# The set of ops that support fp16 calculation and are considered numerically- # The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16. # safe and performance-critical. These ops are always converted to fp16.
WHITE_LIST = { FP16_WHITE_LIST = {
'conv2d', 'conv2d',
'matmul', 'matmul',
'matmul_v2', 'matmul_v2',
...@@ -37,7 +37,7 @@ WHITE_LIST = { ...@@ -37,7 +37,7 @@ WHITE_LIST = {
# The set of ops that support fp16 calculation and are considered numerically- # The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops. # dangerous and whose effects may also be observed in downstream ops.
BLACK_LIST = { FP16_BLACK_LIST = {
'exp', 'exp',
'square', 'square',
'log', 'log',
...@@ -73,7 +73,8 @@ AMP_RELATED_FLAGS_SETTING = { ...@@ -73,7 +73,8 @@ AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_batchnorm_spatial_persistent': 1, 'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
} }
PURE_FP16_WHITE_LIST = set() PURE_FP16_WHITE_LIST = copy.copy(FP16_WHITE_LIST)
PURE_FP16_BLACK_LIST = { PURE_FP16_BLACK_LIST = {
'lookup_table', 'lookup_table',
'lookup_table_v2', 'lookup_table_v2',
...@@ -90,7 +91,7 @@ PURE_FP16_BLACK_LIST = { ...@@ -90,7 +91,7 @@ PURE_FP16_BLACK_LIST = {
BF16_WHITE_LIST = {'conv2d', 'matmul_v2'} BF16_WHITE_LIST = {'conv2d', 'matmul_v2'}
BF16_BLACK_LIST = set() BF16_BLACK_LIST = set()
PURE_BF16_WHITE_LIST = set() PURE_BF16_WHITE_LIST = copy.copy(BF16_WHITE_LIST)
PURE_BF16_BLACK_LIST = set() PURE_BF16_BLACK_LIST = set()
_g_amp_state_ = None _g_amp_state_ = None
...@@ -139,8 +140,8 @@ def _update_list( ...@@ -139,8 +140,8 @@ def _update_list(
""" """
if dtype == 'float16': if dtype == 'float16':
if level == 'O1': if level == 'O1':
_white_list = copy.copy(WHITE_LIST) _white_list = copy.copy(FP16_WHITE_LIST)
_black_list = copy.copy(BLACK_LIST) _black_list = copy.copy(FP16_BLACK_LIST)
else: else:
_white_list = copy.copy(PURE_FP16_WHITE_LIST) _white_list = copy.copy(PURE_FP16_WHITE_LIST)
_black_list = copy.copy(PURE_FP16_BLACK_LIST) _black_list = copy.copy(PURE_FP16_BLACK_LIST)
...@@ -424,8 +425,8 @@ def amp_guard( ...@@ -424,8 +425,8 @@ def amp_guard(
if level == 'O1': if level == 'O1':
amp_level = AMP_LEVEL.O1 amp_level = AMP_LEVEL.O1
if dtype == 'float16': if dtype == 'float16':
_white_list = WHITE_LIST _white_list = FP16_WHITE_LIST
_black_list = BLACK_LIST _black_list = FP16_BLACK_LIST
elif dtype == 'bfloat16': elif dtype == 'bfloat16':
_white_list = BF16_WHITE_LIST _white_list = BF16_WHITE_LIST
_black_list = BF16_BLACK_LIST _black_list = BF16_BLACK_LIST
...@@ -441,8 +442,8 @@ def amp_guard( ...@@ -441,8 +442,8 @@ def amp_guard(
elif level == 'O0': elif level == 'O0':
amp_level = AMP_LEVEL.O0 amp_level = AMP_LEVEL.O0
if dtype == 'float16': if dtype == 'float16':
_white_list = WHITE_LIST _white_list = FP16_WHITE_LIST
_black_list = BLACK_LIST _black_list = FP16_BLACK_LIST
elif dtype == 'bfloat16': elif dtype == 'bfloat16':
_white_list = BF16_WHITE_LIST _white_list = BF16_WHITE_LIST
_black_list = BF16_BLACK_LIST _black_list = BF16_BLACK_LIST
......
...@@ -88,8 +88,8 @@ class TestAutoCast(unittest.TestCase): ...@@ -88,8 +88,8 @@ class TestAutoCast(unittest.TestCase):
def custom_op_list(self): def custom_op_list(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
tracer = fluid.framework._dygraph_tracer() tracer = fluid.framework._dygraph_tracer()
base_white_list = paddle.amp.WHITE_LIST base_white_list = paddle.amp.FP16_WHITE_LIST
base_black_list = paddle.amp.BLACK_LIST base_black_list = paddle.amp.FP16_BLACK_LIST
with paddle.amp.amp_guard( with paddle.amp.amp_guard(
custom_white_list=["log"], custom_black_list=["conv2d"] custom_white_list=["log"], custom_black_list=["conv2d"]
): ):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册