未验证 提交 dfcba7f4 编写于 作者: Y Yiqun Liu 提交者: GitHub

[AMP] Unify the static amp codes of fp16 and bf16. (#52694)

* Unify the static amp codes of fp16 and bf16.

* Polish apis and add unittest.

* Add operator stats collecting tools for program.

* Add the check of number of bloat16 operators in unittest.

* Add warning for operator not supported for amp.

* Add testing of BF16 O1 and O2.
上级 f8d09011
......@@ -271,7 +271,7 @@ def _print_operator_stats(op_count_dict):
"{:-^17}>".format(" Other Calls "),
)
if op_count_dict is not None and isinstance(op_count_dict, dict):
for op_type in op_count_dict:
for op_type in sorted(op_count_dict):
# fp16, bf16, fp32, other
value = op_count_dict[op_type]
if isinstance(value, list):
......
......@@ -189,15 +189,20 @@ class PartialProgramLayer:
self._infer_info = ProgramInfo()
self._forward_end_index_map = {}
custom_white_list, custom_black_list = None, None
amp_dtype, custom_white_list, custom_black_list = None, None, None
tracer = framework._dygraph_tracer()
if tracer:
custom_white_list, custom_black_list = tracer._get_amp_op_list()
# For AMP training
self._amp_list = paddle.static.amp.fp16_lists.AutoMixedPrecisionLists(
custom_white_list=custom_white_list,
custom_black_list=custom_black_list,
)
amp_dtype = tracer._amp_dtype
if amp_dtype is not None and amp_dtype in ['float16', 'bfloat16']:
# For AMP training
self._amp_list = (
paddle.static.amp.fp16_lists.AutoMixedPrecisionLists(
custom_white_list=custom_white_list,
custom_black_list=custom_black_list,
dtype=amp_dtype,
)
)
# program_id -> list(scope)
self._scope_cache = {}
......
......@@ -13,9 +13,10 @@
# limitations under the License.
from . import decorator
from .decorator import decorate
from .decorator import decorate, amp_decorate
from . import fp16_lists
from .fp16_lists import CustomOpLists, AutoMixedPrecisionLists
from . import fp16_utils
from .fp16_utils import fp16_guard, cast_model_to_fp16, cast_parameters_to_fp16
from . import bf16
from . import debugging
......@@ -138,13 +138,10 @@ def update_loss_scaling(
['float16', 'float32', 'float64', 'uint16'],
'update_loss_scaling',
)
if (
e.dtype == core.VarDesc.VarType.FP16
or e.dtype == core.VarDesc.VarType.BF16
):
if e.dtype in [core.VarDesc.VarType.FP16, core.VarDesc.VarType.BF16]:
assert (
prev_loss_scaling.dtype == core.VarDesc.VarType.FP32
), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16 or bfloat16."
else:
assert (
prev_loss_scaling.dtype == e.dtype
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import paddle
class OperatorStatsUnit:
def __init__(self):
self.op_type = None
self.fp32_calls = 0
self.fp16_calls = 0
self.bf16_calls = 0
self.other_calls = 0
def update(self, dtype):
if dtype is None:
self.other_calls = self.other_calls + 1
else:
if dtype == paddle.fluid.core.VarDesc.VarType.FP32:
self.fp32_calls = self.fp32_calls + 1
elif dtype == paddle.fluid.core.VarDesc.VarType.FP16:
self.fp16_calls = self.fp16_calls + 1
elif dtype == paddle.fluid.core.VarDesc.VarType.BF16:
self.bf16_calls = self.bf16_calls + 1
else:
self.other_calls = self.other_calls + 1
def addto(self, another):
self.fp32_calls += another.fp32_calls
self.fp16_calls += another.fp16_calls
self.bf16_calls += another.bf16_calls
self.other_calls += another.other_calls
def convert_to_list(self):
return [
self.fp16_calls,
self.bf16_calls,
self.fp32_calls,
self.other_calls,
]
def _is_floating_point(dtype):
if dtype in [
paddle.fluid.core.VarDesc.VarType.FP64,
paddle.fluid.core.VarDesc.VarType.FP32,
paddle.fluid.core.VarDesc.VarType.FP16,
paddle.fluid.core.VarDesc.VarType.BF16,
]:
return True
else:
return False
def _get_var_dtype_from_block(block, op, arg_name, is_input):
var_names = op.input(arg_name) if is_input else op.output(arg_name)
assert isinstance(var_names, list)
if len(var_names) == 0:
return None
var_name = var_names[0]
try:
var = block._var_recursive(var_name)
return var.dtype
except:
print(
"Operator < {} > gets {} < {} : {} > error!".format(
op.type, "input" if is_input else "output", arg_name, var_name
)
)
return None
def _extract_compute_dtype(op, block):
var_name = None
compute_dtype = None
for in_name in op.input_names:
var_dtype = _get_var_dtype_from_block(block, op, in_name, True)
if var_dtype is None:
continue
if compute_dtype is None:
compute_dtype = var_dtype
else:
if compute_dtype != var_dtype:
if _is_floating_point(compute_dtype) and _is_floating_point(
var_dtype
):
print(
"Operator < {} > has different input data types, input_names = {}, output_names = {}.".format(
op.type, op.input_names, op.output_names
)
)
elif _is_floating_point(var_dtype):
# When there are multiple inputs, such as embedding
# (ids is integer, w is floating-point), the kernel
# dtype is normally decided by the input of floating-point.
compute_dtype = var_dtype
for out_name in op.output_names:
var_dtype = _get_var_dtype_from_block(block, op, out_name, False)
if var_dtype is None:
continue
if compute_dtype is None:
# Kernel dtype is mostly decided by the input's dtype.
# When the operator has no input, it mightly has a attr
# such as dtype to specify the output's dtype.
compute_dtype = var_dtype
else:
if compute_dtype != var_dtype:
if _is_floating_point(compute_dtype) and _is_floating_point(
var_dtype
):
print(
"Operator < {} > has different input / output data types, input_names = {}, output_names = {}.".format(
op.type, op.input_names, op.output_names
)
)
return compute_dtype
def _merge_op_stats(op_stats_list):
merged_op_stats_dict = {}
for each_op_stats_dict in op_stats_list:
for op_type, unit in each_op_stats_dict.items():
if merged_op_stats_dict.get(op_type, None) is None:
merged_op_stats_dict[op_type] = copy.copy(unit)
else:
merged_op_stats_dict[op_type].addto(unit)
return merged_op_stats_dict
def _get_op_stats_list(program):
op_stats_list = []
for block in program.blocks:
block_op_stats_dict = {}
for op in block.ops:
if block_op_stats_dict.get(op.type, None) is None:
unit = OperatorStatsUnit()
block_op_stats_dict[op.type] = unit
else:
unit = block_op_stats_dict[op.type]
if op.type in [
'create_py_reader',
'read',
'create_double_buffer_reader',
]:
compute_dtype = None
elif op.type in [
'cast',
'layer_norm',
'layer_norm_grad',
'batch_norm',
'batch_norm_grad',
]:
# Not check the input and output dtype difference for this operators.
compute_dtype = _get_var_dtype_from_block(block, op, 'X', True)
elif "Param" in op.input_names:
# Specify compute_dtype for optimizers.
compute_dtype = _get_var_dtype_from_block(
block, op, 'Param', True
)
else:
compute_dtype = _extract_compute_dtype(op, block)
unit.update(dtype=compute_dtype)
op_stats_list.append(block_op_stats_dict)
return op_stats_list
def collect_operator_stats(program=None, print_subblocks=False):
def _convert_to_list(op_stats_unit_dict):
for key, value in op_stats_unit_dict.items():
op_stats_unit_dict[key] = value.convert_to_list()
return op_stats_unit_dict
if program is None:
program = paddle.static.default_main_program()
op_stats_list = _get_op_stats_list(program)
merged_op_stats = _merge_op_stats(op_stats_list)
if print_subblocks and len(op_stats_list) > 1:
for i in range(len(op_stats_list)):
print("<{:-^120}>".format(" op list of block " + str(i) + " "))
paddle.amp.debugging._print_operator_stats(
_convert_to_list(op_stats_list[i])
)
print("<{:-^120}>".format(" op list of all blocks "))
paddle.amp.debugging._print_operator_stats(
_convert_to_list(merged_op_stats)
)
......@@ -25,7 +25,7 @@ from paddle.fluid import (
)
from .amp_nn import check_finite_and_unscale, update_loss_scaling
from .fp16_lists import AutoMixedPrecisionLists
from .fp16_lists import AutoMixedPrecisionLists, check_amp_dtype
from .fp16_utils import (
cast_model_to_fp16,
cast_parameters_to_fp16,
......@@ -45,7 +45,14 @@ class OptimizerWithMixedPrecision:
Args:
optimizer (Optimizer): A common Optimizer object.
amp_lists (CustomOpLists): An CustomOpLists object.
amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object.
level(str): Auto mixed precision level. Accepted values are
"O1" and "O2": O1 represent mixed precision, the input data type
of each operator will be casted by white_list and black_list;
O2 represent Pure fp16 or bf16, all operators parameters and input
data will be casted to fp16 or bf16, except operators in black_list,
don't support fp16 or bf16 kernel and batch_norm.
dtype(str): Whether to use 'float16' or 'bfloat16'.
init_loss_scaling (float): The initial loss scaling factor.
use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling.
incr_every_n_steps(int): Increases loss scaling every n consecutive
......@@ -57,24 +64,23 @@ class OptimizerWithMixedPrecision:
scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing
the loss scaling.
use_pure_fp16(bool): Whether to use the pure fp16 training. Default False.
use_fp16_guard(bool): Whether to use `fp16_guard` when constructing the program.
use_amp_guard(bool): Whether to use `fp16_guard` when constructing the program.
Default None, which means that its value is equal to `use_pure_fp16`.
"""
def __init__(
self,
optimizer,
amp_lists,
level,
dtype,
init_loss_scaling,
use_dynamic_loss_scaling,
incr_every_n_steps,
decr_every_n_nan_or_inf,
incr_ratio,
decr_ratio,
use_pure_fp16,
use_fp16_guard,
use_amp_guard=None,
):
self._optimizer = optimizer
self._amp_lists = amp_lists
......@@ -86,10 +92,21 @@ class OptimizerWithMixedPrecision:
self._loss_scaling = None
self._init_loss_scaling = init_loss_scaling
self._use_dynamic_loss_scaling = use_dynamic_loss_scaling
if dtype == "bfloat16":
if use_dynamic_loss_scaling:
self._use_dynamic_loss_scaling = False
self._init_loss_scaling = 1.0
warnings.warn(
"Dynamic loss scaling for bfloat16 amp training is disabled, and the init_loss_scaling is changed to 1.0 automatically by PaddlePaddle."
)
self._amp_vartype = core.VarDesc.VarType.BF16
else:
self._amp_vartype = core.VarDesc.VarType.FP16
self._learning_rate = optimizer._learning_rate
self._learning_rate_map = optimizer._learning_rate_map
self._use_pure_fp16 = use_pure_fp16
self._use_fp16_guard = use_fp16_guard
self._use_pure_fp16 = level == "O2"
self._use_fp16_guard = use_amp_guard
self._to_fp16_var_names = None
if self._use_dynamic_loss_scaling:
self._incr_every_n_steps = incr_every_n_steps
......@@ -209,10 +226,15 @@ class OptimizerWithMixedPrecision:
if self._use_pure_fp16:
self._to_fp16_var_names = cast_model_to_fp16(
self._train_program, self._amp_lists, self._use_fp16_guard
self._train_program,
self._amp_lists,
self._use_fp16_guard,
self._amp_vartype,
)
else:
rewrite_program(self._train_program, self._amp_lists)
rewrite_program(
self._train_program, self._amp_lists, self._amp_vartype
)
if loss.dtype != core.VarDesc.VarType.FP32:
loss = loss.astype('float32')
......@@ -258,7 +280,7 @@ class OptimizerWithMixedPrecision:
outputs={'Out': [name]},
attrs={
'in_dtype': core.VarDesc.VarType.FP32,
'out_dtype': core.VarDesc.VarType.FP16,
'out_dtype': self._amp_vartype,
},
)
self._to_fp16_var_names = None
......@@ -326,15 +348,24 @@ class OptimizerWithMixedPrecision:
), "Please call the minimize method first."
if self._use_pure_fp16:
cast_parameters_to_fp16(
place, self._train_program, scope, self._to_fp16_var_names
place,
self._train_program,
scope,
self._to_fp16_var_names,
self._amp_vartype,
)
if test_program is not None:
if self._use_pure_fp16:
cast_model_to_fp16(
test_program, self._amp_lists, self._use_fp16_guard
test_program,
self._amp_lists,
self._use_fp16_guard,
self._amp_vartype,
)
elif use_fp16_test:
rewrite_program(test_program, self._amp_lists)
rewrite_program(
test_program, self._amp_lists, self._amp_vartype
)
def apply_gradients(self, params_grads):
"""
......@@ -368,7 +399,10 @@ class OptimizerWithMixedPrecision:
return optimize_ops
found_inf = self._check_finite_and_unscale(params_grads)
if self._use_dynamic_loss_scaling:
if (
self._use_dynamic_loss_scaling
and self._amp_vartype == core.VarDesc.VarType.FP16
):
self._add_dynamic_loss_scaling(params_grads, found_inf)
# Pass found_inf to adam, to skip update for not only param, but also momentum and beta_pow
......@@ -395,10 +429,10 @@ class OptimizerWithMixedPrecision:
def _split_grads(self, params_grads):
grads = [g for _, g in params_grads]
fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32]
fp16_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP16]
fp16_grads = [g for g in grads if g.dtype == self._amp_vartype]
assert len(fp32_grads) + len(fp16_grads) == len(
grads
), "Data types of all grads must be either fp16 or fp32."
), "Data types of all grads must be either fp16/bf16 or fp32."
return grads, fp32_grads, fp16_grads
def _check_finite_and_unscale(self, params_grads):
......@@ -587,6 +621,7 @@ def decorate(
use_dynamic_loss_scaling=True,
use_pure_fp16=False,
use_fp16_guard=None,
use_bf16=False,
):
"""
Decorate the given optimizer to adapt to the mixed-precision training.
......@@ -608,6 +643,7 @@ def decorate(
use_pure_fp16(bool): Whether to use the pure fp16 training. Default False.
use_fp16_guard(bool): Whether to use `fp16_guard` when constructing the program.
Default None, which means that its value equals to `use_pure_fp16`.
use_bf16(bool): Whether to enable bfloat16 training. Default False.
Returns:
An optimizer acting like a normal one but with mixed-precision training
......@@ -678,23 +714,70 @@ def decorate(
if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
run_example_code()
"""
amp_dtype = "bfloat16" if use_bf16 else "float16"
if amp_lists is None:
amp_lists = AutoMixedPrecisionLists()
amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype)
if use_fp16_guard is None:
use_fp16_guard = use_pure_fp16
amp_level = "O2" if use_pure_fp16 else "O1"
mp_optimizer = OptimizerWithMixedPrecision(
optimizer,
amp_lists,
init_loss_scaling,
use_dynamic_loss_scaling,
incr_every_n_steps,
decr_every_n_nan_or_inf,
incr_ratio,
decr_ratio,
use_pure_fp16,
use_fp16_guard,
level=amp_level,
dtype=amp_dtype,
init_loss_scaling=init_loss_scaling,
use_dynamic_loss_scaling=use_dynamic_loss_scaling,
incr_every_n_steps=incr_every_n_steps,
decr_every_n_nan_or_inf=decr_every_n_nan_or_inf,
incr_ratio=incr_ratio,
decr_ratio=decr_ratio,
use_amp_guard=use_fp16_guard,
)
return mp_optimizer
def amp_decorate(
optimizer,
amp_lists=None,
level='O1',
dtype='float16',
init_loss_scaling=2**15,
incr_every_n_steps=1000,
decr_every_n_nan_or_inf=2,
incr_ratio=2.0,
decr_ratio=0.8,
use_dynamic_loss_scaling=True,
use_amp_guard=False,
):
"""
Decorate the given optimizer to adapt to the mixed-precision training.
"""
amp_dtype = check_amp_dtype(dtype)
if amp_lists is None:
amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype)
# check amp_level: O0-O2
level = level.upper()
if not (level in ['O0', 'O1', 'O2']):
raise ValueError(
"level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode."
)
mp_optimizer = OptimizerWithMixedPrecision(
optimizer,
amp_lists,
level=level,
dtype=amp_dtype,
init_loss_scaling=init_loss_scaling,
use_dynamic_loss_scaling=use_dynamic_loss_scaling,
incr_every_n_steps=incr_every_n_steps,
decr_every_n_nan_or_inf=decr_every_n_nan_or_inf,
incr_ratio=incr_ratio,
decr_ratio=decr_ratio,
use_amp_guard=use_amp_guard,
)
return mp_optimizer
......@@ -13,11 +13,17 @@
# limitations under the License.
import copy
import logging
from paddle.fluid import core
from paddle.fluid.log_helper import get_logger
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
# lookup_table fp16 is slower than fp32, though fp16 is supported.
_extra_unsupported_fp16_list = {
_extra_unsupported_list = {
'lookup_table',
'lookup_table_v2',
'scatter',
......@@ -25,17 +31,95 @@ _extra_unsupported_fp16_list = {
}
def check_amp_dtype(dtype):
"""
Check amp_dtype: float16 or bfloat16
"""
if isinstance(dtype, str):
dtype = dtype.lower()
if dtype not in ['float16', 'bfloat16']:
raise ValueError(
"If enable AMP, dtype should be 'float16' or 'bfloat16'."
)
return dtype
def get_low_precision_vartype(dtype):
if isinstance(dtype, core.VarDesc.VarType):
return dtype
elif isinstance(dtype, str):
dtype = dtype.lower()
if dtype == "float16":
var_type = core.VarDesc.VarType.FP16
elif dtype == "bfloat16":
var_type = core.VarDesc.VarType.BF16
else:
raise ValueError(
"If enable AMP, dtype should be 'float16' or 'bfloat16'."
)
return var_type
else:
raise TypeError(
"The type of dtype is expected to be string or core.VarDesc.VarType, but recieved {}.".format(
type(dtype)
)
)
def get_low_precision_dtypestr(dtype):
if isinstance(dtype, str):
return check_amp_dtype(dtype)
elif isinstance(dtype, core.VarDesc.VarType):
if dtype == core.VarDesc.VarType.FP16:
return "float16"
elif dtype == core.VarDesc.VarType.BF16:
return "bfloat16"
else:
raise ValueError(
"If enable AMP, dtype should be core.VarDesc.VarType.FP16 or core.VarDesc.VarType.BF16."
)
else:
raise TypeError(
"The type of dtype is expected to be string or core.VarDesc.VarType, but recieved {}.".format(
type(dtype)
)
)
def _get_sys_unsupported_list(dtype):
var_type = get_low_precision_vartype(dtype)
# The set of ops that don't support fp16 calculation
device = None
if core.is_compiled_with_xpu():
device = 'XPU'
elif core.is_compiled_with_custom_device('npu'):
device = 'NPU'
else:
device = 'GPU'
_, _, sys_unsupported_list = core.op_supported_infos(device, var_type)
return device, sys_unsupported_list
def _get_unsupported_list(dtype):
# The set of ops that don't support fp16 calculation
_, _sys_unsupported_list = _get_sys_unsupported_list(dtype)
unsupported_list = _extra_unsupported_list | _sys_unsupported_list
return unsupported_list
class AutoMixedPrecisionLists:
"""
AutoMixedPrecisionLists is a class for black/white list. It can update
pre-defined black list and white list according to users' custom black
white lists. The lists are used for an algorithm which determines op's
execution mode (fp32 or fp16).
execution mode (fp32, fp16 or bf16).
Args:
custom_white_list (set): Users' custom white list.
custom_black_list (set): Users' custom black list.
custom_black_varnames (set): Users' custom black varibles' names.
dtype (str): the low precision dtype, which can be set to 'float16' or 'bfloat16'.
"""
def __init__(
......@@ -43,13 +127,15 @@ class AutoMixedPrecisionLists:
custom_white_list=None,
custom_black_list=None,
custom_black_varnames=None,
dtype="float16",
):
self.amp_dtype = check_amp_dtype(dtype)
self._custom_white_list = custom_white_list
self._custom_black_list = custom_black_list
self.white_list = copy.copy(white_list)
self.black_list = copy.copy(black_list)
self.gray_list = copy.copy(gray_list)
self.unsupported_list = copy.copy(unsupported_fp16_list)
self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype))
self.black_varnames = copy.copy(custom_black_varnames)
self._update_list()
......@@ -61,7 +147,7 @@ class AutoMixedPrecisionLists:
for op_name in self._custom_white_list:
if op_name in self._custom_black_list:
raise ValueError(
"Custom white list overlap " "custom black list"
f"The given custom_white_list overlaps custom_black_list with < {op_name} >!"
)
if self._custom_white_list:
for op_name in self._custom_white_list:
......@@ -70,7 +156,7 @@ class AutoMixedPrecisionLists:
elif op_name in self.gray_list:
self.gray_list.remove(op_name)
self.white_list.add(op_name)
if op_name in _extra_unsupported_fp16_list:
if op_name in _extra_unsupported_list:
self.unsupported_list.remove(op_name)
if self._custom_black_list:
for op_name in self._custom_black_list:
......@@ -80,6 +166,15 @@ class AutoMixedPrecisionLists:
self.gray_list.remove(op_name)
self.black_list.add(op_name)
self.unsupported_list.add(op_name)
device, sys_unsupported_list = _get_sys_unsupported_list(self.amp_dtype)
actual_unsupported_list = []
for op_name in sys_unsupported_list:
if op_name in self.white_list:
actual_unsupported_list.append(op_name)
if len(actual_unsupported_list) > 0:
_logger.warning(
f"On current {device}, {self.amp_dtype} is not supported for operators < {actual_unsupported_list} > in white_list!"
)
# The three sets listed below are changed dynamiclly. They don't contain all
......@@ -175,24 +270,4 @@ gray_list = {
'fused_multi_transformer',
}
# The set of ops that don't support fp16 calculation
# lookup_table fp16 is slower than fp32, though fp16 is supported.
_sys_unsupported_fp16_list = []
if core.is_compiled_with_xpu():
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'XPU', core.VarDesc.VarType.FP16
)
elif core.is_compiled_with_custom_device('npu'):
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'NPU', core.VarDesc.VarType.FP16
)
else:
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'GPU', core.VarDesc.VarType.FP16
)
unsupported_fp16_list = (
_extra_unsupported_fp16_list | _sys_unsupported_fp16_list
)
CustomOpLists = AutoMixedPrecisionLists
......@@ -17,11 +17,12 @@ import logging
import numpy as np
import paddle
from paddle.fluid import core, framework, global_scope
from paddle.fluid.log_helper import get_logger
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
from .fp16_lists import AutoMixedPrecisionLists
from .fp16_lists import AutoMixedPrecisionLists, get_low_precision_dtypestr
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
......@@ -72,7 +73,9 @@ def _dtype_to_str(dtype):
Args:
dtype (VarType): Variable type.
"""
if dtype == core.VarDesc.VarType.FP16:
if dtype in [core.VarDesc.VarType.FP16, core.VarDesc.VarType.BF16]:
# TODO(Xreki): change the returned str to "bf16" for BF16 data type.
# Currently too many codes use "cast_fp16" as key.
return 'fp16'
else:
return 'fp32'
......@@ -220,10 +223,10 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
else:
if op.has_attr('in_dtype'):
op._set_attr('in_dtype', dest_dtype)
if (
src_dtype == core.VarDesc.VarType.FP32
and dest_dtype == core.VarDesc.VarType.FP16
):
if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype in [
core.VarDesc.VarType.FP16,
core.VarDesc.VarType.BF16,
]:
for out_name in op.output_names:
if _keep_fp32_output(op, out_name):
continue
......@@ -232,9 +235,9 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
if out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(core.VarDesc.VarType.FP16)
out_var.desc.set_dtype(dest_dtype)
if op.has_attr('out_dtype'):
op._set_attr('out_dtype', core.VarDesc.VarType.FP16)
op._set_attr('out_dtype', dest_dtype)
return num_cast_ops
......@@ -417,7 +420,12 @@ def fp16_guard():
yield
def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
def cast_model_to_fp16(
program,
amp_lists=None,
use_fp16_guard=True,
dest_type=core.VarDesc.VarType.FP16,
):
"""
Traverse all ops in the whole model and set their inputs and outputs
to the fp16 data type. This function will do some special process for
......@@ -428,10 +436,12 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object.
use_fp16_guard(bool): Determine whether to use `fp16_guard` when
constructing the program. Default True.
dest_type(core.VarDesc.VarType): the cast type. such as core.VarDesc.VarType.FP16 and core.VarDesc.VarType.BF16.
"""
if amp_lists is None:
amp_lists = AutoMixedPrecisionLists()
dtype = get_low_precision_dtypestr(dest_type)
amp_lists = AutoMixedPrecisionLists(dtype)
amp_lists.unsupported_list -= {
"conditional_block_grad",
"conditional_block",
......@@ -487,7 +497,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
continue
if in_var.dtype == core.VarDesc.VarType.FP32:
in_var.desc.set_dtype(core.VarDesc.VarType.FP16)
in_var.desc.set_dtype(dest_type)
to_fp16_var_names.add(in_var_name)
_logger.debug(
......@@ -524,28 +534,19 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
continue
if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(core.VarDesc.VarType.FP16)
out_var.desc.set_dtype(dest_type)
_logger.debug(
"-- op type: {}, out var name: {}, out var dtype: {} --".format(
op.type, out_var_name, out_var.dtype
)
)
if (
op.has_attr('in_dtype')
and op.attr('in_dtype') == core.VarDesc.VarType.FP32
):
op._set_attr('in_dtype', core.VarDesc.VarType.FP16)
if (
op.has_attr('out_dtype')
and op.attr('out_dtype') == core.VarDesc.VarType.FP32
):
op._set_attr('out_dtype', core.VarDesc.VarType.FP16)
if (
op.has_attr('dtype')
and op.attr('dtype') == core.VarDesc.VarType.FP32
):
op._set_attr('dtype', core.VarDesc.VarType.FP16)
for attr_name in ['in_dtype', 'out_dtype', 'dtype']:
if (
op.has_attr(attr_name)
and op.attr(attr_name) == core.VarDesc.VarType.FP32
):
op._set_attr(attr_name, dest_type)
# process ops in keep_fp32_ops
op_var_rename_map = [
......@@ -562,7 +563,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
block,
op,
idx,
core.VarDesc.VarType.FP16,
dest_type,
core.VarDesc.VarType.FP32,
)
num_cast_ops += pre_cast_num
......@@ -570,7 +571,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP16:
if out_var.dtype == dest_type:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
post_ops = find_true_post_op(ops, op, out_var_name)
for post_op in post_ops:
......@@ -581,7 +582,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
op,
idx + pre_cast_num + 1,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
dest_type,
out_var_name,
op_var_rename_map,
)
......@@ -592,7 +593,22 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
return to_fp16_var_names
def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None):
def _convert_float_to_bfloat16(place, fp32_array):
paddle.disable_static()
framework._set_expected_place(place)
fp32_tensor = paddle.to_tensor(fp32_array)
bf16_array = paddle.cast(fp32_tensor, paddle.bfloat16).numpy()
paddle.enable_static()
return bf16_array
def cast_parameters_to_fp16(
place,
program,
scope=None,
to_fp16_var_names=None,
dest_type=core.VarDesc.VarType.FP16,
):
"""
Traverse all parameters in the whole model and set them to the FP16 data type.
Whereas, this function will keep parameters of batchnorms in FP32.
......@@ -604,6 +620,7 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None):
to_fp16_var_names(set|list, optional): The data types of vars in `to_fp16_var_names`
will be set to FP16. Usually, it is the returned
value of `cast_model_to_fp16` API.
dest_type(core.VarDesc.VarType): the cast type. such as core.VarDesc.VarType.FP16 and core.VarDesc.VarType.BF16.
"""
all_parameters = []
for block in program.blocks:
......@@ -613,13 +630,20 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None):
var_scope = scope if scope else global_scope()
for param in all_parameters:
if param.name in fp16_var_names:
_logger.debug(f"---- cast {param.name} to fp16 dtype ----")
param_t = var_scope.find_var(param.name).get_tensor()
data = np.array(param_t)
param_t.set(np.float16(data), place)
_logger.debug(f"---- cast {param.name} to fp16/bf16 dtype ----")
if var_scope.find_var(param.name):
param_t = var_scope.find_var(param.name).get_tensor()
data = np.array(param_t)
if dest_type == core.VarDesc.VarType.BF16:
bf16_data = _convert_float_to_bfloat16(place, data)
param_t.set(bf16_data, place)
else:
param_t.set(np.float16(data), place)
else:
_logger.warning(f"Cannot find {param.name}")
def rewrite_program(main_prog, amp_lists):
def rewrite_program(main_prog, amp_lists, dest_type=core.VarDesc.VarType.FP16):
"""
Traverse all ops in current block and insert cast op according to
which set current op belongs to.
......@@ -638,6 +662,7 @@ def rewrite_program(main_prog, amp_lists):
Args:
main_prog (Program): The main program for training.
dest_type(core.VarDesc.VarType): the cast type. such as core.VarDesc.VarType.FP16 and core.VarDesc.VarType.BF16.
"""
block = main_prog.global_block()
block._sync_with_cpp()
......@@ -708,19 +733,11 @@ def rewrite_program(main_prog, amp_lists):
num_cast_ops = 0
if op in black_op_set:
num_cast_ops = _insert_cast_op(
block,
op,
idx,
core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32,
block, op, idx, dest_type, core.VarDesc.VarType.FP32
)
elif op in white_op_set:
num_cast_ops = _insert_cast_op(
block,
op,
idx,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16,
block, op, idx, core.VarDesc.VarType.FP32, dest_type
)
else:
pass
......
......@@ -200,6 +200,7 @@ def create_parameter(
[
'bool',
'float16',
'uint16',
'float32',
'float64',
'int8',
......
......@@ -45,3 +45,7 @@ endfunction()
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach()
if(APPLE)
set_tests_properties(test_model_cast_to_bf16 PROPERTIES TIMEOUT 300)
endif()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from paddle import nn
_fixed_add_param = np.random.random(size=[16, 16]).astype("float32")
def _build_optimizer(
use_amp, amp_dtype="float16", amp_level="O1", use_grad_clip=False
):
if use_grad_clip:
grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
else:
grad_clip = None
optimizer = paddle.optimizer.AdamW(
learning_rate=0.01,
grad_clip=grad_clip,
beta1=0.78,
beta2=0.836,
epsilon=1e-4,
weight_decay=0.01,
multi_precision=True,
)
if use_amp:
amp_lists = paddle.static.amp.AutoMixedPrecisionLists(
custom_white_list=["elementwise_add"],
custom_black_list=["reduce_mean"],
dtype=amp_dtype,
)
optimizer = paddle.static.amp.amp_decorate(
optimizer, amp_lists=amp_lists, level=amp_level, dtype=amp_dtype
)
return optimizer
class SimpleAddNet(nn.Layer):
def __init__(self, dtype):
super().__init__()
global _fixed_add_param
self.weight = paddle.create_parameter(
name="add_w",
shape=[16, 16],
dtype=dtype,
attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Assign(_fixed_add_param)
),
)
def forward(self, x):
return x + self.weight
def build_add_model(use_amp, amp_dtype="float16", amp_level="O1"):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main_program, startup_program):
x_dtype = "float32"
if use_amp and amp_level == "O2":
if amp_dtype == "bfloat16":
x_dtype = "uint16"
elif amp_dtype == "float16":
x_dtype = "float16"
model = SimpleAddNet(x_dtype)
x = paddle.static.data(name='input', shape=[16, 16], dtype=x_dtype)
out = model(x)
loss = paddle.mean(out)
optimizer = _build_optimizer(use_amp, amp_dtype, amp_level)
optimizer.minimize(loss)
feed_vars = [x]
fetch_vars = [loss]
return main_program, startup_program, optimizer, feed_vars, fetch_vars
class SimpleConvNet(nn.Layer):
def __init__(self):
super().__init__()
self.conv = nn.Conv2D(in_channels=1, out_channels=6, kernel_size=3)
self.linear = nn.Linear(in_features=6, out_features=10)
def forward(self, x):
out = self.conv(x)
out = nn.functional.relu(out)
out = self.linear(out)
out = nn.functional.softmax(out)
return out
def build_conv_model(use_amp, amp_dtype="float16", amp_level="O1"):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main_program, startup_program):
model = SimpleConvNet()
x = paddle.static.data(
name='input', shape=[None, 1, 28, 28], dtype='float32'
)
out = model(x)
loss = paddle.mean(out)
optimizer = _build_optimizer(use_amp, amp_dtype, amp_level)
optimizer.minimize(loss)
return main_program, startup_program
class SimpleEmbeddingNet(nn.Layer):
def __init__(self):
super().__init__()
self.vocab_size = 128
self.hidden_size = 16
self.vocab_size = 128
self.hidden_size = 16
self.embedding = nn.Embedding(self.vocab_size, self.hidden_size)
self.linear = nn.Linear(in_features=16, out_features=10)
def forward(self, x):
out = self.embedding(x)
scale = paddle.full(shape=[1], fill_value=2, dtype="int64")
out = paddle.multiply(out, scale.astype("float32"))
out = self.linear(out)
out = nn.functional.dropout(out, p=0.2)
return out
def build_embedding_model(use_amp, amp_dtype="float16", amp_level="O1"):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main_program, startup_program):
model = SimpleEmbeddingNet()
x = paddle.static.data(name='x', shape=[None, 32], dtype='int64')
out = model(x)
loss = paddle.mean(out)
optimizer = _build_optimizer(use_amp, amp_dtype, amp_level, True)
optimizer.minimize(loss)
return main_program, startup_program
class SimpleWhileNet(nn.Layer):
def __init__(self):
super().__init__()
self.linear = paddle.nn.Linear(16, 10)
def forward(self, x):
def cond(i, loop_len, x, result):
return i < loop_len
def body(i, loop_len, x, result):
result = self.linear(x)
paddle.increment(i)
return [i, loop_len, x, result]
i = paddle.zeros(shape=[1], dtype='int64')
loop_len = paddle.ones(shape=[1], dtype='int64')
result = paddle.zeros(
shape=x.shape[:-1] + self.linear.weight.shape[-1:], dtype="float32"
)
result.stop_gradient = False
_, _, _, results = paddle.static.nn.while_loop(
cond, body, [i, loop_len, x, result]
)
return results
def build_while_model():
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main_program, startup_program):
model = SimpleWhileNet()
x = paddle.static.data(name='x', shape=[32, 16], dtype='float32')
out = model(x)
loss = paddle.mean(out)
return main_program, startup_program
......@@ -14,6 +14,8 @@
import unittest
from paddle.fluid import core
from paddle.static.amp import fp16_lists
from paddle.static.amp.fp16_lists import AutoMixedPrecisionLists
......@@ -39,6 +41,48 @@ class TestAMPList(unittest.TestCase):
for op in default_black_list:
self.assertTrue(op in amp_list.black_list)
def test_apis(self):
def _run_check_dtype():
fp16_lists.check_amp_dtype(dtype="int64")
self.assertRaises(ValueError, _run_check_dtype)
for vartype in [core.VarDesc.VarType.FP16, core.VarDesc.VarType.BF16]:
self.assertEqual(
fp16_lists.get_low_precision_vartype(vartype), vartype
)
self.assertEqual(
fp16_lists.get_low_precision_vartype("float16"),
core.VarDesc.VarType.FP16,
)
self.assertEqual(
fp16_lists.get_low_precision_vartype("bfloat16"),
core.VarDesc.VarType.BF16,
)
def _run_get_vartype():
fp16_lists.get_low_precision_vartype(dtype="int64")
self.assertRaises(ValueError, _run_get_vartype)
for dtype in ["float16", "bfloat16"]:
self.assertEqual(
fp16_lists.get_low_precision_dtypestr(dtype), dtype
)
self.assertEqual(
fp16_lists.get_low_precision_dtypestr(core.VarDesc.VarType.FP16),
"float16",
)
self.assertEqual(
fp16_lists.get_low_precision_dtypestr(core.VarDesc.VarType.BF16),
"bfloat16",
)
def _run_get_dtypestr():
fp16_lists.get_low_precision_dtypestr(dtype="int64")
self.assertRaises(ValueError, _run_get_dtypestr)
if __name__ == "__main__":
unittest.main()
......@@ -14,10 +14,12 @@
import unittest
from amp_base_models import build_while_model
import paddle
class TestAMPList(unittest.TestCase):
class TestOpStatsEager(unittest.TestCase):
def _check_result(self, dtype):
# Returned the dict.
op_list = paddle.fluid.core.get_low_precision_op_list()
......@@ -65,5 +67,17 @@ class TestAMPList(unittest.TestCase):
self._check_result(dtype=out.dtype)
class TestOpStatsStatic(unittest.TestCase):
def test_while_op(self):
paddle.enable_static()
main_program, startup_program = build_while_model()
self.assertEqual(main_program.num_blocks, 2)
paddle.static.amp.debugging.collect_operator_stats(
program=main_program, print_subblocks=True
)
paddle.disable_static()
if __name__ == "__main__":
unittest.main()
......@@ -17,6 +17,7 @@ import struct
import unittest
import numpy as np
from amp_base_models import build_add_model, build_embedding_model
import paddle
from paddle import fluid
......@@ -26,6 +27,21 @@ from paddle.static import amp
paddle.enable_static()
def copy_bits_from_float_to_uint16(f):
return struct.unpack('<I', struct.pack('<f', f))[0] >> 16
def convert_float_to_uint16(in_list):
if in_list.dtype == np.float32:
new_output = []
for x in np.nditer(in_list):
new_output.append(np.uint16(copy_bits_from_float_to_uint16(x)))
new_output = np.reshape(new_output, in_list.shape).view(np.uint16)
return new_output
else:
return in_list
def convert_uint16_to_float(in_list):
if in_list.dtype == np.uint16:
in_list = np.asarray(in_list)
......@@ -204,5 +220,123 @@ class TestModelCastBF16(unittest.TestCase):
)
@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA and not support the bfloat16",
)
class TestProgramBF16(unittest.TestCase):
def _check_bf16_calls(self, op_stats_dict, expected_bf16_calls):
for op_type, value in expected_bf16_calls.items():
self.assertEqual(
op_stats_dict[op_type].bf16_calls,
value,
f"The number of bf16 calls of operator < {op_type} > is expected to be {value}, but recieved {op_stats_dict[op_type].bf16_calls}.",
)
def test_amp_bf16_o1(self):
main_program, startup_program = build_embedding_model(
True, "bfloat16", "O1"
)
self.assertEqual(main_program.num_blocks, 1)
amp.debugging.collect_operator_stats(main_program)
op_stats_list = amp.debugging._get_op_stats_list(main_program)
expected_bf16_calls = {
"matmul_v2": 1,
"elementwise_add": 1,
"dropout": 1,
"lookup_table_v2": 0,
"squared_l2_norm": 0,
"adamw": 0,
}
self._check_bf16_calls(op_stats_list[0], expected_bf16_calls)
def test_amp_bf16_o2(self):
main_program, startup_program = build_embedding_model(
True, "bfloat16", "O2"
)
self.assertEqual(main_program.num_blocks, 1)
amp.debugging.collect_operator_stats(main_program)
op_stats_list = amp.debugging._get_op_stats_list(main_program)
expected_bf16_calls = {
"matmul_v2": 1,
"elementwise_add": 1,
"dropout": 1,
"lookup_table_v2": 0,
"squared_l2_norm": 2,
"adamw": 2,
}
self._check_bf16_calls(op_stats_list[0], expected_bf16_calls)
@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA and not support the bfloat16",
)
class TestStaticBF16(unittest.TestCase):
def _generate_feed_x(self):
x = np.random.random(size=[16, 16]).astype("float32")
x_bf16 = convert_float_to_uint16(x)
x_fp32 = convert_uint16_to_float(x_bf16)
return x_fp32, x_bf16
def test_compare_o1_o2(self):
def _run_o1(exe, x_np, max_iters):
(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
) = build_add_model(True, "bfloat16", "O1")
losses = []
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
exe.run(startup_program)
for iter_id in range(max_iters):
results = exe.run(
program=main_program,
feed={feed_vars[0].name: x_np},
fetch_list=fetch_vars,
)
print(f"-- [BF16 O1] iter={iter_id}, loss={results[0]}")
losses.append(results[0])
return losses
def _run_o2(exe, x_np, max_iters):
(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
) = build_add_model(True, "bfloat16", "O2")
losses = []
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
exe.run(startup_program)
optimizer.amp_init(place)
for iter_id in range(max_iters):
results = exe.run(
program=main_program,
feed={feed_vars[0].name: x_np},
fetch_list=fetch_vars,
)
print(f"-- [BF16 O2] iter={iter_id}, loss={results[0]}")
losses.append(results[0])
return losses
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
max_iters = 2
x_fp32, x_bf16 = self._generate_feed_x()
losses_o1 = _run_o1(exe, x_fp32, max_iters)
losses_o2 = _run_o2(exe, x_bf16, max_iters)
if __name__ == '__main__':
unittest.main()
......@@ -22,7 +22,3 @@ py_test_modules(
set_tests_properties(test_image_classification_fp16 PROPERTIES TIMEOUT 120)
set_tests_properties(test_weight_decay_extend PROPERTIES TIMEOUT 120)
set_tests_properties(test_multi_precision_fp16_train PROPERTIES TIMEOUT 120)
if(APPLE)
set_tests_properties(test_model_cast_to_bf16 PROPERTIES TIMEOUT 300)
endif()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册