未验证 提交 680460fd 编写于 作者: Y Yiqun Liu 提交者: GitHub

[AMP] Allow to enable multi_precision through paddle.static.amp.decorate and...

[AMP] Allow to enable multi_precision through paddle.static.amp.decorate and add documents for some apis. (#53012)

* Add document for some apis. test=docs_preview

* Allow to set master_weight in paddle.static.amp.decorate.

* Polish codes and add unittest.

* Refine docs.

* Remove the repetitive function.
上级 7a9754a7
...@@ -13,8 +13,14 @@ ...@@ -13,8 +13,14 @@
# limitations under the License. # limitations under the License.
import copy import copy
import logging
import paddle import paddle
from paddle.fluid.log_helper import get_logger
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
class OperatorStatsUnit: class OperatorStatsUnit:
...@@ -76,7 +82,7 @@ def _get_var_dtype_from_block(block, op, arg_name, is_input): ...@@ -76,7 +82,7 @@ def _get_var_dtype_from_block(block, op, arg_name, is_input):
var = block._var_recursive(var_name) var = block._var_recursive(var_name)
return var.dtype return var.dtype
except: except:
print( _logger.warning(
"Operator < {} > gets {} < {} : {} > error!".format( "Operator < {} > gets {} < {} : {} > error!".format(
op.type, "input" if is_input else "output", arg_name, var_name op.type, "input" if is_input else "output", arg_name, var_name
) )
...@@ -99,7 +105,7 @@ def _extract_compute_dtype(op, block): ...@@ -99,7 +105,7 @@ def _extract_compute_dtype(op, block):
if _is_floating_point(compute_dtype) and _is_floating_point( if _is_floating_point(compute_dtype) and _is_floating_point(
var_dtype var_dtype
): ):
print( _logger.warning(
"Operator < {} > has different input data types, input_names = {}, output_names = {}.".format( "Operator < {} > has different input data types, input_names = {}, output_names = {}.".format(
op.type, op.input_names, op.output_names op.type, op.input_names, op.output_names
) )
...@@ -125,7 +131,7 @@ def _extract_compute_dtype(op, block): ...@@ -125,7 +131,7 @@ def _extract_compute_dtype(op, block):
if _is_floating_point(compute_dtype) and _is_floating_point( if _is_floating_point(compute_dtype) and _is_floating_point(
var_dtype var_dtype
): ):
print( _logger.warning(
"Operator < {} > has different input / output data types, input_names = {}, output_names = {}.".format( "Operator < {} > has different input / output data types, input_names = {}, output_names = {}.".format(
op.type, op.input_names, op.output_names op.type, op.input_names, op.output_names
) )
...@@ -145,6 +151,15 @@ def _merge_op_stats(op_stats_list): ...@@ -145,6 +151,15 @@ def _merge_op_stats(op_stats_list):
def _get_op_stats_list(program): def _get_op_stats_list(program):
def _is_special_ops_with_input_x(op_type):
# operators have input X and have inputs different dtypes.
special_op_list = ['cast', 'batch_norm', 'instance_norm', 'layer_norm']
if op_type in special_op_list:
return True
if op_type.replace("_grad", "") in special_op_list:
return True
return False
op_stats_list = [] op_stats_list = []
for block in program.blocks: for block in program.blocks:
block_op_stats_dict = {} block_op_stats_dict = {}
...@@ -161,13 +176,7 @@ def _get_op_stats_list(program): ...@@ -161,13 +176,7 @@ def _get_op_stats_list(program):
'create_double_buffer_reader', 'create_double_buffer_reader',
]: ]:
compute_dtype = None compute_dtype = None
elif op.type in [ elif _is_special_ops_with_input_x(op.type):
'cast',
'layer_norm',
'layer_norm_grad',
'batch_norm',
'batch_norm_grad',
]:
# Not check the input and output dtype difference for this operators. # Not check the input and output dtype difference for this operators.
compute_dtype = _get_var_dtype_from_block(block, op, 'X', True) compute_dtype = _get_var_dtype_from_block(block, op, 'X', True)
elif "Param" in op.input_names: elif "Param" in op.input_names:
...@@ -183,6 +192,78 @@ def _get_op_stats_list(program): ...@@ -183,6 +192,78 @@ def _get_op_stats_list(program):
def collect_operator_stats(program=None, print_subblocks=False): def collect_operator_stats(program=None, print_subblocks=False):
"""
Collect the number of operators for different data types through parsing
the program. The statistical data are categorized according to four data
types, namely float32, float16, bfloat16 and others.
Args:
program(Program, optional): The program to parse. Default None, and the default main_program will be parsed.
print_subblocks(bool, optional): Whether to print the operator stats for each subblock. Default False.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
class SimpleConvNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.conv = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=3)
self.linear = paddle.nn.Linear(in_features=26, out_features=10)
def forward(self, x):
out = self.conv(x)
out = paddle.nn.functional.relu(out)
out = self.linear(out)
out = paddle.nn.functional.softmax(out)
return out
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main_program, startup_program):
model = SimpleConvNet()
x = paddle.static.data(
name='input', shape=[None, 1, 28, 28], dtype='float32'
)
out = model(x)
loss = paddle.mean(out)
optimizer = paddle.optimizer.AdamW()
optimizer = paddle.static.amp.decorate(optimizer)
optimizer.minimize(loss)
paddle.static.amp.debugging.collect_operator_stats(main_program)
# <------------------------------------------------ op list of all blocks ------------------------------------------------->
# <------------------------------------------------------- op list -------------------------------------------------------->
# <--------------- Op Name ---------------- | -- FP16 Calls --- | -- BF16 Calls --- | --- FP32 Calls--- | -- Other Calls -->
# adamw | 0 | 0 | 4 | 0
# cast | 5 | 0 | 6 | 0
# check_finite_and_unscale | 0 | 0 | 1 | 0
# conv2d | 1 | 0 | 0 | 0
# conv2d_grad | 1 | 0 | 0 | 0
# elementwise_add | 2 | 0 | 0 | 0
# elementwise_add_grad | 2 | 0 | 0 | 0
# elementwise_mul | 0 | 0 | 1 | 0
# elementwise_mul_grad | 0 | 0 | 1 | 0
# fill_constant | 0 | 0 | 1 | 0
# matmul_v2 | 1 | 0 | 0 | 0
# matmul_v2_grad | 1 | 0 | 0 | 0
# memcpy | 0 | 0 | 0 | 1
# reduce_mean | 0 | 0 | 1 | 0
# reduce_mean_grad | 0 | 0 | 1 | 0
# relu | 1 | 0 | 0 | 0
# relu_grad | 1 | 0 | 0 | 0
# reshape2 | 0 | 0 | 1 | 0
# reshape2_grad | 0 | 0 | 1 | 0
# softmax | 0 | 0 | 1 | 0
# softmax_grad | 0 | 0 | 1 | 0
# update_loss_scaling | 0 | 0 | 1 | 0
# <----------------------------------------------------- op count: 22 ----------------------------------------------------->
"""
def _convert_to_list(op_stats_unit_dict): def _convert_to_list(op_stats_unit_dict):
for key, value in op_stats_unit_dict.items(): for key, value in op_stats_unit_dict.items():
op_stats_unit_dict[key] = value.convert_to_list() op_stats_unit_dict[key] = value.convert_to_list()
......
...@@ -34,6 +34,21 @@ from .fp16_utils import ( ...@@ -34,6 +34,21 @@ from .fp16_utils import (
from .function_overload import FunctionType, overload from .function_overload import FunctionType, overload
def _set_multi_precision(optimizer, multi_precision):
if not isinstance(
optimizer,
(paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer),
):
raise RuntimeError(
"Current AMP training level is O2, optimizer is expected to be paddle.optimizer.Optimizer or paddle.fluid.optimizer.Optimizer, but receive {}.".format(
type(optimizer)
)
)
if multi_precision and hasattr(optimizer, "_multi_precision"):
optimizer._multi_precision = multi_precision
class OptimizerWithMixedPrecision: class OptimizerWithMixedPrecision:
""" """
Optimizer with mixed-precision (MP) training. This is a wrapper of a common Optimizer with mixed-precision (MP) training. This is a wrapper of a common
...@@ -767,22 +782,96 @@ def decorate( ...@@ -767,22 +782,96 @@ def decorate(
amp_lists=None, amp_lists=None,
level='O1', level='O1',
dtype='float16', dtype='float16',
master_weight=None,
init_loss_scaling=2**15, init_loss_scaling=2**15,
incr_every_n_steps=1000, incr_every_n_steps=1000,
decr_every_n_nan_or_inf=2, decr_every_n_nan_or_inf=2,
incr_ratio=2.0, incr_ratio=2.0,
decr_ratio=0.8, decr_ratio=0.8,
use_dynamic_loss_scaling=True, use_dynamic_loss_scaling=None,
use_amp_guard=False, use_amp_guard=False,
use_promote=False, use_promote=False,
): ):
""" """
Decorate the given optimizer to adapt to the mixed-precision training. Decorate the given optimizer to adapt to the mixed-precision training.
"""
amp_dtype = check_amp_dtype(dtype)
if amp_lists is None:
amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype)
Args:
optimizer(Optimizer): A common Optimizer.
amp_lists(CustomOpLists, optional): An CustomOpLists object. The default
white_list and black_list will be used for AMP training when it is
not set. Default is None.
level(str, optional): Auto mixed precision level. Accepted values are
"O1" and "O2": O1 represent mixed precision, the input data type of
each operator will be casted by white_list and black_list;
O2 represent pure FP16 / BF16 training, all operators parameters
and input data will be casted to FP16 / BF16, except operators in
black_list, don't support FP16 / BF16 kernel and batch_norm. Default is O1.
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
master_weight(bool, optinal): For level='O2', whether to use multi-precision
during weight updating. If master_weight is None, in O2 level optimizer
will use multi-precision. Default is None.
init_loss_scaling(float, optional): The initial loss scaling factor.
Default is 32768.
incr_every_n_steps(int, optional): Increases loss scaling every n
consecutive steps with finite gradients. Default is 1000.
decr_every_n_nan_or_inf(int, optional): Decreases loss scaling every n
accumulated steps with nan or inf gradients. Default is 2.
incr_ratio(float, optional): The multiplier to use when increasing the
loss scaling. Default is 2.
decr_ratio(float, optional): The less-than-one-multiplier to use when
decreasing the loss scaling. Default is 0.8.
use_dynamic_loss_scaling(bool, None): Whether to use dynamic loss
scaling. Default is None, which means True for float16, and False
for bfloat16.
Returns:
An optimizer acting like a normal one but with mixed-precision training
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
class SimpleConvNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.conv = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=3)
self.linear = paddle.nn.Linear(in_features=26, out_features=10)
def forward(self, x):
out = self.conv(x)
out = paddle.nn.functional.relu(out)
out = self.linear(out)
out = paddle.nn.functional.softmax(out)
return out
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main_program, startup_program):
model = SimpleConvNet()
x = paddle.static.data(
name='input', shape=[None, 1, 28, 28], dtype='float32'
)
out = model(x)
loss = paddle.mean(out)
optimizer = paddle.optimizer.AdamW()
optimizer = paddle.static.amp.decorate(optimizer, level="O2", dtype="float16")
optimizer.minimize(loss)
if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(startup_program)
# Call `amp_init` after FP32 parameters initialization, such as `exe.run(startup_program)`,
# to convert FP32 parameters to low precision FP16 / BF16.
optimizer.amp_init(place, scope=paddle.static.global_scope())
"""
# check amp_level: O0-O2 # check amp_level: O0-O2
level = level.upper() level = level.upper()
if not (level in ['O0', 'O1', 'O2']): if not (level in ['O0', 'O1', 'O2']):
...@@ -790,6 +879,18 @@ def decorate( ...@@ -790,6 +879,18 @@ def decorate(
"level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode." "level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode."
) )
amp_dtype = check_amp_dtype(dtype)
if amp_lists is None:
amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype)
if use_dynamic_loss_scaling is None:
use_dynamic_loss_scaling = dtype == "float16"
if optimizer is not None:
# support master_weight
multi_precision = not (master_weight is False)
_set_multi_precision(optimizer, multi_precision)
mp_optimizer = OptimizerWithMixedPrecision( mp_optimizer = OptimizerWithMixedPrecision(
optimizer, optimizer,
amp_lists, amp_lists,
......
...@@ -42,7 +42,6 @@ def _build_optimizer( ...@@ -42,7 +42,6 @@ def _build_optimizer(
beta2=0.836, beta2=0.836,
epsilon=1e-4, epsilon=1e-4,
weight_decay=0.01, weight_decay=0.01,
multi_precision=True,
) )
if use_amp: if use_amp:
optimizer = paddle.static.amp.decorate( optimizer = paddle.static.amp.decorate(
......
...@@ -221,11 +221,29 @@ class TestModelCastBF16(unittest.TestCase): ...@@ -221,11 +221,29 @@ class TestModelCastBF16(unittest.TestCase):
class TestProgramBF16(AmpTestBase): class TestProgramBF16(AmpTestBase):
def _check_optimizer(self, program, expected_num_mp):
optimizers = []
for block in program.blocks:
for op in block.ops:
if "Param" in op.input_names and "Grad" in op.input_names:
optimizers.append(op)
actual_num_mp = 0
for op in optimizers:
if op.has_attr("multi_precision") and op.attr("multi_precision"):
actual_num_mp += 1
self.assertEqual(
actual_num_mp,
expected_num_mp,
f"The number of optimizers with multi_precison = True is expected to be {expected_num_mp}, but recieved {actual_num_mp}.",
)
def test_amp_bf16_o1(self): def test_amp_bf16_o1(self):
main_program, startup_program = build_embedding_model( main_program, startup_program = build_embedding_model(
True, "bfloat16", "O1" True, "bfloat16", "O1"
) )
self.assertEqual(main_program.num_blocks, 1) self.assertEqual(main_program.num_blocks, 1)
self._check_optimizer(main_program, 0)
amp.debugging.collect_operator_stats(main_program) amp.debugging.collect_operator_stats(main_program)
op_stats_list = amp.debugging._get_op_stats_list(main_program) op_stats_list = amp.debugging._get_op_stats_list(main_program)
...@@ -255,6 +273,11 @@ class TestProgramBF16(AmpTestBase): ...@@ -255,6 +273,11 @@ class TestProgramBF16(AmpTestBase):
"squared_l2_norm": 2, "squared_l2_norm": 2,
"adamw": 2, "adamw": 2,
} }
self._check_optimizer(
main_program,
expected_bf16_calls["matmul_v2"]
+ expected_bf16_calls["elementwise_add"],
)
self._check_op_calls(op_stats_list[0], expected_bf16_calls) self._check_op_calls(op_stats_list[0], expected_bf16_calls)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册