未验证 提交 c2c3bd43 编写于 作者: N niuliling123 提交者: GitHub

[AMP] support OD level for static (#53768)

上级 52889e38
...@@ -61,12 +61,11 @@ class OptimizerWithMixedPrecision: ...@@ -61,12 +61,11 @@ class OptimizerWithMixedPrecision:
Args: Args:
optimizer (Optimizer): A common Optimizer object. optimizer (Optimizer): A common Optimizer object.
amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object. amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object.
level(str): Auto mixed precision level. Accepted values are level(str): Auto mixed precision level. Accepted values are "O1", "O2" and "OD": At the O1 level, operators in the white list
"O1" and "O2": O1 represent mixed precision, the input data type will use float16/bfloat16 inputs for calculations, and operators in the black list will use float32 inputs for calculations. At the O2
of each operator will be casted by white_list and black_list; level, model's parameters will be casted to float16/bfloat16 by using `decorator`, and operators that have all float16/bfloat16 inputs
O2 represent Pure fp16 or bf16, all operators parameters and input will be converted to float16/bfloat16, and that have any float32 input will be converted to float32. For the OD level, operators in
data will be casted to fp16 or bf16, except operators in black_list, default white list will compute in float16/bfloat16.
don't support fp16 or bf16 kernel and batch_norm.
dtype(str): Whether to use 'float16' or 'bfloat16'. dtype(str): Whether to use 'float16' or 'bfloat16'.
init_loss_scaling (float): The initial loss scaling factor. init_loss_scaling (float): The initial loss scaling factor.
use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling. use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling.
...@@ -123,6 +122,7 @@ class OptimizerWithMixedPrecision: ...@@ -123,6 +122,7 @@ class OptimizerWithMixedPrecision:
self._learning_rate = optimizer._learning_rate self._learning_rate = optimizer._learning_rate
self._learning_rate_map = optimizer._learning_rate_map self._learning_rate_map = optimizer._learning_rate_map
self._use_pure_fp16 = level == "O2" self._use_pure_fp16 = level == "O2"
self._amp_level = level
self._use_fp16_guard = use_amp_guard self._use_fp16_guard = use_amp_guard
self._to_fp16_var_names = None self._to_fp16_var_names = None
if self._use_dynamic_loss_scaling: if self._use_dynamic_loss_scaling:
...@@ -241,7 +241,7 @@ class OptimizerWithMixedPrecision: ...@@ -241,7 +241,7 @@ class OptimizerWithMixedPrecision:
self._amp_lists, self._amp_lists,
use_fp16_guard=False, use_fp16_guard=False,
dest_type=self._amp_vartype, dest_type=self._amp_vartype,
level='O1', level=self._amp_level,
use_promote=self.use_promote, use_promote=self.use_promote,
) )
...@@ -380,7 +380,7 @@ class OptimizerWithMixedPrecision: ...@@ -380,7 +380,7 @@ class OptimizerWithMixedPrecision:
self._amp_lists, self._amp_lists,
use_fp16_guard=False, use_fp16_guard=False,
dest_type=self._amp_vartype, dest_type=self._amp_vartype,
level='O1', level=self._amp_level,
use_promote=self.use_promote, use_promote=self.use_promote,
) )
...@@ -773,12 +773,11 @@ def decorate( ...@@ -773,12 +773,11 @@ def decorate(
amp_lists(CustomOpLists, optional): An CustomOpLists object. The default amp_lists(CustomOpLists, optional): An CustomOpLists object. The default
white_list and black_list will be used for AMP training when it is white_list and black_list will be used for AMP training when it is
not set. Default is None. not set. Default is None.
level(str, optional): Auto mixed precision level. Accepted values are level(str, optional): Auto mixed precision level. Accepted values are "O1", "O2" and "OD": At the O1 level, operators in the white list
"O1" and "O2": O1 represent mixed precision, the input data type of will use float16/bfloat16 inputs for calculations, and operators in the black list will use float32 inputs for calculations. At the O2
each operator will be casted by white_list and black_list; level, model's parameters will be casted to float16/bfloat16 by using `decorator`, and operators that have all float16/bfloat16 inputs
O2 represent pure FP16 / BF16 training, all operators parameters will be converted to float16/bfloat16, and that have any float32 input will be converted to float32. For the OD level, operators in
and input data will be casted to FP16 / BF16, except operators in default white list will compute in float16/bfloat16, and the others will compute in float32. Default is O1.
black_list, don't support FP16 / BF16 kernel and batch_norm. Default is O1.
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'. dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
master_weight(bool, optinal): For level='O2', whether to use multi-precision master_weight(bool, optinal): For level='O2', whether to use multi-precision
during weight updating. If master_weight is None, in O2 level optimizer during weight updating. If master_weight is None, in O2 level optimizer
...@@ -847,15 +846,22 @@ def decorate( ...@@ -847,15 +846,22 @@ def decorate(
""" """
# check amp_level: O0-O2 # check amp_level: O0-O2
level = level.upper() level = level.upper()
if not (level in ['O0', 'O1', 'O2']): if not (level in ['O0', 'OD', 'O1', 'O2']):
raise ValueError( raise ValueError("level should be O0, OD, O1 or O2.")
"level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode."
)
amp_dtype = check_amp_dtype(dtype) amp_dtype = check_amp_dtype(dtype)
if amp_lists is None: if amp_lists is None or level == 'OD':
amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype) amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype)
if level == 'OD':
if amp_lists is not None:
warnings.warn(
"If the Amp level is set to OD, the amp list will not be used."
)
amp_lists.white_list = {"conv2d", "matmul_v2"}
amp_lists.black_list = amp_lists.all_list - amp_lists.white_list
if use_dynamic_loss_scaling is None: if use_dynamic_loss_scaling is None:
use_dynamic_loss_scaling = dtype == "float16" use_dynamic_loss_scaling = dtype == "float16"
......
...@@ -99,7 +99,7 @@ def _get_sys_unsupported_list(dtype): ...@@ -99,7 +99,7 @@ def _get_sys_unsupported_list(dtype):
device = 'XPU' device = 'XPU'
else: else:
device = 'GPU' device = 'GPU'
_, _, sys_unsupported_list = core.op_supported_infos(device, var_type) all_ops, _, sys_unsupported_list = core.op_supported_infos(device, var_type)
# sys_unsupported_list will include the following ops. # sys_unsupported_list will include the following ops.
supported_fp16_list = { supported_fp16_list = {
...@@ -114,13 +114,13 @@ def _get_sys_unsupported_list(dtype): ...@@ -114,13 +114,13 @@ def _get_sys_unsupported_list(dtype):
} }
sys_unsupported_list -= supported_fp16_list sys_unsupported_list -= supported_fp16_list
return device, sys_unsupported_list return device, sys_unsupported_list, all_ops
def _get_unsupported_list(dtype): def _get_unsupported_list(dtype):
# The set of ops that don't support fp16 calculation # The set of ops that don't support fp16 calculation
_, _sys_unsupported_list = _get_sys_unsupported_list(dtype) _, _sys_unsupported_list, _sys_all_list = _get_sys_unsupported_list(dtype)
return _sys_unsupported_list return _sys_unsupported_list, _sys_all_list
# The three sets listed below are changed dynamiclly. They don't contain all # The three sets listed below are changed dynamiclly. They don't contain all
...@@ -200,7 +200,9 @@ class AutoMixedPrecisionLists: ...@@ -200,7 +200,9 @@ class AutoMixedPrecisionLists:
self.white_list = copy.copy(_get_white_list(self.amp_dtype)) self.white_list = copy.copy(_get_white_list(self.amp_dtype))
self.black_list = copy.copy(_get_black_list()) self.black_list = copy.copy(_get_black_list())
self.gray_list = copy.copy(gray_list) self.gray_list = copy.copy(gray_list)
self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype)) unsupported_list, sys_all_list = _get_unsupported_list(self.amp_dtype)
self.unsupported_list = copy.copy(unsupported_list)
self.all_list = copy.copy(sys_all_list)
self.black_varnames = copy.copy(custom_black_varnames) self.black_varnames = copy.copy(custom_black_varnames)
self._update_list() self._update_list()
...@@ -232,7 +234,9 @@ class AutoMixedPrecisionLists: ...@@ -232,7 +234,9 @@ class AutoMixedPrecisionLists:
self.gray_list.remove(op_name) self.gray_list.remove(op_name)
self.black_list.add(op_name) self.black_list.add(op_name)
self.unsupported_list.add(op_name) self.unsupported_list.add(op_name)
device, sys_unsupported_list = _get_sys_unsupported_list(self.amp_dtype) device, sys_unsupported_list, _ = _get_sys_unsupported_list(
self.amp_dtype
)
actual_unsupported_list = [] actual_unsupported_list = []
for op_name in sys_unsupported_list: for op_name in sys_unsupported_list:
if op_name in self.white_list: if op_name in self.white_list:
......
...@@ -426,7 +426,7 @@ def set_var_dst_dtype( ...@@ -426,7 +426,7 @@ def set_var_dst_dtype(
def set_param_dtype(program, dtype, amp_lists, use_fp16_guard, level): def set_param_dtype(program, dtype, amp_lists, use_fp16_guard, level):
keep_fp32_var_names = set() keep_fp32_var_names = set()
if level == "O1": if level == "O1" or level == "OD":
return keep_fp32_var_names return keep_fp32_var_names
all_parameters = [] all_parameters = []
for block in program.blocks: for block in program.blocks:
...@@ -618,6 +618,14 @@ def cast_model_to_fp16( ...@@ -618,6 +618,14 @@ def cast_model_to_fp16(
if level == 'O2': if level == 'O2':
amp_lists.black_list = amp_lists.black_list - black_list amp_lists.black_list = amp_lists.black_list - black_list
if level == 'OD':
if amp_lists is not None:
dtype = get_low_precision_dtypestr(dest_type)
amp_lists = AutoMixedPrecisionLists(dtype)
amp_lists.white_list = {"conv2d", "matmul_v2"}
amp_lists.black_list = amp_lists.all_list - amp_lists.white_list
global_block = program.global_block() global_block = program.global_block()
keep_fp32_ops = set() keep_fp32_ops = set()
keep_fp16_ops = set() keep_fp16_ops = set()
......
...@@ -14,9 +14,11 @@ ...@@ -14,9 +14,11 @@
import unittest import unittest
from amp_base_models import AmpTestBase import numpy as np
from amp_base_models import AmpTestBase, build_conv_model
import paddle import paddle
from paddle.static import amp
class TestAutoCast(AmpTestBase): class TestAutoCast(AmpTestBase):
...@@ -37,6 +39,72 @@ class TestAutoCast(AmpTestBase): ...@@ -37,6 +39,72 @@ class TestAutoCast(AmpTestBase):
self.assertEqual(out3.dtype, paddle.float32) self.assertEqual(out3.dtype, paddle.float32)
class TestStaticDecorate(AmpTestBase):
def check_results(
self, use_amp, dtype, level, use_promote, expected_op_calls
):
(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
) = build_conv_model(use_amp, dtype, level, use_promote)
self.assertEqual(main_program.num_blocks, 1)
optimizer = paddle.fluid.optimizer.Adadelta(learning_rate=0.001)
optimizer = paddle.static.amp.decorate(
optimizer,
init_loss_scaling=128.0,
use_dynamic_loss_scaling=True,
level=level,
)
amp.debugging.collect_operator_stats(main_program)
op_stats_list = amp.debugging._get_op_stats_list(main_program)
self._check_op_calls(
op_stats_list[0], expected_fp16_calls=expected_op_calls
)
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
max_iters = 2
x_fp32 = np.random.random(size=[1, 1, 6, 6]).astype("float32")
losses_o1 = self.run_program(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
place,
exe,
x_fp32,
max_iters,
level,
)
def test_static_amp_o1(self):
paddle.enable_static()
expected_fp16_calls = {
"conv2d": 1,
"elementwise_add": 0,
"relu": 0,
"matmul_v2": 1,
"softmax": 0,
"reduce_mean": 0,
"adamw": 0,
}
self.check_results(
True,
'float16',
'OD',
use_promote=True,
expected_op_calls=expected_fp16_calls,
)
paddle.disable_static()
class TestGradScaler(AmpTestBase): class TestGradScaler(AmpTestBase):
def test_amp_grad_scaler(self): def test_amp_grad_scaler(self):
model = paddle.nn.Conv2D(3, 2, 3) model = paddle.nn.Conv2D(3, 2, 3)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册