Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
a9da6ead
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a9da6ead
编写于
6月 25, 2019
作者:
Y
Yibing Liu
提交者:
GitHub
6月 25, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
init black/white lists (#17847) (#18309)
test=release/1.5
上级
0fad63a3
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
414 addition
and
1 deletion
+414
-1
python/paddle/fluid/contrib/mixed_precision/decorator.py
python/paddle/fluid/contrib/mixed_precision/decorator.py
+2
-1
python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
+234
-0
python/paddle/fluid/contrib/mixed_precision/fp16_utils.py
python/paddle/fluid/contrib/mixed_precision/fp16_utils.py
+178
-0
未找到文件。
python/paddle/fluid/contrib/mixed_precision/decorator.py
浏览文件 @
a9da6ead
...
@@ -18,7 +18,7 @@ from ... import layers
...
@@ -18,7 +18,7 @@ from ... import layers
from
...
import
unique_name
from
...
import
unique_name
from
.
import
fp16_utils
from
.
import
fp16_utils
from
.fp16_utils
import
create_master_params_grads
,
master_param_to_train_param
from
.fp16_utils
import
create_master_params_grads
,
master_param_to_train_param
from
.fp16_utils
import
update_loss_scaling
from
.fp16_utils
import
update_loss_scaling
,
rewrite_program
__all__
=
[
"decorate"
]
__all__
=
[
"decorate"
]
...
@@ -120,6 +120,7 @@ class OptimizerWithMixedPrecison(object):
...
@@ -120,6 +120,7 @@ class OptimizerWithMixedPrecison(object):
A list of (param, grad), which is a tuple of a parameter and its
A list of (param, grad), which is a tuple of a parameter and its
gradient respectively, and the scaled loss.
gradient respectively, and the scaled loss.
"""
"""
rewrite_program
(
self
.
_train_program
)
scaled_loss
=
loss
*
self
.
_loss_scaling
scaled_loss
=
loss
*
self
.
_loss_scaling
self
.
_param_grads
=
self
.
_optimizer
.
backward
(
self
.
_param_grads
=
self
.
_optimizer
.
backward
(
scaled_loss
,
startup_program
,
parameter_list
,
no_grad_set
,
scaled_loss
,
startup_program
,
parameter_list
,
no_grad_set
,
...
...
python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
0 → 100644
浏览文件 @
a9da6ead
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The three sets listed below are changed dynamiclly. They don't contain all
# paddle ops currently.
# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
white_list
=
{
'conv2d'
,
'matmul'
,
'mul'
,
}
# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
black_list
=
{
'exp'
,
'square'
,
'log'
,
'mean'
,
'sum'
,
'cos_sim'
,
'softmax'
,
'softmax_with_cross_entropy'
,
'sigmoid_cross_entropy_with_logits'
,
'cross_entropy'
,
'cross_entropy2'
,
}
# This set contains two types of ops. All ops supported fp16 calculation. One
# of two types is considered numerically-safe, but may be made unsafe by an
# updtream blacklist op. Another type do not have numerically-significant
# effects, like stack, flatten2.
gray_list
=
{
'elementwise_add'
,
'elementwise_sub'
,
'elementwise_mul'
,
'elementwise_div'
,
'elementwise_max'
,
'elementwise_min'
,
'elementwise_pow'
,
'elementwise_mod'
,
'elementwise_floordiv'
,
'tanh'
,
'sigmoid'
,
'lookup_table'
,
'top_k'
,
'pool2d'
,
'pool3d'
,
'dropout'
,
'relu'
,
'relu6'
,
'leaky_relu'
,
'soft_relu'
,
'flatten2'
,
'stack'
,
'unstack'
,
'uniform_random_batch_size_like'
,
'gaussian_random'
,
'gaussian_random_batch_size_like'
,
'slice'
,
'rank'
,
'scale'
,
'transpose2'
,
'reshape2'
,
'gather'
,
'fill_constant'
,
'get_tensor_from_selected_rows'
,
'sign'
,
'cast'
,
}
'''
# The set of ops that don't support fp16 calculation
unsupported_fp16_list = {
# from python/paddle/fluid/layers/io.py
'send',
'send_barrier',
'recv',
'fetch_barrier',
'create_recordio_file_reader',
'create_random_data_generator',
'create_py_reader',
'create_shuffle_reader',
'create_batch_reader',
'create_double_buffer_reader',
'create_multi_pass_reader',
'read',
'load',
# from python/paddle/fluid/control_flow.py
'increment',
'less_than',
'less_equal',
'greater_than',
'greater_equal',
'equal',
'not_equal',
'read_from_array',
'shrink_rnn_memory',
'lod_array_length',
'logical_and',
'logical_or',
'logical_xor',
'logical_not',
'print',
'conditional_block',
'while',
'ifelse',
'is_empty',
'lstm',
'cudnn_lstm',
'lstmp',
'gru',
'gru_unit',
'linear_chain_crf',
'crf_decoding',
'bpr_loss',
'chunk_eval',
'sequence_conv',
'sequence_softmax',
# Depthwise conv2d isn't fast and safe currently.
# ref: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h#L79
'depthwise_conv2d',
# Tensor Core kernels are not available for 3D convolutions currently.
'conv3d',
'sequence_pool',
'sequence_concat',
'sequence_slice',
'data_norm',
'layer_norm',
'group_norm',
'spectral_norm',
'depthwise_conv2d_transpose',
'sequence_expand',
'conv_transposed2d',
'conv_transposed3d',
'sequence_expand_as',
'sequence_pad',
'sequence_unpad',
'sequence_erase',
'beam_search',
'beam_search_decode',
'lstm_unit',
'reduce_sum',
'reduce_mean',
'reduce_max',
'reduce_min',
'reduce_prod',
'reduce_all',
'reduce_any',
'split',
'edit_distance',
'ctc_align',
'warpctc',
'sequence_reshape',
'nce',
'hierarchical_sigmoid',
'im2sequence',
'row_conv',
'multiplex',
'sample_logits',
'one_hot',
'smooth_l1_loss',
'squeeze2',
'unsqueeze2',
'lod_reset',
'lrn',
'pad',
'pad_constant_like',
'label_smooth',
'scatter',
'sequence_scatter',
'random_crop',
'mean_iou',
'selu',
'crop',
'affine_grid',
'rank_loss',
'margin_rank_loss',
'pad2d',
'elu',
'pow',
'stanh',
'hard_sigmoid',
'swish',
'prelu',
'brelu',
'sequence_enumerate',
'sequence_mask',
'expand',
'sampling_id',
'maxout',
'space_to_depth',
'sequence_reverse',
'similarity_focus',
'hash',
'grid_sampler',
'log_loss',
'teacher_student_sigmoid_loss',
'add_position_encoding',
'bilinear_tensor_product',
'shuffle_channel',
'temporal_shift',
'psroi_pool',
'huber_loss',
'kldiv_loss',
'tree_conv',
'pixel_shuffle',
'fsp',
'cvm',
'affine_channel',
'roi_pool',
'roi_align',
'anchor_generator',
'generate_proposals',
'generate_proposal_labels',
'generate_mask_labels',
}
'''
python/paddle/fluid/contrib/mixed_precision/fp16_utils.py
浏览文件 @
a9da6ead
...
@@ -17,6 +17,7 @@ from __future__ import print_function
...
@@ -17,6 +17,7 @@ from __future__ import print_function
from
...
import
core
from
...
import
core
from
...
import
layers
from
...
import
layers
from
...
import
framework
from
...
import
framework
from
.fp16_lists
import
black_list
,
white_list
,
gray_list
def
append_cast_op
(
i
,
o
,
prog
):
def
append_cast_op
(
i
,
o
,
prog
):
...
@@ -121,6 +122,183 @@ def master_param_to_train_param(master_params_grads, params_grads, main_prog):
...
@@ -121,6 +122,183 @@ def master_param_to_train_param(master_params_grads, params_grads, main_prog):
append_cast_op
(
m_p_g
[
0
],
train_p
,
main_prog
)
append_cast_op
(
m_p_g
[
0
],
train_p
,
main_prog
)
def
_rename_arg
(
op
,
old_name
,
new_name
):
"""
If an op has old_name input and output, rename these input
args new_name.
Args:
op (Operator): Current operator.
old_name (str): The old name of input args.
new_name (str): The new name of input args.
"""
op_desc
=
op
.
desc
if
isinstance
(
op_desc
,
tuple
):
op_desc
=
op_desc
[
0
]
op_desc
.
_rename_input
(
old_name
,
new_name
)
op_desc
.
_rename_output
(
old_name
,
new_name
)
def
_dtype_to_str
(
dtype
):
"""
Convert specific variable type to its corresponding string.
Args:
dtype (VarType): Variable type.
"""
if
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
return
'fp16'
else
:
return
'fp32'
def
_insert_cast_op
(
block
,
op
,
idx
,
src_dtype
,
dest_dtype
):
"""
Insert cast op and rename args of input and output.
Args:
block (Program): The block in which the operator is.
op (Operator): The operator to insert cast op.
idx (int): The index of current operator.
src_dtype (VarType): The input variable dtype of cast op.
desr_dtype (VarType): The output variable dtype of cast op.
Returns:
num_cast_op (int): The number of cast ops that have been inserted.
"""
num_cast_ops
=
0
valid_types
=
[
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
,
core
.
VarDesc
.
VarType
.
LOD_TENSOR_ARRAY
]
for
in_name
in
op
.
input_names
:
for
in_var_name
in
op
.
input
(
in_name
):
in_var
=
block
.
var
(
in_var_name
)
if
in_var
.
type
not
in
valid_types
:
continue
if
in_var
.
dtype
==
src_dtype
:
out_var
=
block
.
create_var
(
name
=
in_var
.
name
+
\
'.cast_'
+
_dtype_to_str
(
dest_dtype
),
dtype
=
dest_dtype
,
persistable
=
False
,
stop_gradient
=
False
)
block
.
_insert_op
(
idx
,
type
=
"cast"
,
inputs
=
{
"X"
:
in_var
},
outputs
=
{
"Out"
:
out_var
},
attrs
=
{
"in_dtype"
:
in_var
.
dtype
,
"out_dtype"
:
out_var
.
dtype
})
num_cast_ops
+=
1
_rename_arg
(
op
,
in_var
.
name
,
out_var
.
name
)
else
:
if
op
.
has_attr
(
'in_dtype'
):
op
.
_set_attr
(
'in_dtype'
,
dest_dtype
)
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
for
out_name
in
op
.
output_names
:
for
out_var_name
in
op
.
output
(
out_name
):
out_var
=
block
.
var
(
out_var_name
)
if
out_var
.
type
not
in
valid_types
:
continue
if
out_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
out_var
.
desc
.
set_dtype
(
core
.
VarDesc
.
VarType
.
FP32
)
if
op
.
has_attr
(
'out_dtype'
):
op
.
_set_attr
(
'out_dtype'
,
core
.
VarDesc
.
VarType
.
FP32
)
return
num_cast_ops
def
find_true_prev_op
(
ops
,
var_name
):
for
op
in
ops
:
for
out_name
in
op
.
output_names
:
for
out_var_name
in
op
.
output
(
out_name
):
if
out_var_name
==
var_name
:
return
op
def
rewrite_program
(
main_prog
):
"""
Traverse all ops in current block and insert cast op according to
which set current op belongs to.
1. When an op belongs to the black list, add it to black set
2. When an op belongs to the white list, add it to white set
3. When an op belongs to the gray list. If one
of its inputs is the output of black set op or black list op,
add it to black set. If all of its previous ops are not black
op and one of its inputs is the output of white set op or
white list op, add it to white set.
4. When an op isn't in the lists, add it to black op set.
5. Add necessary cast ops to make sure that black set op will be
computed in fp32 mode, while white set op will be computed in
fp16 mode.
Args:
main_prog (Program): The main program for training.
"""
block
=
main_prog
.
global_block
()
ops
=
block
.
ops
white_op_set
=
set
()
black_op_set
=
set
()
for
i
in
range
(
len
(
ops
)):
op
=
ops
[
i
]
if
op
.
type
in
black_list
:
black_op_set
.
add
(
op
)
elif
op
.
type
in
white_list
:
white_op_set
.
add
(
op
)
elif
op
.
type
in
op
.
type
in
gray_list
:
is_black_op
=
False
is_white_op
=
False
for
in_name
in
op
.
input_names
:
# if this op has inputs
if
in_name
:
for
in_var_name
in
op
.
input
(
in_name
):
in_var
=
block
.
var
(
in_var_name
)
# this in_var isn't the output of other op
if
in_var
.
op
is
None
:
continue
if
in_var
.
op
is
op
:
prev_op
=
find_true_prev_op
(
ops
,
in_var_name
)
else
:
prev_op
=
in_var
.
op
# if it's one of inputs
if
prev_op
in
black_op_set
or
\
prev_op
.
type
in
black_list
:
is_black_op
=
True
if
prev_op
in
white_op_set
or
\
prev_op
.
type
in
white_list
:
is_white_op
=
True
if
is_black_op
:
black_op_set
.
add
(
op
)
elif
is_white_op
:
white_op_set
.
add
(
op
)
else
:
pass
else
:
# For numerical safe, we apply fp32 computation on ops that
# are not determined which list they should stay.
black_op_set
.
add
(
op
)
idx
=
0
while
idx
<
len
(
ops
):
op
=
ops
[
idx
]
num_cast_ops
=
0
if
op
in
black_op_set
:
num_cast_ops
=
_insert_cast_op
(
block
,
op
,
idx
,
core
.
VarDesc
.
VarType
.
FP16
,
core
.
VarDesc
.
VarType
.
FP32
)
elif
op
in
white_op_set
:
num_cast_ops
=
_insert_cast_op
(
block
,
op
,
idx
,
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
FP16
)
else
:
pass
idx
+=
num_cast_ops
+
1
def
update_loss_scaling
(
is_overall_finite
,
prev_loss_scaling
,
num_good_steps
,
def
update_loss_scaling
(
is_overall_finite
,
prev_loss_scaling
,
num_good_steps
,
num_bad_steps
,
incr_every_n_steps
,
num_bad_steps
,
incr_every_n_steps
,
decr_every_n_nan_or_inf
,
incr_ratio
,
decr_ratio
):
decr_every_n_nan_or_inf
,
incr_ratio
,
decr_ratio
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录