提交 f35c8ce6 编写于 作者: M mapingshuo

add custom amp

上级 6c5c547e
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from . import decorator
from .decorator import *
from .fp16_lists import AutoMixedPrecisionLists
__all__ = decorator.__all__
__all__ += fp16_lists.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid import default_main_program
from paddle.fluid import default_startup_program
from paddle.fluid import layers
from paddle.fluid import unique_name
from paddle.fluid import framework
from . import fp16_utils
from .fp16_utils import update_loss_scaling, rewrite_program
from .fp16_utils import update_role_var_grad
from .fp16_lists import AutoMixedPrecisionLists
__all__ = ["decorate"]
class OptimizerWithMixedPrecision(object):
"""
Optimizer with mixed-precision (MP) training. This is a wrapper of a common
optimizer, plus the support of mixed-precision pre-training. The object
of this class almost has the same behavior as the common optimizer, with the
methods `minimize()`, `backward()`, `apply_gradients()` implemented.
Additionally, it enables the MP training automatically, i.e, the creation
and maintenance of master parameters, scaling of loss, etc.
Args:
optimizer (Optimizer): A common Optimizer object.
amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object.
init_loss_scaling (float): The initial loss scaling factor.
use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling.
incr_every_n_steps(int): Increases loss scaling every n consecutive
steps with finite gradients.
decr_every_n_nan_or_inf(int): Decreases loss scaling every n
accumulated steps with nan or
inf gradients.
incr_ratio(float): The multiplier to use when increasing the loss
scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing
the loss scaling.
"""
def __init__(self, optimizer, amp_lists, init_loss_scaling,
use_dynamic_loss_scaling, incr_every_n_steps,
decr_every_n_nan_or_inf, incr_ratio, decr_ratio):
self._optimizer = optimizer
self._amp_lists = amp_lists
self._param_grads = None
self._train_program = default_main_program()
self._startup_prog = default_startup_program()
self._scaled_loss = None
self._loss_scaling = layers.create_global_var(
name=unique_name.generate("loss_scaling"),
shape=[1],
value=init_loss_scaling,
dtype='float32',
persistable=True)
self._use_dynamic_loss_scaling = use_dynamic_loss_scaling
if self._use_dynamic_loss_scaling:
self._incr_every_n_steps = layers.fill_constant(
shape=[1], dtype='int32', value=incr_every_n_steps)
self._decr_every_n_nan_or_inf = layers.fill_constant(
shape=[1], dtype='int32', value=decr_every_n_nan_or_inf)
self._incr_ratio = incr_ratio
self._decr_ratio = decr_ratio
self._num_good_steps = layers.create_global_var(
name=unique_name.generate("num_good_steps"),
shape=[1],
value=0,
dtype='int32',
persistable=True)
self._num_bad_steps = layers.create_global_var(
name=unique_name.generate("num_bad_steps"),
shape=[1],
value=0,
dtype='int32',
persistable=True)
# Ensure the data type of learning rate vars is float32 (same as the
# master parameter dtype)
if isinstance(optimizer._learning_rate, float):
optimizer._learning_rate_map[default_main_program()] = \
layers.create_global_var(
name=unique_name.generate("learning_rate"),
shape=[1],
value=float(optimizer._learning_rate),
dtype='float32',
persistable=True)
def get_loss_scaling(self):
"""Return the real-time loss scaling factor.
"""
return self._loss_scaling
def get_scaled_loss(self):
"""Return the scaled loss.
It's useful when you feed customed loss into executor.
"""
return self._scaled_loss
def backward(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None,
callbacks=None):
"""
Backward propagation or auto differentiation for gradients' computation.
Args:
loss (Variable): The loss Variable to minimize.
startup_program (Program|None): The startup Program for initializing
parameters in `parameter_list`.
parameter_list (list|None): A list of Variables to update.
no_grad_set (set|None): A set of Variables should be ignored.
callbacks (list|None): A list of callable objects to run when appending
backward operator for one parameter.
Returns:
A list of (param, grad), which is a tuple of a parameter and its
gradient respectively, and the scaled loss.
"""
rewrite_program(self._train_program, self._amp_lists)
with framework.name_scope('mixed_precision'):
self._scaled_loss = loss * self._loss_scaling
self._scaled_loss = loss * self._loss_scaling
self._params_grads = self._optimizer.backward(
self._scaled_loss, startup_program, parameter_list, no_grad_set,
callbacks)
# Change the op_role_var attr for some ops, so that gradients
# transferred across GPUs can be FP16.
update_role_var_grad(self._train_program, self._params_grads)
scaled_params_grads = []
with framework.name_scope('mixed_precision'):
for p, g in self._params_grads:
with self._train_program._optimized_guard([p, g]):
scaled_g = g / self._loss_scaling
scaled_params_grads.append([p, scaled_g])
return scaled_params_grads
def apply_gradients(self, scaled_params_grads):
"""
Check scaled gradients to determine whether to update loss scaling and update
parameters by their scaled gradients,
Args:
scaled_params_grads (list): A list of params and scaled grads.
Returns:
A list of optimize operators.
"""
if self._use_dynamic_loss_scaling and len(scaled_params_grads) > 0:
with framework.name_scope('mixed_precision'):
with self._train_program._optimized_guard(scaled_params_grads[
0]):
grads = [
layers.reduce_sum(g) for [_, g] in scaled_params_grads
]
all_grads_sum = layers.sums(grads)
is_overall_finite = layers.isfinite(all_grads_sum)
update_loss_scaling(
is_overall_finite, self._loss_scaling,
self._num_good_steps, self._num_bad_steps,
self._incr_every_n_steps, self._decr_every_n_nan_or_inf,
self._incr_ratio, self._decr_ratio)
# apply_gradient append all ops in global block, thus we shouldn't
# apply gradient in the switch branch.
with layers.Switch() as switch:
with switch.case(is_overall_finite):
pass
with switch.default():
for _, g in scaled_params_grads:
layers.assign(layers.zeros_like(g), g)
optimize_ops = self._optimizer.apply_gradients(scaled_params_grads)
return optimize_ops
def minimize(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
"""
Perform optimization by minimizing the given loss.
Args:
loss (Variable): The loss Variable.
startup_program (Program): startup_program for initializing parameters
in `parameter_list`.
parameter_list (list): list of Variables to update.
no_grad_set (set|None): set of Variables should be ignored.
Returns:
The scaled loss by scaling factor, the list of optimize ops, and a
list of scaled parameters and gradients.
"""
scaled_params_grads = self.backward(
loss,
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set)
optimize_ops = self.apply_gradients(scaled_params_grads)
return optimize_ops, scaled_params_grads
def decorate(optimizer,
amp_lists=None,
init_loss_scaling=2**15,
incr_every_n_steps=1000,
decr_every_n_nan_or_inf=2,
incr_ratio=2.0,
decr_ratio=0.8,
use_dynamic_loss_scaling=True):
"""
Decorate the given optimizer to adapt to the mixed-precision training.
Args:
optimizer(Optimizer): A common Optimizer.
amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object.
init_loss_scaling(float): The initial loss scaling factor.
incr_every_n_steps(int): Increases loss scaling every n consecutive
steps with finite gradients.
decr_every_n_nan_or_inf(int): Decreases loss scaling every n
accumulated steps with nan or
inf gradients.
incr_ratio(float): The multiplier to use when increasing the loss
scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing
the loss scaling.
use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling.
Returns:
An optimizer acting like a normal one but with mixed-precision training
enabled.
Examples:
.. code-block:: python
loss = network()
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
mp_optimizer = fluid.contrib.mixed_precision.decorate(
optimizer=optimizer, init_loss_scaling=8.0)
ops, param_grads = mp_optimizer.minimize(loss)
scaled_loss = mp_optimizer.get_scaled_loss()
"""
if amp_lists is None:
amp_lists = AutoMixedPrecisionLists()
mp_optimizer = OptimizerWithMixedPrecision(
optimizer, amp_lists, init_loss_scaling, use_dynamic_loss_scaling,
incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, decr_ratio)
return mp_optimizer
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
__all__ = ["AutoMixedPrecisionLists"]
class AutoMixedPrecisionLists(object):
"""
AutoMixedPrecisionLists is a class for black/white list. It can update
pre-defined black list and white list according to users' custom black
white lists. The lists are used for an algorithm which determines op's
execution mode (fp32 or fp16).
Args:
custom_white_list (set): Users' custom white list.
custom_black_list (set): Users' custom black list.
"""
def __init__(self,
custom_white_list=None,
custom_black_list=None,
custom_black_varnames=None):
self._custom_white_list = custom_white_list
self._custom_black_list = custom_black_list
self.white_list = copy.copy(white_list)
self.black_list = copy.copy(black_list)
self.gray_list = copy.copy(gray_list)
self.black_varnames = copy.copy(custom_black_varnames)
self._update_list()
def _update_list(self):
"""
Update black and white list according to users' custom list.
"""
if self._custom_white_list and self._custom_black_list:
for op_name in self._custom_white_list:
if op_name in self._custom_black_list:
raise ValueError("Custom white list overlap "
"custom black list")
if self._custom_white_list:
for op_name in self._custom_white_list:
if op_name in self.black_list:
self.black_list.remove(op_name)
elif op_name in self.gray_list:
self.gray_list.remove(op_name)
self.white_list.add(op_name)
if self._custom_black_list:
for op_name in self._custom_black_list:
if op_name in self.white_list:
self.white_list.remove(op_name)
elif op_name in self.gray_list:
self.gray_list.remove(op_name)
self.black_list.add(op_name)
# The three sets listed below are changed dynamiclly. They don't contain all
# paddle ops currently.
# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
white_list = {
'conv2d',
'matmul',
'mul',
}
# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
black_list = {
'exp',
'square',
'log',
'mean',
'sum',
'cos_sim',
'softmax',
'softmax_with_cross_entropy',
'sigmoid_cross_entropy_with_logits',
'cross_entropy',
'cross_entropy2',
}
# This set contains two types of ops. All ops supported fp16 calculation. One
# of two types is considered numerically-safe, but may be made unsafe by an
# upstream blacklist op. Another type do not have numerically-significant
# effects, like stack, flatten2.
gray_list = {
'elementwise_add',
'elementwise_sub',
'elementwise_mul',
'elementwise_div',
'elementwise_max',
'elementwise_min',
'elementwise_pow',
'elementwise_mod',
'elementwise_floordiv',
'batch_norm',
'tanh',
'sigmoid',
'lookup_table',
'top_k',
'pool2d',
'pool3d',
'dropout',
'relu',
'relu6',
'leaky_relu',
'soft_relu',
'flatten2',
'stack',
'unstack',
'uniform_random_batch_size_like',
'gaussian_random',
'gaussian_random_batch_size_like',
'slice',
'rank',
'scale',
'transpose2',
'reshape2',
'gather',
'fill_constant',
'get_tensor_from_selected_rows',
'sign',
'cast',
}
'''
# The set of ops that don't support fp16 calculation
unsupported_fp16_list = {
# from python/paddle/fluid/layers/io.py
'send',
'send_barrier',
'recv',
'fetch_barrier',
'create_py_reader',
'create_double_buffer_reader',
'read',
'load',
# from python/paddle/fluid/control_flow.py
'increment',
'less_than',
'less_equal',
'greater_than',
'greater_equal',
'equal',
'not_equal',
'read_from_array',
'shrink_rnn_memory',
'lod_array_length',
'logical_and',
'logical_or',
'logical_xor',
'logical_not',
'print',
'conditional_block',
'while',
'ifelse',
'is_empty',
'lstm',
'cudnn_lstm',
'lstmp',
'gru',
'gru_unit',
'linear_chain_crf',
'crf_decoding',
'bpr_loss',
'chunk_eval',
'sequence_conv',
'sequence_softmax',
# Depthwise conv2d isn't fast and safe currently.
# ref: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h#L79
'depthwise_conv2d',
# Tensor Core kernels are not available for 3D convolutions currently.
'conv3d',
'sequence_pool',
'sequence_concat',
'sequence_slice',
'data_norm',
'layer_norm',
'group_norm',
'spectral_norm',
'depthwise_conv2d_transpose',
'sequence_expand',
'conv_transposed2d',
'conv_transposed3d',
'sequence_expand_as',
'sequence_pad',
'sequence_unpad',
'sequence_erase',
'beam_search',
'beam_search_decode',
'lstm_unit',
'reduce_sum',
'reduce_mean',
'reduce_max',
'reduce_min',
'reduce_prod',
'reduce_all',
'reduce_any',
'split',
'edit_distance',
'ctc_align',
'warpctc',
'sequence_reshape',
'nce',
'hierarchical_sigmoid',
'im2sequence',
'row_conv',
'multiplex',
'sample_logits',
'one_hot',
'smooth_l1_loss',
'squeeze2',
'unsqueeze2',
'lod_reset',
'lrn',
'pad',
'pad_constant_like',
'label_smooth',
'scatter',
'sequence_scatter',
'random_crop',
'mean_iou',
'selu',
'crop',
'affine_grid',
'rank_loss',
'margin_rank_loss',
'pad2d',
'elu',
'pow',
'stanh',
'hard_sigmoid',
'swish',
'prelu',
'brelu',
'sequence_enumerate',
'sequence_mask',
'expand',
'sampling_id',
'maxout',
'space_to_depth',
'sequence_reverse',
'similarity_focus',
'hash',
'grid_sampler',
'log_loss',
'teacher_student_sigmoid_loss',
'add_position_encoding',
'bilinear_tensor_product',
'shuffle_channel',
'temporal_shift',
'psroi_pool',
'huber_loss',
'kldiv_loss',
'tree_conv',
'pixel_shuffle',
'fsp',
'cvm',
'affine_channel',
'roi_pool',
'roi_align',
'anchor_generator',
'generate_proposals',
'generate_proposal_labels',
'generate_mask_labels',
}
'''
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from paddle.fluid import core
from paddle.fluid import layers
def _rename_arg(op, old_name, new_name):
"""
If an op has old_name input and output, rename these input
args new_name.
Args:
op (Operator): Current operator.
old_name (str): The old name of input args.
new_name (str): The new name of input args.
"""
op_desc = op.desc
if isinstance(op_desc, tuple):
op_desc = op_desc[0]
op_desc._rename_input(old_name, new_name)
op_desc._rename_output(old_name, new_name)
def _dtype_to_str(dtype):
"""
Convert specific variable type to its corresponding string.
Args:
dtype (VarType): Variable type.
"""
if dtype == core.VarDesc.VarType.FP16:
return 'fp16'
else:
return 'fp32'
def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
"""
Insert cast op and rename args of input and output.
Args:
block (Program): The block in which the operator is.
op (Operator): The operator to insert cast op.
idx (int): The index of current operator.
src_dtype (VarType): The input variable dtype of cast op.
dest_dtype (VarType): The output variable dtype of cast op.
Returns:
num_cast_op (int): The number of cast ops that have been inserted.
"""
num_cast_ops = 0
valid_types = [
core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.SELECTED_ROWS,
core.VarDesc.VarType.LOD_TENSOR_ARRAY
]
for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and op.type == 'batch_norm':
if in_name != 'X':
continue
for in_var_name in op.input(in_name):
in_var = block.var(in_var_name)
if in_var.type not in valid_types:
continue
if in_var.dtype == src_dtype:
cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype)
out_var = block.vars.get(cast_name)
if out_var is None or out_var.dtype != dest_dtype:
out_var = block.create_var(
name=cast_name,
dtype=dest_dtype,
persistable=False,
stop_gradient=False)
block._insert_op(
idx,
type="cast",
inputs={"X": in_var},
outputs={"Out": out_var},
attrs={
"in_dtype": in_var.dtype,
"out_dtype": out_var.dtype
})
num_cast_ops += 1
_rename_arg(op, in_var.name, out_var.name)
else:
if op.has_attr('in_dtype'):
op._set_attr('in_dtype', dest_dtype)
if src_dtype == core.VarDesc.VarType.FP32:
for out_name in op.output_names:
if op.type == 'batch_norm' and out_name != 'Y':
continue
for out_var_name in op.output(out_name):
out_var = block.var(out_var_name)
if out_var.type not in valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP32:
out_var.desc.set_dtype(core.VarDesc.VarType.FP16)
if op.has_attr('out_dtype'):
op._set_attr('out_dtype', core.VarDesc.VarType.FP16)
return num_cast_ops
def find_true_prev_op(ops, cur_op, var_name):
"""
Find the true prev op that outputs var_name variable.
Args:
ops (list): A list of ops.
cur_op (Operator): Current operator which has var_name variable.
var_name (string): Variable name.
"""
prev_op = []
for op in ops:
if op == cur_op:
break
for out_name in op.output_names:
for out_var_name in op.output(out_name):
if out_var_name == var_name:
prev_op.append(op)
if prev_op:
if not len(prev_op) == 1:
raise ValueError("There must be only one previous op "
"that outputs {0} variable".format(var_name))
else:
return prev_op[0]
return None
def find_true_post_op(ops, cur_op, var_name):
"""
if there are post ops, return them, if there is no post op,
return None instead.
Args:
ops (list): A list of ops.
cur_op (Operator): Current operator which has var_name variable.
var_name (string): Variable name.
"""
post_op = []
for idx, op in enumerate(ops):
if op == cur_op:
break
for i in range(idx + 1, len(ops)):
op = ops[i]
for in_name in op.input_names:
for in_var_name in op.input(in_name):
if in_var_name == var_name:
post_op.append(op)
if post_op != []:
return post_op
return None
def find_op_index(block_desc, cur_op_desc):
"""
"""
for idx in range(block_desc.op_size()):
if cur_op_desc == block_desc.op(idx):
return idx
return -1
def _is_in_black_varnames(op, amp_lists):
for in_name in op.input_arg_names:
if in_name in amp_lists.black_varnames:
return True
for out_name in op.output_arg_names:
if out_name in amp_lists.black_varnames:
return True
return False
def rewrite_program(main_prog, amp_lists):
"""
Traverse all ops in current block and insert cast op according to
which set current op belongs to.
1. When an op belongs to the black list, add it to black set
2. When an op belongs to the white list, add it to white set
3. When an op belongs to the gray list. If one
of its inputs is the output of black set op or black list op,
add it to black set. If all of its previous ops are not black
op and one of its inputs is the output of white set op or
white list op, add it to white set.
4. When an op isn't in the lists, add it to black op set.
5. Add necessary cast ops to make sure that black set op will be
computed in fp32 mode, while white set op will be computed in
fp16 mode.
Args:
main_prog (Program): The main program for training.
"""
block = main_prog.global_block()
ops = block.ops
white_op_set = set()
black_op_set = set()
for op in ops:
if amp_lists.black_varnames is not None and _is_in_black_varnames(
op, amp_lists):
black_op_set.add(op)
continue
if op.type in amp_lists.black_list:
black_op_set.add(op)
elif op.type in amp_lists.white_list:
white_op_set.add(op)
elif op.type in amp_lists.gray_list:
is_black_op = False
is_white_op = False
for in_name in op.input_names:
# if this op has inputs
if in_name:
for in_var_name in op.input(in_name):
in_var = block.var(in_var_name)
# this in_var isn't the output of other op
if in_var.op is None:
continue
elif in_var.op is op:
prev_op = find_true_prev_op(ops, op, in_var_name)
if prev_op is None:
continue
else:
prev_op = in_var.op
# if it's one of inputs
if prev_op in black_op_set or \
prev_op.type in amp_lists.black_list:
is_black_op = True
elif prev_op in white_op_set or \
prev_op.type in amp_lists.white_list:
is_white_op = True
if is_black_op:
black_op_set.add(op)
elif is_white_op:
white_op_set.add(op)
else:
pass
else:
# For numerical safe, we apply fp32 computation on ops that
# are not determined which list they should stay.
black_op_set.add(op)
idx = 0
while idx < len(ops):
op = ops[idx]
num_cast_ops = 0
if op in black_op_set:
num_cast_ops = _insert_cast_op(block, op, idx,
core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32)
elif op in white_op_set:
num_cast_ops = _insert_cast_op(block, op, idx,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16)
else:
pass
idx += num_cast_ops + 1
def update_role_var_grad(main_prog, params_grads):
"""
Update op_role_var attr for some ops to make sure the gradients
transferred across GPUs is FP16.
1. Check whether the op that outputs gradient is cast or not.
2. If op is cast and gradient is FP32, remove the op_role_var
and find the prev op which outputs FP16 gradient
3. Update the op_role_var of the prev op.
Args:
main_prog (Program): The main program for training.
params_grads (list): A list of params and grads.
"""
block = main_prog.global_block()
BACKWARD = core.op_proto_and_checker_maker.OpRole.Backward
OPTIMIZE = core.op_proto_and_checker_maker.OpRole.Optimize
for p, g in params_grads:
op = g.op
if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast':
role = op.attr('op_role')
if role & int(BACKWARD) and op.has_attr('op_role_var'):
op.desc.remove_attr("op_role_var")
else:
raise ValueError("The cast op {0} must be in BACKWARD role "
"and have op_role_var attr.".format(op))
fp16_grad_name = op.input(op.input_names[0])[0]
op_for_fp16_grad = find_true_prev_op(block.ops, op, fp16_grad_name)
op_role_var_attr_name = \
core.op_proto_and_checker_maker.kOpRoleVarAttrName()
attr_val = [p.name, fp16_grad_name]
if op_for_fp16_grad.has_attr(op_role_var_attr_name):
attr_val.extend(op_for_fp16_grad.attr(op_role_var_attr_name))
op_for_fp16_grad._set_attr(op_role_var_attr_name, attr_val)
# Maximize the all_reduce overlap, and perform the cast
# operation after gradients transfer.
op._set_attr('op_role', OPTIMIZE)
# optimize op should stay behind forward and backward ops
if op == block.ops[-1]:
continue
post_ops = find_true_post_op(block.ops, op, g.name)
if post_ops is not None:
raise ValueError("The cast op {0}'s output should not be"
"used by a non-optimize op, however, it"
"is used by {1}".format(op, post_ops[0]))
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(op.desc)
op_idx = find_op_index(block.desc, op.desc)
if op_idx == -1:
raise ValueError("The op {0} is not in program".format(op))
block.desc._remove_op(op_idx, op_idx + 1)
block._sync_with_cpp()
def update_loss_scaling(is_overall_finite, prev_loss_scaling, num_good_steps,
num_bad_steps, incr_every_n_steps,
decr_every_n_nan_or_inf, incr_ratio, decr_ratio):
"""
Update loss scaling according to overall gradients. If all gradients is
finite after incr_every_n_steps, loss scaling will increase by incr_ratio.
Otherwise, loss scaling will decrease by decr_ratio after
decr_every_n_nan_or_inf steps and each step some gradients are infinite.
Args:
is_overall_finite (Variable): A boolean variable indicates whether
all gradients are finite.
prev_loss_scaling (Variable): Previous loss scaling.
num_good_steps (Variable): A variable accumulates good steps in which
all gradients are finite.
num_bad_steps (Variable): A variable accumulates bad steps in which
some gradients are infinite.
incr_every_n_steps (Variable): A variable represents increasing loss
scaling every n consecutive steps with
finite gradients.
decr_every_n_nan_or_inf (Variable): A variable represents decreasing
loss scaling every n accumulated
steps with nan or inf gradients.
incr_ratio(float): The multiplier to use when increasing the loss
scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing
loss scaling.
"""
zero_steps = layers.fill_constant(shape=[1], dtype='int32', value=0)
with layers.Switch() as switch:
with switch.case(is_overall_finite):
should_incr_loss_scaling = layers.less_than(incr_every_n_steps,
num_good_steps + 1)
with layers.Switch() as switch1:
with switch1.case(should_incr_loss_scaling):
new_loss_scaling = prev_loss_scaling * incr_ratio
loss_scaling_is_finite = layers.isfinite(new_loss_scaling)
with layers.Switch() as switch2:
with switch2.case(loss_scaling_is_finite):
layers.assign(new_loss_scaling, prev_loss_scaling)
with switch2.default():
pass
layers.assign(zero_steps, num_good_steps)
layers.assign(zero_steps, num_bad_steps)
with switch1.default():
layers.increment(num_good_steps)
layers.assign(zero_steps, num_bad_steps)
with switch.default():
should_decr_loss_scaling = layers.less_than(decr_every_n_nan_or_inf,
num_bad_steps + 1)
with layers.Switch() as switch3:
with switch3.case(should_decr_loss_scaling):
new_loss_scaling = prev_loss_scaling * decr_ratio
static_loss_scaling = \
layers.fill_constant(shape=[1],
dtype='float32',
value=1.0)
less_than_one = layers.less_than(new_loss_scaling,
static_loss_scaling)
with layers.Switch() as switch4:
with switch4.case(less_than_one):
layers.assign(static_loss_scaling,
prev_loss_scaling)
with switch4.default():
layers.assign(new_loss_scaling, prev_loss_scaling)
layers.assign(zero_steps, num_good_steps)
layers.assign(zero_steps, num_bad_steps)
with switch3.default():
layers.assign(zero_steps, num_good_steps)
layers.increment(num_bad_steps)
......@@ -16,7 +16,7 @@ from .common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY, CollectiveHelper
from .common import is_update_op, is_loss_grad_op, is_backward_op, is_optimizer_op
from .meta_optimizer_base import MetaOptimizerBase
from paddle.fluid import unique_name, core
from paddle.fluid.contrib.mixed_precision.decorator import OptimizerWithMixedPrecision
from zero.decorator import decorate as amp_decorate
import paddle.fluid as fluid
import math
......@@ -813,8 +813,7 @@ class ZeroOptimizer(MetaOptimizerBase):
optimizer._set_checkpoints(ckpts)
if self.user_defined_strategy.zero_configs["amp"]:
optimizer = fluid.contrib.mixed_precision.decorate(
optimizer, use_dynamic_loss_scaling=True)
optimizer = amp_decorate(optimizer, use_dynamic_loss_scaling=True)
self._nrings = self.user_defined_strategy.zero_configs["nrings"]
self._fuse_broadcast_MB_bytes = self.user_defined_strategy.zero_configs[
......@@ -1184,8 +1183,7 @@ class ZeroOptimizer(MetaOptimizerBase):
optimizer = self.inner_opt
if self.user_defined_strategy.zero_configs["amp"]:
optimizer = fluid.contrib.mixed_precision.decorate(
optimizer, use_dynamic_loss_scaling=True)
optimizer = amp_decorate(optimizer, use_dynamic_loss_scaling=True)
optimize_ops, params_grads = optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册