Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
e3334f3e
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e3334f3e
编写于
9月 23, 2020
作者:
M
mapingshuo
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add zero
上级
43240a1b
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
1323 addition
and
22 deletion
+1323
-22
paddle/fluid/framework/distributed_strategy.proto
paddle/fluid/framework/distributed_strategy.proto
+10
-0
paddle/fluid/operators/collective/c_sync_comm_stream_op.cc
paddle/fluid/operators/collective/c_sync_comm_stream_op.cc
+4
-2
python/paddle/distributed/fleet/base/distributed_strategy.py
python/paddle/distributed/fleet/base/distributed_strategy.py
+33
-0
python/paddle/distributed/fleet/base/fleet_base.py
python/paddle/distributed/fleet/base/fleet_base.py
+3
-0
python/paddle/distributed/fleet/meta_optimizers/__init__.py
python/paddle/distributed/fleet/meta_optimizers/__init__.py
+1
-0
python/paddle/distributed/fleet/meta_optimizers/zero_optimizer.py
...addle/distributed/fleet/meta_optimizers/zero_optimizer.py
+1245
-0
python/paddle/fluid/clip.py
python/paddle/fluid/clip.py
+1
-1
python/paddle/fluid/contrib/mixed_precision/decorator.py
python/paddle/fluid/contrib/mixed_precision/decorator.py
+18
-14
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+8
-5
未找到文件。
paddle/fluid/framework/distributed_strategy.proto
浏览文件 @
e3334f3e
...
...
@@ -24,6 +24,14 @@ enum Mode {
message
RecomputeConfig
{
repeated
string
checkpoints
=
1
;
}
message
ZeROConfig
{
optional
bool
amp
=
1
[
default
=
true
];
optional
int32
nrings
=
2
[
default
=
3
];
optional
float
fuse_broadcast_MB_bytes
=
3
[
default
=
64.0
];
repeated
string
checkpoints
=
4
;
optional
bool
allreduce
=
5
[
default
=
false
];
}
message
AMPConfig
{
optional
float
init_loss_scaling
=
1
[
default
=
32768.0
];
optional
int32
incr_every_n_steps
=
2
[
default
=
1000
];
...
...
@@ -127,6 +135,7 @@ message DistributedStrategy {
optional
int32
conv_workspace_size_limit
=
22
[
default
=
4000
];
optional
bool
cudnn_batchnorm_spatial_persistent
=
23
[
default
=
true
];
optional
bool
adaptive_localsgd
=
24
[
default
=
false
];
optional
bool
zero
=
25
[
default
=
false
];
optional
RecomputeConfig
recompute_configs
=
101
;
optional
AMPConfig
amp_configs
=
102
;
...
...
@@ -138,6 +147,7 @@ message DistributedStrategy {
optional
LarsConfig
lars_configs
=
108
;
optional
LambConfig
lamb_configs
=
109
;
optional
AdaptiveLocalSGDConfig
adaptive_localsgd_configs
=
110
;
optional
ZeROConfig
zero_configs
=
111
;
optional
BuildStrategy
build_strategy
=
201
;
optional
ExecutionStrategy
execution_strategy
=
202
;
}
...
...
paddle/fluid/operators/collective/c_sync_comm_stream_op.cc
浏览文件 @
e3334f3e
...
...
@@ -55,8 +55,10 @@ class CSyncCommStreamOp : public framework::OperatorBase {
class
CSyncCommStreamOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
{
AddInput
(
"X"
,
"(Tensor) Dependency of the variable need to sync"
);
AddOutput
(
"Out"
,
"(Tensor) Dependency of the variable need to sync"
);
AddInput
(
"X"
,
"(Tensor) Dependency of the variable need to sync"
)
.
AsDuplicable
();
AddOutput
(
"Out"
,
"(Tensor) Dependency of the variable need to sync"
)
.
AsDuplicable
();
AddAttr
<
int
>
(
"ring_id"
,
"(int default 0) ring id."
).
SetDefault
(
0
);
AddComment
(
R"DOC(
CSyncCommStream Operator
...
...
python/paddle/distributed/fleet/base/distributed_strategy.py
浏览文件 @
e3334f3e
...
...
@@ -611,6 +611,39 @@ class DistributedStrategy(object):
"checkpoint_configs"
)
assign_configs_value
(
self
.
strategy
.
recompute_configs
,
configs
)
@
property
def
zero
(
self
):
"""
Indicating whether we are using Zero Redundancy Optimizer for memory
optimization
Default value: False
Examples:
.. code-block:: python
import paddle.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.zero = True
"""
return
self
.
strategy
.
zero
@
zero
.
setter
def
zero
(
self
,
flag
):
if
isinstance
(
flag
,
bool
):
self
.
strategy
.
zero
=
flag
else
:
print
(
"WARNING: zero should have value of bool type"
)
@
property
def
zero_configs
(
self
):
"""
Set zero configurations.
"""
return
get_msg_dict
(
self
.
strategy
.
zero_configs
)
@
zero_configs
.
setter
def
zero_configs
(
self
,
configs
):
check_configs_key
(
self
.
strategy
.
zero_configs
,
configs
,
"zero_configs"
)
assign_configs_value
(
self
.
strategy
.
zero_configs
,
configs
)
@
property
def
pipeline
(
self
):
"""
...
...
python/paddle/distributed/fleet/base/fleet_base.py
浏览文件 @
e3334f3e
...
...
@@ -1086,6 +1086,9 @@ class Fleet(object):
context
[
"program_optimize_ops"
]
=
optimize_ops
context
[
"program_params_grads"
]
=
params_grads
if
self
.
user_defined_strategy
.
zero
:
graph_optimizer
=
None
if
graph_optimizer
:
optimize_ops
,
params_grads
=
graph_optimizer
.
minimize
(
loss
,
...
...
python/paddle/distributed/fleet/meta_optimizers/__init__.py
浏览文件 @
e3334f3e
...
...
@@ -23,3 +23,4 @@ from .lars_optimizer import LarsOptimizer
from
.parameter_server_graph_optimizer
import
ParameterServerGraphOptimizer
from
.dgc_optimizer
import
DGCOptimizer
from
.lamb_optimizer
import
LambOptimizer
from
.zero_optimizer
import
ZeroOptimizer
python/paddle/distributed/fleet/meta_optimizers/zero_optimizer.py
0 → 100644
浏览文件 @
e3334f3e
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.common
import
OpRole
,
OP_ROLE_KEY
,
OP_ROLE_VAR_KEY
,
CollectiveHelper
from
.common
import
is_update_op
,
is_loss_grad_op
,
is_backward_op
,
is_optimizer_op
from
.meta_optimizer_base
import
MetaOptimizerBase
from
paddle.fluid
import
unique_name
,
core
from
paddle.fluid.contrib.mixed_precision.decorator
import
OptimizerWithMixedPrecision
import
paddle.fluid
as
fluid
import
math
import
re
__all__
=
[
"ZeroOptimizer"
]
def
_pretty_op_desc_
(
op_desc
,
prefix
):
out_s
=
"%s
\t
name:[%s]
\n
%s
\t
inputs:[%s]
\n
%s
\t
outputs:[%s]"
%
\
(
prefix
+
"_op"
,
str
(
op_desc
.
type
()),
prefix
+
"_input"
,
" "
.
join
(
op_desc
.
input_arg_names
()),
prefix
+
"_output"
,
" "
.
join
(
op_desc
.
output_arg_names
()))
return
out_s
class
SubProgram
(
object
):
def
__init__
(
self
,
block
):
self
.
_block
=
block
self
.
_allreduce_vars
=
[]
# sub program start idx
self
.
_start_idx
=
-
1
# sub program end idx
self
.
_end_idx
=
-
1
# param name to broadcast name
self
.
_param2broadcast
=
{}
self
.
_broadcast_vars
=
[]
# cast op pairs, fp16 name (str) -> fp32 name (str)
self
.
_cast_ops
=
{}
# fill constant vars
self
.
_fill_constant_vars
=
[]
# parameter mems
self
.
_param_mem
=
0.0
class
ProgramDeps
(
object
):
def
__init__
(
self
,
block
,
start_vars
,
end_vars
):
self
.
_block
=
block
# vars where to start to build the deps
self
.
_start_vars
=
start_vars
# vars where to stop to build the deps
self
.
_end_vars
=
end_vars
# var name -> op idxs which depends on this var
self
.
_var_deps
=
{}
# sub block deps which is a subset of this topo
self
.
_sub_block_deps
=
{}
# var name -> op idxs which generate var
self
.
_var_to_generate_op
=
{}
self
.
_should_removed_var
=
set
()
self
.
_father_block_deps
=
None
self
.
_build_deps
()
def
get_sub_block_deps
(
self
,
idx
):
if
idx
in
self
.
_sub_block_deps
:
return
self
.
_sub_block_deps
[
idx
]
else
:
return
None
def
get_var_deps
(
self
,
var_name
):
if
var_name
in
self
.
_var_deps
:
return
self
.
_var_deps
[
var_name
]
else
:
return
None
def
_build_deps
(
self
,
):
for
var_name
in
self
.
_start_vars
:
self
.
_var_deps
[
var_name
]
=
[
-
1
]
self
.
_var_to_generate_op
[
var_name
]
=
[
-
1
]
for
idx
,
op
in
enumerate
(
self
.
_block
.
ops
):
if
op
.
type
in
[
"c_allreduce_sum"
,
"c_sync_comm_stream"
,
"c_calc_comm_stream"
]:
continue
input_vars
=
op
.
desc
.
input_arg_names
()
output_vars
=
op
.
desc
.
output_arg_names
()
deps_reduce
=
False
for
input_name
in
input_vars
:
if
input_name
in
self
.
_var_deps
:
deps_reduce
=
True
if
deps_reduce
:
for
input_name
in
input_vars
:
if
input_name
in
self
.
_var_deps
:
self
.
_var_deps
[
input_name
].
append
(
idx
)
for
output_name
in
output_vars
:
self
.
_var_deps
[
output_name
]
=
[]
if
output_name
not
in
self
.
_var_to_generate_op
:
self
.
_var_to_generate_op
[
output_name
]
=
[
idx
]
else
:
self
.
_var_to_generate_op
[
output_name
].
append
(
idx
)
if
op
.
type
==
"conditional_block"
:
# subblock
assert
(
op
.
desc
.
has_attr
(
"sub_block"
))
subblock_idx
=
op
.
desc
.
attr
(
"sub_block"
).
id
subblock_deps
=
ProgramDeps
(
self
.
_block
.
program
.
block
(
subblock_idx
),
op
.
desc
.
input_arg_names
(),
op
.
desc
.
output_arg_names
())
self
.
_sub_block_deps
[
subblock_idx
]
=
subblock_deps
subblock_deps
.
_father_block_deps
=
self
def
crop_input_var_from_op
(
self
,
op_idx
,
var_name
):
if
var_name
in
self
.
_var_deps
:
# update var -> dep_var_op
if
self
.
_var_deps
[
var_name
]
!=
[]:
assert
(
op_idx
in
self
.
_var_deps
[
var_name
])
self
.
_var_deps
[
var_name
].
remove
(
op_idx
)
# update _should_removed_var
if
var_name
in
self
.
_start_vars
:
self
.
_should_removed_var
.
discard
(
var_name
)
elif
self
.
_var_deps
[
var_name
]
==
[]:
# no more deps of this var
self
.
_should_removed_var
.
add
(
var_name
)
elif
self
.
_var_to_generate_op
[
var_name
][
-
1
]
>=
self
.
_var_deps
[
var_name
][
-
1
]:
# there are circle in the graph
self
.
_should_removed_var
.
add
(
var_name
)
else
:
# input_name should not be deleted
self
.
_should_removed_var
.
discard
(
var_name
)
def
crop_output_var_from_op
(
self
,
op_idx
,
var_name
):
if
var_name
in
self
.
_var_to_generate_op
:
assert
(
op_idx
in
self
.
_var_to_generate_op
[
var_name
])
self
.
_var_to_generate_op
[
var_name
].
remove
(
op_idx
)
if
self
.
_block
.
has_var
(
var_name
)
and
self
.
_var_to_generate_op
[
var_name
]
==
[]:
print
(
"main_block remove var {}"
.
format
(
var_name
))
self
.
_block
.
_remove_var
(
var_name
)
def
remove_op
(
self
,
op_idx
):
# update deps
op
=
self
.
_block
.
ops
[
op_idx
]
print
(
"main_block remove op {}"
.
format
(
op
.
type
))
for
input_name
in
op
.
desc
.
input_arg_names
():
self
.
crop_input_var_from_op
(
op_idx
,
input_name
)
for
output_name
in
op
.
desc
.
output_arg_names
():
self
.
crop_output_var_from_op
(
op_idx
,
output_name
)
self
.
_block
.
_remove_op
(
op_idx
)
def
should_remove_op
(
self
,
op_idx
):
op
=
self
.
_block
.
ops
[
op_idx
]
for
output_name
in
op
.
desc
.
output_arg_names
():
if
output_name
not
in
self
.
_should_removed_var
:
return
False
return
True
class
ZeroOptimizer
(
MetaOptimizerBase
):
def
__init__
(
self
,
optimizer
):
super
(
ZeroOptimizer
,
self
).
__init__
(
optimizer
)
self
.
inner_opt
=
optimizer
self
.
_main_program
=
None
self
.
_startup_program
=
None
# we do not allow meta optimizer to be inner optimizer currently
self
.
meta_optimizers_white_list
=
[]
# params and fp16 params is for broadcast
self
.
_params
=
set
([])
self
.
_fp16_params
=
set
([])
# fp16 to fp32
self
.
_fp16_to_params
=
{}
self
.
_broadcast_vars
=
set
([])
# _param(str) -> device_id(int)
self
.
_param2device
=
{}
# varname(str) -> param(Variable)
# reduced grads to param name
self
.
_reduced_grads_to_param
=
{}
# self._nrings(int) is for nccl communicate
self
.
_nrings
=
3
# self._sub_progs
self
.
_sub_progs
=
[]
self
.
_fuse_broadcast_MB_bytes
=
64
self
.
_dtype_to_size
=
{
core
.
VarDesc
.
VarType
.
FP16
:
2
,
core
.
VarDesc
.
VarType
.
FP32
:
4
,
core
.
VarDesc
.
VarType
.
FP64
:
8
,
core
.
VarDesc
.
VarType
.
INT16
:
2
,
core
.
VarDesc
.
VarType
.
INT32
:
4
,
core
.
VarDesc
.
VarType
.
INT64
:
8
,
core
.
VarDesc
.
VarType
.
BOOL
:
1
,
core
.
VarDesc
.
VarType
.
UINT8
:
1
,
}
def
_get_var_size
(
self
,
param
):
"""
input:
- param: var
return:
var size in Bytes
"""
assert
-
1
not
in
param
.
shape
return
reduce
(
lambda
x
,
y
:
x
*
y
,
param
.
shape
)
*
self
.
_dtype_to_size
[
param
.
dtype
]
/
1024.0
/
1024.0
def
_can_apply
(
self
):
return
self
.
user_defined_strategy
.
zero
def
_disable_strategy
(
self
,
dist_strategy
):
dist_strategy
.
zero
=
False
def
_is_fp16_cast_op
(
self
,
block
,
op
):
if
op
.
type
!=
"cast"
:
return
False
if
is_optimizer_op
(
op
):
return
False
assert
(
len
(
op
.
desc
.
input_arg_names
())
==
1
)
assert
(
len
(
op
.
desc
.
output_arg_names
())
==
1
)
input_name
,
output_name
=
op
.
desc
.
input_arg_names
()[
0
],
op
.
desc
.
output_arg_names
()[
0
]
if
input_name
not
in
self
.
_params
:
return
False
input_var
=
block
.
var
(
input_name
)
output_var
=
block
.
var
(
output_name
)
if
input_var
.
dtype
!=
core
.
VarDesc
.
VarType
.
FP32
or
\
output_var
.
dtype
!=
core
.
VarDesc
.
VarType
.
FP16
:
return
False
return
True
def
_split_params
(
self
,
params
):
param2device
=
{}
total_param_mem
=
0.0
param2mem
=
[]
for
param
in
params
:
mem
=
self
.
_get_var_size
(
param
)
total_param_mem
+=
mem
param2mem
.
append
((
param
.
name
,
mem
))
# print(param.name, mem)
# print("total_param_mem: ", total_param_mem)
device_num
=
self
.
role_maker
.
worker_num
()
# print("device_num: ", device_num)
device2params
=
{
x
:
[]
for
x
in
range
(
device_num
)}
device_idx
=
0
mem_accu
=
0.0
for
param_name
,
mem
in
param2mem
:
if
mem_accu
>
total_param_mem
*
1.0
*
(
device_idx
+
1
)
/
device_num
:
device_idx
+=
1
device2params
[
device_idx
].
append
(
param_name
)
param2device
[
param_name
]
=
device_idx
mem_accu
+=
mem
# for debug
print
(
device2params
)
return
param2device
def
_is_opti_var
(
self
,
var_name
):
if
var_name
in
self
.
_params
:
return
True
for
suffix
in
[
"_moment1_0"
,
"_moment2_0"
,
"_beta1_pow_acc_0"
,
"_beta2_pow_acc_0"
]:
base_name
=
re
.
sub
(
suffix
,
''
,
var_name
)
if
base_name
in
self
.
_params
:
return
True
return
False
def
_var_device_id
(
self
,
var_name
):
if
not
self
.
_is_opti_var
(
var_name
):
return
-
1
if
var_name
in
self
.
_param2device
:
return
self
.
_param2device
[
var_name
]
for
suffix
in
[
"_moment1_0"
,
"_moment2_0"
,
"_beta1_pow_acc_0"
,
"_beta2_pow_acc_0"
]:
base_name
=
re
.
sub
(
suffix
,
''
,
var_name
)
if
base_name
in
self
.
_param2device
:
return
self
.
_param2device
[
base_name
]
return
-
1
def
_insert_scale_loss_grad_ops
(
self
,
block
,
scale
=
1.0
):
'''
In order to keep the learning rate consistent in different numbers of
training workers, we scale the loss grad by the number of workers
'''
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
is_loss_grad_op
(
op
):
loss_grad_var
=
block
.
vars
[
op
.
output_arg_names
[
0
]]
block
.
_insert_op
(
idx
+
1
,
type
=
'scale'
,
inputs
=
{
'X'
:
loss_grad_var
},
outputs
=
{
'Out'
:
loss_grad_var
},
attrs
=
{
'scale'
:
scale
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
def
_split_program
(
self
,
block
):
for
op_idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
int
(
op
.
attr
(
'op_role'
))
!=
int
(
OpRole
.
Optimize
):
last_backward_op_idx
=
op_idx
+
1
break
sub_prog
=
SubProgram
(
block
)
sub_prog
.
_end_idx
=
last_backward_op_idx
for
op_idx
in
reversed
(
range
(
last_backward_op_idx
)):
op
=
block
.
ops
[
op_idx
]
assert
(
int
(
op
.
attr
(
'op_role'
))
!=
int
(
OpRole
.
Optimize
))
if
sub_prog
.
_param_mem
>=
self
.
_fuse_broadcast_MB_bytes
:
sub_prog
.
_start_idx
=
op_idx
+
1
self
.
_sub_progs
.
insert
(
0
,
sub_prog
)
sub_prog
=
SubProgram
(
block
)
sub_prog
.
_end_idx
=
op_idx
+
1
# find broadcast vars
for
input_name
in
op
.
desc
.
input_arg_names
():
if
input_name
not
in
self
.
_broadcast_vars
:
continue
root_device
=
self
.
_param2device
[
input_name
]
if
input_name
in
sub_prog
.
_param2broadcast
:
# skip broadcast because it reuse the old broadcast var
broadcast_name
=
sub_prog
.
_param2broadcast
[
input_name
]
if
input_name
!=
broadcast_name
:
op
.
_rename_input
(
input_name
,
broadcast_name
)
continue
if
root_device
==
self
.
role_maker
.
worker_index
():
broadcast_var_name
=
input_name
else
:
broadcast_var_name
=
unique_name
.
generate
(
input_name
+
"@BroadCast"
)
sub_prog
.
_fill_constant_vars
.
append
(
broadcast_var_name
)
sub_prog
.
_param2broadcast
[
input_name
]
=
broadcast_var_name
sub_prog
.
_broadcast_vars
.
append
(
(
broadcast_var_name
,
self
.
_param2device
[
input_name
]))
sub_prog
.
_param_mem
+=
self
.
_get_var_size
(
self
.
_main_program
.
global_block
().
var
(
input_name
))
# find reduce vars
if
is_backward_op
(
op
)
and
\
OP_ROLE_VAR_KEY
in
op
.
attr_names
:
op_role_var
=
op
.
all_attrs
()[
OP_ROLE_VAR_KEY
]
if
len
(
op_role_var
)
!=
0
:
assert
len
(
op_role_var
)
%
2
==
0
for
i
in
range
(
0
,
len
(
op_role_var
),
2
):
param
,
reduced_grad
=
op_role_var
[
i
],
op_role_var
[
i
+
1
]
sub_prog
.
_allreduce_vars
.
append
(
reduced_grad
)
assert
(
reduced_grad
not
in
self
.
_reduced_grads_to_param
)
self
.
_reduced_grads_to_param
[
reduced_grad
]
=
param
# find cast op
if
self
.
_is_fp16_cast_op
(
block
,
op
):
fp32_param
=
op
.
desc
.
input_arg_names
()[
0
]
fp16_param
=
op
.
desc
.
output_arg_names
()[
0
]
if
self
.
_param2device
[
fp32_param
]
==
self
.
role_maker
.
worker_index
():
sub_prog
.
_cast_ops
[
fp16_param
]
=
fp32_param
if
sub_prog
.
_param_mem
>
0
:
sub_prog
.
_start_idx
=
0
self
.
_sub_progs
.
insert
(
0
,
sub_prog
)
return
def
_is_gradient_clip_sum_op
(
self
,
op
):
return
op
.
type
==
"sum"
and
op
.
desc
.
has_attr
(
"op_namescope"
)
\
and
op
.
desc
.
attr
(
"op_namescope"
).
startswith
(
"/gradient_clip_@CLIP"
)
def
_is_amp_sum_op
(
self
,
op
):
return
op
.
type
==
"sum"
and
op
.
desc
.
has_attr
(
"op_namescope"
)
\
and
op
.
desc
.
attr
(
"op_namescope"
).
startswith
(
"/mixed_precision"
)
def
_is_amp_subblock
(
self
,
op
):
return
op
.
type
==
"conditional_block"
and
op
.
desc
.
has_attr
(
"op_namescope"
)
\
and
op
.
desc
.
attr
(
"op_namescope"
).
startswith
(
"/mixed_precision"
)
def
_prune_main_program
(
self
,
block
):
"""
calculate deps from allredce op to optimize op,
remove ops and vars not needed in this worker
"""
# build prog deps
reduced_grads
=
[]
var_to_reduce_var
=
{}
for
idx
,
op
in
enumerate
(
block
.
ops
):
input_names
=
op
.
desc
.
input_arg_names
()
output_names
=
op
.
desc
.
output_arg_names
()
if
op
.
type
==
"c_allreduce_sum"
:
assert
(
len
(
output_names
)
==
1
)
output_name
=
output_names
[
0
]
reduced_grads
.
append
(
output_name
)
var_to_reduce_var
[
output_name
]
=
output_name
else
:
non_persistable_input
=
[
x
for
x
in
input_names
if
not
block
.
var
(
x
).
persistable
]
if
len
(
non_persistable_input
)
==
1
and
len
(
output_names
)
==
1
and
non_persistable_input
[
0
]
in
var_to_reduce_var
:
var_to_reduce_var
[
output_names
[
0
]]
=
var_to_reduce_var
[
non_persistable_input
[
0
]]
params
=
[]
for
var_name
,
_
in
block
.
vars
.
items
():
if
self
.
_is_opti_var
(
var_name
)
and
\
self
.
_var_device_id
(
var_name
)
!=
self
.
role_maker
.
worker_index
():
params
.
append
(
var_name
)
program_deps
=
ProgramDeps
(
block
,
reduced_grads
,
params
)
# Init
for
var_name
in
program_deps
.
_end_vars
:
program_deps
.
_should_removed_var
.
add
(
var_name
)
# Prune
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
op
.
type
in
[
"c_allreduce_sum"
,
"c_sync_comm_stream"
,
"c_calc_comm_stream"
,
"c_gen_nccl_id"
,
"c_comm_init"
]:
pass
elif
self
.
_is_gradient_clip_sum_op
(
op
)
or
self
.
_is_amp_sum_op
(
op
):
reversed_input_vars
=
[]
for
input_name
in
op
.
desc
.
input_arg_names
():
assert
(
input_name
in
var_to_reduce_var
)
reduce_var
=
var_to_reduce_var
[
input_name
]
param_name
=
self
.
_reduced_grads_to_param
[
reduce_var
]
if
self
.
_param2device
[
param_name
]
!=
self
.
role_maker
.
worker_index
():
program_deps
.
crop_input_var_from_op
(
idx
,
input_name
)
else
:
reversed_input_vars
.
append
(
input_name
)
op
.
desc
.
set_input
(
"X"
,
reversed_input_vars
)
assert
(
len
(
op
.
desc
.
output_arg_names
())
==
1
)
sum_res
=
op
.
desc
.
output_arg_names
()[
0
]
block
.
_insert_op
(
idx
+
1
,
type
=
'c_sync_comm_stream'
,
inputs
=
{
'X'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
attrs
=
{
'ring_id'
:
0
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
block
.
_insert_op
(
idx
+
1
,
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
attrs
=
{
'ring_id'
:
0
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
block
.
_insert_op
(
idx
+
1
,
type
=
'c_sync_calc_stream'
,
inputs
=
{
'X'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
attrs
=
{
OP_ROLE_KEY
:
OpRole
.
Optimize
})
elif
op
.
type
==
"conditional_block"
:
assert
(
op
.
desc
.
has_attr
(
"sub_block"
))
subblock_idx
=
op
.
desc
.
attr
(
"sub_block"
).
id
subblock_deps
=
program_deps
.
get_sub_block_deps
(
subblock_idx
)
# only prune amp subblock
if
subblock_deps
is
None
or
not
self
.
_is_amp_subblock
(
op
):
continue
# init
reversed_output_vars
=
[]
for
output_name
in
op
.
desc
.
output
(
"Out"
):
if
output_name
in
program_deps
.
_should_removed_var
:
subblock_deps
.
_should_removed_var
.
add
(
output_name
)
program_deps
.
crop_output_var_from_op
(
idx
,
output_name
)
else
:
reversed_output_vars
.
append
(
output_name
)
# prune
for
sub_op_idx
,
_
in
reversed
(
list
(
enumerate
(
subblock_deps
.
_block
.
ops
))):
if
subblock_deps
.
should_remove_op
(
sub_op_idx
):
subblock_deps
.
remove_op
(
sub_op_idx
)
reversed_input_vars
=
[]
for
input_name
in
op
.
desc
.
input
(
'Input'
):
if
input_name
not
in
subblock_deps
.
_should_removed_var
:
reversed_input_vars
.
append
(
input_name
)
else
:
program_deps
.
crop_input_var_from_op
(
idx
,
input_name
)
op
.
desc
.
set_input
(
'Input'
,
reversed_input_vars
)
op
.
desc
.
set_output
(
'Out'
,
reversed_output_vars
)
else
:
if
program_deps
.
should_remove_op
(
idx
):
program_deps
.
remove_op
(
idx
)
block
.
_sync_with_cpp
()
return
def
_remove_cast_op
(
self
,
block
,
sub_prog
,
offset
):
inserted_op_num
=
0
for
op_idx
in
reversed
(
range
(
offset
+
sub_prog
.
_start_idx
,
offset
+
sub_prog
.
_end_idx
)):
op
=
block
.
ops
[
op_idx
]
if
self
.
_is_fp16_cast_op
(
block
,
op
):
block
.
_remove_op
(
op_idx
)
inserted_op_num
-=
1
block
.
_sync_with_cpp
()
return
inserted_op_num
def
_insert_broadcast_ops
(
self
,
block
,
insert_idx
,
broadcast2root
):
"""
_add_broadcast_ops
"""
ring_id
=
-
1
# TODO(mapingshuo): correct OP_ROLE_KEY
for
broadcast_name
,
root_device
in
broadcast2root
:
ring_id
=
(
ring_id
+
1
)
%
self
.
_nrings
block
.
_insert_op
(
insert_idx
,
type
=
'c_broadcast'
,
inputs
=
{
'X'
:
broadcast_name
},
outputs
=
{
'Out'
:
broadcast_name
},
attrs
=
{
'ring_id'
:
ring_id
,
'root'
:
root_device
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
return
def
_insert_allreduce_ops
(
self
,
block
,
insert_idx
,
allreduce_vars
):
"""
_add_allreduce_ops
"""
ring_id
=
-
1
for
var
in
allreduce_vars
:
ring_id
=
(
ring_id
+
1
)
%
self
.
_nrings
block
.
_insert_op
(
insert_idx
,
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
var
},
outputs
=
{
'Out'
:
var
},
attrs
=
{
'ring_id'
:
ring_id
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
return
def
_insert_cast_ops
(
self
,
block
,
insert_idx
,
cast_ops
):
"""
_add_cast_ops
"""
for
fp16_name
,
fp32_name
in
cast_ops
.
items
():
block
.
_insert_op
(
insert_idx
,
type
=
"cast"
,
inputs
=
{
"X"
:
fp32_name
},
outputs
=
{
"Out"
:
fp16_name
},
attrs
=
{
"in_dtype"
:
core
.
VarDesc
.
VarType
.
FP32
,
"out_dtype"
:
core
.
VarDesc
.
VarType
.
FP16
})
return
def
_insert_fill_constant_ops
(
self
,
block
,
insert_idx
,
fill_constant_vars
):
"""
_add_fill_constant_ops
"""
for
broadcast_name
in
fill_constant_vars
:
broadcast_var
=
block
.
var
(
broadcast_name
)
block
.
_insert_op
(
insert_idx
,
type
=
"fill_constant"
,
outputs
=
{
"Out"
:
broadcast_var
.
name
},
attrs
=
{
"shape"
:
broadcast_var
.
shape
,
"dtype"
:
broadcast_var
.
dtype
,
"value"
:
0.0
,
})
return
def
_insert_sync_comm_ops
(
self
,
block
,
insert_idx
,
comm_dep_vars
):
"""
_insert_sync_comm_ops
"""
# TODO(mapingshuo) fix OP_ROLE_KEY
for
i
in
range
(
self
.
_nrings
):
block
.
_insert_op
(
insert_idx
,
type
=
'c_sync_comm_stream'
,
inputs
=
{
'X'
:
comm_dep_vars
},
outputs
=
{
'Out'
:
comm_dep_vars
},
attrs
=
{
'ring_id'
:
i
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
return
def
_insert_sync_calc_op
(
self
,
block
,
insert_idx
,
calc_dep_vars
):
"""
_insert_sync_calc_op
"""
# TODO(mapingshuo) fix OP_ROLE_KEY
block
.
_insert_op
(
insert_idx
,
type
=
'c_sync_calc_stream'
,
inputs
=
{
'X'
:
calc_dep_vars
},
outputs
=
{
'Out'
:
calc_dep_vars
},
attrs
=
{
OP_ROLE_KEY
:
OpRole
.
Forward
})
return
def
_add_broadcast_allreduce_v2
(
self
,
block
):
"""
_add_broadcast_allreduce_v2
"""
ring_id
=
-
1
if
len
(
self
.
_sub_progs
)
<
1
:
return
if
self
.
_sub_progs
[
-
1
].
_allreduce_vars
:
self
.
_insert_sync_comm_ops
(
block
,
self
.
_sub_progs
[
-
1
].
_end_idx
,
self
.
_sub_progs
[
-
1
].
_allreduce_vars
)
self
.
_insert_allreduce_ops
(
block
,
self
.
_sub_progs
[
-
1
].
_end_idx
,
self
.
_sub_progs
[
-
1
].
_allreduce_vars
)
for
idx
,
subprog
in
reversed
(
list
(
enumerate
(
self
.
_sub_progs
))):
print
(
"subprog_{}: ({}-{})"
.
format
(
idx
,
subprog
.
_start_idx
,
subprog
.
_end_idx
))
allreduce_vars
=
self
.
_sub_progs
[
idx
-
1
].
_allreduce_vars
if
idx
>
0
else
[]
broadcast_vars
=
self
.
_sub_progs
[
idx
+
1
].
_broadcast_vars
if
idx
<
len
(
self
.
_sub_progs
)
-
1
else
[]
fill_constant_vars
=
self
.
_sub_progs
[
idx
+
2
].
_fill_constant_vars
if
idx
<
len
(
self
.
_sub_progs
)
-
2
else
[]
cast_ops
=
self
.
_sub_progs
[
idx
+
2
].
_cast_ops
if
idx
<
len
(
self
.
_sub_progs
)
-
2
else
{}
# for x in fill_constant_vars:
# print("fill_constant_vars: ", x)
# step1: modify calculate ops
# for op_idx in reversed(range(subprog._start_idx, subprog._end_idx)):
# op = block.ops[op_idx]
# print(_pretty_op_desc_(op.desc, "subprog_op"))
for
op_idx
in
reversed
(
range
(
subprog
.
_start_idx
,
subprog
.
_end_idx
)):
op
=
block
.
ops
[
op_idx
]
for
input_name
in
op
.
desc
.
input_arg_names
():
if
input_name
in
subprog
.
_param2broadcast
and
\
input_name
!=
subprog
.
_param2broadcast
[
input_name
]:
op
.
_rename_input
(
input_name
,
subprog
.
_param2broadcast
[
input_name
])
for
param_name
,
broadcast_name
in
subprog
.
_param2broadcast
.
items
():
if
param_name
!=
broadcast_name
:
block
.
create_var
(
name
=
broadcast_name
,
shape
=
self
.
_main_program
.
global_block
().
var
(
param_name
).
shape
,
dtype
=
self
.
_main_program
.
global_block
().
var
(
param_name
)
.
dtype
,
persistable
=
False
)
# step2: remove cast ops
block
.
_sync_with_cpp
()
subprog
.
_end_idx
+=
self
.
_remove_cast_op
(
block
,
subprog
,
0
)
# step3: add Sync ops
comm_dep_vars
=
allreduce_vars
+
[
x
[
0
]
for
x
in
broadcast_vars
]
if
len
(
comm_dep_vars
)
>
0
:
self
.
_insert_sync_comm_ops
(
block
,
subprog
.
_end_idx
,
comm_dep_vars
,
)
calc_dep_vars
=
fill_constant_vars
+
[
k
for
k
,
v
in
cast_ops
.
items
()
]
if
len
(
calc_dep_vars
)
>
0
:
self
.
_insert_sync_calc_op
(
block
,
subprog
.
_end_idx
,
[
calc_dep_vars
[
-
1
]])
# step4: insert `fill_constant` ops
self
.
_insert_fill_constant_ops
(
block
,
subprog
.
_end_idx
,
fill_constant_vars
)
# step5: add `cast` ops
self
.
_insert_cast_ops
(
block
,
subprog
.
_end_idx
,
cast_ops
)
# step6: add broadcast ops
self
.
_insert_broadcast_ops
(
block
,
subprog
.
_start_idx
,
broadcast_vars
)
# step7: add all_reduce ops
self
.
_insert_allreduce_ops
(
block
,
subprog
.
_start_idx
,
allreduce_vars
)
block
.
_sync_with_cpp
()
if
self
.
_sub_progs
[
0
].
_broadcast_vars
:
self
.
_insert_sync_comm_ops
(
block
,
self
.
_sub_progs
[
0
].
_start_idx
,
[
x
[
0
]
for
x
in
self
.
_sub_progs
[
0
].
_broadcast_vars
])
self
.
_insert_broadcast_ops
(
block
,
self
.
_sub_progs
[
0
].
_start_idx
,
self
.
_sub_progs
[
0
].
_broadcast_vars
)
fill_constant_vars
=
reduce
(
lambda
x
,
y
:
x
.
_fill_constant_vars
+
y
.
_fill_constant_vars
,
self
.
_sub_progs
[:
2
])
# Join
cast_ops
=
{}
for
x
in
self
.
_sub_progs
[:
2
]:
for
k
,
v
in
x
.
_cast_ops
.
items
():
cast_ops
[
k
]
=
v
calc_deps_vars
=
fill_constant_vars
+
[
k
for
k
,
v
in
cast_ops
.
items
()]
if
fill_constant_vars
or
cast_ops
:
self
.
_insert_sync_calc_op
(
block
,
self
.
_sub_progs
[
0
].
_start_idx
,
[
calc_deps_vars
[
-
1
]])
if
fill_constant_vars
:
self
.
_insert_fill_constant_ops
(
block
,
self
.
_sub_progs
[
0
].
_start_idx
,
fill_constant_vars
)
if
cast_ops
:
self
.
_insert_cast_ops
(
block
,
self
.
_sub_progs
[
0
].
_start_idx
,
cast_ops
)
return
def
_prune_startup_program
(
self
,
block
):
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
for
output_name
in
op
.
desc
.
output_arg_names
():
var_device_id
=
self
.
_var_device_id
(
output_name
)
if
var_device_id
==
-
1
or
var_device_id
==
self
.
role_maker
.
worker_index
(
):
continue
print
(
"%d: startup_block remove op %s"
%
(
self
.
role_maker
.
worker_index
(),
op
.
type
))
block
.
_remove_op
(
idx
)
break
for
var_name
,
_
in
block
.
vars
.
items
():
var_device_id
=
self
.
_var_device_id
(
var_name
)
if
var_device_id
==
-
1
or
var_device_id
==
self
.
role_maker
.
worker_index
(
):
continue
print
(
"%d: startup_block remove var %s"
%
(
self
.
role_maker
.
worker_index
(),
var_name
))
block
.
_remove_var
(
var_name
)
block
.
_sync_with_cpp
()
def
_find_broadcast_params
(
self
,
params
,
param2device
):
broadcast_vars
=
set
([])
fp16_params
=
set
([])
fp16_to_fp32
=
{}
main_block
=
self
.
_main_program
.
global_block
()
param_usage
=
{
x
:
0
for
x
in
params
}
for
op
in
main_block
.
ops
:
if
is_optimizer_op
(
op
):
continue
for
input_name
in
op
.
desc
.
input_arg_names
():
if
input_name
in
params
:
param_usage
[
input_name
]
+=
1
for
op
in
main_block
.
ops
:
if
not
self
.
_is_fp16_cast_op
(
main_block
,
op
):
continue
input_name
=
op
.
input_arg_names
[
0
]
output_name
=
op
.
output_arg_names
[
0
]
broadcast_vars
.
add
(
output_name
)
fp16_params
.
add
(
output_name
)
fp16_to_fp32
[
output_name
]
=
input_name
param_usage
[
input_name
]
-=
1
param2device
[
output_name
]
=
param2device
[
input_name
]
for
param
,
usage
in
param_usage
.
items
():
if
usage
>
0
:
broadcast_vars
.
add
(
param
)
return
fp16_params
,
broadcast_vars
,
fp16_to_fp32
def
_set_up
(
self
,
params_grads
):
# step 1: initialize nccl
# TODO(mapingshuo) fix get_trainer_endpoints
print
(
"work idx: "
,
self
.
role_maker
.
worker_index
())
endpoints
=
self
.
role_maker
.
get_trainer_endpoints
()
current_endpoint
=
endpoints
[
self
.
role_maker
.
worker_index
()]
collective_helper
=
CollectiveHelper
(
self
.
role_maker
,
self
.
_nrings
)
for
ring_id
in
range
(
self
.
_nrings
):
collective_helper
.
_init_communicator
(
self
.
_startup_program
,
current_endpoint
,
endpoints
,
self
.
role_maker
.
worker_index
(),
ring_id
,
'6174'
)
startup_block
=
self
.
_startup_program
.
global_block
()
startup_block
.
_sync_with_cpp
()
# step 2: split params
self
.
_params
=
set
([
x
[
0
].
name
for
x
in
params_grads
])
self
.
_param2device
=
self
.
_split_params
([
x
[
0
]
for
x
in
params_grads
])
# step 3: get broadcast vars
self
.
_fp16_params
,
self
.
_broadcast_vars
,
self
.
_fp16_to_params
=
self
.
_find_broadcast_params
(
self
.
_params
,
self
.
_param2device
)
def
minimize_impl
(
self
,
loss
,
startup_program
=
None
,
parameter_list
=
None
,
no_grad_set
=
None
):
if
self
.
user_defined_strategy
.
zero_configs
[
"allreduce"
]:
return
self
.
minimize_impl_allreduce
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
)
ckpts
=
list
(
self
.
user_defined_strategy
.
zero_configs
[
"checkpoints"
])
optimizer
=
self
.
inner_opt
if
len
(
ckpts
)
>
0
:
print
(
"add recompute"
)
print
(
ckpts
)
optimizer
=
fluid
.
optimizer
.
RecomputeOptimizer
(
optimizer
)
optimizer
.
_set_checkpoints
(
ckpts
)
if
self
.
user_defined_strategy
.
zero_configs
[
"amp"
]:
optimizer
=
fluid
.
contrib
.
mixed_precision
.
decorate
(
optimizer
,
use_dynamic_loss_scaling
=
True
)
self
.
_nrings
=
self
.
user_defined_strategy
.
zero_configs
[
"nrings"
]
self
.
_fuse_broadcast_MB_bytes
=
self
.
user_defined_strategy
.
zero_configs
[
"fuse_broadcast_MB_bytes"
]
print
(
"doing zero optimize..."
)
optimize_ops
,
params_grads
=
optimizer
.
minimize
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
)
if
startup_program
is
None
:
startup_program
=
default_startup_program
()
main_block
=
loss
.
block
startup_block
=
startup_program
.
global_block
()
self
.
_main_program
=
main_block
.
program
self
.
_startup_program
=
startup_program
# step1: set_up
self
.
_set_up
(
params_grads
)
# step2: split_program
self
.
_split_program
(
main_block
)
# step3: add broadcast and reduce ops
print
(
"insert broadcast and allreduce"
)
self
.
_add_broadcast_allreduce_v2
(
main_block
)
main_block
.
_sync_with_cpp
()
startup_block
.
_sync_with_cpp
()
# step4: insert reduce_sum for grad
self
.
_insert_scale_loss_grad_ops
(
main_block
,
scale
=
1.0
/
self
.
role_maker
.
worker_num
())
main_block
.
_sync_with_cpp
()
# step5: remove unneeded ops and vars from block
print
(
"main_block remove ops and vars"
)
self
.
_prune_main_program
(
main_block
)
print
(
"startup_block remove ops and vars"
)
self
.
_prune_startup_program
(
startup_block
)
# check op dependecy for broadcast
self
.
_check_broadcast
(
main_block
)
return
optimize_ops
,
params_grads
def
_check_broadcast
(
self
,
block
):
"""
if a var is broadcasted, it should have a sync_comm before
this var is used, if not, raise error.
if the broadcasted var has a fill_constant op, the fill_constant
op should stay forward before the broadcast op, and before a
sync_calc op. Otherwise, raise error.
"""
broadcast_vars
=
{}
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
op
.
type
==
"c_broadcast"
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
if
"@BroadCast"
in
var_name
:
if
var_name
in
broadcast_vars
:
print
(
"error: var_name areadly exist: "
,
var_name
)
print
(
"the old pos is "
,
broadcast_vars
[
var_name
][
"broadcast_pos"
])
print
(
"the new pos is "
,
idx
)
assert
(
var_name
not
in
broadcast_vars
)
broadcast_vars
[
var_name
]
=
{
"fill_constant_pos"
:
-
1
,
"broadcast_pos"
:
idx
,
}
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
op
.
type
==
"fill_constant"
:
var_name
=
op
.
desc
.
output_arg_names
()[
0
]
if
var_name
in
broadcast_vars
:
broadcast_vars
[
var_name
][
"fill_constant_pos"
]
=
idx
continue
last_sync_comm_op_idx
=
-
1
last_sync_calc_op_idx
=
-
1
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
op
.
type
==
"c_sync_comm_stream"
:
last_sync_comm_op_idx
=
idx
continue
if
op
.
type
==
"c_sync_calc_stream"
:
last_sync_calc_op_idx
=
idx
continue
if
op
.
type
==
"c_broadcast"
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
if
"@BroadCast"
in
var_name
:
if
broadcast_vars
[
var_name
][
"fill_constant_pos"
]
!=
-
1
:
assert
(
last_sync_calc_op_idx
!=
-
1
)
assert
(
broadcast_vars
[
var_name
][
"fill_constant_pos"
]
<
last_sync_calc_op_idx
)
assert
(
last_sync_calc_op_idx
<
idx
)
continue
for
input_name
in
op
.
desc
.
input_arg_names
():
if
input_name
in
broadcast_vars
:
assert
(
broadcast_vars
[
input_name
][
"broadcast_pos"
]
!=
-
1
)
assert
(
broadcast_vars
[
input_name
][
"broadcast_pos"
]
<
last_sync_comm_op_idx
)
assert
(
last_sync_comm_op_idx
<
idx
)
print
(
"check done"
)
return
def
_add_broadcast_allreduce
(
self
,
block
,
sub_prog
,
offset
):
"""
add broadcast and allreduce
"""
# insert reduce ops
inserted_op_num
=
0
ring_id
=
-
1
if
len
(
sub_prog
.
_allreduce_vars
)
>
0
:
for
i
in
range
(
self
.
_nrings
):
block
.
_insert_op
(
offset
+
sub_prog
.
_end_idx
,
type
=
'c_sync_comm_stream'
,
inputs
=
{
'X'
:
sub_prog
.
_allreduce_vars
},
outputs
=
{
'Out'
:
sub_prog
.
_allreduce_vars
},
attrs
=
{
'ring_id'
:
i
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
inserted_op_num
+=
self
.
_nrings
for
var
in
sub_prog
.
_allreduce_vars
:
ring_id
=
(
ring_id
+
1
)
%
self
.
_nrings
block
.
_insert_op
(
offset
+
sub_prog
.
_end_idx
,
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
var
},
outputs
=
{
'Out'
:
var
},
attrs
=
{
'ring_id'
:
ring_id
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
inserted_op_num
+=
1
block
.
_insert_op
(
offset
+
sub_prog
.
_end_idx
,
type
=
'c_sync_calc_stream'
,
inputs
=
{
'X'
:
sub_prog
.
_allreduce_vars
[
-
1
]},
outputs
=
{
'Out'
:
sub_prog
.
_allreduce_vars
[
-
1
]},
attrs
=
{
OP_ROLE_KEY
:
OpRole
.
Forward
})
inserted_op_num
+=
1
block
.
_sync_with_cpp
()
# insert broadcast ops
for
op_idx
in
reversed
(
range
(
offset
+
sub_prog
.
_start_idx
,
offset
+
sub_prog
.
_end_idx
)):
op
=
block
.
ops
[
op_idx
]
for
input_name
in
op
.
desc
.
input_arg_names
():
if
input_name
in
sub_prog
.
_param2broadcast
and
\
input_name
!=
sub_prog
.
_param2broadcast
[
input_name
]:
op
.
_rename_input
(
input_name
,
sub_prog
.
_param2broadcast
[
input_name
])
for
param_name
,
broadcast_name
in
sub_prog
.
_param2broadcast
.
items
():
if
param_name
!=
broadcast_name
:
block
.
create_var
(
name
=
broadcast_name
,
shape
=
self
.
_main_program
.
global_block
().
var
(
param_name
).
shape
,
dtype
=
self
.
_main_program
.
global_block
().
var
(
param_name
)
.
dtype
,
persistable
=
False
)
comm_dep_vars
=
[
v
for
k
,
v
in
sub_prog
.
_param2broadcast
.
items
()]
for
i
in
range
(
self
.
_nrings
):
block
.
_insert_op
(
offset
+
sub_prog
.
_start_idx
,
type
=
'c_sync_comm_stream'
,
inputs
=
{
'X'
:
comm_dep_vars
},
outputs
=
{
'Out'
:
comm_dep_vars
},
attrs
=
{
'ring_id'
:
i
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
inserted_op_num
+=
self
.
_nrings
for
param_name
,
broadcast_name
in
sub_prog
.
_param2broadcast
.
items
():
broadcast_var
=
block
.
var
(
broadcast_name
)
root_device
=
self
.
_param2device
[
param_name
]
ring_id
=
(
ring_id
+
1
)
%
self
.
_nrings
block
.
_insert_op
(
offset
+
sub_prog
.
_start_idx
,
type
=
'c_broadcast'
,
inputs
=
{
'X'
:
broadcast_var
.
name
},
outputs
=
{
'Out'
:
broadcast_var
.
name
},
attrs
=
{
'ring_id'
:
ring_id
,
'root'
:
root_device
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
inserted_op_num
+=
1
comm_dep_vars
=
[
v
for
k
,
v
in
sub_prog
.
_param2broadcast
.
items
()
if
k
!=
v
]
if
comm_dep_vars
!=
[]:
block
.
_insert_op
(
offset
+
sub_prog
.
_start_idx
,
type
=
'c_sync_calc_stream'
,
inputs
=
{
'X'
:
comm_dep_vars
[
-
1
]},
outputs
=
{
'Out'
:
comm_dep_vars
[
-
1
]},
attrs
=
{
OP_ROLE_KEY
:
OpRole
.
Forward
})
inserted_op_num
+=
1
for
param_name
,
broadcast_name
in
sub_prog
.
_param2broadcast
.
items
():
if
param_name
!=
broadcast_name
:
broadcast_var
=
block
.
var
(
broadcast_name
)
block
.
_insert_op
(
offset
+
sub_prog
.
_start_idx
,
type
=
"fill_constant"
,
outputs
=
{
"Out"
:
broadcast_var
.
name
},
attrs
=
{
"shape"
:
broadcast_var
.
shape
,
"dtype"
:
broadcast_var
.
dtype
,
"value"
:
0.0
,
})
inserted_op_num
+=
1
for
fp16_name
,
fp32_name
in
sub_prog
.
_cast_ops
.
items
():
block
.
_insert_op
(
offset
+
sub_prog
.
_start_idx
,
type
=
"cast"
,
inputs
=
{
"X"
:
fp32_name
},
outputs
=
{
"Out"
:
fp16_name
},
attrs
=
{
"in_dtype"
:
core
.
VarDesc
.
VarType
.
FP32
,
"out_dtype"
:
core
.
VarDesc
.
VarType
.
FP16
})
inserted_op_num
+=
1
block
.
_sync_with_cpp
()
return
inserted_op_num
def
_broadcast_params
(
self
,
block
):
ring_id
=
-
1
for
param
in
block
.
iter_parameters
():
if
param
.
is_distributed
:
continue
ring_id
=
(
ring_id
+
1
)
%
self
.
_nrings
block
.
append_op
(
type
=
'c_broadcast'
,
inputs
=
{
'X'
:
param
},
outputs
=
{
'Out'
:
param
},
attrs
=
{
'ring_id'
:
ring_id
,
'root'
:
0
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
for
ring_id
in
range
(
self
.
_nrings
):
block
.
append_op
(
type
=
'c_sync_comm_stream'
,
inputs
=
{
'X'
:
param
},
outputs
=
{
'Out'
:
param
},
attrs
=
{
'ring_id'
:
ring_id
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
# def _insert_broadcast_ops(self, block, fuse_broadcast=False):
# def _insert_cache(cache,
# prepend_comm_sync=False,
# append_comm_sync=False):
# insert_idx = cache["insert_idx"]
# dummy_var_name = cache["dummy_var_name"]
# assert (len(cache["broadcast_ops"]) > 0)
# if prepend_comm_sync:
# insert_idx += self._insert_comm_sync(block, insert_idx,
# [dummy_var_name])
# if len(cache["fill_constant_ops"]) > 0:
# insert_idx += self._insert_fill_constant(
# block, insert_idx, cache["fill_constant_ops"],
# [dummy_var_name])
# insert_idx += self._insert_broadcast_inner(block, insert_idx,
# cache["broadcast_ops"])
# if append_comm_sync:
# insert_idx += self._insert_comm_sync(block, insert_idx,
# [dummy_var_name])
# return insert_idx - cache["insert_idx"]
# print("insert_idx: ", [x["insert_idx"] for x in self._sub_progs])
# move_ahead = 1
# for idx, cache in reversed(list(enumerate(self._sub_progs))):
# if idx < move_ahead:
# cache["insert_idx"] = 0
# else:
# cache["insert_idx"] = self._sub_progs[idx - move_ahead][
# "insert_idx"]
# print("insert_idx: ", [x["insert_idx"] for x in self._sub_progs])
# inserted_op_num = 0
# for idx, cache in enumerate(self._sub_progs):
# prepend_comm_sync = True
# append_comm_sync = True
# cache["insert_idx"] += inserted_op_num
# inserted_op_num += _insert_cache(
# cache,
# prepend_comm_sync=prepend_comm_sync,
# append_comm_sync=append_comm_sync)
# return
def
_insert_allreduce_ops_tmp
(
self
,
block
):
ring_id
=
-
1
grad
=
None
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
is_backward_op
(
op
)
and
\
OP_ROLE_VAR_KEY
in
op
.
attr_names
:
op_role_var
=
op
.
all_attrs
()[
OP_ROLE_VAR_KEY
]
if
len
(
op_role_var
)
==
0
:
continue
assert
len
(
op_role_var
)
%
2
==
0
offset
=
idx
for
i
in
range
(
0
,
len
(
op_role_var
),
2
):
# param = block.vars[op_role_var[i]]
grad
=
block
.
vars
[
op_role_var
[
i
+
1
]]
# TODO(mapingshuo): what is is_distributed
# if param.is_distributed:
# continue
if
offset
==
idx
:
offset
+=
1
block
.
_insert_op
(
offset
,
type
=
'c_sync_calc_stream'
,
inputs
=
{
'X'
:
grad
},
outputs
=
{
'Out'
:
grad
},
attrs
=
{
OP_ROLE_KEY
:
OpRole
.
Backward
})
offset
+=
1
# As we search ops reversedly, we should insert c_allreduce_sum
# op in the same way to keep the ring_id alternate
print
(
"add allreduce op for {}"
.
format
(
grad
.
name
))
ring_id
=
(
ring_id
+
1
)
%
self
.
_nrings
block
.
_insert_op
(
offset
,
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
grad
},
outputs
=
{
'Out'
:
grad
},
attrs
=
{
'ring_id'
:
ring_id
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
if
grad
is
None
:
return
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
is_optimizer_op
(
op
):
for
ring_id
in
range
(
self
.
_nrings
):
block
.
_insert_op
(
idx
+
ring_id
,
type
=
'c_sync_comm_stream'
,
inputs
=
{
'X'
:
grad
},
outputs
=
{
'Out'
:
grad
},
attrs
=
{
'ring_id'
:
ring_id
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
break
def
minimize_impl_allreduce
(
self
,
loss
,
startup_program
=
None
,
parameter_list
=
None
,
no_grad_set
=
None
):
self
.
_nrings
=
self
.
user_defined_strategy
.
zero_configs
[
"nrings"
]
optimizer
=
self
.
inner_opt
if
self
.
user_defined_strategy
.
zero_configs
[
"amp"
]:
optimizer
=
fluid
.
contrib
.
mixed_precision
.
decorate
(
optimizer
,
use_dynamic_loss_scaling
=
True
)
optimize_ops
,
params_grads
=
optimizer
.
minimize
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
)
if
startup_program
is
None
:
startup_program
=
default_startup_program
()
print
(
"work idx: "
,
self
.
role_maker
.
worker_index
())
endpoints
=
self
.
role_maker
.
get_trainer_endpoints
()
current_endpoint
=
endpoints
[
self
.
role_maker
.
worker_index
()]
collective_helper
=
CollectiveHelper
(
self
.
role_maker
,
self
.
_nrings
)
for
ring_id
in
range
(
self
.
_nrings
):
collective_helper
.
_init_communicator
(
startup_program
,
current_endpoint
,
endpoints
,
self
.
role_maker
.
worker_index
(),
ring_id
,
'6174'
)
main_block
=
loss
.
block
startup_block
=
startup_program
.
global_block
()
self
.
_broadcast_params
(
startup_block
)
self
.
_insert_scale_loss_grad_ops
(
main_block
,
scale
=
1.0
/
self
.
role_maker
.
worker_num
())
self
.
_insert_allreduce_ops_tmp
(
main_block
)
print
(
"insert allreduce done"
)
return
optimize_ops
,
params_grads
# def _insert_comm_sync(self, block, insert_idx, var_names):
# for r in range(self._nrings):
# block._insert_op(
# insert_idx,
# type='c_sync_comm_stream',
# inputs={'X': var_names},
# outputs={'Out': var_names},
# attrs={'ring_id': r,
# OP_ROLE_KEY: OpRole.Backward})
# insert_idx += 1
# return self._nrings
# def _insert_broadcast_inner(self, block, insert_idx, broadcast_attrs):
# for attr in broadcast_attrs:
# block._insert_op(insert_idx, **attr)
# insert_idx += 1
# return len(broadcast_attrs)
# def _insert_fill_constant(self, block, insert_idx, fill_constant_attrs,
# var_names):
# for attr in fill_constant_attrs:
# block._insert_op(insert_idx, **attr)
# insert_idx += 1
# block._insert_op(
# insert_idx,
# type='c_sync_calc_stream',
# inputs={'X': var_names},
# outputs={'Out': var_names},
# attrs={OP_ROLE_KEY: OpRole.Backward})
# return len(fill_constant_attrs) + 1
python/paddle/fluid/clip.py
浏览文件 @
e3334f3e
...
...
@@ -847,7 +847,7 @@ def append_gradient_clip_ops(param_grads):
if
g
is
None
:
continue
with
p
.
block
.
program
.
_optimized_guard
(
[
p
,
g
]),
framework
.
name_scope
(
'gra
id
ent_clip_@CLIP'
):
[
p
,
g
]),
framework
.
name_scope
(
'gra
di
ent_clip_@CLIP'
):
param
,
new_grad
=
clip_attr
.
_create_operators
(
param
=
p
,
grad
=
g
)
param_new_grad_name_dict
[
param
.
name
]
=
new_grad
.
name
res
.
append
([
param
,
new_grad
])
...
...
python/paddle/fluid/contrib/mixed_precision/decorator.py
浏览文件 @
e3334f3e
...
...
@@ -16,6 +16,7 @@ from ... import default_main_program
from
...
import
default_startup_program
from
...
import
layers
from
...
import
unique_name
from
...
import
framework
from
.
import
fp16_utils
from
.fp16_utils
import
rewrite_program
from
.fp16_utils
import
update_role_var_grad
...
...
@@ -132,6 +133,7 @@ class OptimizerWithMixedPrecision(object):
gradient respectively, and the scaled loss.
"""
rewrite_program
(
self
.
_train_program
,
self
.
_amp_lists
)
with
framework
.
name_scope
(
'mixed_precision'
):
self
.
_scaled_loss
=
loss
*
self
.
_loss_scaling
self
.
_params_grads
=
self
.
_optimizer
.
backward
(
self
.
_scaled_loss
,
startup_program
,
parameter_list
,
no_grad_set
,
...
...
@@ -156,11 +158,13 @@ class OptimizerWithMixedPrecision(object):
grads
=
[
g
for
_
,
g
in
params_grads
]
with
self
.
_train_program
.
_optimized_guard
(
grads
):
with
framework
.
name_scope
(
'mixed_precision'
):
grads
,
found_inf
=
check_finite_and_unscale
(
grads
,
self
.
_loss_scaling
,
name
=
"find_infinite_scale"
)
if
self
.
_use_dynamic_loss_scaling
:
with
self
.
_train_program
.
_optimized_guard
(
grads
):
with
framework
.
name_scope
(
'mixed_precision'
):
grads
=
update_loss_scaling
(
grads
,
found_inf
,
...
...
python/paddle/fluid/framework.py
浏览文件 @
e3334f3e
...
...
@@ -2063,9 +2063,15 @@ class Operator(object):
%
(
out_proto
.
name
,
len
(
out_args
)))
out_arg_names
=
[]
for
arg
in
out_args
:
if
isinstance
(
arg
,
six
.
string_types
):
out_arg_names
.
append
(
arg
)
else
:
out_arg_names
.
append
(
cpt
.
to_text
(
arg
.
name
))
# TODO(minqiyang): could we remove variable's op in static mode?
if
not
in_dygraph_mode
():
if
isinstance
(
arg
,
six
.
string_types
):
block
.
var
(
arg
).
op
=
self
else
:
arg
.
op
=
self
self
.
desc
.
set_output
(
out_proto
.
name
,
out_arg_names
)
...
...
@@ -2801,7 +2807,6 @@ class Block(object):
return
var
def
_remove_var
(
self
,
name
):
self
.
_sync_with_cpp
()
self
.
desc
.
_remove_var
(
cpt
.
to_bytes
(
name
))
del
self
.
vars
[
name
]
...
...
@@ -2893,7 +2898,6 @@ class Block(object):
Returns:
Operator: the insert Operator.
"""
self
.
_sync_with_cpp
()
op_desc
=
self
.
desc
.
_insert_op
(
index
)
op
=
Operator
(
block
=
self
,
desc
=
op_desc
,
*
args
,
**
kwargs
)
self
.
ops
.
insert
(
index
,
op
)
...
...
@@ -2909,7 +2913,6 @@ class Block(object):
Returns:
None
"""
self
.
_sync_with_cpp
()
self
.
desc
.
_remove_op
(
index
,
index
+
1
)
del
self
.
ops
[
index
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录