diff --git a/python/paddle/distributed/fleet/meta_optimizers/zero/__init__.py b/python/paddle/distributed/fleet/meta_optimizers/zero/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6296bcac93015c5f6c55861575a45a3a33b3628 --- /dev/null +++ b/python/paddle/distributed/fleet/meta_optimizers/zero/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +from . import decorator +from .decorator import * +from .fp16_lists import AutoMixedPrecisionLists + +__all__ = decorator.__all__ +__all__ += fp16_lists.__all__ diff --git a/python/paddle/distributed/fleet/meta_optimizers/zero/decorator.py b/python/paddle/distributed/fleet/meta_optimizers/zero/decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..3088db5a44d88e98f85013b40b7da56fdf82a109 --- /dev/null +++ b/python/paddle/distributed/fleet/meta_optimizers/zero/decorator.py @@ -0,0 +1,272 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.fluid import default_main_program +from paddle.fluid import default_startup_program +from paddle.fluid import layers +from paddle.fluid import unique_name +from paddle.fluid import framework +from . import fp16_utils +from .fp16_utils import update_loss_scaling, rewrite_program +from .fp16_utils import update_role_var_grad +from .fp16_lists import AutoMixedPrecisionLists + +__all__ = ["decorate"] + + +class OptimizerWithMixedPrecision(object): + """ + Optimizer with mixed-precision (MP) training. This is a wrapper of a common + optimizer, plus the support of mixed-precision pre-training. The object + of this class almost has the same behavior as the common optimizer, with the + methods `minimize()`, `backward()`, `apply_gradients()` implemented. + Additionally, it enables the MP training automatically, i.e, the creation + and maintenance of master parameters, scaling of loss, etc. + + Args: + optimizer (Optimizer): A common Optimizer object. + amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object. + init_loss_scaling (float): The initial loss scaling factor. + use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling. + incr_every_n_steps(int): Increases loss scaling every n consecutive + steps with finite gradients. + decr_every_n_nan_or_inf(int): Decreases loss scaling every n + accumulated steps with nan or + inf gradients. + incr_ratio(float): The multiplier to use when increasing the loss + scaling. + decr_ratio(float): The less-than-one-multiplier to use when decreasing + the loss scaling. + + """ + + def __init__(self, optimizer, amp_lists, init_loss_scaling, + use_dynamic_loss_scaling, incr_every_n_steps, + decr_every_n_nan_or_inf, incr_ratio, decr_ratio): + self._optimizer = optimizer + self._amp_lists = amp_lists + self._param_grads = None + self._train_program = default_main_program() + self._startup_prog = default_startup_program() + self._scaled_loss = None + self._loss_scaling = layers.create_global_var( + name=unique_name.generate("loss_scaling"), + shape=[1], + value=init_loss_scaling, + dtype='float32', + persistable=True) + self._use_dynamic_loss_scaling = use_dynamic_loss_scaling + if self._use_dynamic_loss_scaling: + self._incr_every_n_steps = layers.fill_constant( + shape=[1], dtype='int32', value=incr_every_n_steps) + self._decr_every_n_nan_or_inf = layers.fill_constant( + shape=[1], dtype='int32', value=decr_every_n_nan_or_inf) + self._incr_ratio = incr_ratio + self._decr_ratio = decr_ratio + self._num_good_steps = layers.create_global_var( + name=unique_name.generate("num_good_steps"), + shape=[1], + value=0, + dtype='int32', + persistable=True) + self._num_bad_steps = layers.create_global_var( + name=unique_name.generate("num_bad_steps"), + shape=[1], + value=0, + dtype='int32', + persistable=True) + + # Ensure the data type of learning rate vars is float32 (same as the + # master parameter dtype) + if isinstance(optimizer._learning_rate, float): + optimizer._learning_rate_map[default_main_program()] = \ + layers.create_global_var( + name=unique_name.generate("learning_rate"), + shape=[1], + value=float(optimizer._learning_rate), + dtype='float32', + persistable=True) + + def get_loss_scaling(self): + """Return the real-time loss scaling factor. + """ + return self._loss_scaling + + def get_scaled_loss(self): + """Return the scaled loss. + It's useful when you feed customed loss into executor. + """ + + return self._scaled_loss + + def backward(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None, + callbacks=None): + """ + Backward propagation or auto differentiation for gradients' computation. + + Args: + loss (Variable): The loss Variable to minimize. + startup_program (Program|None): The startup Program for initializing + parameters in `parameter_list`. + parameter_list (list|None): A list of Variables to update. + no_grad_set (set|None): A set of Variables should be ignored. + callbacks (list|None): A list of callable objects to run when appending + backward operator for one parameter. + + Returns: + A list of (param, grad), which is a tuple of a parameter and its + gradient respectively, and the scaled loss. + """ + rewrite_program(self._train_program, self._amp_lists) + with framework.name_scope('mixed_precision'): + self._scaled_loss = loss * self._loss_scaling + self._scaled_loss = loss * self._loss_scaling + self._params_grads = self._optimizer.backward( + self._scaled_loss, startup_program, parameter_list, no_grad_set, + callbacks) + # Change the op_role_var attr for some ops, so that gradients + # transferred across GPUs can be FP16. + update_role_var_grad(self._train_program, self._params_grads) + scaled_params_grads = [] + with framework.name_scope('mixed_precision'): + for p, g in self._params_grads: + with self._train_program._optimized_guard([p, g]): + scaled_g = g / self._loss_scaling + scaled_params_grads.append([p, scaled_g]) + return scaled_params_grads + + def apply_gradients(self, scaled_params_grads): + """ + Check scaled gradients to determine whether to update loss scaling and update + parameters by their scaled gradients, + + Args: + scaled_params_grads (list): A list of params and scaled grads. + + Returns: + A list of optimize operators. + """ + + if self._use_dynamic_loss_scaling and len(scaled_params_grads) > 0: + with framework.name_scope('mixed_precision'): + with self._train_program._optimized_guard(scaled_params_grads[ + 0]): + grads = [ + layers.reduce_sum(g) for [_, g] in scaled_params_grads + ] + all_grads_sum = layers.sums(grads) + is_overall_finite = layers.isfinite(all_grads_sum) + + update_loss_scaling( + is_overall_finite, self._loss_scaling, + self._num_good_steps, self._num_bad_steps, + self._incr_every_n_steps, self._decr_every_n_nan_or_inf, + self._incr_ratio, self._decr_ratio) + + # apply_gradient append all ops in global block, thus we shouldn't + # apply gradient in the switch branch. + with layers.Switch() as switch: + with switch.case(is_overall_finite): + pass + with switch.default(): + for _, g in scaled_params_grads: + layers.assign(layers.zeros_like(g), g) + + optimize_ops = self._optimizer.apply_gradients(scaled_params_grads) + + return optimize_ops + + def minimize(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None): + """ + Perform optimization by minimizing the given loss. + + Args: + loss (Variable): The loss Variable. + startup_program (Program): startup_program for initializing parameters + in `parameter_list`. + parameter_list (list): list of Variables to update. + no_grad_set (set|None): set of Variables should be ignored. + + Returns: + The scaled loss by scaling factor, the list of optimize ops, and a + list of scaled parameters and gradients. + """ + scaled_params_grads = self.backward( + loss, + startup_program=startup_program, + parameter_list=parameter_list, + no_grad_set=no_grad_set) + + optimize_ops = self.apply_gradients(scaled_params_grads) + + return optimize_ops, scaled_params_grads + + +def decorate(optimizer, + amp_lists=None, + init_loss_scaling=2**15, + incr_every_n_steps=1000, + decr_every_n_nan_or_inf=2, + incr_ratio=2.0, + decr_ratio=0.8, + use_dynamic_loss_scaling=True): + """ + Decorate the given optimizer to adapt to the mixed-precision training. + + Args: + optimizer(Optimizer): A common Optimizer. + amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object. + init_loss_scaling(float): The initial loss scaling factor. + incr_every_n_steps(int): Increases loss scaling every n consecutive + steps with finite gradients. + decr_every_n_nan_or_inf(int): Decreases loss scaling every n + accumulated steps with nan or + inf gradients. + incr_ratio(float): The multiplier to use when increasing the loss + scaling. + decr_ratio(float): The less-than-one-multiplier to use when decreasing + the loss scaling. + use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling. + + Returns: + An optimizer acting like a normal one but with mixed-precision training + enabled. + + Examples: + .. code-block:: python + + loss = network() + optimizer = fluid.optimizer.Adam(learning_rate=0.001) + + mp_optimizer = fluid.contrib.mixed_precision.decorate( + optimizer=optimizer, init_loss_scaling=8.0) + + ops, param_grads = mp_optimizer.minimize(loss) + scaled_loss = mp_optimizer.get_scaled_loss() + """ + if amp_lists is None: + amp_lists = AutoMixedPrecisionLists() + mp_optimizer = OptimizerWithMixedPrecision( + optimizer, amp_lists, init_loss_scaling, use_dynamic_loss_scaling, + incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, decr_ratio) + + return mp_optimizer diff --git a/python/paddle/distributed/fleet/meta_optimizers/zero/fp16_lists.py b/python/paddle/distributed/fleet/meta_optimizers/zero/fp16_lists.py new file mode 100644 index 0000000000000000000000000000000000000000..1f301b7148d005d4e3d5d272fd78f78af6dc1e6a --- /dev/null +++ b/python/paddle/distributed/fleet/meta_optimizers/zero/fp16_lists.py @@ -0,0 +1,284 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +__all__ = ["AutoMixedPrecisionLists"] + + +class AutoMixedPrecisionLists(object): + """ + AutoMixedPrecisionLists is a class for black/white list. It can update + pre-defined black list and white list according to users' custom black + white lists. The lists are used for an algorithm which determines op's + execution mode (fp32 or fp16). + + Args: + custom_white_list (set): Users' custom white list. + custom_black_list (set): Users' custom black list. + """ + + def __init__(self, + custom_white_list=None, + custom_black_list=None, + custom_black_varnames=None): + self._custom_white_list = custom_white_list + self._custom_black_list = custom_black_list + self.white_list = copy.copy(white_list) + self.black_list = copy.copy(black_list) + self.gray_list = copy.copy(gray_list) + self.black_varnames = copy.copy(custom_black_varnames) + self._update_list() + + def _update_list(self): + """ + Update black and white list according to users' custom list. + """ + if self._custom_white_list and self._custom_black_list: + for op_name in self._custom_white_list: + if op_name in self._custom_black_list: + raise ValueError("Custom white list overlap " + "custom black list") + if self._custom_white_list: + for op_name in self._custom_white_list: + if op_name in self.black_list: + self.black_list.remove(op_name) + elif op_name in self.gray_list: + self.gray_list.remove(op_name) + self.white_list.add(op_name) + if self._custom_black_list: + for op_name in self._custom_black_list: + if op_name in self.white_list: + self.white_list.remove(op_name) + elif op_name in self.gray_list: + self.gray_list.remove(op_name) + self.black_list.add(op_name) + + +# The three sets listed below are changed dynamiclly. They don't contain all +# paddle ops currently. + +# The set of ops that support fp16 calculation and are considered numerically- +# safe and performance-critical. These ops are always converted to fp16. +white_list = { + 'conv2d', + 'matmul', + 'mul', +} + +# The set of ops that support fp16 calculation and are considered numerically- +# dangerous and whose effects may also be observed in downstream ops. +black_list = { + 'exp', + 'square', + 'log', + 'mean', + 'sum', + 'cos_sim', + 'softmax', + 'softmax_with_cross_entropy', + 'sigmoid_cross_entropy_with_logits', + 'cross_entropy', + 'cross_entropy2', +} + +# This set contains two types of ops. All ops supported fp16 calculation. One +# of two types is considered numerically-safe, but may be made unsafe by an +# upstream blacklist op. Another type do not have numerically-significant +# effects, like stack, flatten2. +gray_list = { + 'elementwise_add', + 'elementwise_sub', + 'elementwise_mul', + 'elementwise_div', + 'elementwise_max', + 'elementwise_min', + 'elementwise_pow', + 'elementwise_mod', + 'elementwise_floordiv', + 'batch_norm', + 'tanh', + 'sigmoid', + 'lookup_table', + 'top_k', + 'pool2d', + 'pool3d', + 'dropout', + 'relu', + 'relu6', + 'leaky_relu', + 'soft_relu', + 'flatten2', + 'stack', + 'unstack', + 'uniform_random_batch_size_like', + 'gaussian_random', + 'gaussian_random_batch_size_like', + 'slice', + 'rank', + 'scale', + 'transpose2', + 'reshape2', + 'gather', + 'fill_constant', + 'get_tensor_from_selected_rows', + 'sign', + 'cast', +} +''' +# The set of ops that don't support fp16 calculation +unsupported_fp16_list = { + # from python/paddle/fluid/layers/io.py + 'send', + 'send_barrier', + 'recv', + 'fetch_barrier', + 'create_py_reader', + 'create_double_buffer_reader', + 'read', + 'load', + + # from python/paddle/fluid/control_flow.py + 'increment', + 'less_than', + 'less_equal', + 'greater_than', + 'greater_equal', + 'equal', + 'not_equal', + 'read_from_array', + 'shrink_rnn_memory', + 'lod_array_length', + 'logical_and', + 'logical_or', + 'logical_xor', + 'logical_not', + 'print', + 'conditional_block', + 'while', + 'ifelse', + 'is_empty', + + 'lstm', + 'cudnn_lstm', + 'lstmp', + 'gru', + 'gru_unit', + 'linear_chain_crf', + 'crf_decoding', + 'bpr_loss', + 'chunk_eval', + 'sequence_conv', + 'sequence_softmax', + # Depthwise conv2d isn't fast and safe currently. + # ref: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h#L79 + 'depthwise_conv2d', + # Tensor Core kernels are not available for 3D convolutions currently. + 'conv3d', + 'sequence_pool', + 'sequence_concat', + 'sequence_slice', + 'data_norm', + 'layer_norm', + 'group_norm', + 'spectral_norm', + 'depthwise_conv2d_transpose', + 'sequence_expand', + 'conv_transposed2d', + 'conv_transposed3d', + 'sequence_expand_as', + 'sequence_pad', + 'sequence_unpad', + 'sequence_erase', + 'beam_search', + 'beam_search_decode', + 'lstm_unit', + 'reduce_sum', + 'reduce_mean', + 'reduce_max', + 'reduce_min', + 'reduce_prod', + 'reduce_all', + 'reduce_any', + 'split', + 'edit_distance', + 'ctc_align', + 'warpctc', + 'sequence_reshape', + 'nce', + 'hierarchical_sigmoid', + 'im2sequence', + 'row_conv', + 'multiplex', + 'sample_logits', + 'one_hot', + 'smooth_l1_loss', + 'squeeze2', + 'unsqueeze2', + 'lod_reset', + 'lrn', + 'pad', + 'pad_constant_like', + 'label_smooth', + 'scatter', + 'sequence_scatter', + 'random_crop', + 'mean_iou', + 'selu', + 'crop', + 'affine_grid', + 'rank_loss', + 'margin_rank_loss', + 'pad2d', + 'elu', + 'pow', + 'stanh', + 'hard_sigmoid', + 'swish', + 'prelu', + 'brelu', + 'sequence_enumerate', + 'sequence_mask', + 'expand', + 'sampling_id', + 'maxout', + 'space_to_depth', + 'sequence_reverse', + 'similarity_focus', + 'hash', + 'grid_sampler', + 'log_loss', + 'teacher_student_sigmoid_loss', + 'add_position_encoding', + 'bilinear_tensor_product', + 'shuffle_channel', + 'temporal_shift', + 'psroi_pool', + 'huber_loss', + 'kldiv_loss', + 'tree_conv', + 'pixel_shuffle', + 'fsp', + 'cvm', + + 'affine_channel', + 'roi_pool', + 'roi_align', + 'anchor_generator', + 'generate_proposals', + 'generate_proposal_labels', + 'generate_mask_labels', + +} +''' diff --git a/python/paddle/distributed/fleet/meta_optimizers/zero/fp16_utils.py b/python/paddle/distributed/fleet/meta_optimizers/zero/fp16_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2026c7ea2195c305872d68e157be716a56a95e5a --- /dev/null +++ b/python/paddle/distributed/fleet/meta_optimizers/zero/fp16_utils.py @@ -0,0 +1,404 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +from paddle.fluid import core +from paddle.fluid import layers + + +def _rename_arg(op, old_name, new_name): + """ + If an op has old_name input and output, rename these input + args new_name. + + Args: + op (Operator): Current operator. + old_name (str): The old name of input args. + new_name (str): The new name of input args. + """ + op_desc = op.desc + if isinstance(op_desc, tuple): + op_desc = op_desc[0] + op_desc._rename_input(old_name, new_name) + op_desc._rename_output(old_name, new_name) + + +def _dtype_to_str(dtype): + """ + Convert specific variable type to its corresponding string. + + Args: + dtype (VarType): Variable type. + """ + if dtype == core.VarDesc.VarType.FP16: + return 'fp16' + else: + return 'fp32' + + +def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): + """ + Insert cast op and rename args of input and output. + + Args: + block (Program): The block in which the operator is. + op (Operator): The operator to insert cast op. + idx (int): The index of current operator. + src_dtype (VarType): The input variable dtype of cast op. + dest_dtype (VarType): The output variable dtype of cast op. + + Returns: + num_cast_op (int): The number of cast ops that have been inserted. + """ + num_cast_ops = 0 + valid_types = [ + core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.SELECTED_ROWS, + core.VarDesc.VarType.LOD_TENSOR_ARRAY + ] + + for in_name in op.input_names: + if src_dtype == core.VarDesc.VarType.FP32 and op.type == 'batch_norm': + if in_name != 'X': + continue + for in_var_name in op.input(in_name): + in_var = block.var(in_var_name) + if in_var.type not in valid_types: + continue + if in_var.dtype == src_dtype: + cast_name = in_var.name + '.cast_' + _dtype_to_str(dest_dtype) + out_var = block.vars.get(cast_name) + if out_var is None or out_var.dtype != dest_dtype: + out_var = block.create_var( + name=cast_name, + dtype=dest_dtype, + persistable=False, + stop_gradient=False) + + block._insert_op( + idx, + type="cast", + inputs={"X": in_var}, + outputs={"Out": out_var}, + attrs={ + "in_dtype": in_var.dtype, + "out_dtype": out_var.dtype + }) + num_cast_ops += 1 + _rename_arg(op, in_var.name, out_var.name) + else: + if op.has_attr('in_dtype'): + op._set_attr('in_dtype', dest_dtype) + if src_dtype == core.VarDesc.VarType.FP32: + for out_name in op.output_names: + if op.type == 'batch_norm' and out_name != 'Y': + continue + for out_var_name in op.output(out_name): + out_var = block.var(out_var_name) + if out_var.type not in valid_types: + continue + if out_var.dtype == core.VarDesc.VarType.FP32: + out_var.desc.set_dtype(core.VarDesc.VarType.FP16) + if op.has_attr('out_dtype'): + op._set_attr('out_dtype', core.VarDesc.VarType.FP16) + return num_cast_ops + + +def find_true_prev_op(ops, cur_op, var_name): + """ + Find the true prev op that outputs var_name variable. + + Args: + ops (list): A list of ops. + cur_op (Operator): Current operator which has var_name variable. + var_name (string): Variable name. + """ + prev_op = [] + for op in ops: + if op == cur_op: + break + for out_name in op.output_names: + for out_var_name in op.output(out_name): + if out_var_name == var_name: + prev_op.append(op) + if prev_op: + if not len(prev_op) == 1: + raise ValueError("There must be only one previous op " + "that outputs {0} variable".format(var_name)) + else: + return prev_op[0] + return None + + +def find_true_post_op(ops, cur_op, var_name): + """ + if there are post ops, return them, if there is no post op, + return None instead. + Args: + ops (list): A list of ops. + cur_op (Operator): Current operator which has var_name variable. + var_name (string): Variable name. + """ + post_op = [] + for idx, op in enumerate(ops): + if op == cur_op: + break + + for i in range(idx + 1, len(ops)): + op = ops[i] + for in_name in op.input_names: + for in_var_name in op.input(in_name): + if in_var_name == var_name: + post_op.append(op) + if post_op != []: + return post_op + return None + + +def find_op_index(block_desc, cur_op_desc): + """ + """ + for idx in range(block_desc.op_size()): + if cur_op_desc == block_desc.op(idx): + return idx + return -1 + + +def _is_in_black_varnames(op, amp_lists): + for in_name in op.input_arg_names: + if in_name in amp_lists.black_varnames: + return True + + for out_name in op.output_arg_names: + if out_name in amp_lists.black_varnames: + return True + + return False + + +def rewrite_program(main_prog, amp_lists): + """ + Traverse all ops in current block and insert cast op according to + which set current op belongs to. + + 1. When an op belongs to the black list, add it to black set + 2. When an op belongs to the white list, add it to white set + 3. When an op belongs to the gray list. If one + of its inputs is the output of black set op or black list op, + add it to black set. If all of its previous ops are not black + op and one of its inputs is the output of white set op or + white list op, add it to white set. + 4. When an op isn't in the lists, add it to black op set. + 5. Add necessary cast ops to make sure that black set op will be + computed in fp32 mode, while white set op will be computed in + fp16 mode. + + Args: + main_prog (Program): The main program for training. + """ + block = main_prog.global_block() + ops = block.ops + white_op_set = set() + black_op_set = set() + for op in ops: + if amp_lists.black_varnames is not None and _is_in_black_varnames( + op, amp_lists): + black_op_set.add(op) + continue + + if op.type in amp_lists.black_list: + black_op_set.add(op) + elif op.type in amp_lists.white_list: + white_op_set.add(op) + elif op.type in amp_lists.gray_list: + is_black_op = False + is_white_op = False + for in_name in op.input_names: + # if this op has inputs + if in_name: + for in_var_name in op.input(in_name): + in_var = block.var(in_var_name) + # this in_var isn't the output of other op + if in_var.op is None: + continue + elif in_var.op is op: + prev_op = find_true_prev_op(ops, op, in_var_name) + if prev_op is None: + continue + else: + prev_op = in_var.op + # if it's one of inputs + if prev_op in black_op_set or \ + prev_op.type in amp_lists.black_list: + is_black_op = True + elif prev_op in white_op_set or \ + prev_op.type in amp_lists.white_list: + is_white_op = True + if is_black_op: + black_op_set.add(op) + elif is_white_op: + white_op_set.add(op) + else: + pass + else: + # For numerical safe, we apply fp32 computation on ops that + # are not determined which list they should stay. + black_op_set.add(op) + + idx = 0 + while idx < len(ops): + op = ops[idx] + num_cast_ops = 0 + if op in black_op_set: + num_cast_ops = _insert_cast_op(block, op, idx, + core.VarDesc.VarType.FP16, + core.VarDesc.VarType.FP32) + elif op in white_op_set: + num_cast_ops = _insert_cast_op(block, op, idx, + core.VarDesc.VarType.FP32, + core.VarDesc.VarType.FP16) + else: + pass + + idx += num_cast_ops + 1 + + +def update_role_var_grad(main_prog, params_grads): + """ + Update op_role_var attr for some ops to make sure the gradients + transferred across GPUs is FP16. + 1. Check whether the op that outputs gradient is cast or not. + 2. If op is cast and gradient is FP32, remove the op_role_var + and find the prev op which outputs FP16 gradient + 3. Update the op_role_var of the prev op. + + Args: + main_prog (Program): The main program for training. + params_grads (list): A list of params and grads. + """ + block = main_prog.global_block() + BACKWARD = core.op_proto_and_checker_maker.OpRole.Backward + OPTIMIZE = core.op_proto_and_checker_maker.OpRole.Optimize + for p, g in params_grads: + op = g.op + if g.dtype == core.VarDesc.VarType.FP32 and op.type == 'cast': + role = op.attr('op_role') + if role & int(BACKWARD) and op.has_attr('op_role_var'): + op.desc.remove_attr("op_role_var") + else: + raise ValueError("The cast op {0} must be in BACKWARD role " + "and have op_role_var attr.".format(op)) + + fp16_grad_name = op.input(op.input_names[0])[0] + op_for_fp16_grad = find_true_prev_op(block.ops, op, fp16_grad_name) + op_role_var_attr_name = \ + core.op_proto_and_checker_maker.kOpRoleVarAttrName() + attr_val = [p.name, fp16_grad_name] + if op_for_fp16_grad.has_attr(op_role_var_attr_name): + attr_val.extend(op_for_fp16_grad.attr(op_role_var_attr_name)) + op_for_fp16_grad._set_attr(op_role_var_attr_name, attr_val) + + # Maximize the all_reduce overlap, and perform the cast + # operation after gradients transfer. + op._set_attr('op_role', OPTIMIZE) + # optimize op should stay behind forward and backward ops + if op == block.ops[-1]: + continue + post_ops = find_true_post_op(block.ops, op, g.name) + if post_ops is not None: + raise ValueError("The cast op {0}'s output should not be" + "used by a non-optimize op, however, it" + "is used by {1}".format(op, post_ops[0])) + new_op_desc = block.desc.append_op() + new_op_desc.copy_from(op.desc) + + op_idx = find_op_index(block.desc, op.desc) + if op_idx == -1: + raise ValueError("The op {0} is not in program".format(op)) + block.desc._remove_op(op_idx, op_idx + 1) + block._sync_with_cpp() + + +def update_loss_scaling(is_overall_finite, prev_loss_scaling, num_good_steps, + num_bad_steps, incr_every_n_steps, + decr_every_n_nan_or_inf, incr_ratio, decr_ratio): + """ + Update loss scaling according to overall gradients. If all gradients is + finite after incr_every_n_steps, loss scaling will increase by incr_ratio. + Otherwise, loss scaling will decrease by decr_ratio after + decr_every_n_nan_or_inf steps and each step some gradients are infinite. + + Args: + is_overall_finite (Variable): A boolean variable indicates whether + all gradients are finite. + prev_loss_scaling (Variable): Previous loss scaling. + num_good_steps (Variable): A variable accumulates good steps in which + all gradients are finite. + num_bad_steps (Variable): A variable accumulates bad steps in which + some gradients are infinite. + incr_every_n_steps (Variable): A variable represents increasing loss + scaling every n consecutive steps with + finite gradients. + decr_every_n_nan_or_inf (Variable): A variable represents decreasing + loss scaling every n accumulated + steps with nan or inf gradients. + incr_ratio(float): The multiplier to use when increasing the loss + scaling. + decr_ratio(float): The less-than-one-multiplier to use when decreasing + loss scaling. + """ + zero_steps = layers.fill_constant(shape=[1], dtype='int32', value=0) + with layers.Switch() as switch: + with switch.case(is_overall_finite): + should_incr_loss_scaling = layers.less_than(incr_every_n_steps, + num_good_steps + 1) + with layers.Switch() as switch1: + with switch1.case(should_incr_loss_scaling): + new_loss_scaling = prev_loss_scaling * incr_ratio + loss_scaling_is_finite = layers.isfinite(new_loss_scaling) + with layers.Switch() as switch2: + with switch2.case(loss_scaling_is_finite): + layers.assign(new_loss_scaling, prev_loss_scaling) + with switch2.default(): + pass + layers.assign(zero_steps, num_good_steps) + layers.assign(zero_steps, num_bad_steps) + + with switch1.default(): + layers.increment(num_good_steps) + layers.assign(zero_steps, num_bad_steps) + + with switch.default(): + should_decr_loss_scaling = layers.less_than(decr_every_n_nan_or_inf, + num_bad_steps + 1) + with layers.Switch() as switch3: + with switch3.case(should_decr_loss_scaling): + new_loss_scaling = prev_loss_scaling * decr_ratio + static_loss_scaling = \ + layers.fill_constant(shape=[1], + dtype='float32', + value=1.0) + less_than_one = layers.less_than(new_loss_scaling, + static_loss_scaling) + with layers.Switch() as switch4: + with switch4.case(less_than_one): + layers.assign(static_loss_scaling, + prev_loss_scaling) + with switch4.default(): + layers.assign(new_loss_scaling, prev_loss_scaling) + layers.assign(zero_steps, num_good_steps) + layers.assign(zero_steps, num_bad_steps) + with switch3.default(): + layers.assign(zero_steps, num_good_steps) + layers.increment(num_bad_steps) diff --git a/python/paddle/distributed/fleet/meta_optimizers/zero_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/zero_optimizer.py index 11adf124a2c329965cba007fab834420639f2b60..5856dd070faef62234ca4096518afbfc604b76ed 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/zero_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/zero_optimizer.py @@ -16,7 +16,7 @@ from .common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY, CollectiveHelper from .common import is_update_op, is_loss_grad_op, is_backward_op, is_optimizer_op from .meta_optimizer_base import MetaOptimizerBase from paddle.fluid import unique_name, core -from paddle.fluid.contrib.mixed_precision.decorator import OptimizerWithMixedPrecision +from zero.decorator import decorate as amp_decorate import paddle.fluid as fluid import math @@ -813,8 +813,7 @@ class ZeroOptimizer(MetaOptimizerBase): optimizer._set_checkpoints(ckpts) if self.user_defined_strategy.zero_configs["amp"]: - optimizer = fluid.contrib.mixed_precision.decorate( - optimizer, use_dynamic_loss_scaling=True) + optimizer = amp_decorate(optimizer, use_dynamic_loss_scaling=True) self._nrings = self.user_defined_strategy.zero_configs["nrings"] self._fuse_broadcast_MB_bytes = self.user_defined_strategy.zero_configs[ @@ -1184,8 +1183,7 @@ class ZeroOptimizer(MetaOptimizerBase): optimizer = self.inner_opt if self.user_defined_strategy.zero_configs["amp"]: - optimizer = fluid.contrib.mixed_precision.decorate( - optimizer, use_dynamic_loss_scaling=True) + optimizer = amp_decorate(optimizer, use_dynamic_loss_scaling=True) optimize_ops, params_grads = optimizer.minimize( loss, startup_program, parameter_list, no_grad_set)