Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
f35c8ce6
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f35c8ce6
编写于
9月 23, 2020
作者:
M
mapingshuo
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add custom amp
上级
6c5c547e
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
984 addition
and
5 deletion
+984
-5
python/paddle/distributed/fleet/meta_optimizers/zero/__init__.py
...paddle/distributed/fleet/meta_optimizers/zero/__init__.py
+21
-0
python/paddle/distributed/fleet/meta_optimizers/zero/decorator.py
...addle/distributed/fleet/meta_optimizers/zero/decorator.py
+272
-0
python/paddle/distributed/fleet/meta_optimizers/zero/fp16_lists.py
...ddle/distributed/fleet/meta_optimizers/zero/fp16_lists.py
+284
-0
python/paddle/distributed/fleet/meta_optimizers/zero/fp16_utils.py
...ddle/distributed/fleet/meta_optimizers/zero/fp16_utils.py
+404
-0
python/paddle/distributed/fleet/meta_optimizers/zero_optimizer.py
...addle/distributed/fleet/meta_optimizers/zero_optimizer.py
+3
-5
未找到文件。
python/paddle/distributed/fleet/meta_optimizers/zero/__init__.py
0 → 100644
浏览文件 @
f35c8ce6
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
from
.
import
decorator
from
.decorator
import
*
from
.fp16_lists
import
AutoMixedPrecisionLists
__all__
=
decorator
.
__all__
__all__
+=
fp16_lists
.
__all__
python/paddle/distributed/fleet/meta_optimizers/zero/decorator.py
0 → 100644
浏览文件 @
f35c8ce6
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.fluid
import
default_main_program
from
paddle.fluid
import
default_startup_program
from
paddle.fluid
import
layers
from
paddle.fluid
import
unique_name
from
paddle.fluid
import
framework
from
.
import
fp16_utils
from
.fp16_utils
import
update_loss_scaling
,
rewrite_program
from
.fp16_utils
import
update_role_var_grad
from
.fp16_lists
import
AutoMixedPrecisionLists
__all__
=
[
"decorate"
]
class
OptimizerWithMixedPrecision
(
object
):
"""
Optimizer with mixed-precision (MP) training. This is a wrapper of a common
optimizer, plus the support of mixed-precision pre-training. The object
of this class almost has the same behavior as the common optimizer, with the
methods `minimize()`, `backward()`, `apply_gradients()` implemented.
Additionally, it enables the MP training automatically, i.e, the creation
and maintenance of master parameters, scaling of loss, etc.
Args:
optimizer (Optimizer): A common Optimizer object.
amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object.
init_loss_scaling (float): The initial loss scaling factor.
use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling.
incr_every_n_steps(int): Increases loss scaling every n consecutive
steps with finite gradients.
decr_every_n_nan_or_inf(int): Decreases loss scaling every n
accumulated steps with nan or
inf gradients.
incr_ratio(float): The multiplier to use when increasing the loss
scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing
the loss scaling.
"""
def
__init__
(
self
,
optimizer
,
amp_lists
,
init_loss_scaling
,
use_dynamic_loss_scaling
,
incr_every_n_steps
,
decr_every_n_nan_or_inf
,
incr_ratio
,
decr_ratio
):
self
.
_optimizer
=
optimizer
self
.
_amp_lists
=
amp_lists
self
.
_param_grads
=
None
self
.
_train_program
=
default_main_program
()
self
.
_startup_prog
=
default_startup_program
()
self
.
_scaled_loss
=
None
self
.
_loss_scaling
=
layers
.
create_global_var
(
name
=
unique_name
.
generate
(
"loss_scaling"
),
shape
=
[
1
],
value
=
init_loss_scaling
,
dtype
=
'float32'
,
persistable
=
True
)
self
.
_use_dynamic_loss_scaling
=
use_dynamic_loss_scaling
if
self
.
_use_dynamic_loss_scaling
:
self
.
_incr_every_n_steps
=
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
'int32'
,
value
=
incr_every_n_steps
)
self
.
_decr_every_n_nan_or_inf
=
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
'int32'
,
value
=
decr_every_n_nan_or_inf
)
self
.
_incr_ratio
=
incr_ratio
self
.
_decr_ratio
=
decr_ratio
self
.
_num_good_steps
=
layers
.
create_global_var
(
name
=
unique_name
.
generate
(
"num_good_steps"
),
shape
=
[
1
],
value
=
0
,
dtype
=
'int32'
,
persistable
=
True
)
self
.
_num_bad_steps
=
layers
.
create_global_var
(
name
=
unique_name
.
generate
(
"num_bad_steps"
),
shape
=
[
1
],
value
=
0
,
dtype
=
'int32'
,
persistable
=
True
)
# Ensure the data type of learning rate vars is float32 (same as the
# master parameter dtype)
if
isinstance
(
optimizer
.
_learning_rate
,
float
):
optimizer
.
_learning_rate_map
[
default_main_program
()]
=
\
layers
.
create_global_var
(
name
=
unique_name
.
generate
(
"learning_rate"
),
shape
=
[
1
],
value
=
float
(
optimizer
.
_learning_rate
),
dtype
=
'float32'
,
persistable
=
True
)
def
get_loss_scaling
(
self
):
"""Return the real-time loss scaling factor.
"""
return
self
.
_loss_scaling
def
get_scaled_loss
(
self
):
"""Return the scaled loss.
It's useful when you feed customed loss into executor.
"""
return
self
.
_scaled_loss
def
backward
(
self
,
loss
,
startup_program
=
None
,
parameter_list
=
None
,
no_grad_set
=
None
,
callbacks
=
None
):
"""
Backward propagation or auto differentiation for gradients' computation.
Args:
loss (Variable): The loss Variable to minimize.
startup_program (Program|None): The startup Program for initializing
parameters in `parameter_list`.
parameter_list (list|None): A list of Variables to update.
no_grad_set (set|None): A set of Variables should be ignored.
callbacks (list|None): A list of callable objects to run when appending
backward operator for one parameter.
Returns:
A list of (param, grad), which is a tuple of a parameter and its
gradient respectively, and the scaled loss.
"""
rewrite_program
(
self
.
_train_program
,
self
.
_amp_lists
)
with
framework
.
name_scope
(
'mixed_precision'
):
self
.
_scaled_loss
=
loss
*
self
.
_loss_scaling
self
.
_scaled_loss
=
loss
*
self
.
_loss_scaling
self
.
_params_grads
=
self
.
_optimizer
.
backward
(
self
.
_scaled_loss
,
startup_program
,
parameter_list
,
no_grad_set
,
callbacks
)
# Change the op_role_var attr for some ops, so that gradients
# transferred across GPUs can be FP16.
update_role_var_grad
(
self
.
_train_program
,
self
.
_params_grads
)
scaled_params_grads
=
[]
with
framework
.
name_scope
(
'mixed_precision'
):
for
p
,
g
in
self
.
_params_grads
:
with
self
.
_train_program
.
_optimized_guard
([
p
,
g
]):
scaled_g
=
g
/
self
.
_loss_scaling
scaled_params_grads
.
append
([
p
,
scaled_g
])
return
scaled_params_grads
def
apply_gradients
(
self
,
scaled_params_grads
):
"""
Check scaled gradients to determine whether to update loss scaling and update
parameters by their scaled gradients,
Args:
scaled_params_grads (list): A list of params and scaled grads.
Returns:
A list of optimize operators.
"""
if
self
.
_use_dynamic_loss_scaling
and
len
(
scaled_params_grads
)
>
0
:
with
framework
.
name_scope
(
'mixed_precision'
):
with
self
.
_train_program
.
_optimized_guard
(
scaled_params_grads
[
0
]):
grads
=
[
layers
.
reduce_sum
(
g
)
for
[
_
,
g
]
in
scaled_params_grads
]
all_grads_sum
=
layers
.
sums
(
grads
)
is_overall_finite
=
layers
.
isfinite
(
all_grads_sum
)
update_loss_scaling
(
is_overall_finite
,
self
.
_loss_scaling
,
self
.
_num_good_steps
,
self
.
_num_bad_steps
,
self
.
_incr_every_n_steps
,
self
.
_decr_every_n_nan_or_inf
,
self
.
_incr_ratio
,
self
.
_decr_ratio
)
# apply_gradient append all ops in global block, thus we shouldn't
# apply gradient in the switch branch.
with
layers
.
Switch
()
as
switch
:
with
switch
.
case
(
is_overall_finite
):
pass
with
switch
.
default
():
for
_
,
g
in
scaled_params_grads
:
layers
.
assign
(
layers
.
zeros_like
(
g
),
g
)
optimize_ops
=
self
.
_optimizer
.
apply_gradients
(
scaled_params_grads
)
return
optimize_ops
def
minimize
(
self
,
loss
,
startup_program
=
None
,
parameter_list
=
None
,
no_grad_set
=
None
):
"""
Perform optimization by minimizing the given loss.
Args:
loss (Variable): The loss Variable.
startup_program (Program): startup_program for initializing parameters
in `parameter_list`.
parameter_list (list): list of Variables to update.
no_grad_set (set|None): set of Variables should be ignored.
Returns:
The scaled loss by scaling factor, the list of optimize ops, and a
list of scaled parameters and gradients.
"""
scaled_params_grads
=
self
.
backward
(
loss
,
startup_program
=
startup_program
,
parameter_list
=
parameter_list
,
no_grad_set
=
no_grad_set
)
optimize_ops
=
self
.
apply_gradients
(
scaled_params_grads
)
return
optimize_ops
,
scaled_params_grads
def
decorate
(
optimizer
,
amp_lists
=
None
,
init_loss_scaling
=
2
**
15
,
incr_every_n_steps
=
1000
,
decr_every_n_nan_or_inf
=
2
,
incr_ratio
=
2.0
,
decr_ratio
=
0.8
,
use_dynamic_loss_scaling
=
True
):
"""
Decorate the given optimizer to adapt to the mixed-precision training.
Args:
optimizer(Optimizer): A common Optimizer.
amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object.
init_loss_scaling(float): The initial loss scaling factor.
incr_every_n_steps(int): Increases loss scaling every n consecutive
steps with finite gradients.
decr_every_n_nan_or_inf(int): Decreases loss scaling every n
accumulated steps with nan or
inf gradients.
incr_ratio(float): The multiplier to use when increasing the loss
scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing
the loss scaling.
use_dynamic_loss_scaling(bool): Whether to use dynamic loss scaling.
Returns:
An optimizer acting like a normal one but with mixed-precision training
enabled.
Examples:
.. code-block:: python
loss = network()
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
mp_optimizer = fluid.contrib.mixed_precision.decorate(
optimizer=optimizer, init_loss_scaling=8.0)
ops, param_grads = mp_optimizer.minimize(loss)
scaled_loss = mp_optimizer.get_scaled_loss()
"""
if
amp_lists
is
None
:
amp_lists
=
AutoMixedPrecisionLists
()
mp_optimizer
=
OptimizerWithMixedPrecision
(
optimizer
,
amp_lists
,
init_loss_scaling
,
use_dynamic_loss_scaling
,
incr_every_n_steps
,
decr_every_n_nan_or_inf
,
incr_ratio
,
decr_ratio
)
return
mp_optimizer
python/paddle/distributed/fleet/meta_optimizers/zero/fp16_lists.py
0 → 100644
浏览文件 @
f35c8ce6
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
__all__
=
[
"AutoMixedPrecisionLists"
]
class
AutoMixedPrecisionLists
(
object
):
"""
AutoMixedPrecisionLists is a class for black/white list. It can update
pre-defined black list and white list according to users' custom black
white lists. The lists are used for an algorithm which determines op's
execution mode (fp32 or fp16).
Args:
custom_white_list (set): Users' custom white list.
custom_black_list (set): Users' custom black list.
"""
def
__init__
(
self
,
custom_white_list
=
None
,
custom_black_list
=
None
,
custom_black_varnames
=
None
):
self
.
_custom_white_list
=
custom_white_list
self
.
_custom_black_list
=
custom_black_list
self
.
white_list
=
copy
.
copy
(
white_list
)
self
.
black_list
=
copy
.
copy
(
black_list
)
self
.
gray_list
=
copy
.
copy
(
gray_list
)
self
.
black_varnames
=
copy
.
copy
(
custom_black_varnames
)
self
.
_update_list
()
def
_update_list
(
self
):
"""
Update black and white list according to users' custom list.
"""
if
self
.
_custom_white_list
and
self
.
_custom_black_list
:
for
op_name
in
self
.
_custom_white_list
:
if
op_name
in
self
.
_custom_black_list
:
raise
ValueError
(
"Custom white list overlap "
"custom black list"
)
if
self
.
_custom_white_list
:
for
op_name
in
self
.
_custom_white_list
:
if
op_name
in
self
.
black_list
:
self
.
black_list
.
remove
(
op_name
)
elif
op_name
in
self
.
gray_list
:
self
.
gray_list
.
remove
(
op_name
)
self
.
white_list
.
add
(
op_name
)
if
self
.
_custom_black_list
:
for
op_name
in
self
.
_custom_black_list
:
if
op_name
in
self
.
white_list
:
self
.
white_list
.
remove
(
op_name
)
elif
op_name
in
self
.
gray_list
:
self
.
gray_list
.
remove
(
op_name
)
self
.
black_list
.
add
(
op_name
)
# The three sets listed below are changed dynamiclly. They don't contain all
# paddle ops currently.
# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
white_list
=
{
'conv2d'
,
'matmul'
,
'mul'
,
}
# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
black_list
=
{
'exp'
,
'square'
,
'log'
,
'mean'
,
'sum'
,
'cos_sim'
,
'softmax'
,
'softmax_with_cross_entropy'
,
'sigmoid_cross_entropy_with_logits'
,
'cross_entropy'
,
'cross_entropy2'
,
}
# This set contains two types of ops. All ops supported fp16 calculation. One
# of two types is considered numerically-safe, but may be made unsafe by an
# upstream blacklist op. Another type do not have numerically-significant
# effects, like stack, flatten2.
gray_list
=
{
'elementwise_add'
,
'elementwise_sub'
,
'elementwise_mul'
,
'elementwise_div'
,
'elementwise_max'
,
'elementwise_min'
,
'elementwise_pow'
,
'elementwise_mod'
,
'elementwise_floordiv'
,
'batch_norm'
,
'tanh'
,
'sigmoid'
,
'lookup_table'
,
'top_k'
,
'pool2d'
,
'pool3d'
,
'dropout'
,
'relu'
,
'relu6'
,
'leaky_relu'
,
'soft_relu'
,
'flatten2'
,
'stack'
,
'unstack'
,
'uniform_random_batch_size_like'
,
'gaussian_random'
,
'gaussian_random_batch_size_like'
,
'slice'
,
'rank'
,
'scale'
,
'transpose2'
,
'reshape2'
,
'gather'
,
'fill_constant'
,
'get_tensor_from_selected_rows'
,
'sign'
,
'cast'
,
}
'''
# The set of ops that don't support fp16 calculation
unsupported_fp16_list = {
# from python/paddle/fluid/layers/io.py
'send',
'send_barrier',
'recv',
'fetch_barrier',
'create_py_reader',
'create_double_buffer_reader',
'read',
'load',
# from python/paddle/fluid/control_flow.py
'increment',
'less_than',
'less_equal',
'greater_than',
'greater_equal',
'equal',
'not_equal',
'read_from_array',
'shrink_rnn_memory',
'lod_array_length',
'logical_and',
'logical_or',
'logical_xor',
'logical_not',
'print',
'conditional_block',
'while',
'ifelse',
'is_empty',
'lstm',
'cudnn_lstm',
'lstmp',
'gru',
'gru_unit',
'linear_chain_crf',
'crf_decoding',
'bpr_loss',
'chunk_eval',
'sequence_conv',
'sequence_softmax',
# Depthwise conv2d isn't fast and safe currently.
# ref: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h#L79
'depthwise_conv2d',
# Tensor Core kernels are not available for 3D convolutions currently.
'conv3d',
'sequence_pool',
'sequence_concat',
'sequence_slice',
'data_norm',
'layer_norm',
'group_norm',
'spectral_norm',
'depthwise_conv2d_transpose',
'sequence_expand',
'conv_transposed2d',
'conv_transposed3d',
'sequence_expand_as',
'sequence_pad',
'sequence_unpad',
'sequence_erase',
'beam_search',
'beam_search_decode',
'lstm_unit',
'reduce_sum',
'reduce_mean',
'reduce_max',
'reduce_min',
'reduce_prod',
'reduce_all',
'reduce_any',
'split',
'edit_distance',
'ctc_align',
'warpctc',
'sequence_reshape',
'nce',
'hierarchical_sigmoid',
'im2sequence',
'row_conv',
'multiplex',
'sample_logits',
'one_hot',
'smooth_l1_loss',
'squeeze2',
'unsqueeze2',
'lod_reset',
'lrn',
'pad',
'pad_constant_like',
'label_smooth',
'scatter',
'sequence_scatter',
'random_crop',
'mean_iou',
'selu',
'crop',
'affine_grid',
'rank_loss',
'margin_rank_loss',
'pad2d',
'elu',
'pow',
'stanh',
'hard_sigmoid',
'swish',
'prelu',
'brelu',
'sequence_enumerate',
'sequence_mask',
'expand',
'sampling_id',
'maxout',
'space_to_depth',
'sequence_reverse',
'similarity_focus',
'hash',
'grid_sampler',
'log_loss',
'teacher_student_sigmoid_loss',
'add_position_encoding',
'bilinear_tensor_product',
'shuffle_channel',
'temporal_shift',
'psroi_pool',
'huber_loss',
'kldiv_loss',
'tree_conv',
'pixel_shuffle',
'fsp',
'cvm',
'affine_channel',
'roi_pool',
'roi_align',
'anchor_generator',
'generate_proposals',
'generate_proposal_labels',
'generate_mask_labels',
}
'''
python/paddle/distributed/fleet/meta_optimizers/zero/fp16_utils.py
0 → 100644
浏览文件 @
f35c8ce6
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
from
paddle.fluid
import
core
from
paddle.fluid
import
layers
def
_rename_arg
(
op
,
old_name
,
new_name
):
"""
If an op has old_name input and output, rename these input
args new_name.
Args:
op (Operator): Current operator.
old_name (str): The old name of input args.
new_name (str): The new name of input args.
"""
op_desc
=
op
.
desc
if
isinstance
(
op_desc
,
tuple
):
op_desc
=
op_desc
[
0
]
op_desc
.
_rename_input
(
old_name
,
new_name
)
op_desc
.
_rename_output
(
old_name
,
new_name
)
def
_dtype_to_str
(
dtype
):
"""
Convert specific variable type to its corresponding string.
Args:
dtype (VarType): Variable type.
"""
if
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
return
'fp16'
else
:
return
'fp32'
def
_insert_cast_op
(
block
,
op
,
idx
,
src_dtype
,
dest_dtype
):
"""
Insert cast op and rename args of input and output.
Args:
block (Program): The block in which the operator is.
op (Operator): The operator to insert cast op.
idx (int): The index of current operator.
src_dtype (VarType): The input variable dtype of cast op.
dest_dtype (VarType): The output variable dtype of cast op.
Returns:
num_cast_op (int): The number of cast ops that have been inserted.
"""
num_cast_ops
=
0
valid_types
=
[
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
,
core
.
VarDesc
.
VarType
.
LOD_TENSOR_ARRAY
]
for
in_name
in
op
.
input_names
:
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
op
.
type
==
'batch_norm'
:
if
in_name
!=
'X'
:
continue
for
in_var_name
in
op
.
input
(
in_name
):
in_var
=
block
.
var
(
in_var_name
)
if
in_var
.
type
not
in
valid_types
:
continue
if
in_var
.
dtype
==
src_dtype
:
cast_name
=
in_var
.
name
+
'.cast_'
+
_dtype_to_str
(
dest_dtype
)
out_var
=
block
.
vars
.
get
(
cast_name
)
if
out_var
is
None
or
out_var
.
dtype
!=
dest_dtype
:
out_var
=
block
.
create_var
(
name
=
cast_name
,
dtype
=
dest_dtype
,
persistable
=
False
,
stop_gradient
=
False
)
block
.
_insert_op
(
idx
,
type
=
"cast"
,
inputs
=
{
"X"
:
in_var
},
outputs
=
{
"Out"
:
out_var
},
attrs
=
{
"in_dtype"
:
in_var
.
dtype
,
"out_dtype"
:
out_var
.
dtype
})
num_cast_ops
+=
1
_rename_arg
(
op
,
in_var
.
name
,
out_var
.
name
)
else
:
if
op
.
has_attr
(
'in_dtype'
):
op
.
_set_attr
(
'in_dtype'
,
dest_dtype
)
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
for
out_name
in
op
.
output_names
:
if
op
.
type
==
'batch_norm'
and
out_name
!=
'Y'
:
continue
for
out_var_name
in
op
.
output
(
out_name
):
out_var
=
block
.
var
(
out_var_name
)
if
out_var
.
type
not
in
valid_types
:
continue
if
out_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
out_var
.
desc
.
set_dtype
(
core
.
VarDesc
.
VarType
.
FP16
)
if
op
.
has_attr
(
'out_dtype'
):
op
.
_set_attr
(
'out_dtype'
,
core
.
VarDesc
.
VarType
.
FP16
)
return
num_cast_ops
def
find_true_prev_op
(
ops
,
cur_op
,
var_name
):
"""
Find the true prev op that outputs var_name variable.
Args:
ops (list): A list of ops.
cur_op (Operator): Current operator which has var_name variable.
var_name (string): Variable name.
"""
prev_op
=
[]
for
op
in
ops
:
if
op
==
cur_op
:
break
for
out_name
in
op
.
output_names
:
for
out_var_name
in
op
.
output
(
out_name
):
if
out_var_name
==
var_name
:
prev_op
.
append
(
op
)
if
prev_op
:
if
not
len
(
prev_op
)
==
1
:
raise
ValueError
(
"There must be only one previous op "
"that outputs {0} variable"
.
format
(
var_name
))
else
:
return
prev_op
[
0
]
return
None
def
find_true_post_op
(
ops
,
cur_op
,
var_name
):
"""
if there are post ops, return them, if there is no post op,
return None instead.
Args:
ops (list): A list of ops.
cur_op (Operator): Current operator which has var_name variable.
var_name (string): Variable name.
"""
post_op
=
[]
for
idx
,
op
in
enumerate
(
ops
):
if
op
==
cur_op
:
break
for
i
in
range
(
idx
+
1
,
len
(
ops
)):
op
=
ops
[
i
]
for
in_name
in
op
.
input_names
:
for
in_var_name
in
op
.
input
(
in_name
):
if
in_var_name
==
var_name
:
post_op
.
append
(
op
)
if
post_op
!=
[]:
return
post_op
return
None
def
find_op_index
(
block_desc
,
cur_op_desc
):
"""
"""
for
idx
in
range
(
block_desc
.
op_size
()):
if
cur_op_desc
==
block_desc
.
op
(
idx
):
return
idx
return
-
1
def
_is_in_black_varnames
(
op
,
amp_lists
):
for
in_name
in
op
.
input_arg_names
:
if
in_name
in
amp_lists
.
black_varnames
:
return
True
for
out_name
in
op
.
output_arg_names
:
if
out_name
in
amp_lists
.
black_varnames
:
return
True
return
False
def
rewrite_program
(
main_prog
,
amp_lists
):
"""
Traverse all ops in current block and insert cast op according to
which set current op belongs to.
1. When an op belongs to the black list, add it to black set
2. When an op belongs to the white list, add it to white set
3. When an op belongs to the gray list. If one
of its inputs is the output of black set op or black list op,
add it to black set. If all of its previous ops are not black
op and one of its inputs is the output of white set op or
white list op, add it to white set.
4. When an op isn't in the lists, add it to black op set.
5. Add necessary cast ops to make sure that black set op will be
computed in fp32 mode, while white set op will be computed in
fp16 mode.
Args:
main_prog (Program): The main program for training.
"""
block
=
main_prog
.
global_block
()
ops
=
block
.
ops
white_op_set
=
set
()
black_op_set
=
set
()
for
op
in
ops
:
if
amp_lists
.
black_varnames
is
not
None
and
_is_in_black_varnames
(
op
,
amp_lists
):
black_op_set
.
add
(
op
)
continue
if
op
.
type
in
amp_lists
.
black_list
:
black_op_set
.
add
(
op
)
elif
op
.
type
in
amp_lists
.
white_list
:
white_op_set
.
add
(
op
)
elif
op
.
type
in
amp_lists
.
gray_list
:
is_black_op
=
False
is_white_op
=
False
for
in_name
in
op
.
input_names
:
# if this op has inputs
if
in_name
:
for
in_var_name
in
op
.
input
(
in_name
):
in_var
=
block
.
var
(
in_var_name
)
# this in_var isn't the output of other op
if
in_var
.
op
is
None
:
continue
elif
in_var
.
op
is
op
:
prev_op
=
find_true_prev_op
(
ops
,
op
,
in_var_name
)
if
prev_op
is
None
:
continue
else
:
prev_op
=
in_var
.
op
# if it's one of inputs
if
prev_op
in
black_op_set
or
\
prev_op
.
type
in
amp_lists
.
black_list
:
is_black_op
=
True
elif
prev_op
in
white_op_set
or
\
prev_op
.
type
in
amp_lists
.
white_list
:
is_white_op
=
True
if
is_black_op
:
black_op_set
.
add
(
op
)
elif
is_white_op
:
white_op_set
.
add
(
op
)
else
:
pass
else
:
# For numerical safe, we apply fp32 computation on ops that
# are not determined which list they should stay.
black_op_set
.
add
(
op
)
idx
=
0
while
idx
<
len
(
ops
):
op
=
ops
[
idx
]
num_cast_ops
=
0
if
op
in
black_op_set
:
num_cast_ops
=
_insert_cast_op
(
block
,
op
,
idx
,
core
.
VarDesc
.
VarType
.
FP16
,
core
.
VarDesc
.
VarType
.
FP32
)
elif
op
in
white_op_set
:
num_cast_ops
=
_insert_cast_op
(
block
,
op
,
idx
,
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
FP16
)
else
:
pass
idx
+=
num_cast_ops
+
1
def
update_role_var_grad
(
main_prog
,
params_grads
):
"""
Update op_role_var attr for some ops to make sure the gradients
transferred across GPUs is FP16.
1. Check whether the op that outputs gradient is cast or not.
2. If op is cast and gradient is FP32, remove the op_role_var
and find the prev op which outputs FP16 gradient
3. Update the op_role_var of the prev op.
Args:
main_prog (Program): The main program for training.
params_grads (list): A list of params and grads.
"""
block
=
main_prog
.
global_block
()
BACKWARD
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Backward
OPTIMIZE
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Optimize
for
p
,
g
in
params_grads
:
op
=
g
.
op
if
g
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
op
.
type
==
'cast'
:
role
=
op
.
attr
(
'op_role'
)
if
role
&
int
(
BACKWARD
)
and
op
.
has_attr
(
'op_role_var'
):
op
.
desc
.
remove_attr
(
"op_role_var"
)
else
:
raise
ValueError
(
"The cast op {0} must be in BACKWARD role "
"and have op_role_var attr."
.
format
(
op
))
fp16_grad_name
=
op
.
input
(
op
.
input_names
[
0
])[
0
]
op_for_fp16_grad
=
find_true_prev_op
(
block
.
ops
,
op
,
fp16_grad_name
)
op_role_var_attr_name
=
\
core
.
op_proto_and_checker_maker
.
kOpRoleVarAttrName
()
attr_val
=
[
p
.
name
,
fp16_grad_name
]
if
op_for_fp16_grad
.
has_attr
(
op_role_var_attr_name
):
attr_val
.
extend
(
op_for_fp16_grad
.
attr
(
op_role_var_attr_name
))
op_for_fp16_grad
.
_set_attr
(
op_role_var_attr_name
,
attr_val
)
# Maximize the all_reduce overlap, and perform the cast
# operation after gradients transfer.
op
.
_set_attr
(
'op_role'
,
OPTIMIZE
)
# optimize op should stay behind forward and backward ops
if
op
==
block
.
ops
[
-
1
]:
continue
post_ops
=
find_true_post_op
(
block
.
ops
,
op
,
g
.
name
)
if
post_ops
is
not
None
:
raise
ValueError
(
"The cast op {0}'s output should not be"
"used by a non-optimize op, however, it"
"is used by {1}"
.
format
(
op
,
post_ops
[
0
]))
new_op_desc
=
block
.
desc
.
append_op
()
new_op_desc
.
copy_from
(
op
.
desc
)
op_idx
=
find_op_index
(
block
.
desc
,
op
.
desc
)
if
op_idx
==
-
1
:
raise
ValueError
(
"The op {0} is not in program"
.
format
(
op
))
block
.
desc
.
_remove_op
(
op_idx
,
op_idx
+
1
)
block
.
_sync_with_cpp
()
def
update_loss_scaling
(
is_overall_finite
,
prev_loss_scaling
,
num_good_steps
,
num_bad_steps
,
incr_every_n_steps
,
decr_every_n_nan_or_inf
,
incr_ratio
,
decr_ratio
):
"""
Update loss scaling according to overall gradients. If all gradients is
finite after incr_every_n_steps, loss scaling will increase by incr_ratio.
Otherwise, loss scaling will decrease by decr_ratio after
decr_every_n_nan_or_inf steps and each step some gradients are infinite.
Args:
is_overall_finite (Variable): A boolean variable indicates whether
all gradients are finite.
prev_loss_scaling (Variable): Previous loss scaling.
num_good_steps (Variable): A variable accumulates good steps in which
all gradients are finite.
num_bad_steps (Variable): A variable accumulates bad steps in which
some gradients are infinite.
incr_every_n_steps (Variable): A variable represents increasing loss
scaling every n consecutive steps with
finite gradients.
decr_every_n_nan_or_inf (Variable): A variable represents decreasing
loss scaling every n accumulated
steps with nan or inf gradients.
incr_ratio(float): The multiplier to use when increasing the loss
scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing
loss scaling.
"""
zero_steps
=
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
'int32'
,
value
=
0
)
with
layers
.
Switch
()
as
switch
:
with
switch
.
case
(
is_overall_finite
):
should_incr_loss_scaling
=
layers
.
less_than
(
incr_every_n_steps
,
num_good_steps
+
1
)
with
layers
.
Switch
()
as
switch1
:
with
switch1
.
case
(
should_incr_loss_scaling
):
new_loss_scaling
=
prev_loss_scaling
*
incr_ratio
loss_scaling_is_finite
=
layers
.
isfinite
(
new_loss_scaling
)
with
layers
.
Switch
()
as
switch2
:
with
switch2
.
case
(
loss_scaling_is_finite
):
layers
.
assign
(
new_loss_scaling
,
prev_loss_scaling
)
with
switch2
.
default
():
pass
layers
.
assign
(
zero_steps
,
num_good_steps
)
layers
.
assign
(
zero_steps
,
num_bad_steps
)
with
switch1
.
default
():
layers
.
increment
(
num_good_steps
)
layers
.
assign
(
zero_steps
,
num_bad_steps
)
with
switch
.
default
():
should_decr_loss_scaling
=
layers
.
less_than
(
decr_every_n_nan_or_inf
,
num_bad_steps
+
1
)
with
layers
.
Switch
()
as
switch3
:
with
switch3
.
case
(
should_decr_loss_scaling
):
new_loss_scaling
=
prev_loss_scaling
*
decr_ratio
static_loss_scaling
=
\
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
'float32'
,
value
=
1.0
)
less_than_one
=
layers
.
less_than
(
new_loss_scaling
,
static_loss_scaling
)
with
layers
.
Switch
()
as
switch4
:
with
switch4
.
case
(
less_than_one
):
layers
.
assign
(
static_loss_scaling
,
prev_loss_scaling
)
with
switch4
.
default
():
layers
.
assign
(
new_loss_scaling
,
prev_loss_scaling
)
layers
.
assign
(
zero_steps
,
num_good_steps
)
layers
.
assign
(
zero_steps
,
num_bad_steps
)
with
switch3
.
default
():
layers
.
assign
(
zero_steps
,
num_good_steps
)
layers
.
increment
(
num_bad_steps
)
python/paddle/distributed/fleet/meta_optimizers/zero_optimizer.py
浏览文件 @
f35c8ce6
...
...
@@ -16,7 +16,7 @@ from .common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY, CollectiveHelper
from
.common
import
is_update_op
,
is_loss_grad_op
,
is_backward_op
,
is_optimizer_op
from
.meta_optimizer_base
import
MetaOptimizerBase
from
paddle.fluid
import
unique_name
,
core
from
paddle.fluid.contrib.mixed_precision.decorator
import
OptimizerWithMixedPrecision
from
zero.decorator
import
decorate
as
amp_decorate
import
paddle.fluid
as
fluid
import
math
...
...
@@ -813,8 +813,7 @@ class ZeroOptimizer(MetaOptimizerBase):
optimizer
.
_set_checkpoints
(
ckpts
)
if
self
.
user_defined_strategy
.
zero_configs
[
"amp"
]:
optimizer
=
fluid
.
contrib
.
mixed_precision
.
decorate
(
optimizer
,
use_dynamic_loss_scaling
=
True
)
optimizer
=
amp_decorate
(
optimizer
,
use_dynamic_loss_scaling
=
True
)
self
.
_nrings
=
self
.
user_defined_strategy
.
zero_configs
[
"nrings"
]
self
.
_fuse_broadcast_MB_bytes
=
self
.
user_defined_strategy
.
zero_configs
[
...
...
@@ -1184,8 +1183,7 @@ class ZeroOptimizer(MetaOptimizerBase):
optimizer
=
self
.
inner_opt
if
self
.
user_defined_strategy
.
zero_configs
[
"amp"
]:
optimizer
=
fluid
.
contrib
.
mixed_precision
.
decorate
(
optimizer
,
use_dynamic_loss_scaling
=
True
)
optimizer
=
amp_decorate
(
optimizer
,
use_dynamic_loss_scaling
=
True
)
optimize_ops
,
params_grads
=
optimizer
.
minimize
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录