提交 30e178fa 编写于 作者: J Jie Fang 提交者: Yibing Liu

init auto loss scaling (#17194)

* init auto loss scaling

test=develop

* change API.spec

* change ifelse to switch and use reduce_sum to optimize checking isfinite

test=develop

* Remove redundant code

test=develop
上级 4a1b7fec
...@@ -421,7 +421,7 @@ paddle.fluid.contrib.HDFSClient.upload (ArgSpec(args=['self', 'hdfs_path', 'loca ...@@ -421,7 +421,7 @@ paddle.fluid.contrib.HDFSClient.upload (ArgSpec(args=['self', 'hdfs_path', 'loca
paddle.fluid.contrib.multi_download (ArgSpec(args=['client', 'hdfs_path', 'local_path', 'trainer_id', 'trainers', 'multi_processes'], varargs=None, keywords=None, defaults=(5,)), ('document', '100927be598ed8f9eaa1f3ef1b23568a')) paddle.fluid.contrib.multi_download (ArgSpec(args=['client', 'hdfs_path', 'local_path', 'trainer_id', 'trainers', 'multi_processes'], varargs=None, keywords=None, defaults=(5,)), ('document', '100927be598ed8f9eaa1f3ef1b23568a'))
paddle.fluid.contrib.multi_upload (ArgSpec(args=['client', 'hdfs_path', 'local_path', 'multi_processes', 'overwrite', 'sync'], varargs=None, keywords=None, defaults=(5, False, True)), ('document', '183f34c83d30dbe16e09e8716c41958a')) paddle.fluid.contrib.multi_upload (ArgSpec(args=['client', 'hdfs_path', 'local_path', 'multi_processes', 'overwrite', 'sync'], varargs=None, keywords=None, defaults=(5, False, True)), ('document', '183f34c83d30dbe16e09e8716c41958a'))
paddle.fluid.contrib.extend_with_decoupled_weight_decay (ArgSpec(args=['base_optimizer'], varargs=None, keywords=None, defaults=None), ('document', 'a1095dfd4ec725747f662d69cd7659d4')) paddle.fluid.contrib.extend_with_decoupled_weight_decay (ArgSpec(args=['base_optimizer'], varargs=None, keywords=None, defaults=None), ('document', 'a1095dfd4ec725747f662d69cd7659d4'))
paddle.fluid.contrib.mixed_precision.decorate (ArgSpec(args=['optimizer', 'init_loss_scaling', 'use_dynamic_loss_scaling'], varargs=None, keywords=None, defaults=(1.0, False)), ('document', '67e9bf14f345b38da169beb1ebb276eb')) paddle.fluid.contrib.mixed_precision.decorate (ArgSpec(args=['optimizer', 'init_loss_scaling', 'incr_every_n_steps', 'decr_every_n_nan_or_inf', 'incr_ratio', 'decr_ratio', 'use_dynamic_loss_scaling'], varargs=None, keywords=None, defaults=(1.0, 1000, 2, 2.0, 0.8, False)), ('document', 'bdb8f9dbb0d94b3957272c53eeee9818'))
paddle.fluid.transpiler.DistributeTranspiler.__init__ (ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.transpiler.DistributeTranspiler.__init__ (ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program (ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None), ('document', '292ab72977afbe58e6a3bde175452680')) paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program (ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None), ('document', '292ab72977afbe58e6a3bde175452680'))
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_programs (ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None), ('document', '78f4949aedf317666a89ca74b3748ba8')) paddle.fluid.transpiler.DistributeTranspiler.get_pserver_programs (ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None), ('document', '78f4949aedf317666a89ca74b3748ba8'))
......
...@@ -18,6 +18,7 @@ from ... import layers ...@@ -18,6 +18,7 @@ from ... import layers
from ... import unique_name from ... import unique_name
from . import fp16_utils from . import fp16_utils
from .fp16_utils import create_master_params_grads, master_param_to_train_param from .fp16_utils import create_master_params_grads, master_param_to_train_param
from .fp16_utils import update_loss_scaling
__all__ = ["decorate"] __all__ = ["decorate"]
...@@ -35,15 +36,51 @@ class OptimizerWithMixedPrecison(object): ...@@ -35,15 +36,51 @@ class OptimizerWithMixedPrecison(object):
optimizer (Optimizer): A common Optimizer object. optimizer (Optimizer): A common Optimizer object.
init_loss_scaling (float): The initial loss scaling factor. init_loss_scaling (float): The initial loss scaling factor.
use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling. use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling.
incr_every_n_steps(int): Increases loss scaling every n consecutive
steps with finite gradients.
decr_every_n_nan_or_inf(int): Decreases loss scaling every n
accumulated steps with nan or
inf gradients.
incr_ratio(float): The multiplier to use when increasing the loss
scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing
the loss scaling.
""" """
def __init__(self, optimizer, init_loss_scaling, use_dynamic_loss_scaling): def __init__(self, optimizer, init_loss_scaling, use_dynamic_loss_scaling,
incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio,
decr_ratio):
self._optimizer = optimizer self._optimizer = optimizer
self._param_grads = None self._param_grads = None
self._train_program = default_main_program() self._train_program = default_main_program()
self._startup_prog = default_startup_program() self._startup_prog = default_startup_program()
self._loss_scaling = init_loss_scaling self._loss_scaling = layers.create_global_var(
name=unique_name.generate("loss_scaling"),
shape=[1],
value=init_loss_scaling,
dtype='float32',
persistable=True)
self._use_dynamic_loss_scaling = use_dynamic_loss_scaling self._use_dynamic_loss_scaling = use_dynamic_loss_scaling
if self._use_dynamic_loss_scaling:
self._incr_every_n_steps = layers.fill_constant(
shape=[1], dtype='int32', value=incr_every_n_steps)
self._decr_every_n_nan_or_inf = layers.fill_constant(
shape=[1], dtype='int32', value=decr_every_n_nan_or_inf)
self._incr_ratio = incr_ratio
self._decr_ratio = decr_ratio
self._num_good_steps = layers.create_global_var(
name=unique_name.generate("num_good_steps"),
shape=[1],
value=0,
dtype='int32',
persistable=True)
self._num_bad_steps = layers.create_global_var(
name=unique_name.generate("num_bad_steps"),
shape=[1],
value=0,
dtype='int32',
persistable=True)
# Ensure the data type of learning rate vars is float32 (same as the # Ensure the data type of learning rate vars is float32 (same as the
# master parameter dtype) # master parameter dtype)
...@@ -104,9 +141,33 @@ class OptimizerWithMixedPrecison(object): ...@@ -104,9 +141,33 @@ class OptimizerWithMixedPrecison(object):
Returns: Returns:
A list of optimize operators. A list of optimize operators.
""" """
if self._use_dynamic_loss_scaling:
grads = [layers.reduce_sum(g) for [_, g] in master_params_grads]
all_grads = layers.concat(grads)
all_grads_sum = layers.reduce_sum(all_grads)
is_overall_finite = layers.isfinite(all_grads_sum)
update_loss_scaling(is_overall_finite, self._loss_scaling,
self._num_good_steps, self._num_bad_steps,
self._incr_every_n_steps,
self._decr_every_n_nan_or_inf, self._incr_ratio,
self._decr_ratio)
# apply_gradient append all ops in global block, thus we shouldn't
# apply gradient in the switch branch.
with layers.Switch() as switch:
with switch.case(is_overall_finite):
pass
with switch.default():
for _, g in master_params_grads:
layers.assign(layers.zeros_like(g), g)
optimize_ops = self._optimizer.apply_gradients(master_params_grads) optimize_ops = self._optimizer.apply_gradients(master_params_grads)
master_param_to_train_param(master_params_grads, self._param_grads, master_param_to_train_param(master_params_grads, self._param_grads,
self._train_program) self._train_program)
return optimize_ops return optimize_ops
def minimize(self, loss): def minimize(self, loss):
...@@ -126,13 +187,28 @@ class OptimizerWithMixedPrecison(object): ...@@ -126,13 +187,28 @@ class OptimizerWithMixedPrecison(object):
return scaled_loss, optimize_ops, master_params_grads return scaled_loss, optimize_ops, master_params_grads
def decorate(optimizer, init_loss_scaling=1.0, use_dynamic_loss_scaling=False): def decorate(optimizer,
init_loss_scaling=1.0,
incr_every_n_steps=1000,
decr_every_n_nan_or_inf=2,
incr_ratio=2.0,
decr_ratio=0.8,
use_dynamic_loss_scaling=False):
""" """
Decorate the given optimizer to adapt to the mixed-precision training. Decorate the given optimizer to adapt to the mixed-precision training.
Args: Args:
optimizer(Optimizer): A common Optimizer. optimizer(Optimizer): A common Optimizer.
init_loss_scaling(float): The initial loss scaling factor. init_loss_scaling(float): The initial loss scaling factor.
incr_every_n_steps(int): Increases loss scaling every n consecutive
steps with finite gradients.
decr_every_n_nan_or_inf(int): Decreases loss scaling every n
accumulated steps with nan or
inf gradients.
incr_ratio(float): The multiplier to use when increasing the loss
scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing
the loss scaling.
use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling. use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling.
Returns: Returns:
...@@ -151,7 +227,8 @@ def decorate(optimizer, init_loss_scaling=1.0, use_dynamic_loss_scaling=False): ...@@ -151,7 +227,8 @@ def decorate(optimizer, init_loss_scaling=1.0, use_dynamic_loss_scaling=False):
scaled_loss, _, _ = mp_optimizer.minimize(loss) scaled_loss, _, _ = mp_optimizer.minimize(loss)
""" """
mp_optimizer = OptimizerWithMixedPrecison(optimizer, init_loss_scaling, mp_optimizer = OptimizerWithMixedPrecison(
use_dynamic_loss_scaling) optimizer, init_loss_scaling, use_dynamic_loss_scaling,
incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, decr_ratio)
return mp_optimizer return mp_optimizer
...@@ -91,15 +91,11 @@ def create_master_params_grads(params_grads, main_prog, startup_prog, ...@@ -91,15 +91,11 @@ def create_master_params_grads(params_grads, main_prog, startup_prog,
append_cast_op(startup_p, startup_master_param, startup_prog) append_cast_op(startup_p, startup_master_param, startup_prog)
# cast fp16 gradients to fp32 before apply gradients # cast fp16 gradients to fp32 before apply gradients
if g.name.find("batch_norm") > -1: if g.name.find("batch_norm") > -1:
if loss_scaling > 1: scaled_g = g / loss_scaling
scaled_g = g / float(loss_scaling)
else:
scaled_g = g
master_params_grads.append([p, scaled_g]) master_params_grads.append([p, scaled_g])
continue continue
master_grad = layers.cast(x=g, dtype="float32") master_grad = layers.cast(x=g, dtype="float32")
if loss_scaling > 1: master_grad = master_grad / loss_scaling
master_grad = master_grad / float(loss_scaling)
master_params_grads.append([master_param, master_grad]) master_params_grads.append([master_param, master_grad])
return master_params_grads return master_params_grads
...@@ -123,3 +119,77 @@ def master_param_to_train_param(master_params_grads, params_grads, main_prog): ...@@ -123,3 +119,77 @@ def master_param_to_train_param(master_params_grads, params_grads, main_prog):
with main_prog._optimized_guard([m_p_g[0], m_p_g[1]]): with main_prog._optimized_guard([m_p_g[0], m_p_g[1]]):
# fp32 -> fp16 # fp32 -> fp16
append_cast_op(m_p_g[0], train_p, main_prog) append_cast_op(m_p_g[0], train_p, main_prog)
def update_loss_scaling(is_overall_finite, prev_loss_scaling, num_good_steps,
num_bad_steps, incr_every_n_steps,
decr_every_n_nan_or_inf, incr_ratio, decr_ratio):
"""
Update loss scaling according to overall gradients. If all gradients is
finite after incr_every_n_steps, loss scaling will increase by incr_ratio.
Otherwisw, loss scaling will decrease by decr_ratio after
decr_every_n_nan_or_inf steps and each step some gradients are infinite.
Args:
is_overall_finite (Variable): A boolean variable indicates whether
all gradients are finite.
prev_loss_scaling (Variable): Previous loss scaling.
num_good_steps (Variable): A variable accumulates good steps in which
all gradients are finite.
num_bad_steps (Variable): A variable accumulates bad steps in which
some gradients are infinite.
incr_every_n_steps (Variable): A variable represents increasing loss
scaling every n consecutive steps with
finite gradients.
decr_every_n_nan_or_inf (Variable): A variable represents decreasing
loss scaling every n accumulated
steps with nan or inf gradients.
incr_ratio(float): The multiplier to use when increasing the loss
scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing
loss scaling.
"""
zero_steps = layers.fill_constant(shape=[1], dtype='int32', value=0)
with layers.Switch() as switch:
with switch.case(is_overall_finite):
should_incr_loss_scaling = layers.less_than(incr_every_n_steps,
num_good_steps + 1)
with layers.Switch() as switch1:
with switch1.case(should_incr_loss_scaling):
new_loss_scaling = prev_loss_scaling * incr_ratio
loss_scaling_is_finite = layers.isfinite(new_loss_scaling)
with layers.Switch() as switch2:
with switch2.case(loss_scaling_is_finite):
layers.assign(new_loss_scaling, prev_loss_scaling)
with switch2.default():
pass
layers.assign(zero_steps, num_good_steps)
layers.assign(zero_steps, num_bad_steps)
with switch1.default():
layers.increment(num_good_steps)
layers.assign(zero_steps, num_bad_steps)
with switch.default():
should_decr_loss_scaling = layers.less_than(decr_every_n_nan_or_inf,
num_bad_steps + 1)
with layers.Switch() as switch3:
with switch3.case(should_decr_loss_scaling):
new_loss_scaling = prev_loss_scaling * decr_ratio
static_loss_scaling = \
layers.fill_constant(shape=[1],
dtype='float32',
value=1.0)
less_than_one = layers.less_than(new_loss_scaling,
static_loss_scaling)
with layers.Switch() as switch4:
with switch4.case(less_than_one):
layers.assign(static_loss_scaling,
prev_loss_scaling)
with switch4.default():
layers.assign(new_loss_scaling, prev_loss_scaling)
layers.assign(zero_steps, num_good_steps)
layers.assign(zero_steps, num_bad_steps)
with switch3.default():
layers.assign(zero_steps, num_good_steps)
layers.increment(num_bad_steps)
...@@ -135,7 +135,9 @@ def train(net_type, use_cuda, save_dirname, is_local): ...@@ -135,7 +135,9 @@ def train(net_type, use_cuda, save_dirname, is_local):
optimizer = fluid.optimizer.Adam(learning_rate=0.001) optimizer = fluid.optimizer.Adam(learning_rate=0.001)
mp_optimizer = fluid.contrib.mixed_precision.decorate( mp_optimizer = fluid.contrib.mixed_precision.decorate(
optimizer=optimizer, init_loss_scaling=8.0) optimizer=optimizer,
init_loss_scaling=8.0,
use_dynamic_loss_scaling=True)
scaled_loss, _, _ = mp_optimizer.minimize(avg_cost) scaled_loss, _, _ = mp_optimizer.minimize(avg_cost)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册