提交 172c2fac 编写于 作者: J Jie Fang 提交者: Yibing Liu

init black/white lists (#17847)

test=develop
上级 e06c69c7
......@@ -18,7 +18,7 @@ from ... import layers
from ... import unique_name
from . import fp16_utils
from .fp16_utils import create_master_params_grads, master_param_to_train_param
from .fp16_utils import update_loss_scaling
from .fp16_utils import update_loss_scaling, rewrite_program
__all__ = ["decorate"]
......@@ -120,6 +120,7 @@ class OptimizerWithMixedPrecison(object):
A list of (param, grad), which is a tuple of a parameter and its
gradient respectively, and the scaled loss.
"""
rewrite_program(self._train_program)
scaled_loss = loss * self._loss_scaling
self._param_grads = self._optimizer.backward(
scaled_loss, startup_program, parameter_list, no_grad_set,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The three sets listed below are changed dynamiclly. They don't contain all
# paddle ops currently.
# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
white_list = {
'conv2d',
'matmul',
'mul',
}
# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
black_list = {
'exp',
'square',
'log',
'mean',
'sum',
'cos_sim',
'softmax',
'softmax_with_cross_entropy',
'sigmoid_cross_entropy_with_logits',
'cross_entropy',
'cross_entropy2',
}
# This set contains two types of ops. All ops supported fp16 calculation. One
# of two types is considered numerically-safe, but may be made unsafe by an
# updtream blacklist op. Another type do not have numerically-significant
# effects, like stack, flatten2.
gray_list = {
'elementwise_add',
'elementwise_sub',
'elementwise_mul',
'elementwise_div',
'elementwise_max',
'elementwise_min',
'elementwise_pow',
'elementwise_mod',
'elementwise_floordiv',
'tanh',
'sigmoid',
'lookup_table',
'top_k',
'pool2d',
'pool3d',
'dropout',
'relu',
'relu6',
'leaky_relu',
'soft_relu',
'flatten2',
'stack',
'unstack',
'uniform_random_batch_size_like',
'gaussian_random',
'gaussian_random_batch_size_like',
'slice',
'rank',
'scale',
'transpose2',
'reshape2',
'gather',
'fill_constant',
'get_tensor_from_selected_rows',
'sign',
'cast',
}
'''
# The set of ops that don't support fp16 calculation
unsupported_fp16_list = {
# from python/paddle/fluid/layers/io.py
'send',
'send_barrier',
'recv',
'fetch_barrier',
'create_recordio_file_reader',
'create_random_data_generator',
'create_py_reader',
'create_shuffle_reader',
'create_batch_reader',
'create_double_buffer_reader',
'create_multi_pass_reader',
'read',
'load',
# from python/paddle/fluid/control_flow.py
'increment',
'less_than',
'less_equal',
'greater_than',
'greater_equal',
'equal',
'not_equal',
'read_from_array',
'shrink_rnn_memory',
'lod_array_length',
'logical_and',
'logical_or',
'logical_xor',
'logical_not',
'print',
'conditional_block',
'while',
'ifelse',
'is_empty',
'lstm',
'cudnn_lstm',
'lstmp',
'gru',
'gru_unit',
'linear_chain_crf',
'crf_decoding',
'bpr_loss',
'chunk_eval',
'sequence_conv',
'sequence_softmax',
# Depthwise conv2d isn't fast and safe currently.
# ref: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h#L79
'depthwise_conv2d',
# Tensor Core kernels are not available for 3D convolutions currently.
'conv3d',
'sequence_pool',
'sequence_concat',
'sequence_slice',
'data_norm',
'layer_norm',
'group_norm',
'spectral_norm',
'depthwise_conv2d_transpose',
'sequence_expand',
'conv_transposed2d',
'conv_transposed3d',
'sequence_expand_as',
'sequence_pad',
'sequence_unpad',
'sequence_erase',
'beam_search',
'beam_search_decode',
'lstm_unit',
'reduce_sum',
'reduce_mean',
'reduce_max',
'reduce_min',
'reduce_prod',
'reduce_all',
'reduce_any',
'split',
'edit_distance',
'ctc_align',
'warpctc',
'sequence_reshape',
'nce',
'hierarchical_sigmoid',
'im2sequence',
'row_conv',
'multiplex',
'sample_logits',
'one_hot',
'smooth_l1_loss',
'squeeze2',
'unsqueeze2',
'lod_reset',
'lrn',
'pad',
'pad_constant_like',
'label_smooth',
'scatter',
'sequence_scatter',
'random_crop',
'mean_iou',
'selu',
'crop',
'affine_grid',
'rank_loss',
'margin_rank_loss',
'pad2d',
'elu',
'pow',
'stanh',
'hard_sigmoid',
'swish',
'prelu',
'brelu',
'sequence_enumerate',
'sequence_mask',
'expand',
'sampling_id',
'maxout',
'space_to_depth',
'sequence_reverse',
'similarity_focus',
'hash',
'grid_sampler',
'log_loss',
'teacher_student_sigmoid_loss',
'add_position_encoding',
'bilinear_tensor_product',
'shuffle_channel',
'temporal_shift',
'psroi_pool',
'huber_loss',
'kldiv_loss',
'tree_conv',
'pixel_shuffle',
'fsp',
'cvm',
'affine_channel',
'roi_pool',
'roi_align',
'anchor_generator',
'generate_proposals',
'generate_proposal_labels',
'generate_mask_labels',
}
'''
......@@ -17,6 +17,7 @@ from __future__ import print_function
from ... import core
from ... import layers
from ... import framework
from .fp16_lists import black_list, white_list, gray_list
def append_cast_op(i, o, prog):
......@@ -121,6 +122,183 @@ def master_param_to_train_param(master_params_grads, params_grads, main_prog):
append_cast_op(m_p_g[0], train_p, main_prog)
def _rename_arg(op, old_name, new_name):
"""
If an op has old_name input and output, rename these input
args new_name.
Args:
op (Operator): Current operator.
old_name (str): The old name of input args.
new_name (str): The new name of input args.
"""
op_desc = op.desc
if isinstance(op_desc, tuple):
op_desc = op_desc[0]
op_desc._rename_input(old_name, new_name)
op_desc._rename_output(old_name, new_name)
def _dtype_to_str(dtype):
"""
Convert specific variable type to its corresponding string.
Args:
dtype (VarType): Variable type.
"""
if dtype == core.VarDesc.VarType.FP16:
return 'fp16'
else:
return 'fp32'
def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
"""
Insert cast op and rename args of input and output.
Args:
block (Program): The block in which the operator is.
op (Operator): The operator to insert cast op.
idx (int): The index of current operator.
src_dtype (VarType): The input variable dtype of cast op.
desr_dtype (VarType): The output variable dtype of cast op.
Returns:
num_cast_op (int): The number of cast ops that have been inserted.
"""
num_cast_ops = 0
valid_types = [
core.VarDesc.VarType.LOD_TENSOR, core.VarDesc.VarType.SELECTED_ROWS,
core.VarDesc.VarType.LOD_TENSOR_ARRAY
]
for in_name in op.input_names:
for in_var_name in op.input(in_name):
in_var = block.var(in_var_name)
if in_var.type not in valid_types:
continue
if in_var.dtype == src_dtype:
out_var = block.create_var(
name=in_var.name + \
'.cast_' + _dtype_to_str(dest_dtype),
dtype=dest_dtype,
persistable=False,
stop_gradient=False)
block._insert_op(
idx,
type="cast",
inputs={"X": in_var},
outputs={"Out": out_var},
attrs={
"in_dtype": in_var.dtype,
"out_dtype": out_var.dtype
})
num_cast_ops += 1
_rename_arg(op, in_var.name, out_var.name)
else:
if op.has_attr('in_dtype'):
op._set_attr('in_dtype', dest_dtype)
if src_dtype == core.VarDesc.VarType.FP16:
for out_name in op.output_names:
for out_var_name in op.output(out_name):
out_var = block.var(out_var_name)
if out_var.type not in valid_types:
continue
if out_var.dtype == core.VarDesc.VarType.FP16:
out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
if op.has_attr('out_dtype'):
op._set_attr('out_dtype', core.VarDesc.VarType.FP32)
return num_cast_ops
def find_true_prev_op(ops, var_name):
for op in ops:
for out_name in op.output_names:
for out_var_name in op.output(out_name):
if out_var_name == var_name:
return op
def rewrite_program(main_prog):
"""
Traverse all ops in current block and insert cast op according to
which set current op belongs to.
1. When an op belongs to the black list, add it to black set
2. When an op belongs to the white list, add it to white set
3. When an op belongs to the gray list. If one
of its inputs is the output of black set op or black list op,
add it to black set. If all of its previous ops are not black
op and one of its inputs is the output of white set op or
white list op, add it to white set.
4. When an op isn't in the lists, add it to black op set.
5. Add necessary cast ops to make sure that black set op will be
computed in fp32 mode, while white set op will be computed in
fp16 mode.
Args:
main_prog (Program): The main program for training.
"""
block = main_prog.global_block()
ops = block.ops
white_op_set = set()
black_op_set = set()
for i in range(len(ops)):
op = ops[i]
if op.type in black_list:
black_op_set.add(op)
elif op.type in white_list:
white_op_set.add(op)
elif op.type in op.type in gray_list:
is_black_op = False
is_white_op = False
for in_name in op.input_names:
# if this op has inputs
if in_name:
for in_var_name in op.input(in_name):
in_var = block.var(in_var_name)
# this in_var isn't the output of other op
if in_var.op is None:
continue
if in_var.op is op:
prev_op = find_true_prev_op(ops, in_var_name)
else:
prev_op = in_var.op
# if it's one of inputs
if prev_op in black_op_set or \
prev_op.type in black_list:
is_black_op = True
if prev_op in white_op_set or \
prev_op.type in white_list:
is_white_op = True
if is_black_op:
black_op_set.add(op)
elif is_white_op:
white_op_set.add(op)
else:
pass
else:
# For numerical safe, we apply fp32 computation on ops that
# are not determined which list they should stay.
black_op_set.add(op)
idx = 0
while idx < len(ops):
op = ops[idx]
num_cast_ops = 0
if op in black_op_set:
num_cast_ops = _insert_cast_op(block, op, idx,
core.VarDesc.VarType.FP16,
core.VarDesc.VarType.FP32)
elif op in white_op_set:
num_cast_ops = _insert_cast_op(block, op, idx,
core.VarDesc.VarType.FP32,
core.VarDesc.VarType.FP16)
else:
pass
idx += num_cast_ops + 1
def update_loss_scaling(is_overall_finite, prev_loss_scaling, num_good_steps,
num_bad_steps, incr_every_n_steps,
decr_every_n_nan_or_inf, incr_ratio, decr_ratio):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册