未验证 提交 dfcba7f4 编写于 作者: Y Yiqun Liu 提交者: GitHub

[AMP] Unify the static amp codes of fp16 and bf16. (#52694)

* Unify the static amp codes of fp16 and bf16.

* Polish apis and add unittest.

* Add operator stats collecting tools for program.

* Add the check of number of bloat16 operators in unittest.

* Add warning for operator not supported for amp.

* Add testing of BF16 O1 and O2.
上级 f8d09011
...@@ -271,7 +271,7 @@ def _print_operator_stats(op_count_dict): ...@@ -271,7 +271,7 @@ def _print_operator_stats(op_count_dict):
"{:-^17}>".format(" Other Calls "), "{:-^17}>".format(" Other Calls "),
) )
if op_count_dict is not None and isinstance(op_count_dict, dict): if op_count_dict is not None and isinstance(op_count_dict, dict):
for op_type in op_count_dict: for op_type in sorted(op_count_dict):
# fp16, bf16, fp32, other # fp16, bf16, fp32, other
value = op_count_dict[op_type] value = op_count_dict[op_type]
if isinstance(value, list): if isinstance(value, list):
......
...@@ -189,15 +189,20 @@ class PartialProgramLayer: ...@@ -189,15 +189,20 @@ class PartialProgramLayer:
self._infer_info = ProgramInfo() self._infer_info = ProgramInfo()
self._forward_end_index_map = {} self._forward_end_index_map = {}
custom_white_list, custom_black_list = None, None amp_dtype, custom_white_list, custom_black_list = None, None, None
tracer = framework._dygraph_tracer() tracer = framework._dygraph_tracer()
if tracer: if tracer:
custom_white_list, custom_black_list = tracer._get_amp_op_list() custom_white_list, custom_black_list = tracer._get_amp_op_list()
# For AMP training amp_dtype = tracer._amp_dtype
self._amp_list = paddle.static.amp.fp16_lists.AutoMixedPrecisionLists( if amp_dtype is not None and amp_dtype in ['float16', 'bfloat16']:
custom_white_list=custom_white_list, # For AMP training
custom_black_list=custom_black_list, self._amp_list = (
) paddle.static.amp.fp16_lists.AutoMixedPrecisionLists(
custom_white_list=custom_white_list,
custom_black_list=custom_black_list,
dtype=amp_dtype,
)
)
# program_id -> list(scope) # program_id -> list(scope)
self._scope_cache = {} self._scope_cache = {}
......
...@@ -13,9 +13,10 @@ ...@@ -13,9 +13,10 @@
# limitations under the License. # limitations under the License.
from . import decorator from . import decorator
from .decorator import decorate from .decorator import decorate, amp_decorate
from . import fp16_lists from . import fp16_lists
from .fp16_lists import CustomOpLists, AutoMixedPrecisionLists from .fp16_lists import CustomOpLists, AutoMixedPrecisionLists
from . import fp16_utils from . import fp16_utils
from .fp16_utils import fp16_guard, cast_model_to_fp16, cast_parameters_to_fp16 from .fp16_utils import fp16_guard, cast_model_to_fp16, cast_parameters_to_fp16
from . import bf16 from . import bf16
from . import debugging
...@@ -138,13 +138,10 @@ def update_loss_scaling( ...@@ -138,13 +138,10 @@ def update_loss_scaling(
['float16', 'float32', 'float64', 'uint16'], ['float16', 'float32', 'float64', 'uint16'],
'update_loss_scaling', 'update_loss_scaling',
) )
if ( if e.dtype in [core.VarDesc.VarType.FP16, core.VarDesc.VarType.BF16]:
e.dtype == core.VarDesc.VarType.FP16
or e.dtype == core.VarDesc.VarType.BF16
):
assert ( assert (
prev_loss_scaling.dtype == core.VarDesc.VarType.FP32 prev_loss_scaling.dtype == core.VarDesc.VarType.FP32
), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16." ), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16 or bfloat16."
else: else:
assert ( assert (
prev_loss_scaling.dtype == e.dtype prev_loss_scaling.dtype == e.dtype
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import paddle
class OperatorStatsUnit:
def __init__(self):
self.op_type = None
self.fp32_calls = 0
self.fp16_calls = 0
self.bf16_calls = 0
self.other_calls = 0
def update(self, dtype):
if dtype is None:
self.other_calls = self.other_calls + 1
else:
if dtype == paddle.fluid.core.VarDesc.VarType.FP32:
self.fp32_calls = self.fp32_calls + 1
elif dtype == paddle.fluid.core.VarDesc.VarType.FP16:
self.fp16_calls = self.fp16_calls + 1
elif dtype == paddle.fluid.core.VarDesc.VarType.BF16:
self.bf16_calls = self.bf16_calls + 1
else:
self.other_calls = self.other_calls + 1
def addto(self, another):
self.fp32_calls += another.fp32_calls
self.fp16_calls += another.fp16_calls
self.bf16_calls += another.bf16_calls
self.other_calls += another.other_calls
def convert_to_list(self):
return [
self.fp16_calls,
self.bf16_calls,
self.fp32_calls,
self.other_calls,
]
def _is_floating_point(dtype):
if dtype in [
paddle.fluid.core.VarDesc.VarType.FP64,
paddle.fluid.core.VarDesc.VarType.FP32,
paddle.fluid.core.VarDesc.VarType.FP16,
paddle.fluid.core.VarDesc.VarType.BF16,
]:
return True
else:
return False
def _get_var_dtype_from_block(block, op, arg_name, is_input):
var_names = op.input(arg_name) if is_input else op.output(arg_name)
assert isinstance(var_names, list)
if len(var_names) == 0:
return None
var_name = var_names[0]
try:
var = block._var_recursive(var_name)
return var.dtype
except:
print(
"Operator < {} > gets {} < {} : {} > error!".format(
op.type, "input" if is_input else "output", arg_name, var_name
)
)
return None
def _extract_compute_dtype(op, block):
var_name = None
compute_dtype = None
for in_name in op.input_names:
var_dtype = _get_var_dtype_from_block(block, op, in_name, True)
if var_dtype is None:
continue
if compute_dtype is None:
compute_dtype = var_dtype
else:
if compute_dtype != var_dtype:
if _is_floating_point(compute_dtype) and _is_floating_point(
var_dtype
):
print(
"Operator < {} > has different input data types, input_names = {}, output_names = {}.".format(
op.type, op.input_names, op.output_names
)
)
elif _is_floating_point(var_dtype):
# When there are multiple inputs, such as embedding
# (ids is integer, w is floating-point), the kernel
# dtype is normally decided by the input of floating-point.
compute_dtype = var_dtype
for out_name in op.output_names:
var_dtype = _get_var_dtype_from_block(block, op, out_name, False)
if var_dtype is None:
continue
if compute_dtype is None:
# Kernel dtype is mostly decided by the input's dtype.
# When the operator has no input, it mightly has a attr
# such as dtype to specify the output's dtype.
compute_dtype = var_dtype
else:
if compute_dtype != var_dtype:
if _is_floating_point(compute_dtype) and _is_floating_point(
var_dtype
):
print(
"Operator < {} > has different input / output data types, input_names = {}, output_names = {}.".format(
op.type, op.input_names, op.output_names
)
)
return compute_dtype
def _merge_op_stats(op_stats_list):
merged_op_stats_dict = {}
for each_op_stats_dict in op_stats_list:
for op_type, unit in each_op_stats_dict.items():
if merged_op_stats_dict.get(op_type, None) is None:
merged_op_stats_dict[op_type] = copy.copy(unit)
else:
merged_op_stats_dict[op_type].addto(unit)
return merged_op_stats_dict
def _get_op_stats_list(program):
op_stats_list = []
for block in program.blocks:
block_op_stats_dict = {}
for op in block.ops:
if block_op_stats_dict.get(op.type, None) is None:
unit = OperatorStatsUnit()
block_op_stats_dict[op.type] = unit
else:
unit = block_op_stats_dict[op.type]
if op.type in [
'create_py_reader',
'read',
'create_double_buffer_reader',
]:
compute_dtype = None
elif op.type in [
'cast',
'layer_norm',
'layer_norm_grad',
'batch_norm',
'batch_norm_grad',
]:
# Not check the input and output dtype difference for this operators.
compute_dtype = _get_var_dtype_from_block(block, op, 'X', True)
elif "Param" in op.input_names:
# Specify compute_dtype for optimizers.
compute_dtype = _get_var_dtype_from_block(
block, op, 'Param', True
)
else:
compute_dtype = _extract_compute_dtype(op, block)
unit.update(dtype=compute_dtype)
op_stats_list.append(block_op_stats_dict)
return op_stats_list
def collect_operator_stats(program=None, print_subblocks=False):
def _convert_to_list(op_stats_unit_dict):
for key, value in op_stats_unit_dict.items():
op_stats_unit_dict[key] = value.convert_to_list()
return op_stats_unit_dict
if program is None:
program = paddle.static.default_main_program()
op_stats_list = _get_op_stats_list(program)
merged_op_stats = _merge_op_stats(op_stats_list)
if print_subblocks and len(op_stats_list) > 1:
for i in range(len(op_stats_list)):
print("<{:-^120}>".format(" op list of block " + str(i) + " "))
paddle.amp.debugging._print_operator_stats(
_convert_to_list(op_stats_list[i])
)
print("<{:-^120}>".format(" op list of all blocks "))
paddle.amp.debugging._print_operator_stats(
_convert_to_list(merged_op_stats)
)
...@@ -25,7 +25,7 @@ from paddle.fluid import ( ...@@ -25,7 +25,7 @@ from paddle.fluid import (
) )
from .amp_nn import check_finite_and_unscale, update_loss_scaling from .amp_nn import check_finite_and_unscale, update_loss_scaling
from .fp16_lists import AutoMixedPrecisionLists from .fp16_lists import AutoMixedPrecisionLists, check_amp_dtype
from .fp16_utils import ( from .fp16_utils import (
cast_model_to_fp16, cast_model_to_fp16,
cast_parameters_to_fp16, cast_parameters_to_fp16,
...@@ -45,7 +45,14 @@ class OptimizerWithMixedPrecision: ...@@ -45,7 +45,14 @@ class OptimizerWithMixedPrecision:
Args: Args:
optimizer (Optimizer): A common Optimizer object. optimizer (Optimizer): A common Optimizer object.
amp_lists (CustomOpLists): An CustomOpLists object. amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object.
level(str): Auto mixed precision level. Accepted values are
"O1" and "O2": O1 represent mixed precision, the input data type
of each operator will be casted by white_list and black_list;
O2 represent Pure fp16 or bf16, all operators parameters and input
data will be casted to fp16 or bf16, except operators in black_list,
don't support fp16 or bf16 kernel and batch_norm.
dtype(str): Whether to use 'float16' or 'bfloat16'.
init_loss_scaling (float): The initial loss scaling factor. init_loss_scaling (float): The initial loss scaling factor.
use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling. use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling.
incr_every_n_steps(int): Increases loss scaling every n consecutive incr_every_n_steps(int): Increases loss scaling every n consecutive
...@@ -57,24 +64,23 @@ class OptimizerWithMixedPrecision: ...@@ -57,24 +64,23 @@ class OptimizerWithMixedPrecision:
scaling. scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing decr_ratio(float): The less-than-one-multiplier to use when decreasing
the loss scaling. the loss scaling.
use_pure_fp16(bool): Whether to use the pure fp16 training. Default False. use_amp_guard(bool): Whether to use `fp16_guard` when constructing the program.
use_fp16_guard(bool): Whether to use `fp16_guard` when constructing the program.
Default None, which means that its value is equal to `use_pure_fp16`. Default None, which means that its value is equal to `use_pure_fp16`.
""" """
def __init__( def __init__(
self, self,
optimizer, optimizer,
amp_lists, amp_lists,
level,
dtype,
init_loss_scaling, init_loss_scaling,
use_dynamic_loss_scaling, use_dynamic_loss_scaling,
incr_every_n_steps, incr_every_n_steps,
decr_every_n_nan_or_inf, decr_every_n_nan_or_inf,
incr_ratio, incr_ratio,
decr_ratio, decr_ratio,
use_pure_fp16, use_amp_guard=None,
use_fp16_guard,
): ):
self._optimizer = optimizer self._optimizer = optimizer
self._amp_lists = amp_lists self._amp_lists = amp_lists
...@@ -86,10 +92,21 @@ class OptimizerWithMixedPrecision: ...@@ -86,10 +92,21 @@ class OptimizerWithMixedPrecision:
self._loss_scaling = None self._loss_scaling = None
self._init_loss_scaling = init_loss_scaling self._init_loss_scaling = init_loss_scaling
self._use_dynamic_loss_scaling = use_dynamic_loss_scaling self._use_dynamic_loss_scaling = use_dynamic_loss_scaling
if dtype == "bfloat16":
if use_dynamic_loss_scaling:
self._use_dynamic_loss_scaling = False
self._init_loss_scaling = 1.0
warnings.warn(
"Dynamic loss scaling for bfloat16 amp training is disabled, and the init_loss_scaling is changed to 1.0 automatically by PaddlePaddle."
)
self._amp_vartype = core.VarDesc.VarType.BF16
else:
self._amp_vartype = core.VarDesc.VarType.FP16
self._learning_rate = optimizer._learning_rate self._learning_rate = optimizer._learning_rate
self._learning_rate_map = optimizer._learning_rate_map self._learning_rate_map = optimizer._learning_rate_map
self._use_pure_fp16 = use_pure_fp16 self._use_pure_fp16 = level == "O2"
self._use_fp16_guard = use_fp16_guard self._use_fp16_guard = use_amp_guard
self._to_fp16_var_names = None self._to_fp16_var_names = None
if self._use_dynamic_loss_scaling: if self._use_dynamic_loss_scaling:
self._incr_every_n_steps = incr_every_n_steps self._incr_every_n_steps = incr_every_n_steps
...@@ -209,10 +226,15 @@ class OptimizerWithMixedPrecision: ...@@ -209,10 +226,15 @@ class OptimizerWithMixedPrecision:
if self._use_pure_fp16: if self._use_pure_fp16:
self._to_fp16_var_names = cast_model_to_fp16( self._to_fp16_var_names = cast_model_to_fp16(
self._train_program, self._amp_lists, self._use_fp16_guard self._train_program,
self._amp_lists,
self._use_fp16_guard,
self._amp_vartype,
) )
else: else:
rewrite_program(self._train_program, self._amp_lists) rewrite_program(
self._train_program, self._amp_lists, self._amp_vartype
)
if loss.dtype != core.VarDesc.VarType.FP32: if loss.dtype != core.VarDesc.VarType.FP32:
loss = loss.astype('float32') loss = loss.astype('float32')
...@@ -258,7 +280,7 @@ class OptimizerWithMixedPrecision: ...@@ -258,7 +280,7 @@ class OptimizerWithMixedPrecision:
outputs={'Out': [name]}, outputs={'Out': [name]},
attrs={ attrs={
'in_dtype': core.VarDesc.VarType.FP32, 'in_dtype': core.VarDesc.VarType.FP32,
'out_dtype': core.VarDesc.VarType.FP16, 'out_dtype': self._amp_vartype,
}, },
) )
self._to_fp16_var_names = None self._to_fp16_var_names = None
...@@ -326,15 +348,24 @@ class OptimizerWithMixedPrecision: ...@@ -326,15 +348,24 @@ class OptimizerWithMixedPrecision:
), "Please call the minimize method first." ), "Please call the minimize method first."
if self._use_pure_fp16: if self._use_pure_fp16:
cast_parameters_to_fp16( cast_parameters_to_fp16(
place, self._train_program, scope, self._to_fp16_var_names place,
self._train_program,
scope,
self._to_fp16_var_names,
self._amp_vartype,
) )
if test_program is not None: if test_program is not None:
if self._use_pure_fp16: if self._use_pure_fp16:
cast_model_to_fp16( cast_model_to_fp16(
test_program, self._amp_lists, self._use_fp16_guard test_program,
self._amp_lists,
self._use_fp16_guard,
self._amp_vartype,
) )
elif use_fp16_test: elif use_fp16_test:
rewrite_program(test_program, self._amp_lists) rewrite_program(
test_program, self._amp_lists, self._amp_vartype
)
def apply_gradients(self, params_grads): def apply_gradients(self, params_grads):
""" """
...@@ -368,7 +399,10 @@ class OptimizerWithMixedPrecision: ...@@ -368,7 +399,10 @@ class OptimizerWithMixedPrecision:
return optimize_ops return optimize_ops
found_inf = self._check_finite_and_unscale(params_grads) found_inf = self._check_finite_and_unscale(params_grads)
if self._use_dynamic_loss_scaling: if (
self._use_dynamic_loss_scaling
and self._amp_vartype == core.VarDesc.VarType.FP16
):
self._add_dynamic_loss_scaling(params_grads, found_inf) self._add_dynamic_loss_scaling(params_grads, found_inf)
# Pass found_inf to adam, to skip update for not only param, but also momentum and beta_pow # Pass found_inf to adam, to skip update for not only param, but also momentum and beta_pow
...@@ -395,10 +429,10 @@ class OptimizerWithMixedPrecision: ...@@ -395,10 +429,10 @@ class OptimizerWithMixedPrecision:
def _split_grads(self, params_grads): def _split_grads(self, params_grads):
grads = [g for _, g in params_grads] grads = [g for _, g in params_grads]
fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32] fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32]
fp16_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP16] fp16_grads = [g for g in grads if g.dtype == self._amp_vartype]
assert len(fp32_grads) + len(fp16_grads) == len( assert len(fp32_grads) + len(fp16_grads) == len(
grads grads
), "Data types of all grads must be either fp16 or fp32." ), "Data types of all grads must be either fp16/bf16 or fp32."
return grads, fp32_grads, fp16_grads return grads, fp32_grads, fp16_grads
def _check_finite_and_unscale(self, params_grads): def _check_finite_and_unscale(self, params_grads):
...@@ -587,6 +621,7 @@ def decorate( ...@@ -587,6 +621,7 @@ def decorate(
use_dynamic_loss_scaling=True, use_dynamic_loss_scaling=True,
use_pure_fp16=False, use_pure_fp16=False,
use_fp16_guard=None, use_fp16_guard=None,
use_bf16=False,
): ):
""" """
Decorate the given optimizer to adapt to the mixed-precision training. Decorate the given optimizer to adapt to the mixed-precision training.
...@@ -608,6 +643,7 @@ def decorate( ...@@ -608,6 +643,7 @@ def decorate(
use_pure_fp16(bool): Whether to use the pure fp16 training. Default False. use_pure_fp16(bool): Whether to use the pure fp16 training. Default False.
use_fp16_guard(bool): Whether to use `fp16_guard` when constructing the program. use_fp16_guard(bool): Whether to use `fp16_guard` when constructing the program.
Default None, which means that its value equals to `use_pure_fp16`. Default None, which means that its value equals to `use_pure_fp16`.
use_bf16(bool): Whether to enable bfloat16 training. Default False.
Returns: Returns:
An optimizer acting like a normal one but with mixed-precision training An optimizer acting like a normal one but with mixed-precision training
...@@ -678,23 +714,70 @@ def decorate( ...@@ -678,23 +714,70 @@ def decorate(
if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0: if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
run_example_code() run_example_code()
""" """
amp_dtype = "bfloat16" if use_bf16 else "float16"
if amp_lists is None: if amp_lists is None:
amp_lists = AutoMixedPrecisionLists() amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype)
if use_fp16_guard is None: if use_fp16_guard is None:
use_fp16_guard = use_pure_fp16 use_fp16_guard = use_pure_fp16
amp_level = "O2" if use_pure_fp16 else "O1"
mp_optimizer = OptimizerWithMixedPrecision( mp_optimizer = OptimizerWithMixedPrecision(
optimizer, optimizer,
amp_lists, amp_lists,
init_loss_scaling, level=amp_level,
use_dynamic_loss_scaling, dtype=amp_dtype,
incr_every_n_steps, init_loss_scaling=init_loss_scaling,
decr_every_n_nan_or_inf, use_dynamic_loss_scaling=use_dynamic_loss_scaling,
incr_ratio, incr_every_n_steps=incr_every_n_steps,
decr_ratio, decr_every_n_nan_or_inf=decr_every_n_nan_or_inf,
use_pure_fp16, incr_ratio=incr_ratio,
use_fp16_guard, decr_ratio=decr_ratio,
use_amp_guard=use_fp16_guard,
)
return mp_optimizer
def amp_decorate(
optimizer,
amp_lists=None,
level='O1',
dtype='float16',
init_loss_scaling=2**15,
incr_every_n_steps=1000,
decr_every_n_nan_or_inf=2,
incr_ratio=2.0,
decr_ratio=0.8,
use_dynamic_loss_scaling=True,
use_amp_guard=False,
):
"""
Decorate the given optimizer to adapt to the mixed-precision training.
"""
amp_dtype = check_amp_dtype(dtype)
if amp_lists is None:
amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype)
# check amp_level: O0-O2
level = level.upper()
if not (level in ['O0', 'O1', 'O2']):
raise ValueError(
"level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode."
)
mp_optimizer = OptimizerWithMixedPrecision(
optimizer,
amp_lists,
level=level,
dtype=amp_dtype,
init_loss_scaling=init_loss_scaling,
use_dynamic_loss_scaling=use_dynamic_loss_scaling,
incr_every_n_steps=incr_every_n_steps,
decr_every_n_nan_or_inf=decr_every_n_nan_or_inf,
incr_ratio=incr_ratio,
decr_ratio=decr_ratio,
use_amp_guard=use_amp_guard,
) )
return mp_optimizer return mp_optimizer
...@@ -13,11 +13,17 @@ ...@@ -13,11 +13,17 @@
# limitations under the License. # limitations under the License.
import copy import copy
import logging
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.log_helper import get_logger
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
# lookup_table fp16 is slower than fp32, though fp16 is supported. # lookup_table fp16 is slower than fp32, though fp16 is supported.
_extra_unsupported_fp16_list = { _extra_unsupported_list = {
'lookup_table', 'lookup_table',
'lookup_table_v2', 'lookup_table_v2',
'scatter', 'scatter',
...@@ -25,17 +31,95 @@ _extra_unsupported_fp16_list = { ...@@ -25,17 +31,95 @@ _extra_unsupported_fp16_list = {
} }
def check_amp_dtype(dtype):
"""
Check amp_dtype: float16 or bfloat16
"""
if isinstance(dtype, str):
dtype = dtype.lower()
if dtype not in ['float16', 'bfloat16']:
raise ValueError(
"If enable AMP, dtype should be 'float16' or 'bfloat16'."
)
return dtype
def get_low_precision_vartype(dtype):
if isinstance(dtype, core.VarDesc.VarType):
return dtype
elif isinstance(dtype, str):
dtype = dtype.lower()
if dtype == "float16":
var_type = core.VarDesc.VarType.FP16
elif dtype == "bfloat16":
var_type = core.VarDesc.VarType.BF16
else:
raise ValueError(
"If enable AMP, dtype should be 'float16' or 'bfloat16'."
)
return var_type
else:
raise TypeError(
"The type of dtype is expected to be string or core.VarDesc.VarType, but recieved {}.".format(
type(dtype)
)
)
def get_low_precision_dtypestr(dtype):
if isinstance(dtype, str):
return check_amp_dtype(dtype)
elif isinstance(dtype, core.VarDesc.VarType):
if dtype == core.VarDesc.VarType.FP16:
return "float16"
elif dtype == core.VarDesc.VarType.BF16:
return "bfloat16"
else:
raise ValueError(
"If enable AMP, dtype should be core.VarDesc.VarType.FP16 or core.VarDesc.VarType.BF16."
)
else:
raise TypeError(
"The type of dtype is expected to be string or core.VarDesc.VarType, but recieved {}.".format(
type(dtype)
)
)
def _get_sys_unsupported_list(dtype):
var_type = get_low_precision_vartype(dtype)
# The set of ops that don't support fp16 calculation
device = None
if core.is_compiled_with_xpu():
device = 'XPU'
elif core.is_compiled_with_custom_device('npu'):
device = 'NPU'
else:
device = 'GPU'
_, _, sys_unsupported_list = core.op_supported_infos(device, var_type)
return device, sys_unsupported_list
def _get_unsupported_list(dtype):
# The set of ops that don't support fp16 calculation
_, _sys_unsupported_list = _get_sys_unsupported_list(dtype)
unsupported_list = _extra_unsupported_list | _sys_unsupported_list
return unsupported_list
class AutoMixedPrecisionLists: class AutoMixedPrecisionLists:
""" """
AutoMixedPrecisionLists is a class for black/white list. It can update AutoMixedPrecisionLists is a class for black/white list. It can update
pre-defined black list and white list according to users' custom black pre-defined black list and white list according to users' custom black
white lists. The lists are used for an algorithm which determines op's white lists. The lists are used for an algorithm which determines op's
execution mode (fp32 or fp16). execution mode (fp32, fp16 or bf16).
Args: Args:
custom_white_list (set): Users' custom white list. custom_white_list (set): Users' custom white list.
custom_black_list (set): Users' custom black list. custom_black_list (set): Users' custom black list.
custom_black_varnames (set): Users' custom black varibles' names. custom_black_varnames (set): Users' custom black varibles' names.
dtype (str): the low precision dtype, which can be set to 'float16' or 'bfloat16'.
""" """
def __init__( def __init__(
...@@ -43,13 +127,15 @@ class AutoMixedPrecisionLists: ...@@ -43,13 +127,15 @@ class AutoMixedPrecisionLists:
custom_white_list=None, custom_white_list=None,
custom_black_list=None, custom_black_list=None,
custom_black_varnames=None, custom_black_varnames=None,
dtype="float16",
): ):
self.amp_dtype = check_amp_dtype(dtype)
self._custom_white_list = custom_white_list self._custom_white_list = custom_white_list
self._custom_black_list = custom_black_list self._custom_black_list = custom_black_list
self.white_list = copy.copy(white_list) self.white_list = copy.copy(white_list)
self.black_list = copy.copy(black_list) self.black_list = copy.copy(black_list)
self.gray_list = copy.copy(gray_list) self.gray_list = copy.copy(gray_list)
self.unsupported_list = copy.copy(unsupported_fp16_list) self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype))
self.black_varnames = copy.copy(custom_black_varnames) self.black_varnames = copy.copy(custom_black_varnames)
self._update_list() self._update_list()
...@@ -61,7 +147,7 @@ class AutoMixedPrecisionLists: ...@@ -61,7 +147,7 @@ class AutoMixedPrecisionLists:
for op_name in self._custom_white_list: for op_name in self._custom_white_list:
if op_name in self._custom_black_list: if op_name in self._custom_black_list:
raise ValueError( raise ValueError(
"Custom white list overlap " "custom black list" f"The given custom_white_list overlaps custom_black_list with < {op_name} >!"
) )
if self._custom_white_list: if self._custom_white_list:
for op_name in self._custom_white_list: for op_name in self._custom_white_list:
...@@ -70,7 +156,7 @@ class AutoMixedPrecisionLists: ...@@ -70,7 +156,7 @@ class AutoMixedPrecisionLists:
elif op_name in self.gray_list: elif op_name in self.gray_list:
self.gray_list.remove(op_name) self.gray_list.remove(op_name)
self.white_list.add(op_name) self.white_list.add(op_name)
if op_name in _extra_unsupported_fp16_list: if op_name in _extra_unsupported_list:
self.unsupported_list.remove(op_name) self.unsupported_list.remove(op_name)
if self._custom_black_list: if self._custom_black_list:
for op_name in self._custom_black_list: for op_name in self._custom_black_list:
...@@ -80,6 +166,15 @@ class AutoMixedPrecisionLists: ...@@ -80,6 +166,15 @@ class AutoMixedPrecisionLists:
self.gray_list.remove(op_name) self.gray_list.remove(op_name)
self.black_list.add(op_name) self.black_list.add(op_name)
self.unsupported_list.add(op_name) self.unsupported_list.add(op_name)
device, sys_unsupported_list = _get_sys_unsupported_list(self.amp_dtype)
actual_unsupported_list = []
for op_name in sys_unsupported_list:
if op_name in self.white_list:
actual_unsupported_list.append(op_name)
if len(actual_unsupported_list) > 0:
_logger.warning(
f"On current {device}, {self.amp_dtype} is not supported for operators < {actual_unsupported_list} > in white_list!"
)
# The three sets listed below are changed dynamiclly. They don't contain all # The three sets listed below are changed dynamiclly. They don't contain all
...@@ -175,24 +270,4 @@ gray_list = { ...@@ -175,24 +270,4 @@ gray_list = {
'fused_multi_transformer', 'fused_multi_transformer',
} }
# The set of ops that don't support fp16 calculation
# lookup_table fp16 is slower than fp32, though fp16 is supported.
_sys_unsupported_fp16_list = []
if core.is_compiled_with_xpu():
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'XPU', core.VarDesc.VarType.FP16
)
elif core.is_compiled_with_custom_device('npu'):
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'NPU', core.VarDesc.VarType.FP16
)
else:
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'GPU', core.VarDesc.VarType.FP16
)
unsupported_fp16_list = (
_extra_unsupported_fp16_list | _sys_unsupported_fp16_list
)
CustomOpLists = AutoMixedPrecisionLists CustomOpLists = AutoMixedPrecisionLists
...@@ -17,11 +17,12 @@ import logging ...@@ -17,11 +17,12 @@ import logging
import numpy as np import numpy as np
import paddle
from paddle.fluid import core, framework, global_scope from paddle.fluid import core, framework, global_scope
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
from .fp16_lists import AutoMixedPrecisionLists from .fp16_lists import AutoMixedPrecisionLists, get_low_precision_dtypestr
_logger = get_logger( _logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s' __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
...@@ -72,7 +73,9 @@ def _dtype_to_str(dtype): ...@@ -72,7 +73,9 @@ def _dtype_to_str(dtype):
Args: Args:
dtype (VarType): Variable type. dtype (VarType): Variable type.
""" """
if dtype == core.VarDesc.VarType.FP16: if dtype in [core.VarDesc.VarType.FP16, core.VarDesc.VarType.BF16]:
# TODO(Xreki): change the returned str to "bf16" for BF16 data type.
# Currently too many codes use "cast_fp16" as key.
return 'fp16' return 'fp16'
else: else:
return 'fp32' return 'fp32'
...@@ -220,10 +223,10 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): ...@@ -220,10 +223,10 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
else: else:
if op.has_attr('in_dtype'): if op.has_attr('in_dtype'):
op._set_attr('in_dtype', dest_dtype) op._set_attr('in_dtype', dest_dtype)
if ( if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype in [
src_dtype == core.VarDesc.VarType.FP32 core.VarDesc.VarType.FP16,
and dest_dtype == core.VarDesc.VarType.FP16 core.VarDesc.VarType.BF16,
): ]:
for out_name in op.output_names: for out_name in op.output_names:
if _keep_fp32_output(op, out_name): if _keep_fp32_output(op, out_name):
continue continue
...@@ -232,9 +235,9 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): ...@@ -232,9 +235,9 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
if out_var.type not in _valid_types: if out_var.type not in _valid_types:
continue continue
if out_var.dtype == core.VarDesc.VarType.FP32: if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(core.VarDesc.VarType.FP16) out_var.desc.set_dtype(dest_dtype)
if op.has_attr('out_dtype'): if op.has_attr('out_dtype'):
op._set_attr('out_dtype', core.VarDesc.VarType.FP16) op._set_attr('out_dtype', dest_dtype)
return num_cast_ops return num_cast_ops
...@@ -417,7 +420,12 @@ def fp16_guard(): ...@@ -417,7 +420,12 @@ def fp16_guard():
yield yield
def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): def cast_model_to_fp16(
program,
amp_lists=None,
use_fp16_guard=True,
dest_type=core.VarDesc.VarType.FP16,
):
""" """
Traverse all ops in the whole model and set their inputs and outputs Traverse all ops in the whole model and set their inputs and outputs
to the fp16 data type. This function will do some special process for to the fp16 data type. This function will do some special process for
...@@ -428,10 +436,12 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): ...@@ -428,10 +436,12 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object. amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object.
use_fp16_guard(bool): Determine whether to use `fp16_guard` when use_fp16_guard(bool): Determine whether to use `fp16_guard` when
constructing the program. Default True. constructing the program. Default True.
dest_type(core.VarDesc.VarType): the cast type. such as core.VarDesc.VarType.FP16 and core.VarDesc.VarType.BF16.
""" """
if amp_lists is None: if amp_lists is None:
amp_lists = AutoMixedPrecisionLists() dtype = get_low_precision_dtypestr(dest_type)
amp_lists = AutoMixedPrecisionLists(dtype)
amp_lists.unsupported_list -= { amp_lists.unsupported_list -= {
"conditional_block_grad", "conditional_block_grad",
"conditional_block", "conditional_block",
...@@ -487,7 +497,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): ...@@ -487,7 +497,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
continue continue
if in_var.dtype == core.VarDesc.VarType.FP32: if in_var.dtype == core.VarDesc.VarType.FP32:
in_var.desc.set_dtype(core.VarDesc.VarType.FP16) in_var.desc.set_dtype(dest_type)
to_fp16_var_names.add(in_var_name) to_fp16_var_names.add(in_var_name)
_logger.debug( _logger.debug(
...@@ -524,28 +534,19 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): ...@@ -524,28 +534,19 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
continue continue
if out_var.dtype == core.VarDesc.VarType.FP32: if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(core.VarDesc.VarType.FP16) out_var.desc.set_dtype(dest_type)
_logger.debug( _logger.debug(
"-- op type: {}, out var name: {}, out var dtype: {} --".format( "-- op type: {}, out var name: {}, out var dtype: {} --".format(
op.type, out_var_name, out_var.dtype op.type, out_var_name, out_var.dtype
) )
) )
if ( for attr_name in ['in_dtype', 'out_dtype', 'dtype']:
op.has_attr('in_dtype') if (
and op.attr('in_dtype') == core.VarDesc.VarType.FP32 op.has_attr(attr_name)
): and op.attr(attr_name) == core.VarDesc.VarType.FP32
op._set_attr('in_dtype', core.VarDesc.VarType.FP16) ):
if ( op._set_attr(attr_name, dest_type)
op.has_attr('out_dtype')
and op.attr('out_dtype') == core.VarDesc.VarType.FP32
):
op._set_attr('out_dtype', core.VarDesc.VarType.FP16)
if (
op.has_attr('dtype')
and op.attr('dtype') == core.VarDesc.VarType.FP32
):
op._set_attr('dtype', core.VarDesc.VarType.FP16)
# process ops in keep_fp32_ops # process ops in keep_fp32_ops
op_var_rename_map = [ op_var_rename_map = [
...@@ -562,7 +563,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): ...@@ -562,7 +563,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
block, block,
op, op,
idx, idx,
core.VarDesc.VarType.FP16, dest_type,
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP32,
) )
num_cast_ops += pre_cast_num num_cast_ops += pre_cast_num
...@@ -570,7 +571,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): ...@@ -570,7 +571,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
out_var = block.vars.get(out_var_name) out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types: if out_var is None or out_var.type not in _valid_types:
continue continue
if out_var.dtype == core.VarDesc.VarType.FP16: if out_var.dtype == dest_type:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32) out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
post_ops = find_true_post_op(ops, op, out_var_name) post_ops = find_true_post_op(ops, op, out_var_name)
for post_op in post_ops: for post_op in post_ops:
...@@ -581,7 +582,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): ...@@ -581,7 +582,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
op, op,
idx + pre_cast_num + 1, idx + pre_cast_num + 1,
core.VarDesc.VarType.FP32, core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, dest_type,
out_var_name, out_var_name,
op_var_rename_map, op_var_rename_map,
) )
...@@ -592,7 +593,22 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): ...@@ -592,7 +593,22 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
return to_fp16_var_names return to_fp16_var_names
def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None): def _convert_float_to_bfloat16(place, fp32_array):
paddle.disable_static()
framework._set_expected_place(place)
fp32_tensor = paddle.to_tensor(fp32_array)
bf16_array = paddle.cast(fp32_tensor, paddle.bfloat16).numpy()
paddle.enable_static()
return bf16_array
def cast_parameters_to_fp16(
place,
program,
scope=None,
to_fp16_var_names=None,
dest_type=core.VarDesc.VarType.FP16,
):
""" """
Traverse all parameters in the whole model and set them to the FP16 data type. Traverse all parameters in the whole model and set them to the FP16 data type.
Whereas, this function will keep parameters of batchnorms in FP32. Whereas, this function will keep parameters of batchnorms in FP32.
...@@ -604,6 +620,7 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None): ...@@ -604,6 +620,7 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None):
to_fp16_var_names(set|list, optional): The data types of vars in `to_fp16_var_names` to_fp16_var_names(set|list, optional): The data types of vars in `to_fp16_var_names`
will be set to FP16. Usually, it is the returned will be set to FP16. Usually, it is the returned
value of `cast_model_to_fp16` API. value of `cast_model_to_fp16` API.
dest_type(core.VarDesc.VarType): the cast type. such as core.VarDesc.VarType.FP16 and core.VarDesc.VarType.BF16.
""" """
all_parameters = [] all_parameters = []
for block in program.blocks: for block in program.blocks:
...@@ -613,13 +630,20 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None): ...@@ -613,13 +630,20 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None):
var_scope = scope if scope else global_scope() var_scope = scope if scope else global_scope()
for param in all_parameters: for param in all_parameters:
if param.name in fp16_var_names: if param.name in fp16_var_names:
_logger.debug(f"---- cast {param.name} to fp16 dtype ----") _logger.debug(f"---- cast {param.name} to fp16/bf16 dtype ----")
param_t = var_scope.find_var(param.name).get_tensor() if var_scope.find_var(param.name):
data = np.array(param_t) param_t = var_scope.find_var(param.name).get_tensor()
param_t.set(np.float16(data), place) data = np.array(param_t)
if dest_type == core.VarDesc.VarType.BF16:
bf16_data = _convert_float_to_bfloat16(place, data)
param_t.set(bf16_data, place)
else:
param_t.set(np.float16(data), place)
else:
_logger.warning(f"Cannot find {param.name}")
def rewrite_program(main_prog, amp_lists): def rewrite_program(main_prog, amp_lists, dest_type=core.VarDesc.VarType.FP16):
""" """
Traverse all ops in current block and insert cast op according to Traverse all ops in current block and insert cast op according to
which set current op belongs to. which set current op belongs to.
...@@ -638,6 +662,7 @@ def rewrite_program(main_prog, amp_lists): ...@@ -638,6 +662,7 @@ def rewrite_program(main_prog, amp_lists):
Args: Args:
main_prog (Program): The main program for training. main_prog (Program): The main program for training.
dest_type(core.VarDesc.VarType): the cast type. such as core.VarDesc.VarType.FP16 and core.VarDesc.VarType.BF16.
""" """
block = main_prog.global_block() block = main_prog.global_block()
block._sync_with_cpp() block._sync_with_cpp()
...@@ -708,19 +733,11 @@ def rewrite_program(main_prog, amp_lists): ...@@ -708,19 +733,11 @@ def rewrite_program(main_prog, amp_lists):
num_cast_ops = 0 num_cast_ops = 0
if op in black_op_set: if op in black_op_set:
num_cast_ops = _insert_cast_op( num_cast_ops = _insert_cast_op(
block, block, op, idx, dest_type, core.VarDesc.VarType.FP32
op,
idx,
core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32,
) )
elif op in white_op_set: elif op in white_op_set:
num_cast_ops = _insert_cast_op( num_cast_ops = _insert_cast_op(
block, block, op, idx, core.VarDesc.VarType.FP32, dest_type
op,
idx,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
) )
else: else:
pass pass
......
...@@ -200,6 +200,7 @@ def create_parameter( ...@@ -200,6 +200,7 @@ def create_parameter(
[ [
'bool', 'bool',
'float16', 'float16',
'uint16',
'float32', 'float32',
'float64', 'float64',
'int8', 'int8',
......
...@@ -45,3 +45,7 @@ endfunction() ...@@ -45,3 +45,7 @@ endfunction()
foreach(TEST_OP ${TEST_OPS}) foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP}) py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach() endforeach()
if(APPLE)
set_tests_properties(test_model_cast_to_bf16 PROPERTIES TIMEOUT 300)
endif()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from paddle import nn
_fixed_add_param = np.random.random(size=[16, 16]).astype("float32")
def _build_optimizer(
use_amp, amp_dtype="float16", amp_level="O1", use_grad_clip=False
):
if use_grad_clip:
grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
else:
grad_clip = None
optimizer = paddle.optimizer.AdamW(
learning_rate=0.01,
grad_clip=grad_clip,
beta1=0.78,
beta2=0.836,
epsilon=1e-4,
weight_decay=0.01,
multi_precision=True,
)
if use_amp:
amp_lists = paddle.static.amp.AutoMixedPrecisionLists(
custom_white_list=["elementwise_add"],
custom_black_list=["reduce_mean"],
dtype=amp_dtype,
)
optimizer = paddle.static.amp.amp_decorate(
optimizer, amp_lists=amp_lists, level=amp_level, dtype=amp_dtype
)
return optimizer
class SimpleAddNet(nn.Layer):
def __init__(self, dtype):
super().__init__()
global _fixed_add_param
self.weight = paddle.create_parameter(
name="add_w",
shape=[16, 16],
dtype=dtype,
attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Assign(_fixed_add_param)
),
)
def forward(self, x):
return x + self.weight
def build_add_model(use_amp, amp_dtype="float16", amp_level="O1"):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main_program, startup_program):
x_dtype = "float32"
if use_amp and amp_level == "O2":
if amp_dtype == "bfloat16":
x_dtype = "uint16"
elif amp_dtype == "float16":
x_dtype = "float16"
model = SimpleAddNet(x_dtype)
x = paddle.static.data(name='input', shape=[16, 16], dtype=x_dtype)
out = model(x)
loss = paddle.mean(out)
optimizer = _build_optimizer(use_amp, amp_dtype, amp_level)
optimizer.minimize(loss)
feed_vars = [x]
fetch_vars = [loss]
return main_program, startup_program, optimizer, feed_vars, fetch_vars
class SimpleConvNet(nn.Layer):
def __init__(self):
super().__init__()
self.conv = nn.Conv2D(in_channels=1, out_channels=6, kernel_size=3)
self.linear = nn.Linear(in_features=6, out_features=10)
def forward(self, x):
out = self.conv(x)
out = nn.functional.relu(out)
out = self.linear(out)
out = nn.functional.softmax(out)
return out
def build_conv_model(use_amp, amp_dtype="float16", amp_level="O1"):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main_program, startup_program):
model = SimpleConvNet()
x = paddle.static.data(
name='input', shape=[None, 1, 28, 28], dtype='float32'
)
out = model(x)
loss = paddle.mean(out)
optimizer = _build_optimizer(use_amp, amp_dtype, amp_level)
optimizer.minimize(loss)
return main_program, startup_program
class SimpleEmbeddingNet(nn.Layer):
def __init__(self):
super().__init__()
self.vocab_size = 128
self.hidden_size = 16
self.vocab_size = 128
self.hidden_size = 16
self.embedding = nn.Embedding(self.vocab_size, self.hidden_size)
self.linear = nn.Linear(in_features=16, out_features=10)
def forward(self, x):
out = self.embedding(x)
scale = paddle.full(shape=[1], fill_value=2, dtype="int64")
out = paddle.multiply(out, scale.astype("float32"))
out = self.linear(out)
out = nn.functional.dropout(out, p=0.2)
return out
def build_embedding_model(use_amp, amp_dtype="float16", amp_level="O1"):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main_program, startup_program):
model = SimpleEmbeddingNet()
x = paddle.static.data(name='x', shape=[None, 32], dtype='int64')
out = model(x)
loss = paddle.mean(out)
optimizer = _build_optimizer(use_amp, amp_dtype, amp_level, True)
optimizer.minimize(loss)
return main_program, startup_program
class SimpleWhileNet(nn.Layer):
def __init__(self):
super().__init__()
self.linear = paddle.nn.Linear(16, 10)
def forward(self, x):
def cond(i, loop_len, x, result):
return i < loop_len
def body(i, loop_len, x, result):
result = self.linear(x)
paddle.increment(i)
return [i, loop_len, x, result]
i = paddle.zeros(shape=[1], dtype='int64')
loop_len = paddle.ones(shape=[1], dtype='int64')
result = paddle.zeros(
shape=x.shape[:-1] + self.linear.weight.shape[-1:], dtype="float32"
)
result.stop_gradient = False
_, _, _, results = paddle.static.nn.while_loop(
cond, body, [i, loop_len, x, result]
)
return results
def build_while_model():
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main_program, startup_program):
model = SimpleWhileNet()
x = paddle.static.data(name='x', shape=[32, 16], dtype='float32')
out = model(x)
loss = paddle.mean(out)
return main_program, startup_program
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
import unittest import unittest
from paddle.fluid import core
from paddle.static.amp import fp16_lists
from paddle.static.amp.fp16_lists import AutoMixedPrecisionLists from paddle.static.amp.fp16_lists import AutoMixedPrecisionLists
...@@ -39,6 +41,48 @@ class TestAMPList(unittest.TestCase): ...@@ -39,6 +41,48 @@ class TestAMPList(unittest.TestCase):
for op in default_black_list: for op in default_black_list:
self.assertTrue(op in amp_list.black_list) self.assertTrue(op in amp_list.black_list)
def test_apis(self):
def _run_check_dtype():
fp16_lists.check_amp_dtype(dtype="int64")
self.assertRaises(ValueError, _run_check_dtype)
for vartype in [core.VarDesc.VarType.FP16, core.VarDesc.VarType.BF16]:
self.assertEqual(
fp16_lists.get_low_precision_vartype(vartype), vartype
)
self.assertEqual(
fp16_lists.get_low_precision_vartype("float16"),
core.VarDesc.VarType.FP16,
)
self.assertEqual(
fp16_lists.get_low_precision_vartype("bfloat16"),
core.VarDesc.VarType.BF16,
)
def _run_get_vartype():
fp16_lists.get_low_precision_vartype(dtype="int64")
self.assertRaises(ValueError, _run_get_vartype)
for dtype in ["float16", "bfloat16"]:
self.assertEqual(
fp16_lists.get_low_precision_dtypestr(dtype), dtype
)
self.assertEqual(
fp16_lists.get_low_precision_dtypestr(core.VarDesc.VarType.FP16),
"float16",
)
self.assertEqual(
fp16_lists.get_low_precision_dtypestr(core.VarDesc.VarType.BF16),
"bfloat16",
)
def _run_get_dtypestr():
fp16_lists.get_low_precision_dtypestr(dtype="int64")
self.assertRaises(ValueError, _run_get_dtypestr)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -14,10 +14,12 @@ ...@@ -14,10 +14,12 @@
import unittest import unittest
from amp_base_models import build_while_model
import paddle import paddle
class TestAMPList(unittest.TestCase): class TestOpStatsEager(unittest.TestCase):
def _check_result(self, dtype): def _check_result(self, dtype):
# Returned the dict. # Returned the dict.
op_list = paddle.fluid.core.get_low_precision_op_list() op_list = paddle.fluid.core.get_low_precision_op_list()
...@@ -65,5 +67,17 @@ class TestAMPList(unittest.TestCase): ...@@ -65,5 +67,17 @@ class TestAMPList(unittest.TestCase):
self._check_result(dtype=out.dtype) self._check_result(dtype=out.dtype)
class TestOpStatsStatic(unittest.TestCase):
def test_while_op(self):
paddle.enable_static()
main_program, startup_program = build_while_model()
self.assertEqual(main_program.num_blocks, 2)
paddle.static.amp.debugging.collect_operator_stats(
program=main_program, print_subblocks=True
)
paddle.disable_static()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -17,6 +17,7 @@ import struct ...@@ -17,6 +17,7 @@ import struct
import unittest import unittest
import numpy as np import numpy as np
from amp_base_models import build_add_model, build_embedding_model
import paddle import paddle
from paddle import fluid from paddle import fluid
...@@ -26,6 +27,21 @@ from paddle.static import amp ...@@ -26,6 +27,21 @@ from paddle.static import amp
paddle.enable_static() paddle.enable_static()
def copy_bits_from_float_to_uint16(f):
return struct.unpack('<I', struct.pack('<f', f))[0] >> 16
def convert_float_to_uint16(in_list):
if in_list.dtype == np.float32:
new_output = []
for x in np.nditer(in_list):
new_output.append(np.uint16(copy_bits_from_float_to_uint16(x)))
new_output = np.reshape(new_output, in_list.shape).view(np.uint16)
return new_output
else:
return in_list
def convert_uint16_to_float(in_list): def convert_uint16_to_float(in_list):
if in_list.dtype == np.uint16: if in_list.dtype == np.uint16:
in_list = np.asarray(in_list) in_list = np.asarray(in_list)
...@@ -204,5 +220,123 @@ class TestModelCastBF16(unittest.TestCase): ...@@ -204,5 +220,123 @@ class TestModelCastBF16(unittest.TestCase):
) )
@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA and not support the bfloat16",
)
class TestProgramBF16(unittest.TestCase):
def _check_bf16_calls(self, op_stats_dict, expected_bf16_calls):
for op_type, value in expected_bf16_calls.items():
self.assertEqual(
op_stats_dict[op_type].bf16_calls,
value,
f"The number of bf16 calls of operator < {op_type} > is expected to be {value}, but recieved {op_stats_dict[op_type].bf16_calls}.",
)
def test_amp_bf16_o1(self):
main_program, startup_program = build_embedding_model(
True, "bfloat16", "O1"
)
self.assertEqual(main_program.num_blocks, 1)
amp.debugging.collect_operator_stats(main_program)
op_stats_list = amp.debugging._get_op_stats_list(main_program)
expected_bf16_calls = {
"matmul_v2": 1,
"elementwise_add": 1,
"dropout": 1,
"lookup_table_v2": 0,
"squared_l2_norm": 0,
"adamw": 0,
}
self._check_bf16_calls(op_stats_list[0], expected_bf16_calls)
def test_amp_bf16_o2(self):
main_program, startup_program = build_embedding_model(
True, "bfloat16", "O2"
)
self.assertEqual(main_program.num_blocks, 1)
amp.debugging.collect_operator_stats(main_program)
op_stats_list = amp.debugging._get_op_stats_list(main_program)
expected_bf16_calls = {
"matmul_v2": 1,
"elementwise_add": 1,
"dropout": 1,
"lookup_table_v2": 0,
"squared_l2_norm": 2,
"adamw": 2,
}
self._check_bf16_calls(op_stats_list[0], expected_bf16_calls)
@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA and not support the bfloat16",
)
class TestStaticBF16(unittest.TestCase):
def _generate_feed_x(self):
x = np.random.random(size=[16, 16]).astype("float32")
x_bf16 = convert_float_to_uint16(x)
x_fp32 = convert_uint16_to_float(x_bf16)
return x_fp32, x_bf16
def test_compare_o1_o2(self):
def _run_o1(exe, x_np, max_iters):
(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
) = build_add_model(True, "bfloat16", "O1")
losses = []
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
exe.run(startup_program)
for iter_id in range(max_iters):
results = exe.run(
program=main_program,
feed={feed_vars[0].name: x_np},
fetch_list=fetch_vars,
)
print(f"-- [BF16 O1] iter={iter_id}, loss={results[0]}")
losses.append(results[0])
return losses
def _run_o2(exe, x_np, max_iters):
(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
) = build_add_model(True, "bfloat16", "O2")
losses = []
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
exe.run(startup_program)
optimizer.amp_init(place)
for iter_id in range(max_iters):
results = exe.run(
program=main_program,
feed={feed_vars[0].name: x_np},
fetch_list=fetch_vars,
)
print(f"-- [BF16 O2] iter={iter_id}, loss={results[0]}")
losses.append(results[0])
return losses
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
max_iters = 2
x_fp32, x_bf16 = self._generate_feed_x()
losses_o1 = _run_o1(exe, x_fp32, max_iters)
losses_o2 = _run_o2(exe, x_bf16, max_iters)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -22,7 +22,3 @@ py_test_modules( ...@@ -22,7 +22,3 @@ py_test_modules(
set_tests_properties(test_image_classification_fp16 PROPERTIES TIMEOUT 120) set_tests_properties(test_image_classification_fp16 PROPERTIES TIMEOUT 120)
set_tests_properties(test_weight_decay_extend PROPERTIES TIMEOUT 120) set_tests_properties(test_weight_decay_extend PROPERTIES TIMEOUT 120)
set_tests_properties(test_multi_precision_fp16_train PROPERTIES TIMEOUT 120) set_tests_properties(test_multi_precision_fp16_train PROPERTIES TIMEOUT 120)
if(APPLE)
set_tests_properties(test_model_cast_to_bf16 PROPERTIES TIMEOUT 300)
endif()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册