未验证 提交 c2c3bd43 编写于 作者: N niuliling123 提交者: GitHub

[AMP] support OD level for static (#53768)

上级 52889e38
......@@ -61,12 +61,11 @@ class OptimizerWithMixedPrecision:
Args:
optimizer (Optimizer): A common Optimizer object.
amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object.
level(str): Auto mixed precision level. Accepted values are
"O1" and "O2": O1 represent mixed precision, the input data type
of each operator will be casted by white_list and black_list;
O2 represent Pure fp16 or bf16, all operators parameters and input
data will be casted to fp16 or bf16, except operators in black_list,
don't support fp16 or bf16 kernel and batch_norm.
level(str): Auto mixed precision level. Accepted values are "O1", "O2" and "OD": At the O1 level, operators in the white list
will use float16/bfloat16 inputs for calculations, and operators in the black list will use float32 inputs for calculations. At the O2
level, model's parameters will be casted to float16/bfloat16 by using `decorator`, and operators that have all float16/bfloat16 inputs
will be converted to float16/bfloat16, and that have any float32 input will be converted to float32. For the OD level, operators in
default white list will compute in float16/bfloat16.
dtype(str): Whether to use 'float16' or 'bfloat16'.
init_loss_scaling (float): The initial loss scaling factor.
use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling.
......@@ -123,6 +122,7 @@ class OptimizerWithMixedPrecision:
self._learning_rate = optimizer._learning_rate
self._learning_rate_map = optimizer._learning_rate_map
self._use_pure_fp16 = level == "O2"
self._amp_level = level
self._use_fp16_guard = use_amp_guard
self._to_fp16_var_names = None
if self._use_dynamic_loss_scaling:
......@@ -241,7 +241,7 @@ class OptimizerWithMixedPrecision:
self._amp_lists,
use_fp16_guard=False,
dest_type=self._amp_vartype,
level='O1',
level=self._amp_level,
use_promote=self.use_promote,
)
......@@ -380,7 +380,7 @@ class OptimizerWithMixedPrecision:
self._amp_lists,
use_fp16_guard=False,
dest_type=self._amp_vartype,
level='O1',
level=self._amp_level,
use_promote=self.use_promote,
)
......@@ -773,12 +773,11 @@ def decorate(
amp_lists(CustomOpLists, optional): An CustomOpLists object. The default
white_list and black_list will be used for AMP training when it is
not set. Default is None.
level(str, optional): Auto mixed precision level. Accepted values are
"O1" and "O2": O1 represent mixed precision, the input data type of
each operator will be casted by white_list and black_list;
O2 represent pure FP16 / BF16 training, all operators parameters
and input data will be casted to FP16 / BF16, except operators in
black_list, don't support FP16 / BF16 kernel and batch_norm. Default is O1.
level(str, optional): Auto mixed precision level. Accepted values are "O1", "O2" and "OD": At the O1 level, operators in the white list
will use float16/bfloat16 inputs for calculations, and operators in the black list will use float32 inputs for calculations. At the O2
level, model's parameters will be casted to float16/bfloat16 by using `decorator`, and operators that have all float16/bfloat16 inputs
will be converted to float16/bfloat16, and that have any float32 input will be converted to float32. For the OD level, operators in
default white list will compute in float16/bfloat16, and the others will compute in float32. Default is O1.
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
master_weight(bool, optinal): For level='O2', whether to use multi-precision
during weight updating. If master_weight is None, in O2 level optimizer
......@@ -847,15 +846,22 @@ def decorate(
"""
# check amp_level: O0-O2
level = level.upper()
if not (level in ['O0', 'O1', 'O2']):
raise ValueError(
"level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode."
)
if not (level in ['O0', 'OD', 'O1', 'O2']):
raise ValueError("level should be O0, OD, O1 or O2.")
amp_dtype = check_amp_dtype(dtype)
if amp_lists is None:
if amp_lists is None or level == 'OD':
amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype)
if level == 'OD':
if amp_lists is not None:
warnings.warn(
"If the Amp level is set to OD, the amp list will not be used."
)
amp_lists.white_list = {"conv2d", "matmul_v2"}
amp_lists.black_list = amp_lists.all_list - amp_lists.white_list
if use_dynamic_loss_scaling is None:
use_dynamic_loss_scaling = dtype == "float16"
......
......@@ -99,7 +99,7 @@ def _get_sys_unsupported_list(dtype):
device = 'XPU'
else:
device = 'GPU'
_, _, sys_unsupported_list = core.op_supported_infos(device, var_type)
all_ops, _, sys_unsupported_list = core.op_supported_infos(device, var_type)
# sys_unsupported_list will include the following ops.
supported_fp16_list = {
......@@ -114,13 +114,13 @@ def _get_sys_unsupported_list(dtype):
}
sys_unsupported_list -= supported_fp16_list
return device, sys_unsupported_list
return device, sys_unsupported_list, all_ops
def _get_unsupported_list(dtype):
# The set of ops that don't support fp16 calculation
_, _sys_unsupported_list = _get_sys_unsupported_list(dtype)
return _sys_unsupported_list
_, _sys_unsupported_list, _sys_all_list = _get_sys_unsupported_list(dtype)
return _sys_unsupported_list, _sys_all_list
# The three sets listed below are changed dynamiclly. They don't contain all
......@@ -200,7 +200,9 @@ class AutoMixedPrecisionLists:
self.white_list = copy.copy(_get_white_list(self.amp_dtype))
self.black_list = copy.copy(_get_black_list())
self.gray_list = copy.copy(gray_list)
self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype))
unsupported_list, sys_all_list = _get_unsupported_list(self.amp_dtype)
self.unsupported_list = copy.copy(unsupported_list)
self.all_list = copy.copy(sys_all_list)
self.black_varnames = copy.copy(custom_black_varnames)
self._update_list()
......@@ -232,7 +234,9 @@ class AutoMixedPrecisionLists:
self.gray_list.remove(op_name)
self.black_list.add(op_name)
self.unsupported_list.add(op_name)
device, sys_unsupported_list = _get_sys_unsupported_list(self.amp_dtype)
device, sys_unsupported_list, _ = _get_sys_unsupported_list(
self.amp_dtype
)
actual_unsupported_list = []
for op_name in sys_unsupported_list:
if op_name in self.white_list:
......
......@@ -426,7 +426,7 @@ def set_var_dst_dtype(
def set_param_dtype(program, dtype, amp_lists, use_fp16_guard, level):
keep_fp32_var_names = set()
if level == "O1":
if level == "O1" or level == "OD":
return keep_fp32_var_names
all_parameters = []
for block in program.blocks:
......@@ -618,6 +618,14 @@ def cast_model_to_fp16(
if level == 'O2':
amp_lists.black_list = amp_lists.black_list - black_list
if level == 'OD':
if amp_lists is not None:
dtype = get_low_precision_dtypestr(dest_type)
amp_lists = AutoMixedPrecisionLists(dtype)
amp_lists.white_list = {"conv2d", "matmul_v2"}
amp_lists.black_list = amp_lists.all_list - amp_lists.white_list
global_block = program.global_block()
keep_fp32_ops = set()
keep_fp16_ops = set()
......
......@@ -14,9 +14,11 @@
import unittest
from amp_base_models import AmpTestBase
import numpy as np
from amp_base_models import AmpTestBase, build_conv_model
import paddle
from paddle.static import amp
class TestAutoCast(AmpTestBase):
......@@ -37,6 +39,72 @@ class TestAutoCast(AmpTestBase):
self.assertEqual(out3.dtype, paddle.float32)
class TestStaticDecorate(AmpTestBase):
def check_results(
self, use_amp, dtype, level, use_promote, expected_op_calls
):
(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
) = build_conv_model(use_amp, dtype, level, use_promote)
self.assertEqual(main_program.num_blocks, 1)
optimizer = paddle.fluid.optimizer.Adadelta(learning_rate=0.001)
optimizer = paddle.static.amp.decorate(
optimizer,
init_loss_scaling=128.0,
use_dynamic_loss_scaling=True,
level=level,
)
amp.debugging.collect_operator_stats(main_program)
op_stats_list = amp.debugging._get_op_stats_list(main_program)
self._check_op_calls(
op_stats_list[0], expected_fp16_calls=expected_op_calls
)
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
max_iters = 2
x_fp32 = np.random.random(size=[1, 1, 6, 6]).astype("float32")
losses_o1 = self.run_program(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
place,
exe,
x_fp32,
max_iters,
level,
)
def test_static_amp_o1(self):
paddle.enable_static()
expected_fp16_calls = {
"conv2d": 1,
"elementwise_add": 0,
"relu": 0,
"matmul_v2": 1,
"softmax": 0,
"reduce_mean": 0,
"adamw": 0,
}
self.check_results(
True,
'float16',
'OD',
use_promote=True,
expected_op_calls=expected_fp16_calls,
)
paddle.disable_static()
class TestGradScaler(AmpTestBase):
def test_amp_grad_scaler(self):
model = paddle.nn.Conv2D(3, 2, 3)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册