未验证 提交 6959eae5 编写于 作者: Y Yiqun Liu 提交者: GitHub

Unify the static amp codes of fp16 and bf16. Reimplement #52694 in release/2.4. (#52697)

上级 d1e8b1e2
......@@ -15,7 +15,7 @@
from __future__ import print_function
from . import decorator
from .decorator import *
from .decorator import decorate, amp_decorate
from . import fp16_lists
from .fp16_lists import *
from . import fp16_utils
......
......@@ -27,8 +27,8 @@ def check_finite_and_unscale(x, scale, name=None, float_status=None):
$$Out = X / scale$$
If any tensor in X contains Inf or Nan, the Out will generate a indicator.
FoundInfinite will be 1 (True), and Out will not be scaled. In this case, the data of
Out should not be used, and its data may not be deterministic.
FoundInfinite will be 1 (True), and Out will not be scaled. In this case, the data of
Out should not be used, and its data may not be deterministic.
Otherwise, FoundInfinite will be 0 (False).
Args:
......@@ -38,75 +38,98 @@ def check_finite_and_unscale(x, scale, name=None, float_status=None):
"""
check_type(x, 'x', (tuple, list), 'check_finite_and_unscale')
for e in x:
check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'],
'check_finite_and_unscale')
check_variable_and_dtype(
e,
"x",
['float16', 'float32', 'float64', 'uint16'],
'check_finite_and_unscale',
)
helper = LayerHelper("check_finite_and_unscale", **locals())
found_inf = helper.create_variable_for_type_inference(dtype='bool')
inputs = {'X': x, 'Scale': scale}
if core.is_compiled_with_npu():
check_variable_and_dtype(float_status, "float_status",
['float16', 'float32'],
'check_finite_and_unscale')
check_variable_and_dtype(
float_status,
"float_status",
['float16', 'float32'],
'check_finite_and_unscale',
)
inputs['FloatStatus'] = float_status
outputs = {'Out': x, 'FoundInfinite': found_inf}
helper.append_op(type='check_finite_and_unscale',
inputs=inputs,
outputs=outputs)
helper.append_op(
type='check_finite_and_unscale', inputs=inputs, outputs=outputs
)
return x, found_inf
def update_loss_scaling(x,
found_inf,
prev_loss_scaling,
num_good_steps,
num_bad_steps,
incr_every_n_steps,
decr_every_n_nan_or_inf,
incr_ratio,
decr_ratio,
stop_update=False,
name=None):
def update_loss_scaling(
x,
found_inf,
prev_loss_scaling,
num_good_steps,
num_bad_steps,
incr_every_n_steps,
decr_every_n_nan_or_inf,
incr_ratio,
decr_ratio,
stop_update=False,
name=None,
):
"""
Update loss scaling according to overall gradients. If all gradients is
finite after incr_every_n_steps, loss scaling will increase by incr_ratio.
Update loss scaling according to overall gradients. If all gradients is
finite after incr_every_n_steps, loss scaling will increase by incr_ratio.
Otherwise, loss scaling will decrease by decr_ratio after
decr_every_n_nan_or_inf steps and each step some gradients are infinite.
Args:
x(list|tuple): The input tensors of update_loss_scaling operator.
found_inf (Variable): A boolean variable indicates whether
found_inf (Variable): A boolean variable indicates whether
there is any infinite gradient.
prev_loss_scaling (Variable): Previous loss scaling.
num_good_steps (Variable): A variable accumulates good steps in which
num_good_steps (Variable): A variable accumulates good steps in which
all gradients are finite.
num_bad_steps (Variable): A variable accumulates bad steps in which
num_bad_steps (Variable): A variable accumulates bad steps in which
some gradients are infinite.
incr_every_n_steps (int): A variable represents increasing loss
scaling every n consecutive steps with
incr_every_n_steps (int): A variable represents increasing loss
scaling every n consecutive steps with
finite gradients.
decr_every_n_nan_or_inf (int): A variable represents decreasing
loss scaling every n accumulated
decr_every_n_nan_or_inf (int): A variable represents decreasing
loss scaling every n accumulated
steps with nan or inf gradients.
incr_ratio(float): The multiplier to use when increasing the loss
incr_ratio(float): The multiplier to use when increasing the loss
scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing
decr_ratio(float): The less-than-one-multiplier to use when decreasing
loss scaling.
"""
check_variable_and_dtype(prev_loss_scaling, "prev_loss_scaling",
['float32', 'float64'], "update_loss_scaling")
check_variable_and_dtype(
prev_loss_scaling,
"prev_loss_scaling",
['float32', 'float64'],
"update_loss_scaling",
)
check_type(x, 'x', (tuple, list), 'update_loss_scaling')
for e in x:
check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'],
'update_loss_scaling')
if e.dtype == core.VarDesc.VarType.FP16:
assert prev_loss_scaling.dtype == core.VarDesc.VarType.FP32, \
"The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
check_variable_and_dtype(
e,
"x",
['float16', 'float32', 'float64', 'uint16'],
'update_loss_scaling',
)
if (
e.dtype == core.VarDesc.VarType.FP16
or e.dtype == core.VarDesc.VarType.BF16
):
assert (
prev_loss_scaling.dtype == core.VarDesc.VarType.FP32
), "The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
else:
assert prev_loss_scaling.dtype == e.dtype, "The dtype of prev_loss_scaling should be equal to the dtype of x."
assert (
prev_loss_scaling.dtype == e.dtype
), "The dtype of prev_loss_scaling should be equal to the dtype of x."
helper = LayerHelper("update_loss_scaling", **locals())
......@@ -115,14 +138,14 @@ def update_loss_scaling(x,
'FoundInfinite': found_inf,
'PrevLossScaling': prev_loss_scaling,
'InGoodSteps': num_good_steps,
'InBadSteps': num_bad_steps
'InBadSteps': num_bad_steps,
}
outputs = {
'Out': x,
'LossScaling': prev_loss_scaling,
'OutGoodSteps': num_good_steps,
'OutBadSteps': num_bad_steps
'OutBadSteps': num_bad_steps,
}
attrs = {
......@@ -137,9 +160,8 @@ def update_loss_scaling(x,
else:
attrs['stop_update'] = stop_update
helper.append_op(type='update_loss_scaling',
inputs=inputs,
outputs=outputs,
attrs=attrs)
helper.append_op(
type='update_loss_scaling', inputs=inputs, outputs=outputs, attrs=attrs
)
return x
......@@ -36,11 +36,11 @@ __all__ = ["decorate"]
class OptimizerWithMixedPrecision(object):
"""
Optimizer with mixed-precision (MP) training. This is a wrapper of a common
Optimizer with mixed-precision (MP) training. This is a wrapper of a common
optimizer, plus the support of mixed-precision pre-training. The object
of this class almost has the same behavior as the common optimizer, with the
methods `minimize()`, `backward()`, `apply_gradients()` implemented.
Additionally, it enables the MP training automatically, i.e, the creation
of this class almost has the same behavior as the common optimizer, with the
methods `minimize()`, `backward()`, `apply_gradients()` implemented.
Additionally, it enables the MP training automatically, i.e, the creation
and maintenance of master parameters, scaling of loss, etc.
Args:
......@@ -48,14 +48,14 @@ class OptimizerWithMixedPrecision(object):
amp_lists (CustomOpLists): An CustomOpLists object.
init_loss_scaling (float): The initial loss scaling factor.
use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling.
incr_every_n_steps(int): Increases loss scaling every n consecutive
incr_every_n_steps(int): Increases loss scaling every n consecutive
steps with finite gradients.
decr_every_n_nan_or_inf(int): Decreases loss scaling every n
accumulated steps with nan or
decr_every_n_nan_or_inf(int): Decreases loss scaling every n
accumulated steps with nan or
inf gradients.
incr_ratio(float): The multiplier to use when increasing the loss
incr_ratio(float): The multiplier to use when increasing the loss
scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing
decr_ratio(float): The less-than-one-multiplier to use when decreasing
the loss scaling.
use_pure_fp16(bool): Whether to use the pure fp16 training. Default False.
use_fp16_guard(bool): Whether to use `fp16_guard` when constructing the program.
......@@ -63,10 +63,20 @@ class OptimizerWithMixedPrecision(object):
"""
def __init__(self, optimizer, amp_lists, init_loss_scaling,
use_dynamic_loss_scaling, incr_every_n_steps,
decr_every_n_nan_or_inf, incr_ratio, decr_ratio, use_pure_fp16,
use_fp16_guard):
def __init__(
self,
optimizer,
amp_lists,
init_loss_scaling,
use_dynamic_loss_scaling,
incr_every_n_steps,
decr_every_n_nan_or_inf,
incr_ratio,
decr_ratio,
use_pure_fp16,
use_fp16_guard,
use_bf16=False,
):
self._optimizer = optimizer
self._amp_lists = amp_lists
self._param_grads = None
......@@ -77,11 +87,23 @@ class OptimizerWithMixedPrecision(object):
self._loss_scaling = None
self._init_loss_scaling = init_loss_scaling
self._use_dynamic_loss_scaling = use_dynamic_loss_scaling
if use_bf16:
if use_dynamic_loss_scaling:
self._use_dynamic_loss_scaling = False
self._init_loss_scaling = 1.0
warnings.warn(
"Dynamic loss scaling for bfloat16 amp training is disabled, and the init_loss_scaling is changed to 1.0 automatically by PaddlePaddle."
)
self._amp_dtype = core.VarDesc.VarType.BF16
else:
self._amp_dtype = core.VarDesc.VarType.FP16
self._learning_rate = optimizer._learning_rate
self._learning_rate_map = optimizer._learning_rate_map
self._use_pure_fp16 = use_pure_fp16
self._use_fp16_guard = use_fp16_guard
self._to_fp16_var_names = None
self._use_bf16 = use_bf16
if self._use_dynamic_loss_scaling:
self._incr_every_n_steps = incr_every_n_steps
self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf
......@@ -97,9 +119,10 @@ class OptimizerWithMixedPrecision(object):
self._is_distributed = flag
def get_loss_scaling(self):
"""Return the real-time loss scaling factor.
"""
assert self._loss_scaling is not None, 'Please call minimize() before calling get_loss_scaling().'
"""Return the real-time loss scaling factor."""
assert (
self._loss_scaling is not None
), 'Please call minimize() before calling get_loss_scaling().'
return self._loss_scaling
def get_scaled_loss(self):
......@@ -117,7 +140,8 @@ class OptimizerWithMixedPrecision(object):
shape=[1],
value=self._init_loss_scaling,
dtype='float32',
persistable=True)
persistable=True,
)
if self._use_dynamic_loss_scaling:
self._num_good_steps = layers.create_global_var(
......@@ -125,37 +149,43 @@ class OptimizerWithMixedPrecision(object):
shape=[1],
value=0,
dtype='int32',
persistable=True)
persistable=True,
)
self._num_bad_steps = layers.create_global_var(
name=unique_name.generate("num_bad_steps"),
shape=[1],
value=0,
dtype='int32',
persistable=True)
persistable=True,
)
# Ensure the data type of learning rate vars is float32 (same as the
# master parameter dtype)
if isinstance(self._optimizer._learning_rate, float):
self._optimizer._learning_rate_map[default_main_program()] = \
layers.create_global_var(
name=unique_name.generate("learning_rate"),
shape=[1],
value=float(self._optimizer._learning_rate),
dtype='float32',
persistable=True)
def backward(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None,
callbacks=None):
self._optimizer._learning_rate_map[
default_main_program()
] = layers.create_global_var(
name=unique_name.generate("learning_rate"),
shape=[1],
value=float(self._optimizer._learning_rate),
dtype='float32',
persistable=True,
)
def backward(
self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None,
callbacks=None,
):
"""
Backward propagation or auto differentiation for gradients' computation.
Args:
loss (Variable): The loss Variable to minimize.
startup_program (Program|None): The startup Program for initializing
startup_program (Program|None): The startup Program for initializing
parameters in `parameter_list`.
parameter_list (list|None): A list of Variables to update.
no_grad_set (set|None): A set of Variables should be ignored.
......@@ -163,7 +193,7 @@ class OptimizerWithMixedPrecision(object):
backward operator for one parameter.
Returns:
A list of (param, grad), which is a tuple of a parameter and its
A list of (param, grad), which is a tuple of a parameter and its
gradient respectively, and the scaled loss.
"""
train_program = loss.block.program
......@@ -171,9 +201,9 @@ class OptimizerWithMixedPrecision(object):
# NOTE(zhiqiu): _float_status is only used for NPU.
if core.is_compiled_with_npu():
float_status = paddle.static.data(name="float_status",
shape=[8],
dtype='float32')
float_status = paddle.static.data(
name="float_status", shape=[8], dtype='float32'
)
self._train_program.global_block().append_op(
type="alloc_float_status",
outputs={"FloatStatus": float_status},
......@@ -192,9 +222,15 @@ class OptimizerWithMixedPrecision(object):
if self._use_pure_fp16:
self._to_fp16_var_names = cast_model_to_fp16(
self._train_program, self._amp_lists, self._use_fp16_guard)
self._train_program,
self._amp_lists,
self._use_fp16_guard,
self._amp_dtype,
)
else:
rewrite_program(self._train_program, self._amp_lists)
rewrite_program(
self._train_program, self._amp_lists, self._amp_dtype
)
if loss.dtype != core.VarDesc.VarType.FP32:
loss = loss.astype('float32')
......@@ -205,10 +241,13 @@ class OptimizerWithMixedPrecision(object):
else:
self._scaled_loss = loss
params_grads = self._optimizer.backward(self._scaled_loss,
startup_program,
parameter_list, no_grad_set,
callbacks)
params_grads = self._optimizer.backward(
self._scaled_loss,
startup_program,
parameter_list,
no_grad_set,
callbacks,
)
if self._supports_check_nan_inf():
self._add_cast_ops_to_startup_program(startup_program)
return params_grads
......@@ -216,8 +255,11 @@ class OptimizerWithMixedPrecision(object):
def _add_cast_ops_to_startup_program(self, startup_program):
names = list(self._to_fp16_var_names) if self._to_fp16_var_names else []
names.sort()
startup_program = default_startup_program(
) if startup_program is None else startup_program
startup_program = (
default_startup_program()
if startup_program is None
else startup_program
)
block = startup_program.global_block()
param_names = [p.name for p in block.all_parameters()]
for name in names:
......@@ -225,28 +267,28 @@ class OptimizerWithMixedPrecision(object):
continue
tmp = block.create_var(dtype=core.VarDesc.VarType.FP32)
block.append_op(type='assign',
inputs={'X': [name]},
outputs={'Out': [tmp]})
block.append_op(type='cast',
inputs={'X': [tmp]},
outputs={'Out': [name]},
attrs={
'in_dtype': core.VarDesc.VarType.FP32,
'out_dtype': core.VarDesc.VarType.FP16,
})
block.append_op(
type='assign', inputs={'X': [name]}, outputs={'Out': [tmp]}
)
block.append_op(
type='cast',
inputs={'X': [tmp]},
outputs={'Out': [name]},
attrs={
'in_dtype': core.VarDesc.VarType.FP32,
'out_dtype': self._amp_dtype,
},
)
self._to_fp16_var_names = None
def amp_init(self,
place,
scope=None,
test_program=None,
use_fp16_test=False):
def amp_init(
self, place, scope=None, test_program=None, use_fp16_test=False
):
"""
Init the amp training, such as cast fp32 parameters to fp16 type.
Args:
place(CUDAPlace): place is used to initialize
place(CUDAPlace): place is used to initialize
fp16 parameters with fp32 values.
scope(Scope): The scope is used to find fp32 parameters.
test_program(Program): The program is used for testing.
......@@ -273,7 +315,7 @@ class OptimizerWithMixedPrecision(object):
loss = paddle.mean(hidden)
# 2) Create the optimizer and set `multi_precision` to True.
# Setting `multi_precision` to True can avoid the poor accuracy
# or the slow convergence in a way.
# or the slow convergence in a way.
optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
# 3) These ops in `custom_black_list` will keep in the float32 computation type.
amp_list = paddle.static.amp.CustomOpLists(
......@@ -293,30 +335,40 @@ class OptimizerWithMixedPrecision(object):
# 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
# If you want to perform the testing process, you should pass `test_program` into `amp_init`.
optimizer.amp_init(place, scope=paddle.static.global_scope())
if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
run_example_code()
run_example_code()
"""
assert self._train_program is not None, \
"Please call the minimize method first."
assert (
self._train_program is not None
), "Please call the minimize method first."
if self._use_pure_fp16:
cast_parameters_to_fp16(place, self._train_program, scope,
self._to_fp16_var_names)
cast_parameters_to_fp16(
place,
self._train_program,
scope,
self._to_fp16_var_names,
self._amp_dtype,
)
if test_program is not None:
if self._use_pure_fp16:
cast_model_to_fp16(test_program, self._amp_lists,
self._use_fp16_guard)
cast_model_to_fp16(
test_program,
self._amp_lists,
self._use_fp16_guard,
self._amp_dtype,
)
elif use_fp16_test:
rewrite_program(test_program, self._amp_lists)
rewrite_program(test_program, self._amp_lists, self._amp_dtype)
def apply_gradients(self, params_grads):
"""
Check scaled gradients to determine whether to update loss scaling and update
Check scaled gradients to determine whether to update loss scaling and update
parameters by their scaled gradients.
Args:
params_grads (list): A list of params and scaled grads.
Returns:
A list of optimize operators.
"""
......@@ -327,7 +379,10 @@ class OptimizerWithMixedPrecision(object):
# When not using dynamic loss scaling and the init loss scaling value is equal to 1.0,
# the model can be optimized.
if not self._use_dynamic_loss_scaling and self._init_loss_scaling == 1.0:
if (
not self._use_dynamic_loss_scaling
and self._init_loss_scaling == 1.0
):
return self._optimizer.apply_gradients(params_grads)
if self._supports_check_nan_inf():
......@@ -338,7 +393,10 @@ class OptimizerWithMixedPrecision(object):
return optimize_ops
found_inf = self._check_finite_and_unscale(params_grads)
if self._use_dynamic_loss_scaling:
if (
self._use_dynamic_loss_scaling
and self._amp_dtype == core.VarDesc.VarType.FP16
):
self._add_dynamic_loss_scaling(params_grads, found_inf)
# Pass found_inf to adam, to skip update for not only param, but also momentum and beta_pow
......@@ -346,13 +404,16 @@ class OptimizerWithMixedPrecision(object):
real_optimizer = self._optimizer
while hasattr(real_optimizer, "inner_opt"):
real_optimizer = real_optimizer.inner_opt
if isinstance(real_optimizer,
(paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW)):
if isinstance(
real_optimizer,
(paddle.fluid.optimizer.Adam, paddle.optimizer.AdamW),
):
# NOTE(zhiqiu): Since found_inf needs to be on cpu in adam op, we
# copy it in advance to avoid multiple time copies.
with self._train_program._optimized_guard([]):
found_inf = paddle.tensor.creation._memcpy(
found_inf, paddle.CPUPlace())
found_inf, paddle.CPUPlace()
)
real_optimizer._set_auxiliary_var('found_inf', found_inf)
elif hasattr(real_optimizer, "_set_auxiliary_var"):
real_optimizer._set_auxiliary_var('found_inf', found_inf)
......@@ -362,9 +423,10 @@ class OptimizerWithMixedPrecision(object):
def _split_grads(self, params_grads):
grads = [g for _, g in params_grads]
fp32_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP32]
fp16_grads = [g for g in grads if g.dtype == core.VarDesc.VarType.FP16]
assert len(fp32_grads) + len(fp16_grads) == len(grads), \
"Data types of all grads must be either fp16 or fp32."
fp16_grads = [g for g in grads if g.dtype == self._amp_dtype]
assert len(fp32_grads) + len(fp16_grads) == len(
grads
), "Data types of all grads must be either fp16/bf16 or fp32."
return grads, fp32_grads, fp16_grads
def _check_finite_and_unscale(self, params_grads):
......@@ -380,7 +442,8 @@ class OptimizerWithMixedPrecision(object):
grads,
self._loss_scaling,
name="find_infinite_scale",
float_status=self._float_status)
float_status=self._float_status,
)
found_infs.append(found_inf)
else:
for p, g in params_grads:
......@@ -391,7 +454,8 @@ class OptimizerWithMixedPrecision(object):
],
self._loss_scaling,
name="find_infinite_scale",
float_status=self._float_status)
float_status=self._float_status,
)
found_infs.append(found_inf)
elif self._use_pure_fp16:
if fp32_grads:
......@@ -400,7 +464,8 @@ class OptimizerWithMixedPrecision(object):
fp32_grads,
self._loss_scaling,
name="find_infinite_scale_fp32",
float_status=self._float_status)
float_status=self._float_status,
)
found_infs.append(fp32_found_inf)
if fp16_grads:
with self._train_program._optimized_guard(fp16_grads):
......@@ -408,7 +473,8 @@ class OptimizerWithMixedPrecision(object):
fp16_grads,
self._loss_scaling,
name="find_infinite_scale_fp16",
float_status=self._float_status)
float_status=self._float_status,
)
found_infs.append(fp16_found_inf)
else:
with self._train_program._optimized_guard(grads):
......@@ -416,7 +482,8 @@ class OptimizerWithMixedPrecision(object):
grads,
self._loss_scaling,
name="find_infinite_scale",
float_status=self._float_status)
float_status=self._float_status,
)
if self._is_distributed or self._use_pure_fp16:
with self._train_program._optimized_guard([]):
......@@ -439,7 +506,8 @@ class OptimizerWithMixedPrecision(object):
self._incr_ratio,
self._decr_ratio,
stop_update=self._optimizer._get_stop_update_var(),
name="update_loss_scaling")
name="update_loss_scaling",
)
return
grads, fp32_grads, fp16_grads = self._split_grads(params_grads)
......@@ -447,42 +515,48 @@ class OptimizerWithMixedPrecision(object):
stop_update = False
with self._train_program._optimized_guard([]):
if fp32_grads:
update_loss_scaling(fp32_grads,
found_inf,
self._loss_scaling,
self._num_good_steps,
self._num_bad_steps,
self._incr_every_n_steps,
self._decr_every_n_nan_or_inf,
self._incr_ratio,
self._decr_ratio,
stop_update=stop_update,
name="update_loss_scaling_fp32")
update_loss_scaling(
fp32_grads,
found_inf,
self._loss_scaling,
self._num_good_steps,
self._num_bad_steps,
self._incr_every_n_steps,
self._decr_every_n_nan_or_inf,
self._incr_ratio,
self._decr_ratio,
stop_update=stop_update,
name="update_loss_scaling_fp32",
)
stop_update = True
if fp16_grads:
update_loss_scaling(fp16_grads,
found_inf,
self._loss_scaling,
self._num_good_steps,
self._num_bad_steps,
self._incr_every_n_steps,
self._decr_every_n_nan_or_inf,
self._incr_ratio,
self._decr_ratio,
stop_update=stop_update,
name="update_loss_scaling_fp16")
update_loss_scaling(
fp16_grads,
found_inf,
self._loss_scaling,
self._num_good_steps,
self._num_bad_steps,
self._incr_every_n_steps,
self._decr_every_n_nan_or_inf,
self._incr_ratio,
self._decr_ratio,
stop_update=stop_update,
name="update_loss_scaling_fp16",
)
else:
with self._train_program._optimized_guard([]):
update_loss_scaling(grads,
found_inf,
self._loss_scaling,
self._num_good_steps,
self._num_bad_steps,
self._incr_every_n_steps,
self._decr_every_n_nan_or_inf,
self._incr_ratio,
self._decr_ratio,
name="update_loss_scaling")
update_loss_scaling(
grads,
found_inf,
self._loss_scaling,
self._num_good_steps,
self._num_bad_steps,
self._incr_every_n_steps,
self._decr_every_n_nan_or_inf,
self._incr_ratio,
self._decr_ratio,
name="update_loss_scaling",
)
def apply_optimize(self, loss, startup_program, params_grads):
program = loss.block.program
......@@ -490,11 +564,9 @@ class OptimizerWithMixedPrecision(object):
optimize_ops = self.apply_gradients(params_grads)
return optimize_ops
def minimize(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
def minimize(
self, loss, startup_program=None, parameter_list=None, no_grad_set=None
):
"""
Perform optimization by minimizing the given loss.
......@@ -511,48 +583,55 @@ class OptimizerWithMixedPrecision(object):
"""
opt_dict = self._optimizer.__class__.__dict__
if 'minimize' in opt_dict and isinstance(opt_dict['minimize'],
types.FunctionType):
if 'minimize' in opt_dict and isinstance(
opt_dict['minimize'], types.FunctionType
):
warnings.warn(
"The decorated optimizer has its own `minimize` method, but it will not be executed."
)
scaled_params_grads = self.backward(loss,
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set)
scaled_params_grads = self.backward(
loss,
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set,
)
optimize_ops = self.apply_optimize(loss, startup_program,
scaled_params_grads)
optimize_ops = self.apply_optimize(
loss, startup_program, scaled_params_grads
)
return optimize_ops, scaled_params_grads
def decorate(optimizer,
amp_lists=None,
init_loss_scaling=2**15,
incr_every_n_steps=1000,
decr_every_n_nan_or_inf=2,
incr_ratio=2.0,
decr_ratio=0.8,
use_dynamic_loss_scaling=True,
use_pure_fp16=False,
use_fp16_guard=None):
"""
def decorate(
optimizer,
amp_lists=None,
init_loss_scaling=2**15,
incr_every_n_steps=1000,
decr_every_n_nan_or_inf=2,
incr_ratio=2.0,
decr_ratio=0.8,
use_dynamic_loss_scaling=True,
use_pure_fp16=False,
use_fp16_guard=None,
use_bf16=False,
):
"""
Decorate the given optimizer to adapt to the mixed-precision training.
Args:
optimizer(Optimizer): A common Optimizer.
amp_lists (CustomOpLists): An CustomOpLists object.
init_loss_scaling(float): The initial loss scaling factor.
incr_every_n_steps(int): Increases loss scaling every n consecutive
incr_every_n_steps(int): Increases loss scaling every n consecutive
steps with finite gradients.
decr_every_n_nan_or_inf(int): Decreases loss scaling every n
accumulated steps with nan or
decr_every_n_nan_or_inf(int): Decreases loss scaling every n
accumulated steps with nan or
inf gradients.
incr_ratio(float): The multiplier to use when increasing the loss
incr_ratio(float): The multiplier to use when increasing the loss
scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing
decr_ratio(float): The less-than-one-multiplier to use when decreasing
the loss scaling.
use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling.
use_pure_fp16(bool): Whether to use the pure fp16 training. Default False.
......@@ -560,11 +639,11 @@ def decorate(optimizer,
Default None, which means that its value equals to `use_pure_fp16`.
Returns:
An optimizer acting like a normal one but with mixed-precision training
An optimizer acting like a normal one but with mixed-precision training
enabled.
Examples 1:
.. code-block:: python
.. code-block:: python
# black&white list based strategy example
import paddle
......@@ -604,7 +683,7 @@ def decorate(optimizer,
loss = paddle.mean(hidden)
# 2) Create the optimizer and set `multi_precision` to True.
# Setting `multi_precision` to True can avoid the poor accuracy
# or the slow convergence in a way.
# or the slow convergence in a way.
optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
# 3) These ops in `custom_black_list` will keep in the float32 computation type.
amp_list = paddle.static.amp.CustomOpLists(
......@@ -624,19 +703,82 @@ def decorate(optimizer,
# 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
# If you want to perform the testing process, you should pass `test_program` into `amp_init`.
optimizer.amp_init(place, scope=paddle.static.global_scope())
if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
run_example_code()
"""
dtype = "bfloat16" if use_bf16 else "float16"
if amp_lists is None:
amp_lists = AutoMixedPrecisionLists()
amp_lists = AutoMixedPrecisionLists(dtype=dtype)
if use_fp16_guard is None:
use_fp16_guard = use_pure_fp16
mp_optimizer = OptimizerWithMixedPrecision(
optimizer, amp_lists, init_loss_scaling, use_dynamic_loss_scaling,
incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, decr_ratio,
use_pure_fp16, use_fp16_guard)
optimizer,
amp_lists,
init_loss_scaling,
use_dynamic_loss_scaling,
incr_every_n_steps,
decr_every_n_nan_or_inf,
incr_ratio,
decr_ratio,
use_pure_fp16,
use_fp16_guard,
use_bf16,
)
return mp_optimizer
def amp_decorate(
optimizer,
amp_lists=None,
level='O1',
dtype='float16',
init_loss_scaling=2**15,
incr_every_n_steps=1000,
decr_every_n_nan_or_inf=2,
incr_ratio=2.0,
decr_ratio=0.8,
use_dynamic_loss_scaling=True,
use_amp_guard=False,
):
"""
Decorate the given optimizer to adapt to the mixed-precision training.
"""
# check amp_dtype: float16 or bfloat16
dtype = dtype.lower()
if not (dtype in ['float16', 'bfloat16']):
raise ValueError(
"If enable AMP, dtype should be 'float16' or 'bfloat16'."
)
if amp_lists is None:
amp_lists = AutoMixedPrecisionLists(dtype=dtype)
# check amp_level: O0-O2
level = level.upper()
if not (level in ['O0', 'O1', 'O2']):
raise ValueError(
"level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode."
)
use_pure_fp16 = level == "O2"
use_fp16_guard = use_amp_guard
use_bf16 = dtype == "bfloat16"
mp_optimizer = OptimizerWithMixedPrecision(
optimizer,
amp_lists,
init_loss_scaling,
use_dynamic_loss_scaling,
incr_every_n_steps,
decr_every_n_nan_or_inf,
incr_ratio,
decr_ratio,
use_pure_fp16,
use_fp16_guard,
use_bf16,
)
return mp_optimizer
......@@ -13,16 +13,47 @@
# limitations under the License.
import copy
from ... import core
__all__ = ["CustomOpLists", "AutoMixedPrecisionLists"]
# lookup_table fp16 is slower than fp32, though fp16 is supported.
_extra_unsupported_fp16_list = {
'lookup_table', 'lookup_table_v2', 'scatter', 'scatter_grad'
_extra_unsupported_list = {
'lookup_table',
'lookup_table_v2',
'scatter',
'scatter_grad',
}
def _get_unsupported_list(dtype):
if dtype == "float16":
amp_dtype = core.VarDesc.VarType.FP16
elif dtype == "bfloat16":
amp_dtype = core.VarDesc.VarType.BF16
else:
raise ValueError(
"If enable AMP, dtype should be 'float16' or 'bfloat16'."
)
# The set of ops that don't support fp16 calculation
# lookup_table fp16 is slower than fp32, though fp16 is supported.
_sys_unsupported_list = []
# _sys_unsupported_bf16_list = []
if core.is_compiled_with_xpu():
_, _, _sys_unsupported_list = core.op_supported_infos('XPU', amp_dtype)
elif core.is_compiled_with_npu():
_, _, _sys_unsupported_list = core.op_supported_infos('NPU', amp_dtype)
elif core.is_compiled_with_mlu():
_, _, _sys_unsupported_list = core.op_supported_infos('MLU', amp_dtype)
else:
_, _, _sys_unsupported_list = core.op_supported_infos('GPU', amp_dtype)
unsupported_list = _extra_unsupported_list | _sys_unsupported_list
return unsupported_list
class AutoMixedPrecisionLists(object):
"""
AutoMixedPrecisionLists is a class for black/white list. It can update
......@@ -36,16 +67,20 @@ class AutoMixedPrecisionLists(object):
custom_black_varnames (set): Users' custom black varibles' names.
"""
def __init__(self,
custom_white_list=None,
custom_black_list=None,
custom_black_varnames=None):
def __init__(
self,
custom_white_list=None,
custom_black_list=None,
custom_black_varnames=None,
dtype="float16",
):
self._custom_white_list = custom_white_list
self._custom_black_list = custom_black_list
self.amp_dtype = dtype
self.white_list = copy.copy(white_list)
self.black_list = copy.copy(black_list)
self.gray_list = copy.copy(gray_list)
self.unsupported_list = copy.copy(unsupported_fp16_list)
self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype))
self.black_varnames = copy.copy(custom_black_varnames)
self._update_list()
......@@ -56,8 +91,9 @@ class AutoMixedPrecisionLists(object):
if self._custom_white_list and self._custom_black_list:
for op_name in self._custom_white_list:
if op_name in self._custom_black_list:
raise ValueError("Custom white list overlap "
"custom black list")
raise ValueError(
"Custom white list overlap " "custom black list"
)
if self._custom_white_list:
for op_name in self._custom_white_list:
if op_name in self.black_list:
......@@ -65,7 +101,7 @@ class AutoMixedPrecisionLists(object):
elif op_name in self.gray_list:
self.gray_list.remove(op_name)
self.white_list.add(op_name)
if op_name in _extra_unsupported_fp16_list:
if op_name in _extra_unsupported_list:
self.unsupported_list.remove(op_name)
if self._custom_black_list:
for op_name in self._custom_black_list:
......@@ -170,22 +206,4 @@ gray_list = {
'fused_multi_transformer',
}
# The set of ops that don't support fp16 calculation
# lookup_table fp16 is slower than fp32, though fp16 is supported.
_sys_unsupported_fp16_list = []
if core.is_compiled_with_xpu():
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'XPU', core.VarDesc.VarType.FP16)
elif core.is_compiled_with_npu():
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'NPU', core.VarDesc.VarType.FP16)
elif core.is_compiled_with_mlu():
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'MLU', core.VarDesc.VarType.FP16)
else:
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'GPU', core.VarDesc.VarType.FP16)
unsupported_fp16_list = _extra_unsupported_fp16_list | _sys_unsupported_fp16_list
CustomOpLists = AutoMixedPrecisionLists
......@@ -14,6 +14,7 @@
from __future__ import print_function
import paddle
from ... import core
from ... import framework
from ... import layers
......@@ -27,13 +28,14 @@ import numpy as np
__all__ = ["fp16_guard", "cast_model_to_fp16", "cast_parameters_to_fp16"]
_logger = get_logger(__name__,
logging.INFO,
fmt='%(asctime)s-%(levelname)s: %(message)s')
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
_valid_types = [
core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.SELECTED_ROWS,
core.VarDesc.VarType.LOD_TENSOR_ARRAY
core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.SELECTED_ROWS,
core.VarDesc.VarType.LOD_TENSOR_ARRAY,
]
_fp16_guard_pattern = "__use_fp16__"
......@@ -75,7 +77,9 @@ def _dtype_to_str(dtype):
Args:
dtype (VarType): Variable type.
"""
if dtype == core.VarDesc.VarType.FP16:
if dtype in [core.VarDesc.VarType.FP16, core.VarDesc.VarType.BF16]:
# TODO(Xreki): change the returned str to "bf16" for BF16 data type.
# Currently too many codes use "cast_fp16" as key.
return 'fp16'
else:
return 'fp32'
......@@ -108,7 +112,12 @@ def _keep_fp32_input(op, in_name):
return in_name not in {'X', 'FilterX', 'Z', 'FilterZ'}
if op_type in ['fused_attention', 'fused_feedforward']:
return in_name in {
'LnScale', 'LnBias', 'Ln2Scale', 'Ln2Bias', "Ln1Scale", "Ln1Bias"
'LnScale',
'LnBias',
'Ln2Scale',
'Ln2Bias',
"Ln1Scale",
"Ln1Bias",
}
if op_type == 'fused_multi_transformer':
return in_name in {'LnScale', 'LnBias', 'FFNLnScale', 'FFNLnBias'}
......@@ -125,8 +134,12 @@ def _keep_fp32_output(op, out_name):
return out_name not in {'Y', 'ConvX', 'ConvZ'}
if op_type in ['fused_attention', 'fused_feedforward']:
return out_name in {
'LnMean', 'LnVariance', 'Ln2Mean', 'Ln2Variance', 'Ln1Mean',
'Ln1Variance'
'LnMean',
'LnVariance',
'Ln2Mean',
'Ln2Variance',
'Ln1Mean',
'Ln1Variance',
}
return False
......@@ -149,7 +162,8 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
op, in_name):
op, in_name
):
continue
for in_var_name in op.input(in_name):
in_var = block._find_var_recursive(in_var_name)
......@@ -165,11 +179,15 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
# set cast_op device to `all`, can reduce send cast_var.
# TODO: need remove this after we unified the dynamic
# and static pipeline interface.
if src_dtype == core.VarDesc.VarType.FP32 and in_var.stop_gradient:
if (
src_dtype == core.VarDesc.VarType.FP32
and in_var.stop_gradient
):
prev_op = None
if in_var.op is op:
prev_op = find_true_prev_op(block.ops, op,
in_var_name)
prev_op = find_true_prev_op(
block.ops, op, in_var_name
)
elif in_var.op is not None:
prev_op = in_var.op
......@@ -177,33 +195,40 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
if prev_op is not None:
prev_op_device = prev_op.attr('op_device')
if prev_op_device is not None and 'all' in prev_op_device:
if (
prev_op_device is not None
and 'all' in prev_op_device
):
op_device = prev_op_device
out_var = block.create_var(
name=cast_name,
dtype=dest_dtype,
persistable=False,
stop_gradient=in_var.stop_gradient)
block._insert_op_without_sync(idx,
type="cast",
inputs={"X": in_var},
outputs={"Out": out_var},
attrs={
"in_dtype": in_var.dtype,
"out_dtype":
out_var.dtype,
"op_device": op_device,
"op_role":
op.attr("op_role"),
})
stop_gradient=in_var.stop_gradient,
)
block._insert_op_without_sync(
idx,
type="cast",
inputs={"X": in_var},
outputs={"Out": out_var},
attrs={
"in_dtype": in_var.dtype,
"out_dtype": out_var.dtype,
"op_device": op_device,
"op_role": op.attr("op_role"),
},
)
num_cast_ops += 1
_rename_arg(op, in_var.name, out_var.name)
else:
if op.has_attr('in_dtype'):
op._set_attr('in_dtype', dest_dtype)
if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype == core.VarDesc.VarType.FP16:
if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype in [
core.VarDesc.VarType.FP16,
core.VarDesc.VarType.BF16,
]:
for out_name in op.output_names:
if _keep_fp32_output(op, out_name):
continue
......@@ -212,41 +237,48 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
if out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(core.VarDesc.VarType.FP16)
out_var.desc.set_dtype(dest_dtype)
if op.has_attr('out_dtype'):
op._set_attr('out_dtype', core.VarDesc.VarType.FP16)
op._set_attr('out_dtype', dest_dtype)
return num_cast_ops
def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name,
op_var_rename_map):
def _insert_cast_post_op(
block, op, idx, src_dtype, dest_dtype, target_name, op_var_rename_map
):
num_cast_ops = 0
target_var = block.var(target_name)
if target_var.type not in _valid_types or target_var.dtype == dest_dtype:
return num_cast_ops
assert target_var.dtype == src_dtype, \
"The real dtype({}) is not equal to the src dtype({})".format(
_dtype_to_str(target_var.dtype), _dtype_to_str(src_dtype))
assert (
target_var.dtype == src_dtype
), "The real dtype({}) is not equal to the src dtype({})".format(
_dtype_to_str(target_var.dtype), _dtype_to_str(src_dtype)
)
cast_name = target_var.name + '.cast_' + _dtype_to_str(dest_dtype)
cast_var = block.vars.get(cast_name)
if cast_var is None or cast_var.dtype != dest_dtype:
cast_var = block.create_var(name=cast_name,
dtype=dest_dtype,
persistable=False,
stop_gradient=target_var.stop_gradient)
block._insert_op(idx,
type="cast",
inputs={"X": target_var},
outputs={"Out": cast_var},
attrs={
"in_dtype": target_var.dtype,
"out_dtype": cast_var.dtype,
"op_device": op.attr("op_device"),
"op_role": op.attr("op_role"),
})
cast_var = block.create_var(
name=cast_name,
dtype=dest_dtype,
persistable=False,
stop_gradient=target_var.stop_gradient,
)
block._insert_op(
idx,
type="cast",
inputs={"X": target_var},
outputs={"Out": cast_var},
attrs={
"in_dtype": target_var.dtype,
"out_dtype": cast_var.dtype,
"op_device": op.attr("op_device"),
"op_role": op.attr("op_role"),
},
)
num_cast_ops += 1
op_var_rename_map[block.idx][target_var.name] = cast_var.name
......@@ -272,8 +304,10 @@ def find_true_prev_op(ops, cur_op, var_name):
prev_op.append(op)
if prev_op:
if not len(prev_op) == 1:
raise ValueError("There must be only one previous op "
"that outputs {0} variable".format(var_name))
raise ValueError(
"There must be only one previous op "
"that outputs {0} variable".format(var_name)
)
else:
return prev_op[0]
return None
......@@ -315,8 +349,7 @@ def find_true_post_op(ops, cur_op, var_name, search_all=False):
def find_op_index(block_desc, cur_op_desc):
"""
"""
""" """
for idx in range(block_desc.op_size()):
if cur_op_desc == block_desc.op(idx):
return idx
......@@ -350,8 +383,9 @@ def _need_keep_fp32(op, unsupported_op_list, use_fp16_guard):
return True
if use_fp16_guard:
if op.has_attr("op_namescope") and \
(_fp16_guard_pattern in op.attr("op_namescope")):
if op.has_attr("op_namescope") and (
_fp16_guard_pattern in op.attr("op_namescope")
):
# op in fp16 guard
return False
else:
......@@ -388,7 +422,12 @@ def fp16_guard():
yield
def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
def cast_model_to_fp16(
program,
amp_lists=None,
use_fp16_guard=True,
dest_type=core.VarDesc.VarType.FP16,
):
"""
Traverse all ops in the whole model and set their inputs and outputs
to the fp16 data type. This function will do some special process for
......@@ -399,6 +438,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object.
use_fp16_guard(bool): Determine whether to use `fp16_guard` when
constructing the program. Default True.
dest_type(core.VarDesc.VarType): the cast type. such as core.VarDesc.VarType.FP16 and core.VarDesc.VarType.BF16.
"""
if amp_lists is None:
......@@ -421,7 +461,8 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
for in_name in op.input_names:
# for ipu, all inputs must be converted to fp16
if not core.is_compiled_with_ipu() and _keep_fp32_input(
op, in_name):
op, in_name
):
continue
for in_var_name in op.input(in_name):
in_var = None
......@@ -429,29 +470,36 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
in_var = block.var(in_var_name)
except ValueError as e:
_logger.debug(
"-- {}, try to get it in the global block --".
format(e))
"-- {}, try to get it in the global block --".format(
e
)
)
in_var = global_block.var(in_var_name)
if in_var is not None:
_logger.debug(
"-- var {} is got in the global block --".
format(in_var_name))
"-- var {} is got in the global block --".format(
in_var_name
)
)
if in_var is None or in_var.type not in _valid_types:
continue
if in_var.dtype == core.VarDesc.VarType.FP32:
in_var.desc.set_dtype(core.VarDesc.VarType.FP16)
in_var.desc.set_dtype(dest_type)
to_fp16_var_names.add(in_var_name)
_logger.debug(
"-- op type: {}, in var name: {}, in var dtype: {} --".
format(op.type, in_var_name, in_var.dtype))
"-- op type: {}, in var name: {}, in var dtype: {} --".format(
op.type, in_var_name, in_var.dtype
)
)
for out_name in op.output_names:
# for ipu, all outputs must be converted to fp16
if not core.is_compiled_with_ipu() and _keep_fp32_output(
op, out_name):
op, out_name
):
continue
for out_var_name in op.output(out_name):
out_var = None
......@@ -459,32 +507,35 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
out_var = block.var(out_var_name)
except ValueError as e:
_logger.debug(
"-- {}, try to get it in the global block --".
format(e))
"-- {}, try to get it in the global block --".format(
e
)
)
out_var = global_block.var(out_var_name)
if out_var is not None:
_logger.debug(
"-- var {} is got in the global block --".
format(out_var_name))
"-- var {} is got in the global block --".format(
out_var_name
)
)
if out_var is None or out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(core.VarDesc.VarType.FP16)
out_var.desc.set_dtype(dest_type)
_logger.debug(
"-- op type: {}, out var name: {}, out var dtype: {} --"
.format(op.type, out_var_name, out_var.dtype))
if op.has_attr('in_dtype') and op.attr(
'in_dtype') == core.VarDesc.VarType.FP32:
op._set_attr('in_dtype', core.VarDesc.VarType.FP16)
if op.has_attr('out_dtype') and op.attr(
'out_dtype') == core.VarDesc.VarType.FP32:
op._set_attr('out_dtype', core.VarDesc.VarType.FP16)
if op.has_attr('dtype') and op.attr(
'dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.FP16)
"-- op type: {}, out var name: {}, out var dtype: {} --".format(
op.type, out_var_name, out_var.dtype
)
)
for attr_name in ['in_dtype', 'out_dtype', 'dtype']:
if (
op.has_attr(attr_name)
and op.attr(attr_name) == core.VarDesc.VarType.FP32
):
op._set_attr(attr_name, dest_type)
# process ops in keep_fp32_ops
op_var_rename_map = [
......@@ -497,25 +548,29 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
op = ops[idx]
num_cast_ops = 0
if op in keep_fp32_ops:
pre_cast_num = _insert_cast_op(block, op, idx,
core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32)
pre_cast_num = _insert_cast_op(
block, op, idx, dest_type, core.VarDesc.VarType.FP32
)
num_cast_ops += pre_cast_num
for out_var_name in op.output_arg_names:
out_var = block.vars.get(out_var_name)
if out_var is None or out_var.type not in _valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP16:
if out_var.dtype == dest_type:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
post_ops = find_true_post_op(ops, op, out_var_name)
for post_op in post_ops:
if post_op in keep_fp32_ops:
continue
post_cast_num = _insert_cast_post_op(
block, op, idx + pre_cast_num + 1,
block,
op,
idx + pre_cast_num + 1,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16, out_var_name,
op_var_rename_map)
dest_type,
out_var_name,
op_var_rename_map,
)
num_cast_ops += post_cast_num
idx += num_cast_ops + 1
......@@ -523,7 +578,22 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
return to_fp16_var_names
def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None):
def _convert_float_to_bfloat16(place, fp32_array):
paddle.disable_static()
framework._set_expected_place(place)
fp32_tensor = paddle.to_tensor(fp32_array)
bf16_array = paddle.cast(fp32_tensor, paddle.bfloat16).numpy()
paddle.enable_static()
return bf16_array
def cast_parameters_to_fp16(
place,
program,
scope=None,
to_fp16_var_names=None,
dest_type=core.VarDesc.VarType.FP16,
):
"""
Traverse all parameters in the whole model and set them to the FP16 data type.
Whereas, this function will keep parameters of batchnorms in FP32.
......@@ -535,6 +605,7 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None):
to_fp16_var_names(set|list, optional): The data types of vars in `to_fp16_var_names`
will be set to FP16. Usually, it is the returned
value of `cast_model_to_fp16` API.
dest_type(core.VarDesc.VarType): the cast type. such as core.VarDesc.VarType.FP16 and core.VarDesc.VarType.BF16.
"""
all_parameters = []
for block in program.blocks:
......@@ -544,13 +615,22 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None):
var_scope = scope if scope else global_scope()
for param in all_parameters:
if param.name in fp16_var_names:
_logger.debug("---- cast {} to fp16 dtype ----".format(param.name))
param_t = var_scope.find_var(param.name).get_tensor()
data = np.array(param_t)
param_t.set(np.float16(data), place)
_logger.debug(
"---- cast {} to fp16/bf16 dtype ----".format(param.name)
)
if var_scope.find_var(param.name):
param_t = var_scope.find_var(param.name).get_tensor()
data = np.array(param_t)
if dest_type == core.VarDesc.VarType.BF16:
bf16_data = _convert_float_to_bfloat16(place, data)
param_t.set(bf16_data, place)
else:
param_t.set(np.float16(data), place)
else:
_logger.warning(f"Cannot find {param.name}")
def rewrite_program(main_prog, amp_lists):
def rewrite_program(main_prog, amp_lists, dest_type=core.VarDesc.VarType.FP16):
"""
Traverse all ops in current block and insert cast op according to
which set current op belongs to.
......@@ -569,6 +649,7 @@ def rewrite_program(main_prog, amp_lists):
Args:
main_prog (Program): The main program for training.
dest_type(core.VarDesc.VarType): the cast type. such as core.VarDesc.VarType.FP16 and core.VarDesc.VarType.BF16.
"""
block = main_prog.global_block()
block._sync_with_cpp()
......@@ -585,7 +666,8 @@ def rewrite_program(main_prog, amp_lists):
continue
if amp_lists.black_varnames is not None and _is_in_black_varnames(
op, amp_lists):
op, amp_lists
):
black_op_set.add(op)
continue
......@@ -611,11 +693,15 @@ def rewrite_program(main_prog, amp_lists):
else:
prev_op = in_var.op
# if it's one of inputs
if prev_op in black_op_set or \
prev_op.type in amp_lists.black_list:
if (
prev_op in black_op_set
or prev_op.type in amp_lists.black_list
):
is_black_op = True
elif prev_op in white_op_set or \
prev_op.type in amp_lists.white_list:
elif (
prev_op in white_op_set
or prev_op.type in amp_lists.white_list
):
is_white_op = True
if is_black_op:
black_op_set.add(op)
......@@ -633,13 +719,13 @@ def rewrite_program(main_prog, amp_lists):
op = ops[idx]
num_cast_ops = 0
if op in black_op_set:
num_cast_ops = _insert_cast_op(block, op, idx,
core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32)
num_cast_ops = _insert_cast_op(
block, op, idx, dest_type, core.VarDesc.VarType.FP32
)
elif op in white_op_set:
num_cast_ops = _insert_cast_op(block, op, idx,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16)
num_cast_ops = _insert_cast_op(
block, op, idx, core.VarDesc.VarType.FP32, dest_type
)
else:
pass
......@@ -670,13 +756,16 @@ def update_role_var_grad(main_prog, params_grads):
if role & int(BACKWARD) and op.has_attr('op_role_var'):
op._remove_attr("op_role_var")
else:
raise ValueError("The cast op {0} must be in BACKWARD role "
"and have op_role_var attr.".format(op))
raise ValueError(
"The cast op {0} must be in BACKWARD role "
"and have op_role_var attr.".format(op)
)
fp16_grad_name = op.input(op.input_names[0])[0]
op_for_fp16_grad = find_true_prev_op(block.ops, op, fp16_grad_name)
op_role_var_attr_name = \
op_role_var_attr_name = (
core.op_proto_and_checker_maker.kOpRoleVarAttrName()
)
attr_val = [p.name, fp16_grad_name]
if op_for_fp16_grad.has_attr(op_role_var_attr_name):
attr_val.extend(op_for_fp16_grad.attr(op_role_var_attr_name))
......@@ -690,18 +779,22 @@ def update_role_var_grad(main_prog, params_grads):
continue
post_ops = find_true_post_op(block.ops, op, g.name)
if post_ops:
raise ValueError("The cast op {0}'s output should not be"
"used by a non-optimize op, however, it"
"is used by {1}".format(op, post_ops[0]))
raise ValueError(
"The cast op {0}'s output should not be"
"used by a non-optimize op, however, it"
"is used by {1}".format(op, post_ops[0])
)
# add new op in the python and cpp at the same time
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(op.desc)
new_op = framework.Operator(block=block,
desc=new_op_desc,
type=None,
inputs=None,
outputs=None,
attrs=None)
new_op = framework.Operator(
block=block,
desc=new_op_desc,
type=None,
inputs=None,
outputs=None,
attrs=None,
)
block.ops.append(new_op)
op_idx = find_op_index(block.desc, op.desc)
if op_idx == -1:
......
......@@ -18,19 +18,32 @@ import six
import paddle
from paddle.fluid import framework, backward, core, program_guard
from paddle.fluid.executor import _is_enable_standalone_executor, _is_dy2st_enable_standalone_executor
from paddle.fluid.executor import (
_is_enable_standalone_executor,
_is_dy2st_enable_standalone_executor,
)
from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import (
RETURN_NO_VALUE_MAGIC_NUM,
)
from paddle.fluid.layers.utils import flatten
from paddle.fluid.layers.utils import pack_sequence_as
from paddle.fluid.layers.utils import _hash_with_id
from paddle.fluid.compiler import BuildStrategy
from paddle.fluid.framework import _apply_pass
from paddle.fluid.contrib.mixed_precision.decorator import AutoMixedPrecisionLists
from paddle.fluid.contrib.mixed_precision.fp16_utils import rewrite_program, cast_model_to_fp16
from paddle.fluid.dygraph.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
from paddle.fluid.contrib.mixed_precision.decorator import (
AutoMixedPrecisionLists,
)
from paddle.fluid.contrib.mixed_precision.fp16_utils import (
rewrite_program,
cast_model_to_fp16,
)
from paddle.fluid.dygraph.amp.auto_cast import (
_in_amp_guard,
_in_pure_fp16_guard,
)
import paddle.compat as cpt
from paddle import _C_ops, _legacy_C_ops
......@@ -64,7 +77,8 @@ class NestSequence(object):
var_ids = []
for idx, var in enumerate(self.__input_list):
if isinstance(
var, (framework.Variable, core.VarBase, core.eager.Tensor)):
var, (framework.Variable, core.VarBase, core.eager.Tensor)
):
var_ids.append(idx)
return var_ids
......@@ -77,15 +91,17 @@ class NestSequence(object):
warning_types = set()
for var in self.__input_list:
if not isinstance(
var,
(framework.Variable, core.VarBase, core.eager.Tensor)):
var, (framework.Variable, core.VarBase, core.eager.Tensor)
):
warning_types.add(type(var))
if warning_types:
logging_utils.warn(
"Output of traced function contains non-tensor type values: {}. "
"Currently, We don't support to update them while training and will return "
"what we first saw. Please try to return them as tensor.".
format(list(warning_types)))
"what we first saw. Please try to return them as tensor.".format(
list(warning_types)
)
)
@property
def var_ids(self):
......@@ -139,12 +155,9 @@ class PartialProgramLayer:
Layer: A Layer object that run all ops internally in static mode.
"""
def __init__(self,
main_program,
inputs,
outputs,
parameters=None,
**kwargs):
def __init__(
self, main_program, inputs, outputs, parameters=None, **kwargs
):
super(PartialProgramLayer, self).__init__()
self._inputs = NestSequence(inputs)
self._outputs = NestSequence(outputs, need_check=True)
......@@ -160,14 +173,18 @@ class PartialProgramLayer:
# Set default mode to train
self.training = True
custom_white_list, custom_black_list = None, None
amp_dtype, custom_white_list, custom_black_list = None, None, None
tracer = framework._dygraph_tracer()
if tracer:
custom_white_list, custom_black_list = tracer._get_amp_op_list()
# For AMP training
self._amp_list = AutoMixedPrecisionLists(
custom_white_list=custom_white_list,
custom_black_list=custom_black_list)
amp_dtype = tracer._amp_dtype
if amp_dtype is not None and amp_dtype in ['float16', 'bfloat16']:
# For AMP training
self._amp_list = AutoMixedPrecisionLists(
custom_white_list=custom_white_list,
custom_black_list=custom_black_list,
dtype=amp_dtype,
)
# program_id -> list(scope)
self._scope_cache = {}
......@@ -203,7 +220,8 @@ class PartialProgramLayer:
return self._origin_main_program.clone(for_test=is_infer_mode)
else:
train_program = self._append_backward_desc(
self._origin_main_program)
self._origin_main_program
)
# Note: Only set grad type once after initializing train program. So we put it here.
self._set_grad_type(self._params, train_program)
return train_program
......@@ -223,16 +241,18 @@ class PartialProgramLayer:
@switch_to_static_graph
def _create_pure_fp16_program(self, is_infer_mode=False):
pure_fp16_program = self._origin_main_program.clone(
for_test=is_infer_mode)
for_test=is_infer_mode
)
with program_guard(pure_fp16_program):
cast_model_to_fp16(pure_fp16_program,
self._amp_list,
use_fp16_guard=False)
cast_model_to_fp16(
pure_fp16_program, self._amp_list, use_fp16_guard=False
)
if is_infer_mode:
return pure_fp16_program
else:
train_pure_fp16_program = self._append_backward_desc(
pure_fp16_program)
pure_fp16_program
)
self._set_grad_type(self._params, train_pure_fp16_program)
return train_pure_fp16_program
......@@ -240,23 +260,27 @@ class PartialProgramLayer:
def _create_forward_backward_train_program(self):
whole_program = self._create_program()
forward_end_op_index = self._infer_program.desc.block(0).op_size()
return self._get_forward_backward_program_form(whole_program,
forward_end_op_index)
return self._get_forward_backward_program_form(
whole_program, forward_end_op_index
)
@switch_to_static_graph
def _create_forward_backward_train_amp_program(self):
whole_program = self._create_amp_program()
forward_end_op_index = self._infer_amp_program.desc.block(0).op_size()
return self._get_forward_backward_program_form(whole_program,
forward_end_op_index)
return self._get_forward_backward_program_form(
whole_program, forward_end_op_index
)
@switch_to_static_graph
def _create_forward_backward_train_pure_fp16_program(self):
whole_program = self._create_pure_fp16_program()
forward_end_op_index = self._infer_pure_fp16_program.desc.block(
0).op_size()
return self._get_forward_backward_program_form(whole_program,
forward_end_op_index)
0
).op_size()
return self._get_forward_backward_program_form(
whole_program, forward_end_op_index
)
@LazyInitialized
def _train_program(self):
......@@ -352,8 +376,9 @@ class PartialProgramLayer:
@LazyInitialized
def _train_program_id(self):
program_id = _hash_with_id(self._train_program, self)
core._set_cached_executor_build_strategy(program_id,
self._build_strategy)
core._set_cached_executor_build_strategy(
program_id, self._build_strategy
)
return program_id
@LazyInitialized
......@@ -363,8 +388,9 @@ class PartialProgramLayer:
@LazyInitialized
def _train_amp_program_id(self):
program_id = _hash_with_id(self._train_amp_program, self)
core._set_cached_executor_build_strategy(program_id,
self._build_strategy)
core._set_cached_executor_build_strategy(
program_id, self._build_strategy
)
return program_id
@LazyInitialized
......@@ -374,8 +400,9 @@ class PartialProgramLayer:
@LazyInitialized
def _train_pure_fp16_program_id(self):
program_id = _hash_with_id(self._train_pure_fp16_program, self)
core._set_cached_executor_build_strategy(program_id,
self._build_strategy)
core._set_cached_executor_build_strategy(
program_id, self._build_strategy
)
return program_id
@LazyInitialized
......@@ -411,8 +438,9 @@ class PartialProgramLayer:
return main_program
def prepare_gradient_aggregation(self, start_idx, main_program,
target_program):
def prepare_gradient_aggregation(
self, start_idx, main_program, target_program
):
"""
Why we need add gradient aggregation operation ?
In some cases, if non leaf nodes are used as output, gradient overwriting will occur, such as
......@@ -420,7 +448,7 @@ class PartialProgramLayer:
x = 2 * in # <---- x is a non-leaf node in program.
y = x + 3
return x, y
loss = forward(in)[0].sum()
loss.backward() # <----- x@grad will be overwrited by elementwise_add_grad Op
"""
......@@ -430,8 +458,8 @@ class PartialProgramLayer:
if exist a op whose inputs is var, then return True
"""
if not isinstance(var, framework.Variable) or var.type not in [
core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.SELECTED_ROWS
core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.SELECTED_ROWS,
]:
return False
if var.dtype not in [paddle.float32, paddle.float64]:
......@@ -448,20 +476,28 @@ class PartialProgramLayer:
new_grad_name = var.name + suffix + "@GRAD"
finded_ops = list(
filter(
lambda x: x[0] >= start_idx and any([
out_arg == var_grad_name
for out_arg in x[1].output_arg_names
]), enumerate(target_program.block(0).ops)))
lambda x: x[0] >= start_idx
and any(
[
out_arg == var_grad_name
for out_arg in x[1].output_arg_names
]
),
enumerate(target_program.block(0).ops),
)
)
# len(finded_ops) may equals zero when stop_gradient works.
# len(finded_ops) may > 1, because we may have fill_constant op.
if len(finded_ops) == 0:
return None
# step1: create a new var named var.name@GRAD
target_program.block(0).create_var(name=new_grad_name,
type=var.type,
dtype=var.dtype,
shape=var.shape)
target_program.block(0).create_var(
name=new_grad_name,
type=var.type,
dtype=var.dtype,
shape=var.shape,
)
# step2: rename the var.name@GRAD to var.name@GRAD@dy2static
for idx, op in finded_ops:
op._rename_input(var_grad_name, new_grad_name)
......@@ -472,11 +508,13 @@ class PartialProgramLayer:
finded_ops[-1][0] + 1,
type='sum',
inputs={'X': [var_grad_name, new_grad_name]},
outputs={"Out": var_grad_name})
outputs={"Out": var_grad_name},
)
return None
to_processed_vars = list(
filter(_need_aggregation, self._outputs.tolist()))
filter(_need_aggregation, self._outputs.tolist())
)
for _var in to_processed_vars:
_insert_aggregation_ops_for_var(target_program, _var)
......@@ -492,8 +530,9 @@ class PartialProgramLayer:
if targets and self._params:
backward.gradients(targets=targets, inputs=[])
start_idx = len(
main_program.block(0).ops) + 2 * len(self._outputs.tolist())
start_idx = len(main_program.block(0).ops) + 2 * len(
self._outputs.tolist()
)
self.prepare_gradient_aggregation(start_idx, main_program, program)
......@@ -512,7 +551,10 @@ class PartialProgramLayer:
found_param = False
for block in program.blocks:
for op in block.ops:
if param.name in op.input_arg_names or param.name in op.output_arg_names:
if (
param.name in op.input_arg_names
or param.name in op.output_arg_names
):
required_params.append(param)
found_param = True
break
......@@ -529,15 +571,21 @@ class PartialProgramLayer:
var_desc = block.vars[name].desc
var_base = None
if not framework._in_eager_mode_:
var_base = core.VarBase(var_desc.dtype(),
var_desc.shape(),
var_desc.name(),
var_desc.type(), False)
var_base = core.VarBase(
var_desc.dtype(),
var_desc.shape(),
var_desc.name(),
var_desc.type(),
False,
)
else:
var_base = core.eager.Tensor(var_desc.dtype(),
var_desc.shape(),
var_desc.name(),
var_desc.type(), False)
var_base = core.eager.Tensor(
var_desc.dtype(),
var_desc.shape(),
var_desc.name(),
var_desc.type(),
False,
)
double_grads.append(var_base)
return self._valid_vars(double_grads)
......@@ -557,36 +605,62 @@ class PartialProgramLayer:
attrs = [
'global_block',
self.program.desc.block(0), 'start_op_index', 0, 'end_op_index',
self._get_end_op_index(), 'is_test', not self.training,
'program_id', self.program_id
self.program.desc.block(0),
'start_op_index',
0,
'end_op_index',
self._get_end_op_index(),
'is_test',
not self.training,
'program_id',
self.program_id,
]
if self._cuda_graph_capture_mode:
attrs.extend(
('cuda_graph_capture_mode', self._cuda_graph_capture_mode,
'cuda_graph_pool_id', self._cuda_graph_pool_id))
use_interpretorcore = _is_enable_standalone_executor(
) and _is_dy2st_enable_standalone_executor()
(
'cuda_graph_capture_mode',
self._cuda_graph_capture_mode,
'cuda_graph_pool_id',
self._cuda_graph_pool_id,
)
)
use_interpretorcore = (
_is_enable_standalone_executor()
and _is_dy2st_enable_standalone_executor()
)
attrs.extend(('use_interpretorcore', use_interpretorcore))
if use_interpretorcore:
attrs.extend(
('forward_global_block', self.forward_program.desc.block(0),
'backward_global_block', self.backward_program.desc.block(0)))
(
'forward_global_block',
self.forward_program.desc.block(0),
'backward_global_block',
self.backward_program.desc.block(0),
)
)
_legacy_C_ops.run_program(
self._valid_vars(in_vars), self._valid_vars(self._params),
self._valid_vars(in_vars),
self._valid_vars(self._params),
self._valid_vars(out_vars),
self._create_scope_vec(program_id=self.program_id,
use_scope_cache=True),
self._double_grads, self._cuda_graph_vec, *attrs)
self._create_scope_vec(
program_id=self.program_id, use_scope_cache=True
),
self._double_grads,
self._cuda_graph_vec,
*attrs
)
else:
_legacy_C_ops.run_program(self._valid_vars(in_vars),
self._valid_vars(self._params),
self._valid_vars(out_vars),
self._create_scope_vec(),
self._double_grads, self._cuda_graph_vec,
*attrs)
_legacy_C_ops.run_program(
self._valid_vars(in_vars),
self._valid_vars(self._params),
self._valid_vars(out_vars),
self._create_scope_vec(),
self._double_grads,
self._cuda_graph_vec,
*attrs
)
restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out)
......@@ -594,9 +668,11 @@ class PartialProgramLayer:
if _in_pure_fp16_guard():
for i, var in enumerate(in_vars):
name = var.name
if (self.program.global_block().has_var(name)
and self.program.global_block().var(name).dtype
== paddle.float16):
if (
self.program.global_block().has_var(name)
and self.program.global_block().var(name).dtype
== paddle.float16
):
in_vars[i] = var.astype('float16')
in_vars[i].name = name
......@@ -627,25 +703,32 @@ class PartialProgramLayer:
return self._infer_program
@switch_to_static_graph
def _get_forward_backward_program_form(self, whole_program,
forward_end_op_index):
def _get_forward_backward_program_form(
self, whole_program, forward_end_op_index
):
forward_builded_program = add_build_strategy_for(
whole_program, 0, forward_end_op_index, self._build_strategy)
whole_program, 0, forward_end_op_index, self._build_strategy
)
backward_start_op_index = forward_end_op_index + 2 * len(
self._outputs.var_ids)
self._outputs.var_ids
)
backward_end_op_index = whole_program.desc.block(0).op_size()
backward_builded_program = add_build_strategy_for(
whole_program, backward_start_op_index, backward_end_op_index,
self._build_strategy)
self._apply_inplace_pass(forward_builded_program,
backward_builded_program)
whole_program,
backward_start_op_index,
backward_end_op_index,
self._build_strategy,
)
self._apply_inplace_pass(
forward_builded_program, backward_builded_program
)
return [forward_builded_program, backward_builded_program]
def _apply_inplace_pass(self, forward_program, backward_program):
attr_types = {
"use_cuda": "bool",
"mem_opt_skip_vars": "list[str]",
"for_partial_block": "bool"
"for_partial_block": "bool",
}
empty_startup_program = paddle.static.Program()
use_cuda = True if core.is_compiled_with_cuda() else False
......@@ -667,22 +750,33 @@ class PartialProgramLayer:
forward_mem_opt_skip_vars.append(var.desc.name())
backward_mem_opt_skip_vars.append(var.desc.name())
for var_name in core.parse_safe_eager_deletion_skip_vars(
backward_program.desc):
backward_program.desc
):
forward_mem_opt_skip_vars.append(var_name)
attrs = {
"use_cuda": use_cuda,
"mem_opt_skip_vars": forward_mem_opt_skip_vars,
"for_partial_block": True
"for_partial_block": True,
}
_apply_pass(forward_program, empty_startup_program,
"buffer_shared_inplace_pass", attrs, attr_types)
_apply_pass(
forward_program,
empty_startup_program,
"buffer_shared_inplace_pass",
attrs,
attr_types,
)
attrs = {
"use_cuda": use_cuda,
"mem_opt_skip_vars": backward_mem_opt_skip_vars,
"for_partial_block": True
"for_partial_block": True,
}
_apply_pass(backward_program, empty_startup_program,
"buffer_shared_inplace_pass", attrs, attr_types)
_apply_pass(
backward_program,
empty_startup_program,
"buffer_shared_inplace_pass",
attrs,
attr_types,
)
def _prepare(self, inputs):
"""
......@@ -698,23 +792,28 @@ class PartialProgramLayer:
if isinstance(value, np.ndarray):
var = None
if not framework._in_eager_mode_:
var = core.VarBase(value=value,
name=self._inputs[i].desc.name(),
persistable=False,
place=expected_place,
zero_copy=True)
var = core.VarBase(
value=value,
name=self._inputs[i].desc.name(),
persistable=False,
place=expected_place,
zero_copy=True,
)
else:
var = core.eager.Tensor(value=value,
name=self._inputs[i].desc.name(),
persistable=False,
place=expected_place,
zero_copy=True)
var = core.eager.Tensor(
value=value,
name=self._inputs[i].desc.name(),
persistable=False,
place=expected_place,
zero_copy=True,
)
elif isinstance(value, (core.VarBase, core.eager.Tensor)):
# NOTE(Aurelius84): If var is on CPUPlace, it will be transformed multi times
# into CUDAPlace when it's as input of multi Ops. so we move it in advance
# to avoid this problem.
if value.stop_gradient and not value.place._equals(
expected_place):
expected_place
):
var = value._copy_to(expected_place, False)
var.stop_gradient = True
else:
......@@ -737,12 +836,21 @@ class PartialProgramLayer:
return out_varbase_map[var_desc.name()]
if not framework._in_eager_mode_:
var_base = core.VarBase(var_desc.dtype(), var_desc.shape(),
var_desc.name(), var_desc.type(), False)
var_base = core.VarBase(
var_desc.dtype(),
var_desc.shape(),
var_desc.name(),
var_desc.type(),
False,
)
else:
var_base = core.eager.Tensor(var_desc.dtype(), var_desc.shape(),
var_desc.name(), var_desc.type(),
False)
var_base = core.eager.Tensor(
var_desc.dtype(),
var_desc.shape(),
var_desc.name(),
var_desc.type(),
False,
)
var_base.stop_gradient = var.stop_gradient
out_varbase_map[var_desc.name()] = var_base
return var_base
......@@ -755,20 +863,30 @@ class PartialProgramLayer:
def _create_scope_vec(self, program_id=None, use_scope_cache=False):
# Hold forward variables
tmp_scope_vec = None
inner_scope = self._get_scope(program_id=program_id,
use_scope_cache=use_scope_cache)
inner_scope = self._get_scope(
program_id=program_id, use_scope_cache=use_scope_cache
)
if not framework._in_eager_mode_:
tmp_scope_vec = core.VarBase(core.VarDesc.VarType.FP32, [],
"program_out_scope",
core.VarDesc.VarType.STEP_SCOPES, True)
tmp_scope_vec = core.VarBase(
core.VarDesc.VarType.FP32,
[],
"program_out_scope",
core.VarDesc.VarType.STEP_SCOPES,
True,
)
tmp_scope_vec.value().set_scope(inner_scope)
else:
tmp_scope_vec = [inner_scope]
return tmp_scope_vec
def _create_cuda_graph_vec(self):
var = core.VarBase(core.VarDesc.VarType.FP32, [], "cuda_graph",
core.VarDesc.VarType.RAW, True)
var = core.VarBase(
core.VarDesc.VarType.FP32,
[],
"cuda_graph",
core.VarDesc.VarType.RAW,
True,
)
var.stop_gradient = True
return var
......@@ -791,8 +909,9 @@ class PartialProgramLayer:
return main_program.clone(for_test=True)
def _is_no_value(self, var):
if isinstance(var,
(core.VarBase, core.eager.Tensor)) and var.shape == [1]:
if isinstance(var, (core.VarBase, core.eager.Tensor)) and var.shape == [
1
]:
# NOTE: .numpy() will insert MemcpySync operation, it hits performance.
if var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM:
return True
......@@ -808,13 +927,14 @@ class PartialProgramLayer:
return out_vars
elif isinstance(out_vars, (tuple, list)):
if isinstance(out_vars, tuple):
res = tuple(var for var in out_vars
if not self._is_no_value(var))
res = tuple(
var for var in out_vars if not self._is_no_value(var)
)
else:
# isinstance(out_vars, list)
res = [var for var in out_vars if not self._is_no_value(var)]
has_removed = (len(out_vars) > len(res))
has_removed = len(out_vars) > len(res)
# len(out_vars) > len(res) means we have removed var. This is
# preventing out_vars is empty or just one element at the beginning
if len(res) == 0 and has_removed:
......@@ -835,7 +955,8 @@ class PartialProgramLayer:
for param in params:
grad_name = param.name + core.grad_var_suffix()
grad_var = train_program.desc.block(0).find_var(
cpt.to_bytes(grad_name))
cpt.to_bytes(grad_name)
)
# NOTE: cannot find var desc maybe no problem, such as in batch_norm
if grad_var is None:
continue
......@@ -864,15 +985,18 @@ class PartialProgramLayer:
if not isinstance(self._params, (list, tuple)):
raise TypeError(
"Type of self._params in PartialProgramLayer should be list or tuple, but received %s."
% type(self._params))
% type(self._params)
)
param_and_buffer_names_set = set()
for i, var in enumerate(self._params):
# self._params constains parameters and buffers with persistable=True.
if not isinstance(var, (core.VarBase, core.eager.Tensor)):
raise TypeError(
'Type of self._params[{}] in PartialProgramLayer should be Parameter or Variable, but received {}.'
.format(i, type(var)))
'Type of self._params[{}] in PartialProgramLayer should be Parameter or Variable, but received {}.'.format(
i, type(var)
)
)
param_and_buffer_names_set.add(var.name)
for block in main_program.blocks:
......@@ -886,7 +1010,8 @@ class PartialProgramLayer:
"\n\tRevise suggestion: "
"\n\t\t1. Please ensure all your sublayers are inheritted from nn.Layer."
"\n\t\t2. Please use nn.ParameterList and nn.LayerList as container instead of using a native Python container such as List"
% name)
% name
)
def _valid_vars(self, vars):
"""
......@@ -903,13 +1028,23 @@ def _create_fake_var():
"""
if not framework._in_eager_mode_:
return [
core.VarBase(core.VarDesc.VarType.FP32, [], "Fake_var",
core.VarDesc.VarType.RAW, False)
core.VarBase(
core.VarDesc.VarType.FP32,
[],
"Fake_var",
core.VarDesc.VarType.RAW,
False,
)
]
else:
return [
core.eager.Tensor(core.VarDesc.VarType.FP32, [], "Fake_var",
core.VarDesc.VarType.RAW, False)
core.eager.Tensor(
core.VarDesc.VarType.FP32,
[],
"Fake_var",
core.VarDesc.VarType.RAW,
False,
)
]
......@@ -918,23 +1053,27 @@ def partial_program_from(concrete_program):
if inputs and isinstance(inputs[0], layers.Layer):
inputs = inputs[1:]
return PartialProgramLayer(concrete_program.main_program, inputs,
concrete_program.outputs,
concrete_program.parameters,
**concrete_program.kwargs)
return PartialProgramLayer(
concrete_program.main_program,
inputs,
concrete_program.outputs,
concrete_program.parameters,
**concrete_program.kwargs
)
@switch_to_static_graph
def add_build_strategy_for(program,
start_op_index,
end_op_index,
build_strategy=None):
if (start_op_index < end_op_index):
def add_build_strategy_for(
program, start_op_index, end_op_index, build_strategy=None
):
if start_op_index < end_op_index:
compiled_program = paddle.static.CompiledProgram(
core.Graph(program.desc, start_op_index, end_op_index),
build_strategy=build_strategy)
compiled_program._compile(core.Scope(),
framework._current_expected_place())
build_strategy=build_strategy,
)
compiled_program._compile(
core.Scope(), framework._current_expected_place()
)
ir_graph = framework.IrGraph(compiled_program._graph)
builded_program = ir_graph.to_program()
if hasattr(compiled_program._program, 'lr_sheduler'):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册