Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
81244fbf
P
Paddle
项目概览
PaddlePaddle
/
Paddle
9 个月 前同步成功
通知
2282
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
81244fbf
编写于
10月 26, 2020
作者:
M
mapingshuo
提交者:
GitHub
10月 26, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add sharding strategy in fleet(#27900)
* add sharding
上级
4877bd59
变更
20
隐藏空白更改
内联
并排
Showing
20 changed file
with
1648 addition
and
14 deletion
+1648
-14
paddle/fluid/framework/distributed_strategy.proto
paddle/fluid/framework/distributed_strategy.proto
+6
-0
python/paddle/distributed/fleet/base/distributed_strategy.py
python/paddle/distributed/fleet/base/distributed_strategy.py
+49
-0
python/paddle/distributed/fleet/meta_optimizers/__init__.py
python/paddle/distributed/fleet/meta_optimizers/__init__.py
+1
-0
python/paddle/distributed/fleet/meta_optimizers/common.py
python/paddle/distributed/fleet/meta_optimizers/common.py
+6
-0
python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py
...paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py
+8
-0
python/paddle/distributed/fleet/meta_optimizers/sharding/__init__.py
...le/distributed/fleet/meta_optimizers/sharding/__init__.py
+13
-0
python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py
...distributed/fleet/meta_optimizers/sharding/fp16_helper.py
+154
-0
python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
...ed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
+90
-0
python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py
...addle/distributed/fleet/meta_optimizers/sharding/prune.py
+131
-0
python/paddle/distributed/fleet/meta_optimizers/sharding/shard.py
...addle/distributed/fleet/meta_optimizers/sharding/shard.py
+144
-0
python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py
...addle/distributed/fleet/meta_optimizers/sharding/utils.py
+274
-0
python/paddle/distributed/fleet/meta_optimizers/sharding/weight_decay_helper.py
...ted/fleet/meta_optimizers/sharding/weight_decay_helper.py
+37
-0
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
...e/distributed/fleet/meta_optimizers/sharding_optimizer.py
+411
-0
python/paddle/fluid/clip.py
python/paddle/fluid/clip.py
+2
-2
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+30
-6
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+2
-0
python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py
...paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py
+14
-3
python/paddle/fluid/tests/unittests/test_fleet_gradient_merge_meta_optimizer.py
...sts/unittests/test_fleet_gradient_merge_meta_optimizer.py
+0
-3
python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py
...uid/tests/unittests/test_fleet_sharding_meta_optimizer.py
+275
-0
python/setup.py.in
python/setup.py.in
+1
-0
未找到文件。
paddle/fluid/framework/distributed_strategy.proto
浏览文件 @
81244fbf
...
...
@@ -24,6 +24,10 @@ enum Mode {
message
RecomputeConfig
{
repeated
string
checkpoints
=
1
;
}
message
ShardingConfig
{
optional
float
fuse_broadcast_MB
=
1
[
default
=
32.0
];
}
message
AMPConfig
{
optional
float
init_loss_scaling
=
1
[
default
=
32768.0
];
optional
int32
incr_every_n_steps
=
2
[
default
=
1000
];
...
...
@@ -130,6 +134,7 @@ message DistributedStrategy {
optional
bool
cudnn_batchnorm_spatial_persistent
=
23
[
default
=
true
];
optional
bool
adaptive_localsgd
=
24
[
default
=
false
];
optional
bool
fp16_allreduce
=
25
[
default
=
false
];
optional
bool
sharding
=
26
[
default
=
false
];
optional
RecomputeConfig
recompute_configs
=
101
;
optional
AMPConfig
amp_configs
=
102
;
...
...
@@ -141,6 +146,7 @@ message DistributedStrategy {
optional
LarsConfig
lars_configs
=
108
;
optional
LambConfig
lamb_configs
=
109
;
optional
AdaptiveLocalSGDConfig
adaptive_localsgd_configs
=
110
;
optional
ShardingConfig
sharding_configs
=
111
;
optional
BuildStrategy
build_strategy
=
201
;
optional
ExecutionStrategy
execution_strategy
=
202
;
}
...
...
python/paddle/distributed/fleet/base/distributed_strategy.py
浏览文件 @
81244fbf
...
...
@@ -611,6 +611,55 @@ class DistributedStrategy(object):
"checkpoint_configs"
)
assign_configs_value
(
self
.
strategy
.
recompute_configs
,
configs
)
@
property
def
sharding
(
self
):
"""
Indicating whether we are using sharding Optimizer for memory
optimization
Default value: False
Examples:
.. code-block:: python
import paddle.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.sharding = True
"""
return
self
.
strategy
.
sharding
@
sharding
.
setter
@
is_strict_auto
def
sharding
(
self
,
flag
):
if
isinstance
(
flag
,
bool
):
self
.
strategy
.
sharding
=
flag
else
:
print
(
"WARNING: sharding should have value of bool type"
)
@
property
def
sharding_configs
(
self
):
"""
Set sharding configurations.
**Note**:
fuse_broadcast_MB(float): size of a fused group of broadcasted parameters.
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.sharding = True
strategy.sharding_configs = {"fuse_broadcast_MB": 32}
"""
return
get_msg_dict
(
self
.
strategy
.
sharding_configs
)
@
sharding_configs
.
setter
@
is_strict_auto
def
sharding_configs
(
self
,
configs
):
check_configs_key
(
self
.
strategy
.
sharding_configs
,
configs
,
"sharding_configs"
)
assign_configs_value
(
self
.
strategy
.
sharding_configs
,
configs
)
@
property
def
pipeline
(
self
):
"""
...
...
python/paddle/distributed/fleet/meta_optimizers/__init__.py
浏览文件 @
81244fbf
...
...
@@ -24,3 +24,4 @@ from .parameter_server_graph_optimizer import ParameterServerGraphOptimizer
from
.dgc_optimizer
import
DGCOptimizer
from
.lamb_optimizer
import
LambOptimizer
from
.fp16_allreduce_optimizer
import
FP16AllReduceOptimizer
from
.sharding_optimizer
import
ShardingOptimizer
python/paddle/distributed/fleet/meta_optimizers/common.py
浏览文件 @
81244fbf
...
...
@@ -99,6 +99,12 @@ class CollectiveHelper(object):
OP_ROLE_KEY
:
OpRole
.
Forward
})
def
_wait
(
self
,
current_endpoint
,
endpoints
):
assert
(
self
.
wait_port
)
other_endpoints
=
endpoints
[:]
other_endpoints
.
remove
(
current_endpoint
)
wait_server_ready
(
other_endpoints
)
def
_broadcast_params
(
self
):
block
=
self
.
startup_program
.
global_block
()
ring_id
=
-
1
...
...
python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py
浏览文件 @
81244fbf
...
...
@@ -30,6 +30,10 @@ class DGCOptimizer(MetaOptimizerBase):
super
(
DGCOptimizer
,
self
).
_set_basic_info
(
loss
,
role_maker
,
user_defined_optimizer
,
user_defined_strategy
)
def
_init_dgc_opt
(
self
):
if
self
.
dgc_opt
is
not
None
:
return
opt
=
self
.
inner_opt
if
not
self
.
role_maker
.
_is_collective
:
...
...
@@ -86,13 +90,16 @@ class DGCOptimizer(MetaOptimizerBase):
parameter_list
=
None
,
no_grad_set
=
None
,
callbacks
=
None
):
self
.
_init_dgc_opt
()
return
self
.
dgc_opt
.
backward
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
,
callbacks
)
def
apply_gradients
(
self
,
params_grads
):
self
.
_init_dgc_opt
()
return
self
.
dgc_opt
.
apply_gradients
(
params_grads
=
params_grads
)
def
apply_optimize
(
self
,
loss
,
startup_program
,
params_grads
):
self
.
_init_dgc_opt
()
return
self
.
dgc_opt
.
apply_optimize
(
loss
,
startup_program
=
startup_program
,
params_grads
=
params_grads
)
...
...
@@ -101,6 +108,7 @@ class DGCOptimizer(MetaOptimizerBase):
startup_program
=
None
,
parameter_list
=
None
,
no_grad_set
=
None
):
self
.
_init_dgc_opt
()
optimize_ops
,
params_grads
=
\
self
.
dgc_opt
.
minimize
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
)
...
...
python/paddle/distributed/fleet/meta_optimizers/sharding/__init__.py
0 → 100644
浏览文件 @
81244fbf
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py
0 → 100644
浏览文件 @
81244fbf
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.distributed.fleet.meta_optimizers.common
import
is_optimizer_op
,
OP_ROLE_KEY
,
OpRole
from
paddle.distributed.fleet.meta_optimizers.sharding.utils
import
*
from
paddle.fluid
import
core
class
FP16Utils
(
object
):
def
__init__
(
self
):
pass
@
staticmethod
def
is_fp16_cast_op
(
block
,
op
,
params
):
if
op
.
type
!=
"cast"
:
return
False
if
is_optimizer_op
(
op
):
return
False
assert
(
len
(
op
.
desc
.
input_arg_names
())
==
1
)
assert
(
len
(
op
.
desc
.
output_arg_names
())
==
1
)
input_name
,
output_name
=
op
.
desc
.
input_arg_names
()[
0
],
op
.
desc
.
output_arg_names
()[
0
]
if
input_name
not
in
params
:
return
False
input_var
=
block
.
var
(
input_name
)
output_var
=
block
.
var
(
output_name
)
if
input_var
.
dtype
!=
core
.
VarDesc
.
VarType
.
FP32
or
\
output_var
.
dtype
!=
core
.
VarDesc
.
VarType
.
FP16
:
return
False
return
True
@
staticmethod
def
is_fp32_cast_op
(
block
,
op
):
if
op
.
type
!=
"cast"
:
return
False
if
not
is_optimizer_op
(
op
):
return
False
assert
(
len
(
op
.
desc
.
input_arg_names
())
==
1
)
assert
(
len
(
op
.
desc
.
output_arg_names
())
==
1
)
input_name
,
output_name
=
op
.
desc
.
input_arg_names
()[
0
],
op
.
desc
.
output_arg_names
()[
0
]
input_var
=
block
.
var
(
input_name
)
output_var
=
block
.
var
(
output_name
)
if
input_var
.
dtype
!=
core
.
VarDesc
.
VarType
.
FP16
or
\
output_var
.
dtype
!=
core
.
VarDesc
.
VarType
.
FP32
:
return
False
return
True
@
staticmethod
def
remove_cast_op
(
block
,
params
,
segment
,
offset
):
inserted_op_num
=
0
for
op_idx
in
reversed
(
range
(
offset
+
segment
.
_start_idx
,
offset
+
segment
.
_end_idx
)):
op
=
block
.
ops
[
op_idx
]
if
FP16Utils
.
is_fp16_cast_op
(
block
,
op
,
params
):
block
.
_remove_op
(
op_idx
,
sync
=
False
)
inserted_op_num
-=
1
block
.
_sync_with_cpp
()
return
inserted_op_num
@
staticmethod
def
prune_fp16
(
block
,
shard
,
reduced_grads_to_param
,
nrings
):
# remove cast
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
not
FP16Utils
.
is_fp32_cast_op
(
block
,
op
):
continue
output_name
=
op
.
desc
.
output_arg_names
()[
0
]
param_name
=
output_name
.
strip
(
"@GRAD"
)
if
param_name
not
in
shard
.
global_params
:
raise
ValueError
(
"Input 'X' of check_finite_and_unscale must"
"be grads, but {} is not a grad"
.
format
(
input_name
))
if
output_name
in
reduced_grads_to_param
:
continue
if
shard
.
has_param
(
param_name
):
continue
block
.
_remove_op
(
idx
,
sync
=
False
)
block
.
_remove_var
(
output_name
,
sync
=
False
)
block
.
_sync_with_cpp
()
update_loss_scaling_op_idx
=
-
1
inf_var_name
=
''
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
op
.
type
==
"update_loss_scaling"
:
update_loss_scaling_op_idx
=
idx
inf_var_name
=
op
.
desc
.
input
(
'FoundInfinite'
)[
0
]
op
.
_rename_input
(
inf_var_name
,
inf_var_name
+
"@sharding"
)
if
op
.
type
in
[
"check_finite_and_unscale"
,
"update_loss_scaling"
]:
reversed_x
=
[]
for
input_name
in
op
.
desc
.
input
(
'X'
):
param_name
=
input_name
.
strip
(
"@GRAD"
)
if
param_name
not
in
shard
.
global_params
:
raise
ValueError
(
"Input 'X' of check_finite_and_unscale must"
"be grads, but {} is not a grad"
.
format
(
input_name
))
if
shard
.
has_param
(
param_name
):
reversed_x
.
append
(
input_name
)
op
.
desc
.
set_input
(
'X'
,
reversed_x
)
op
.
desc
.
set_output
(
'Out'
,
reversed_x
)
if
update_loss_scaling_op_idx
==
-
1
:
return
inf_var
=
block
.
var
(
inf_var_name
)
inf_var_fp32
=
block
.
create_var
(
name
=
inf_var_name
+
"@cast_int32"
,
shape
=
inf_var
.
shape
,
dtype
=
core
.
VarDesc
.
VarType
.
INT32
)
inf_var_sharding
=
block
.
create_var
(
name
=
inf_var_name
+
"@sharding"
,
shape
=
inf_var
.
shape
,
dtype
=
inf_var
.
dtype
)
block
.
_insert_op_without_sync
(
update_loss_scaling_op_idx
,
type
=
'cast'
,
inputs
=
{
'X'
:
inf_var
},
outputs
=
{
'Out'
:
inf_var_fp32
},
attrs
=
{
"in_dtype"
:
inf_var
.
dtype
,
"out_dtype"
:
inf_var_fp32
.
dtype
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
insert_sync_calc_op
(
block
,
update_loss_scaling_op_idx
+
1
,
[
inf_var_fp32
])
block
.
_insert_op_without_sync
(
update_loss_scaling_op_idx
+
2
,
type
=
'c_allreduce_max'
,
inputs
=
{
'X'
:
inf_var_fp32
},
outputs
=
{
'Out'
:
inf_var_fp32
},
attrs
=
{
'ring_id'
:
0
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
comm_op_num
=
insert_sync_comm_ops
(
block
,
update_loss_scaling_op_idx
+
3
,
nrings
,
[
inf_var_fp32
])
block
.
_insert_op_without_sync
(
update_loss_scaling_op_idx
+
3
+
comm_op_num
,
type
=
'cast'
,
inputs
=
{
'X'
:
inf_var_fp32
},
outputs
=
{
'Out'
:
inf_var_sharding
},
attrs
=
{
"in_dtype"
:
inf_var_fp32
.
dtype
,
"out_dtype"
:
inf_var_sharding
.
dtype
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
block
.
_sync_with_cpp
()
python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
0 → 100644
浏览文件 @
81244fbf
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.distributed.fleet.meta_optimizers.common
import
OP_ROLE_KEY
,
OpRole
class
GradientClipHelper
(
object
):
def
__init__
(
self
):
pass
def
_is_gradient_clip_op
(
self
,
op
):
return
op
.
desc
.
has_attr
(
"op_namescope"
)
\
and
op
.
desc
.
attr
(
"op_namescope"
).
startswith
(
"/gradient_clip"
)
def
prune_gradient_clip
(
self
,
block
,
shard
):
deperated_vars
=
set
()
deperate_op_idx
=
set
()
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
not
self
.
_is_gradient_clip_op
(
op
):
continue
if
op
.
type
==
"sum"
:
continue
deperate_op
=
False
for
input_name
in
op
.
desc
.
input_arg_names
():
if
input_name
in
deperated_vars
:
deperate_op
=
True
param_name
=
input_name
.
strip
(
"@GRAD"
)
if
shard
.
is_param
(
param_name
)
and
\
not
shard
.
has_param
(
param_name
):
deperate_op
=
True
if
deperate_op
:
deperate_op_idx
.
add
(
idx
)
for
output_name
in
op
.
desc
.
output_arg_names
():
deperated_vars
.
add
(
output_name
)
if
not
deperated_vars
:
# got no gradient_clip op
return
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
not
self
.
_is_gradient_clip_op
(
op
):
continue
if
idx
in
deperate_op_idx
:
block
.
_remove_op
(
idx
,
sync
=
False
)
continue
reversed_inputs
=
[]
if
op
.
type
==
"sum"
:
for
input_name
in
op
.
desc
.
input_arg_names
():
if
input_name
not
in
deperated_vars
:
reversed_inputs
.
append
(
input_name
)
op
.
desc
.
set_input
(
"X"
,
reversed_inputs
)
assert
(
len
(
op
.
desc
.
output_arg_names
())
==
1
)
sum_res
=
op
.
desc
.
output_arg_names
()[
0
]
block
.
_insert_op_without_sync
(
idx
+
1
,
type
=
'c_sync_comm_stream'
,
inputs
=
{
'X'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
attrs
=
{
'ring_id'
:
0
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
block
.
_insert_op_without_sync
(
idx
+
1
,
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
attrs
=
{
'ring_id'
:
0
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
block
.
_insert_op_without_sync
(
idx
+
1
,
type
=
'c_sync_calc_stream'
,
inputs
=
{
'X'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
attrs
=
{
OP_ROLE_KEY
:
OpRole
.
Optimize
})
for
var_name
in
deperated_vars
:
block
.
_remove_var
(
var_name
,
sync
=
False
)
block
.
_sync_with_cpp
()
return
python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py
0 → 100644
浏览文件 @
81244fbf
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
ProgramDeps
(
object
):
def
__init__
(
self
,
block
,
start_vars
,
end_vars
):
self
.
_block
=
block
# vars where to start to build the deps
self
.
_start_vars
=
start_vars
# vars where to stop to build the deps
self
.
_end_vars
=
end_vars
# var name -> op idxs which depends on this var
self
.
_var_to_use_op
=
{}
# sub block deps which is a subset of this topo
self
.
_sub_block_deps
=
{}
# var name -> op idxs which generate var
self
.
_var_to_generate_op
=
{}
self
.
_should_removed_var
=
set
()
self
.
_father_block_deps
=
None
self
.
_build_deps
()
def
get_sub_block_deps
(
self
,
idx
):
if
idx
in
self
.
_sub_block_deps
:
return
self
.
_sub_block_deps
[
idx
]
else
:
return
None
def
get_var_deps
(
self
,
var_name
):
if
var_name
in
self
.
_var_to_use_op
:
return
self
.
_var_to_use_op
[
var_name
]
else
:
return
None
def
_build_deps
(
self
,
):
for
var_name
in
self
.
_start_vars
:
self
.
_var_to_use_op
[
var_name
]
=
[]
self
.
_var_to_generate_op
[
var_name
]
=
[]
for
idx
,
op
in
enumerate
(
self
.
_block
.
ops
):
if
op
.
type
in
[
"c_allreduce_sum"
,
"c_sync_comm_stream"
,
"c_calc_comm_stream"
]:
continue
input_vars
=
op
.
desc
.
input_arg_names
()
output_vars
=
op
.
desc
.
output_arg_names
()
deps_reduce
=
False
for
input_name
in
input_vars
:
if
input_name
in
self
.
_var_to_use_op
:
deps_reduce
=
True
if
not
deps_reduce
:
continue
for
input_name
in
input_vars
:
if
input_name
in
self
.
_var_to_use_op
:
self
.
_var_to_use_op
[
input_name
].
append
(
idx
)
for
output_name
in
output_vars
:
if
output_name
not
in
self
.
_var_to_use_op
:
self
.
_var_to_use_op
[
output_name
]
=
[]
if
output_name
not
in
self
.
_var_to_generate_op
:
self
.
_var_to_generate_op
[
output_name
]
=
[
idx
]
else
:
self
.
_var_to_generate_op
[
output_name
].
append
(
idx
)
if
op
.
type
==
"conditional_block"
:
# subblock
assert
(
op
.
desc
.
has_attr
(
"sub_block"
))
subblock_idx
=
op
.
desc
.
attr
(
"sub_block"
).
id
subblock_deps
=
ProgramDeps
(
self
.
_block
.
program
.
block
(
subblock_idx
),
op
.
desc
.
input_arg_names
(),
op
.
desc
.
output_arg_names
())
self
.
_sub_block_deps
[
subblock_idx
]
=
subblock_deps
subblock_deps
.
_father_block_deps
=
self
def
crop_input_var_from_op
(
self
,
op_idx
,
var_name
):
if
var_name
in
self
.
_var_to_use_op
:
# update var -> dep_var_op
if
self
.
_var_to_use_op
[
var_name
]
!=
[]:
if
op_idx
not
in
self
.
_var_to_use_op
[
var_name
]:
raise
ValueError
(
"op_idx: {} is not in self._var_to_use_op[{}], "
"self._var_to_use_op[{}] is {}"
.
format
(
op_idx
,
var_name
,
var_name
,
self
.
_var_to_use_op
[
var_name
]))
self
.
_var_to_use_op
[
var_name
].
remove
(
op_idx
)
# update _should_removed_var
if
var_name
in
self
.
_start_vars
:
self
.
_should_removed_var
.
discard
(
var_name
)
elif
self
.
_var_to_use_op
[
var_name
]
==
[]:
# no more deps of this var
self
.
_should_removed_var
.
add
(
var_name
)
elif
self
.
_var_to_generate_op
[
var_name
][
-
1
]
>=
self
.
_var_to_use_op
[
var_name
][
-
1
]:
# there are circle in the graph
self
.
_should_removed_var
.
add
(
var_name
)
else
:
# input_name should not be deleted
self
.
_should_removed_var
.
discard
(
var_name
)
def
crop_output_var_from_op
(
self
,
op_idx
,
var_name
):
if
var_name
in
self
.
_var_to_generate_op
:
assert
(
op_idx
in
self
.
_var_to_generate_op
[
var_name
])
self
.
_var_to_generate_op
[
var_name
].
remove
(
op_idx
)
if
self
.
_block
.
has_var
(
var_name
):
if
var_name
not
in
self
.
_var_to_generate_op
or
self
.
_var_to_generate_op
[
var_name
]
==
[]:
self
.
_block
.
_remove_var
(
var_name
,
sync
=
False
)
def
remove_op
(
self
,
op_idx
):
# update deps
op
=
self
.
_block
.
ops
[
op_idx
]
for
input_name
in
op
.
desc
.
input_arg_names
():
self
.
crop_input_var_from_op
(
op_idx
,
input_name
)
for
output_name
in
op
.
desc
.
output_arg_names
():
self
.
crop_output_var_from_op
(
op_idx
,
output_name
)
self
.
_block
.
_remove_op
(
op_idx
,
sync
=
False
)
def
should_remove_op
(
self
,
op_idx
):
op
=
self
.
_block
.
ops
[
op_idx
]
for
output_name
in
op
.
desc
.
output_arg_names
():
if
output_name
not
in
self
.
_should_removed_var
:
return
False
return
True
python/paddle/distributed/fleet/meta_optimizers/sharding/shard.py
0 → 100644
浏览文件 @
81244fbf
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.distributed.fleet.meta_optimizers.common
import
is_optimizer_op
from
paddle.distributed.fleet.meta_optimizers.sharding.utils
import
*
from
paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper
import
FP16Utils
class
Shard
(
object
):
def
__init__
(
self
,
):
self
.
global_params
=
set
([])
self
.
worker_idx
=
-
1
self
.
worker_num
=
-
1
self
.
global_param2device
=
{}
def
setup
(
self
,
params_grads
,
worker_idx
,
worker_num
):
# param names of all devices
self
.
global_params
=
set
([
x
[
0
].
name
for
x
in
params_grads
])
# _param(str) -> device_id(int)
self
.
worker_idx
=
worker_idx
self
.
worker_num
=
worker_num
# global_param2device contains fp32 params and fp16 params
self
.
global_param2device
=
self
.
_split_params
(
params_grads
,
worker_idx
,
worker_num
)
def
has_param
(
self
,
var_name
):
return
var_name
in
self
.
global_param2device
and
\
self
.
_var_device_id
(
var_name
)
==
self
.
worker_idx
def
has_opt_var
(
self
,
var_name
):
return
self
.
_var_device_id
(
var_name
)
==
self
.
worker_idx
def
has_var
(
self
,
var_name
):
return
self
.
_var_device_id
(
var_name
)
==
-
1
or
\
self
.
_var_device_id
(
var_name
)
==
self
.
worker_idx
def
_split_params
(
self
,
params_grads
,
worker_idx
,
worker_num
):
param2device
=
{}
total_param_mem
=
0.0
param2mem
=
[]
for
param
in
[
x
[
0
]
for
x
in
params_grads
]:
mem
=
get_var_size
(
param
)
total_param_mem
+=
mem
param2mem
.
append
((
param
.
name
,
mem
))
device2params
=
{
x
:
[]
for
x
in
range
(
worker_num
)}
device_idx
=
0
mem_accu
=
0.0
for
param_name
,
mem
in
param2mem
:
if
mem_accu
>
total_param_mem
*
1.0
*
(
device_idx
+
1
)
/
worker_num
:
device_idx
+=
1
device2params
[
device_idx
].
append
(
param_name
)
param2device
[
param_name
]
=
device_idx
mem_accu
+=
mem
return
param2device
def
_var_device_id
(
self
,
var_name
):
if
var_name
in
self
.
global_param2device
:
return
self
.
global_param2device
[
var_name
]
for
suffix
in
[
"_moment1_0"
,
"_moment2_0"
,
"_beta1_pow_acc_0"
,
"_beta2_pow_acc_0"
,
"_velocity_0"
]:
base_name
=
re
.
sub
(
suffix
,
''
,
var_name
)
if
base_name
in
self
.
global_param2device
:
return
self
.
global_param2device
[
base_name
]
return
-
1
def
find_broadcast_params
(
self
,
block
):
broadcast_vars
=
set
([])
fp16_params
=
set
([])
fp16_to_fp32
=
{}
param_usage
=
{
x
:
0
for
x
in
self
.
global_params
}
for
op
in
block
.
ops
:
if
is_optimizer_op
(
op
):
continue
for
input_name
in
op
.
desc
.
input_arg_names
():
if
input_name
in
self
.
global_params
:
param_usage
[
input_name
]
+=
1
for
op
in
block
.
ops
:
if
not
FP16Utils
.
is_fp16_cast_op
(
block
,
op
,
self
.
global_params
):
continue
input_name
=
op
.
input_arg_names
[
0
]
output_name
=
op
.
output_arg_names
[
0
]
broadcast_vars
.
add
(
output_name
)
fp16_params
.
add
(
output_name
)
fp16_to_fp32
[
output_name
]
=
input_name
param_usage
[
input_name
]
-=
1
self
.
global_param2device
[
output_name
]
=
self
.
global_param2device
[
input_name
]
for
param
,
usage
in
param_usage
.
items
():
if
usage
>
0
:
broadcast_vars
.
add
(
param
)
return
broadcast_vars
def
device
(
self
,
var_name
):
return
self
.
_var_device_id
(
var_name
)
def
is_param
(
self
,
var_name
):
return
var_name
in
self
.
global_params
def
is_opti_var
(
self
,
var_name
):
if
var_name
in
self
.
global_params
:
return
True
for
suffix
in
[
"_moment1_0"
,
"_moment2_0"
,
"_beta1_pow_acc_0"
,
"_beta2_pow_acc_0"
,
"_velocity_0"
]:
base_name
=
re
.
sub
(
suffix
,
''
,
var_name
)
if
base_name
in
self
.
global_params
:
return
True
return
False
class
ProgramSegment
(
object
):
def
__init__
(
self
,
block
):
self
.
_block
=
block
self
.
_allreduce_vars
=
[]
# sub program start idx
self
.
_start_idx
=
-
1
# sub program end idx
self
.
_end_idx
=
-
1
# param name to broadcast name
self
.
_param2broadcast
=
{}
self
.
_broadcast_vars
=
[]
# cast op pairs, fp16 name (str) -> fp32 name (str)
self
.
_cast_ops
=
{}
# fill constant vars
self
.
_fill_constant_vars
=
[]
# parameter mems
self
.
_param_mem
=
0.0
python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py
0 → 100644
浏览文件 @
81244fbf
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.fluid
import
core
from
functools
import
reduce
from
paddle.distributed.fleet.meta_optimizers.common
import
is_loss_grad_op
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
,
OP_ROLE_KEY
,
OP_ROLE_VAR_KEY
import
re
def
check_broadcast
(
block
):
"""
if a var is broadcasted, it should have a sync_comm before
this var is used, if not, raise error.
if the broadcasted var has a fill_constant op, the fill_constant
op should stay forward before the broadcast op, and before a
sync_calc op. Otherwise, raise error.
"""
broadcast_vars
=
{}
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
op
.
type
==
"c_broadcast"
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
if
"@BroadCast"
in
var_name
:
if
var_name
in
broadcast_vars
:
raise
ValueError
(
"var_name areadly exist: {}"
"the old pos is {}, the new pos is {}"
.
format
(
var_name
,
broadcast_vars
[
var_name
][
"broadcast_pos"
],
idx
))
broadcast_vars
[
var_name
]
=
{
"fill_constant_pos"
:
-
1
,
"broadcast_pos"
:
idx
,
}
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
op
.
type
==
"fill_constant"
:
var_name
=
op
.
desc
.
output_arg_names
()[
0
]
if
var_name
in
broadcast_vars
:
broadcast_vars
[
var_name
][
"fill_constant_pos"
]
=
idx
continue
last_sync_comm_op_idx
=
-
1
last_sync_calc_op_idx
=
-
1
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
op
.
type
==
"c_sync_comm_stream"
:
last_sync_comm_op_idx
=
idx
continue
if
op
.
type
==
"c_sync_calc_stream"
:
last_sync_calc_op_idx
=
idx
continue
if
op
.
type
==
"c_broadcast"
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
if
"@BroadCast"
in
var_name
:
if
broadcast_vars
[
var_name
][
"fill_constant_pos"
]
!=
-
1
:
assert
(
last_sync_calc_op_idx
!=
-
1
)
assert
(
broadcast_vars
[
var_name
][
"fill_constant_pos"
]
<
last_sync_calc_op_idx
)
assert
(
last_sync_calc_op_idx
<
idx
)
continue
for
input_name
in
op
.
desc
.
input_arg_names
():
if
input_name
in
broadcast_vars
:
assert
(
broadcast_vars
[
input_name
][
"broadcast_pos"
]
!=
-
1
)
assert
(
broadcast_vars
[
input_name
][
"broadcast_pos"
]
<
last_sync_comm_op_idx
)
assert
(
last_sync_comm_op_idx
<
idx
)
return
def
check_allreduce_sum
(
block
):
"""
if a Var is allreduced, the op order should be:
- 0: op that generate Var
- 1: sync_calc
- 2: allreduce_sum op
- 3: sync_comm
- 4: op that use Var
"""
var_status
=
{}
for
op
in
block
.
ops
:
if
op
.
type
==
"c_allreduce_sum"
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
var_status
[
var_name
]
=
-
1
for
op
in
block
.
ops
:
if
op
.
type
==
"c_sync_calc_stream"
:
for
var_name
in
var_status
:
if
var_name
in
var_status
and
var_status
[
var_name
]
==
0
:
var_status
[
var_name
]
=
1
elif
op
.
type
==
"c_allreduce_sum"
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
if
var_status
[
var_name
]
==
-
1
:
raise
ValueError
(
"{} is not generated, but you are"
"trying to all-reduce it"
.
format
(
var_name
))
if
var_status
[
var_name
]
==
0
:
raise
ValueError
(
"There should be a sync_calc op "
"after generate Var: {} and before the"
"c_allreduce_sum op"
.
format
(
var_name
))
assert
(
var_status
[
var_name
]
==
1
)
var_status
[
var_name
]
=
2
elif
op
.
type
==
"c_sync_comm_stream"
:
for
var_name
in
op
.
desc
.
input_arg_names
():
if
var_name
in
var_status
and
var_status
[
var_name
]
==
2
:
var_status
[
var_name
]
=
3
else
:
for
input_name
in
op
.
desc
.
input_arg_names
():
if
input_name
in
var_status
:
if
var_status
[
input_name
]
!=
3
:
raise
ValueError
(
"There should be a sync_comm op "
"after allreduce the Var: {}"
.
format
(
var_name
))
for
output_name
in
op
.
desc
.
output_arg_names
():
if
output_name
in
var_status
and
\
var_status
[
output_name
]
==
-
1
:
var_status
[
output_name
]
=
0
return
def
insert_sync_calc_op
(
block
,
insert_idx
,
calc_dep_vars
):
"""
_insert_sync_calc_op
"""
op_role
=
block
.
ops
[
insert_idx
].
attr
(
'op_role'
)
block
.
_insert_op_without_sync
(
insert_idx
,
type
=
'c_sync_calc_stream'
,
inputs
=
{
'X'
:
calc_dep_vars
},
outputs
=
{
'Out'
:
calc_dep_vars
},
attrs
=
{
OP_ROLE_KEY
:
op_role
})
return
def
insert_sync_comm_ops
(
block
,
insert_idx
,
nrings
,
comm_dep_vars
):
"""
_insert_sync_comm_ops
"""
op_role
=
block
.
ops
[
insert_idx
].
attr
(
'op_role'
)
for
i
in
range
(
nrings
):
block
.
_insert_op_without_sync
(
insert_idx
,
type
=
'c_sync_comm_stream'
,
inputs
=
{
'X'
:
comm_dep_vars
},
outputs
=
{
'Out'
:
comm_dep_vars
},
attrs
=
{
'ring_id'
:
i
,
OP_ROLE_KEY
:
op_role
})
return
nrings
def
insert_fill_constant_ops
(
block
,
insert_idx
,
fill_constant_vars
):
"""
_add_fill_constant_ops
"""
op_role
=
block
.
ops
[
insert_idx
].
attr
(
'op_role'
)
for
broadcast_name
in
fill_constant_vars
:
broadcast_var
=
block
.
var
(
broadcast_name
)
block
.
_insert_op_without_sync
(
insert_idx
,
type
=
"fill_constant"
,
outputs
=
{
"Out"
:
broadcast_var
.
name
},
attrs
=
{
"shape"
:
broadcast_var
.
shape
,
"dtype"
:
broadcast_var
.
dtype
,
"value"
:
0.0
,
OP_ROLE_KEY
:
op_role
})
return
def
insert_cast_ops
(
block
,
insert_idx
,
cast_ops
):
"""
_add_cast_ops
"""
op_role
=
block
.
ops
[
insert_idx
].
attr
(
'op_role'
)
for
fp16_name
,
fp32_name
in
cast_ops
.
items
():
block
.
_insert_op_without_sync
(
insert_idx
,
type
=
"cast"
,
inputs
=
{
"X"
:
fp32_name
},
outputs
=
{
"Out"
:
fp16_name
},
attrs
=
{
"in_dtype"
:
core
.
VarDesc
.
VarType
.
FP32
,
"out_dtype"
:
core
.
VarDesc
.
VarType
.
FP16
,
OP_ROLE_KEY
:
op_role
})
return
def
insert_allreduce_ops
(
block
,
insert_idx
,
nrings
,
allreduce_vars
):
"""
_add_allreduce_ops
"""
ring_id
=
-
1
for
var
in
allreduce_vars
:
ring_id
=
(
ring_id
+
1
)
%
nrings
block
.
_insert_op_without_sync
(
insert_idx
,
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
var
},
outputs
=
{
'Out'
:
var
},
attrs
=
{
'ring_id'
:
ring_id
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
return
def
insert_broadcast_ops
(
block
,
insert_idx
,
nrings
,
broadcast2root
):
"""
_add_broadcast_ops
"""
ring_id
=
-
1
op_role
=
block
.
ops
[
insert_idx
].
attr
(
'op_role'
)
for
broadcast_name
,
root_device
in
broadcast2root
:
ring_id
=
(
ring_id
+
1
)
%
nrings
block
.
_insert_op_without_sync
(
insert_idx
,
type
=
'c_broadcast'
,
inputs
=
{
'X'
:
broadcast_name
},
outputs
=
{
'Out'
:
broadcast_name
},
attrs
=
{
'ring_id'
:
ring_id
,
'root'
:
root_device
,
OP_ROLE_KEY
:
op_role
})
return
DtypeToSize
=
{
core
.
VarDesc
.
VarType
.
FP16
:
2
,
core
.
VarDesc
.
VarType
.
FP32
:
4
,
core
.
VarDesc
.
VarType
.
FP64
:
8
,
core
.
VarDesc
.
VarType
.
INT16
:
2
,
core
.
VarDesc
.
VarType
.
INT32
:
4
,
core
.
VarDesc
.
VarType
.
INT64
:
8
,
core
.
VarDesc
.
VarType
.
BOOL
:
1
,
core
.
VarDesc
.
VarType
.
UINT8
:
1
,
}
def
get_var_size
(
param
):
"""
input:
- param: var
return:
var size in Bytes
"""
assert
-
1
not
in
param
.
shape
return
reduce
(
lambda
x
,
y
:
x
*
y
,
param
.
shape
)
*
DtypeToSize
[
param
.
dtype
]
/
1024.0
/
1024.0
def
insert_scale_loss_grad_ops
(
block
,
scale
=
1.0
):
'''
In order to keep the learning rate consistent in different numbers of
training workers, we scale the loss grad by the number of workers
'''
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
is_loss_grad_op
(
op
):
loss_grad_var
=
block
.
vars
[
op
.
output_arg_names
[
0
]]
block
.
_insert_op_without_sync
(
idx
+
1
,
type
=
'scale'
,
inputs
=
{
'X'
:
loss_grad_var
},
outputs
=
{
'Out'
:
loss_grad_var
},
attrs
=
{
'scale'
:
scale
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
python/paddle/distributed/fleet/meta_optimizers/sharding/weight_decay_helper.py
0 → 100644
浏览文件 @
81244fbf
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.distributed.fleet.meta_optimizers.common
import
OP_ROLE_VAR_KEY
class
WeightDecayHelper
(
object
):
def
__init__
(
self
):
pass
def
_is_weight_decay_op
(
self
,
op
):
return
op
.
desc
.
has_attr
(
"op_namescope"
)
\
and
op
.
desc
.
attr
(
"op_namescope"
).
startswith
(
"/regularization"
)
def
prune_weight_decay
(
self
,
block
,
shard
):
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
not
self
.
_is_weight_decay_op
(
op
):
continue
if
OP_ROLE_VAR_KEY
not
in
op
.
attr_names
:
raise
ValueError
(
"The Weight Dacay op should hold op_role_var attribute"
"but the {} op does not hold op_role_var"
.
format
(
op
.
type
))
op_role_var
=
op
.
all_attrs
()[
OP_ROLE_VAR_KEY
]
if
not
shard
.
has_param
(
op_role_var
[
0
]):
block
.
_remove_op
(
idx
,
sync
=
False
)
block
.
_sync_with_cpp
()
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
0 → 100644
浏览文件 @
81244fbf
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.fluid
import
unique_name
,
core
import
paddle.fluid
as
fluid
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
,
OP_ROLE_VAR_KEY
,
CollectiveHelper
from
paddle.distributed.fleet.meta_optimizers.common
import
is_backward_op
from
paddle.distributed.fleet.meta_optimizers.meta_optimizer_base
import
MetaOptimizerBase
from
paddle.distributed.fleet.meta_optimizers.sharding.shard
import
Shard
,
ProgramSegment
from
paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper
import
FP16Utils
from
paddle.distributed.fleet.meta_optimizers.sharding.weight_decay_helper
import
WeightDecayHelper
from
paddle.distributed.fleet.meta_optimizers.sharding.gradient_clip_helper
import
GradientClipHelper
from
paddle.distributed.fleet.meta_optimizers.sharding.prune
import
ProgramDeps
from
paddle.distributed.fleet.meta_optimizers.sharding.utils
import
*
from
functools
import
reduce
__all__
=
[
"ShardingOptimizer"
]
class
ShardingOptimizer
(
MetaOptimizerBase
):
def
__init__
(
self
,
optimizer
):
super
(
ShardingOptimizer
,
self
).
__init__
(
optimizer
)
self
.
inner_opt
=
optimizer
self
.
meta_optimizers_white_list
=
[
"RecomputeOptimizer"
,
"AMPOptimizer"
,
]
self
.
meta_optimizers_black_list
=
[
"GraphExecutionOptimizer"
,
]
self
.
_main_program
=
None
self
.
_startup_program
=
None
self
.
_segments
=
[]
# params and fp16 params is for broadcast
self
.
_params
=
set
([])
self
.
_broadcast_vars
=
set
([])
# reduced grads to param name
self
.
_reduced_grads_to_param
=
{}
self
.
_shard
=
Shard
()
def
_can_apply
(
self
):
if
not
self
.
role_maker
.
_is_collective
:
return
False
if
self
.
role_maker
.
_worker_num
()
<=
1
:
return
False
return
self
.
user_defined_strategy
.
sharding
def
_disable_strategy
(
self
,
dist_strategy
):
dist_strategy
.
sharding
=
False
dist_strategy
.
sharding_configs
=
{}
def
_enable_strategy
(
self
,
dist_strategy
,
context
):
dist_strategy
.
sharding
=
True
dist_strategy
.
sharding_configs
=
{
"fuse_broadcast_MB"
:
32
}
def
minimize_impl
(
self
,
loss
,
startup_program
=
None
,
parameter_list
=
None
,
no_grad_set
=
None
):
self
.
_nrings
=
self
.
user_defined_strategy
.
nccl_comm_num
self
.
_fuse_broadcast_MB
=
self
.
user_defined_strategy
.
sharding_configs
[
"fuse_broadcast_MB"
]
if
self
.
inner_opt
is
None
:
raise
ValueError
(
"self.inner_opt of ShardingOptimizer should not be None."
)
optimize_ops
,
params_grads
=
self
.
inner_opt
.
minimize
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
)
if
startup_program
is
None
:
startup_program
=
default_startup_program
()
main_block
=
loss
.
block
startup_block
=
startup_program
.
global_block
()
self
.
_main_program
=
main_block
.
program
self
.
_startup_program
=
startup_program
# step1: set_up
self
.
_set_up
(
params_grads
)
# step2: split_program
self
.
_split_program
(
main_block
)
# step3: add broadcast and reduce ops
self
.
_add_broadcast_allreduce
(
main_block
)
main_block
.
_sync_with_cpp
()
startup_block
.
_sync_with_cpp
()
# step4: insert reduce_sum for grad
insert_scale_loss_grad_ops
(
main_block
,
scale
=
1.0
/
self
.
role_maker
.
_worker_num
())
main_block
.
_sync_with_cpp
()
# step5: remove unneeded ops and vars from block
self
.
_prune_main_program
(
main_block
)
self
.
_prune_startup_program
(
startup_block
)
# check op dependecy
check_broadcast
(
main_block
)
check_allreduce_sum
(
main_block
)
self
.
_wait
()
return
optimize_ops
,
params_grads
def
_set_up
(
self
,
params_grads
):
# step 1: initialize nccl
worker_idx
=
self
.
role_maker
.
_worker_index
()
endpoints
=
self
.
role_maker
.
_get_trainer_endpoints
()
current_endpoint
=
endpoints
[
worker_idx
]
self
.
_collective_helper
=
CollectiveHelper
(
self
.
role_maker
,
self
.
_nrings
)
for
ring_id
in
range
(
self
.
_nrings
):
self
.
_collective_helper
.
_init_communicator
(
self
.
_startup_program
,
current_endpoint
,
endpoints
,
worker_idx
,
ring_id
,
None
)
startup_block
=
self
.
_startup_program
.
global_block
()
startup_block
.
_sync_with_cpp
()
# step 2: split params
self
.
_params
=
set
([
x
[
0
].
name
for
x
in
params_grads
])
self
.
_shard
.
setup
(
params_grads
,
worker_idx
,
self
.
role_maker
.
_worker_num
())
# step 3: get broadcast vars
self
.
_broadcast_vars
=
self
.
_shard
.
find_broadcast_params
(
self
.
_main_program
.
global_block
())
def
_wait
(
self
,
):
endpoints
=
self
.
role_maker
.
_get_trainer_endpoints
()
current_endpoint
=
endpoints
[
self
.
role_maker
.
_worker_index
()]
if
self
.
role_maker
.
_worker_index
()
==
0
:
self
.
_collective_helper
.
_wait
(
current_endpoint
,
endpoints
)
def
_split_program
(
self
,
block
):
for
op_idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
int
(
op
.
attr
(
'op_role'
))
!=
int
(
OpRole
.
Optimize
):
last_backward_op_idx
=
op_idx
+
1
break
segment
=
ProgramSegment
(
block
)
segment
.
_end_idx
=
last_backward_op_idx
for
op_idx
in
reversed
(
range
(
last_backward_op_idx
)):
op
=
block
.
ops
[
op_idx
]
assert
(
int
(
op
.
attr
(
'op_role'
))
!=
int
(
OpRole
.
Optimize
))
if
segment
.
_param_mem
>=
self
.
_fuse_broadcast_MB
:
segment
.
_start_idx
=
op_idx
+
1
self
.
_segments
.
insert
(
0
,
segment
)
segment
=
ProgramSegment
(
block
)
segment
.
_end_idx
=
op_idx
+
1
# find broadcast vars
for
input_name
in
op
.
desc
.
input_arg_names
():
if
input_name
not
in
self
.
_broadcast_vars
:
continue
if
input_name
in
segment
.
_param2broadcast
:
# skip broadcast because it reuse the old broadcast var
broadcast_name
=
segment
.
_param2broadcast
[
input_name
]
if
input_name
!=
broadcast_name
:
op
.
_rename_input
(
input_name
,
broadcast_name
)
continue
if
self
.
_shard
.
has_param
(
input_name
):
broadcast_var_name
=
input_name
else
:
broadcast_var_name
=
unique_name
.
generate
(
input_name
+
"@BroadCast"
)
segment
.
_fill_constant_vars
.
append
(
broadcast_var_name
)
segment
.
_param2broadcast
[
input_name
]
=
broadcast_var_name
segment
.
_broadcast_vars
.
append
((
broadcast_var_name
,
self
.
_shard
.
device
(
input_name
)))
segment
.
_param_mem
+=
get_var_size
(
self
.
_main_program
.
global_block
().
var
(
input_name
))
# find reduce vars
if
is_backward_op
(
op
)
and
\
OP_ROLE_VAR_KEY
in
op
.
attr_names
:
op_role_var
=
op
.
all_attrs
()[
OP_ROLE_VAR_KEY
]
if
len
(
op_role_var
)
!=
0
:
assert
len
(
op_role_var
)
%
2
==
0
for
i
in
range
(
0
,
len
(
op_role_var
),
2
):
param
,
reduced_grad
=
op_role_var
[
i
],
op_role_var
[
i
+
1
]
segment
.
_allreduce_vars
.
append
(
reduced_grad
)
assert
(
reduced_grad
not
in
self
.
_reduced_grads_to_param
)
self
.
_reduced_grads_to_param
[
reduced_grad
]
=
param
# find cast op
if
FP16Utils
.
is_fp16_cast_op
(
block
,
op
,
self
.
_params
):
fp32_param
=
op
.
desc
.
input_arg_names
()[
0
]
fp16_param
=
op
.
desc
.
output_arg_names
()[
0
]
if
self
.
_shard
.
has_param
(
fp32_param
):
segment
.
_cast_ops
[
fp16_param
]
=
fp32_param
if
segment
.
_param_mem
>
0
:
segment
.
_start_idx
=
0
self
.
_segments
.
insert
(
0
,
segment
)
return
def
_prune_main_program
(
self
,
block
):
"""
calculate deps from allredce op to optimize op,
remove ops and vars not needed in this worker
"""
weightdecay_helper
=
WeightDecayHelper
()
weightdecay_helper
.
prune_weight_decay
(
block
,
self
.
_shard
)
FP16Utils
.
prune_fp16
(
block
,
self
.
_shard
,
self
.
_reduced_grads_to_param
,
self
.
_nrings
)
gradientclip_helper
=
GradientClipHelper
()
gradientclip_helper
.
prune_gradient_clip
(
block
,
self
.
_shard
)
# build prog deps
reduced_grads
=
[]
for
idx
,
op
in
enumerate
(
block
.
ops
):
input_names
=
op
.
desc
.
input_arg_names
()
output_names
=
op
.
desc
.
output_arg_names
()
if
op
.
type
==
"c_allreduce_sum"
:
assert
(
len
(
output_names
)
==
1
)
output_name
=
output_names
[
0
]
reduced_grads
.
append
(
output_name
)
pruned_opti_vars
=
[]
for
var_name
in
list
(
block
.
vars
.
keys
()):
if
self
.
_shard
.
is_opti_var
(
var_name
)
and
\
not
self
.
_shard
.
has_opt_var
(
var_name
):
pruned_opti_vars
.
append
(
var_name
)
program_deps
=
ProgramDeps
(
block
,
reduced_grads
,
pruned_opti_vars
)
# Init
for
var_name
in
program_deps
.
_end_vars
:
program_deps
.
_should_removed_var
.
add
(
var_name
)
# Prune
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
op
.
type
in
[
"c_allreduce_sum"
,
"c_sync_comm_stream"
,
"c_calc_comm_stream"
,
"c_gen_nccl_id"
,
"c_comm_init"
]:
pass
elif
op
.
type
==
"conditional_block"
:
assert
(
op
.
desc
.
has_attr
(
"sub_block"
))
subblock_idx
=
op
.
desc
.
attr
(
"sub_block"
).
id
subblock_deps
=
program_deps
.
get_sub_block_deps
(
subblock_idx
)
# only prune amp subblock
if
subblock_deps
is
None
or
not
self
.
_is_amp_subblock
(
op
):
continue
# init
reversed_output_vars
=
[]
for
output_name
in
op
.
desc
.
output
(
"Out"
):
if
output_name
in
program_deps
.
_should_removed_var
:
subblock_deps
.
_should_removed_var
.
add
(
output_name
)
program_deps
.
crop_output_var_from_op
(
idx
,
output_name
)
else
:
reversed_output_vars
.
append
(
output_name
)
# prune
for
sub_op_idx
,
_
in
reversed
(
list
(
enumerate
(
subblock_deps
.
_block
.
ops
))):
if
subblock_deps
.
should_remove_op
(
sub_op_idx
):
subblock_deps
.
remove_op
(
sub_op_idx
)
reversed_input_vars
=
[]
for
input_name
in
op
.
desc
.
input
(
'Input'
):
if
input_name
not
in
subblock_deps
.
_should_removed_var
:
reversed_input_vars
.
append
(
input_name
)
else
:
program_deps
.
crop_input_var_from_op
(
idx
,
input_name
)
op
.
desc
.
set_input
(
'Input'
,
reversed_input_vars
)
op
.
desc
.
set_output
(
'Out'
,
reversed_output_vars
)
else
:
if
program_deps
.
should_remove_op
(
idx
):
program_deps
.
remove_op
(
idx
)
block
.
_sync_with_cpp
()
return
def
_add_broadcast_allreduce
(
self
,
block
):
"""
_add_broadcast_allreduce
"""
ring_id
=
-
1
if
len
(
self
.
_segments
)
<
1
:
return
if
self
.
_segments
[
-
1
].
_allreduce_vars
:
insert_sync_comm_ops
(
block
,
self
.
_segments
[
-
1
].
_end_idx
,
self
.
_nrings
,
self
.
_segments
[
-
1
].
_allreduce_vars
)
insert_allreduce_ops
(
block
,
self
.
_segments
[
-
1
].
_end_idx
,
self
.
_nrings
,
self
.
_segments
[
-
1
].
_allreduce_vars
)
for
idx
,
segment
in
reversed
(
list
(
enumerate
(
self
.
_segments
))):
allreduce_vars
=
self
.
_segments
[
idx
-
1
].
_allreduce_vars
if
idx
>
0
else
[]
broadcast_vars
=
self
.
_segments
[
idx
+
1
].
_broadcast_vars
if
idx
<
len
(
self
.
_segments
)
-
1
else
[]
fill_constant_vars
=
self
.
_segments
[
idx
+
2
].
_fill_constant_vars
if
idx
<
len
(
self
.
_segments
)
-
2
else
[]
cast_ops
=
self
.
_segments
[
idx
+
2
].
_cast_ops
if
idx
<
len
(
self
.
_segments
)
-
2
else
{}
for
op_idx
in
reversed
(
range
(
segment
.
_start_idx
,
segment
.
_end_idx
)):
op
=
block
.
ops
[
op_idx
]
for
input_name
in
op
.
desc
.
input_arg_names
():
if
input_name
in
segment
.
_param2broadcast
and
\
input_name
!=
segment
.
_param2broadcast
[
input_name
]:
op
.
_rename_input
(
input_name
,
segment
.
_param2broadcast
[
input_name
])
for
param_name
,
broadcast_name
in
segment
.
_param2broadcast
.
items
():
if
param_name
!=
broadcast_name
:
block
.
create_var
(
name
=
broadcast_name
,
shape
=
self
.
_main_program
.
global_block
().
var
(
param_name
).
shape
,
dtype
=
self
.
_main_program
.
global_block
().
var
(
param_name
)
.
dtype
,
persistable
=
False
)
# step1: remove cast ops
block
.
_sync_with_cpp
()
segment
.
_end_idx
+=
FP16Utils
.
remove_cast_op
(
block
,
self
.
_params
,
segment
,
0
)
# step2: add Sync ops
comm_dep_vars
=
allreduce_vars
+
[
x
[
0
]
for
x
in
broadcast_vars
]
if
len
(
comm_dep_vars
)
>
0
:
insert_sync_comm_ops
(
block
,
segment
.
_end_idx
,
self
.
_nrings
,
comm_dep_vars
,
)
calc_dep_vars
=
fill_constant_vars
+
[
k
for
k
,
v
in
cast_ops
.
items
()
]
+
self
.
_segments
[
idx
].
_allreduce_vars
if
len
(
calc_dep_vars
)
>
0
:
insert_sync_calc_op
(
block
,
segment
.
_end_idx
,
[
calc_dep_vars
[
-
1
]])
# step3: insert `fill_constant` ops
insert_fill_constant_ops
(
block
,
segment
.
_end_idx
,
fill_constant_vars
)
# step4: add `cast` ops
insert_cast_ops
(
block
,
segment
.
_end_idx
,
cast_ops
)
# step5: add broadcast ops
insert_broadcast_ops
(
block
,
segment
.
_start_idx
,
self
.
_nrings
,
broadcast_vars
)
# step6: add all_reduce ops
insert_allreduce_ops
(
block
,
segment
.
_start_idx
,
self
.
_nrings
,
allreduce_vars
)
block
.
_sync_with_cpp
()
if
self
.
_segments
[
0
].
_broadcast_vars
:
insert_sync_comm_ops
(
block
,
self
.
_segments
[
0
].
_start_idx
,
self
.
_nrings
,
[
x
[
0
]
for
x
in
self
.
_segments
[
0
].
_broadcast_vars
])
insert_broadcast_ops
(
block
,
self
.
_segments
[
0
].
_start_idx
,
self
.
_nrings
,
self
.
_segments
[
0
].
_broadcast_vars
)
fill_constant_vars
=
[]
for
x
in
self
.
_segments
[:
2
]:
fill_constant_vars
+=
x
.
_fill_constant_vars
# Join
cast_ops
=
{}
for
x
in
self
.
_segments
[:
2
]:
for
k
,
v
in
x
.
_cast_ops
.
items
():
cast_ops
[
k
]
=
v
calc_deps_vars
=
fill_constant_vars
+
[
k
for
k
,
v
in
cast_ops
.
items
()]
if
fill_constant_vars
or
cast_ops
:
insert_sync_calc_op
(
block
,
self
.
_segments
[
0
].
_start_idx
,
[
calc_deps_vars
[
-
1
]])
if
fill_constant_vars
:
insert_fill_constant_ops
(
block
,
self
.
_segments
[
0
].
_start_idx
,
fill_constant_vars
)
if
cast_ops
:
insert_cast_ops
(
block
,
self
.
_segments
[
0
].
_start_idx
,
cast_ops
)
return
def
_prune_startup_program
(
self
,
block
):
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
for
output_name
in
op
.
desc
.
output_arg_names
():
if
self
.
_shard
.
has_var
(
output_name
):
continue
#TODO why do we remove op, when only one var is removed
block
.
_remove_op
(
idx
,
sync
=
False
)
break
for
var_name
in
list
(
block
.
vars
.
keys
()):
if
self
.
_shard
.
has_var
(
var_name
):
continue
block
.
_remove_var
(
var_name
,
sync
=
False
)
block
.
_sync_with_cpp
()
python/paddle/fluid/clip.py
浏览文件 @
81244fbf
...
...
@@ -669,7 +669,7 @@ def append_gradient_clip_ops(param_grads):
if
g
is
None
:
continue
with
p
.
block
.
program
.
_optimized_guard
(
[
p
,
g
]),
framework
.
name_scope
(
'gradient_clip
_@CLIP
'
):
[
p
,
g
]),
framework
.
name_scope
(
'gradient_clip'
):
clip_attr
=
getattr
(
p
,
'gradient_clip_attr'
,
None
)
if
clip_attr
is
None
:
return
param_grads
...
...
@@ -685,7 +685,7 @@ def append_gradient_clip_ops(param_grads):
if
g
is
None
:
continue
with
p
.
block
.
program
.
_optimized_guard
(
[
p
,
g
]),
framework
.
name_scope
(
'gra
ident_clip_@CLIP
'
):
[
p
,
g
]),
framework
.
name_scope
(
'gra
dient_clip
'
):
param
,
new_grad
=
clip_attr
.
_create_operators
(
param
=
p
,
grad
=
g
)
param_new_grad_name_dict
[
param
.
name
]
=
new_grad
.
name
res
.
append
([
param
,
new_grad
])
...
...
python/paddle/fluid/framework.py
浏览文件 @
81244fbf
...
...
@@ -2100,10 +2100,16 @@ class Operator(object):
%
(
out_proto
.
name
,
len
(
out_args
)))
out_arg_names
=
[]
for
arg
in
out_args
:
out_arg_names
.
append
(
cpt
.
to_text
(
arg
.
name
))
if
isinstance
(
arg
,
six
.
string_types
):
out_arg_names
.
append
(
arg
)
else
:
out_arg_names
.
append
(
cpt
.
to_text
(
arg
.
name
))
# TODO(minqiyang): could we remove variable's op in static mode?
if
not
in_dygraph_mode
():
arg
.
op
=
self
if
isinstance
(
arg
,
six
.
string_types
):
block
.
var
(
arg
).
op
=
self
else
:
arg
.
op
=
self
self
.
desc
.
set_output
(
out_proto
.
name
,
out_arg_names
)
if
op_attrs
is
not
None
:
...
...
@@ -2837,8 +2843,9 @@ class Block(object):
self
.
_sync_with_cpp
()
return
var
def
_remove_var
(
self
,
name
):
self
.
_sync_with_cpp
()
def
_remove_var
(
self
,
name
,
sync
=
True
):
if
sync
==
True
:
self
.
_sync_with_cpp
()
self
.
desc
.
_remove_var
(
cpt
.
to_bytes
(
name
))
del
self
.
vars
[
name
]
...
...
@@ -2936,7 +2943,23 @@ class Block(object):
self
.
ops
.
insert
(
index
,
op
)
return
op
def
_remove_op
(
self
,
index
):
def
_insert_op_without_sync
(
self
,
index
,
*
args
,
**
kwargs
):
"""
Insert an Operator according to the giving arguments,
without sync_with_cpp to meke the compilation faster.
Args:
index(int): the place that the operator to insert.
Returns:
Operator: the insert Operator.
"""
op_desc
=
self
.
desc
.
_insert_op
(
index
)
op
=
Operator
(
block
=
self
,
desc
=
op_desc
,
*
args
,
**
kwargs
)
self
.
ops
.
insert
(
index
,
op
)
return
op
def
_remove_op
(
self
,
index
,
sync
=
True
):
"""
Remove the specific position operator.
...
...
@@ -2946,7 +2969,8 @@ class Block(object):
Returns:
None
"""
self
.
_sync_with_cpp
()
if
sync
==
True
:
self
.
_sync_with_cpp
()
self
.
desc
.
_remove_op
(
index
,
index
+
1
)
del
self
.
ops
[
index
]
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
81244fbf
...
...
@@ -41,6 +41,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_recompute_meta_optimizer)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_pipeline_meta_optimizer
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_amp_meta_optimizer
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_gradient_merge_meta_optimizer
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_sharding_meta_optimizer
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_localsgd_meta_optimizer
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_lars_meta_optimizer
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_lamb_meta_optimizer
)
...
...
@@ -461,6 +462,7 @@ if(WITH_DISTRIBUTE)
py_test_modules
(
test_fleet_recompute_meta_optimizer MODULES test_fleet_recompute_meta_optimizer ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_fleet_graph_executor MODULES test_fleet_graph_executor ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_fleet_gradient_merge_meta_optimizer MODULES test_fleet_gradient_merge_meta_optimizer ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_fleet_sharding_meta_optimizer MODULES test_fleet_sharding_meta_optimizer ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_fleet_amp_meta_optimizer MODULES test_fleet_amp_meta_optimizer ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_fleet_fp16_allreduce_meta_optimizer MODULES test_fleet_fp16_allreduce_meta_optimizer ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_fleet_pipeline_meta_optimizer MODULES test_fleet_pipeline_meta_optimizer ENVS
${
dist_ENVS
}
)
...
...
python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py
浏览文件 @
81244fbf
...
...
@@ -55,14 +55,22 @@ class TestFleetMetaOptimizer(unittest.TestCase):
strategy
,
train_prog
,
startup_prog
,
name
=
'momentum'
):
name
=
'momentum'
,
regularization
=
None
,
grad_clip
=
None
):
with
fluid
.
program_guard
(
train_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
if
name
==
'momentum'
:
optimizer
=
paddle
.
fluid
.
optimizer
.
Momentum
(
learning_rate
=
0.01
,
momentum
=
0.9
)
learning_rate
=
0.01
,
momentum
=
0.9
,
regularization
=
regularization
,
grad_clip
=
grad_clip
)
elif
name
==
'adam'
:
optimizer
=
paddle
.
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.01
)
optimizer
=
paddle
.
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.01
,
regularization
=
regularization
,
grad_clip
=
grad_clip
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
=
strategy
)
optimizer
.
minimize
(
loss
)
...
...
@@ -121,5 +129,8 @@ class TestFleetMetaOptimizer(unittest.TestCase):
elif
name
==
"gradient_merge"
:
strategy
.
gradient_merge
=
True
strategy
.
gradient_merge_configs
=
{
"k_steps"
:
2
,
"avg"
:
True
}
elif
name
==
"sharding"
:
strategy
.
sharding
=
True
strategy
.
sharding_configs
=
{
"fuse_broadcast_MB"
:
0.2
}
else
:
raise
NotImplementedError
()
python/paddle/fluid/tests/unittests/test_fleet_gradient_merge_meta_optimizer.py
浏览文件 @
81244fbf
...
...
@@ -32,9 +32,6 @@ class TestFleetGradientMergeMetaOptimizer(TestFleetMetaOptimizer):
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
)
vars
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()]
with
open
(
"main_program"
,
'w'
)
as
f
:
f
.
write
(
str
(
train_prog
))
self
.
assertIn
(
'@GradientMerge'
,
''
.
join
(
vars
))
def
test_recom_gm_optimizer
(
self
):
...
...
python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py
0 → 100644
浏览文件 @
81244fbf
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
import
os
import
paddle.distributed.fleet
as
fleet
import
paddle.distributed.fleet.base.role_maker
as
role_maker
from
fleet_meta_optimizer_base
import
TestFleetMetaOptimizer
paddle
.
enable_static
()
class
TestFleetShardingMetaOptimizer
(
TestFleetMetaOptimizer
):
def
test_sharding_optimizer
(
self
):
train_prog
,
startup_prog
=
paddle
.
fluid
.
Program
(),
paddle
.
fluid
.
Program
(
)
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
self
.
set_strategy
(
strategy
,
'sharding'
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
)
parameters
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()
if
x
.
persistable
==
True
]
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
vars
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()]
self
.
assertIn
(
'@BroadCast'
,
''
.
join
(
vars
))
self
.
assertEqual
(
set
(
parameters
),
set
([
"fc_1.b_0"
,
"fc_2.b_0"
,
"fc_2.w_0"
,
"fc_1.b_0_velocity_0"
,
"fc_2.b_0_velocity_0"
,
"fc_2.w_0_velocity_0"
,
"learning_rate_0"
]))
self
.
assertEqual
(
ops
,
[
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'c_sync_calc_stream'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_sync_comm_stream'
,
'mul'
,
'elementwise_add'
,
'tanh'
,
'mul'
,
'elementwise_add'
,
'tanh'
,
'mul'
,
'elementwise_add'
,
'softmax'
,
'cross_entropy2'
,
'mean'
,
'fill_constant'
,
'scale'
,
'mean_grad'
,
'cross_entropy_grad2'
,
'softmax_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'tanh_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'tanh_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'c_sync_calc_stream'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_sync_comm_stream'
,
'momentum'
,
'momentum'
,
'momentum'
])
def
test_sharding_amp_optimizer
(
self
):
train_prog
,
startup_prog
=
paddle
.
fluid
.
Program
(),
paddle
.
fluid
.
Program
(
)
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
self
.
set_strategy
(
strategy
,
'sharding'
)
self
.
set_strategy
(
strategy
,
'amp'
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
)
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
vars
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()]
parameters
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()
if
x
.
persistable
==
True
]
self
.
assertIn
(
'@BroadCast'
,
''
.
join
(
vars
))
self
.
assertIn
(
'cast'
,
ops
)
self
.
assertIn
(
'check_finite_and_unscale'
,
ops
)
self
.
assertEqual
(
set
(
parameters
),
set
([
"fc_1.b_0"
,
"fc_2.b_0"
,
"fc_2.w_0"
,
"fc_1.b_0_velocity_0"
,
"fc_2.b_0_velocity_0"
,
"fc_2.w_0_velocity_0"
,
"learning_rate_0"
,
"loss_scaling_0"
,
"num_bad_steps_0"
,
"num_good_steps_0"
]))
self
.
assertEqual
(
ops
,
[
'cast'
,
'cast'
,
'cast'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'c_sync_calc_stream'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_sync_comm_stream'
,
'cast'
,
'mul'
,
'elementwise_add'
,
'cast'
,
'tanh'
,
'cast'
,
'mul'
,
'elementwise_add'
,
'cast'
,
'tanh'
,
'cast'
,
'mul'
,
'elementwise_add'
,
'softmax'
,
'cast'
,
'cross_entropy2'
,
'mean'
,
'elementwise_mul'
,
'fill_constant'
,
'scale'
,
'elementwise_mul_grad'
,
'mean_grad'
,
'cross_entropy_grad2'
,
'cast'
,
'softmax_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'cast'
,
'tanh_grad'
,
'cast'
,
'elementwise_add_grad'
,
'mul_grad'
,
'cast'
,
'tanh_grad'
,
'cast'
,
'elementwise_add_grad'
,
'mul_grad'
,
'c_sync_calc_stream'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_sync_comm_stream'
,
'cast'
,
'cast'
,
'cast'
,
'check_finite_and_unscale'
,
'cast'
,
'c_sync_calc_stream'
,
'c_allreduce_max'
,
'c_sync_comm_stream'
,
'cast'
,
'update_loss_scaling'
,
'momentum'
,
'momentum'
,
'momentum'
])
def
test_sharding_recompute_optimizer
(
self
):
train_prog
,
startup_prog
=
paddle
.
fluid
.
Program
(),
paddle
.
fluid
.
Program
(
)
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
self
.
set_strategy
(
strategy
,
'sharding'
)
self
.
set_strategy
(
strategy
,
'recompute'
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
)
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
vars
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()]
parameters
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()
if
x
.
persistable
==
True
]
self
.
assertIn
(
'@BroadCast'
,
''
.
join
(
vars
))
self
.
assertIn
(
'subprog'
,
''
.
join
(
vars
))
self
.
assertEqual
(
set
(
parameters
),
set
([
"fc_1.b_0"
,
"fc_2.b_0"
,
"fc_2.w_0"
,
"fc_1.b_0_velocity_0"
,
"fc_2.b_0_velocity_0"
,
"fc_2.w_0_velocity_0"
,
"learning_rate_0"
]))
self
.
assertEqual
(
ops
,
[
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'c_sync_calc_stream'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_sync_comm_stream'
,
'mul'
,
'elementwise_add'
,
'tanh'
,
'mul'
,
'elementwise_add'
,
'tanh'
,
'mul'
,
'elementwise_add'
,
'softmax'
,
'cross_entropy2'
,
'mean'
,
'fill_constant'
,
'scale'
,
'mean_grad'
,
'cross_entropy_grad2'
,
'softmax_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'mul'
,
'elementwise_add'
,
'tanh_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'mul'
,
'elementwise_add'
,
'tanh_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'c_sync_calc_stream'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_sync_comm_stream'
,
'momentum'
,
'momentum'
,
'momentum'
])
def
test_sharding_amp_recompute_optimizer
(
self
):
train_prog
,
startup_prog
=
paddle
.
fluid
.
Program
(),
paddle
.
fluid
.
Program
(
)
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
self
.
set_strategy
(
strategy
,
'sharding'
)
self
.
set_strategy
(
strategy
,
'recompute'
)
self
.
set_strategy
(
strategy
,
'amp'
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
)
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
vars
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()]
parameters
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()
if
x
.
persistable
==
True
]
self
.
assertIn
(
'@BroadCast'
,
''
.
join
(
vars
))
self
.
assertIn
(
'subprog'
,
''
.
join
(
vars
))
self
.
assertIn
(
'cast'
,
ops
)
self
.
assertIn
(
'check_finite_and_unscale'
,
ops
)
self
.
assertEqual
(
set
(
parameters
),
set
([
"fc_1.b_0"
,
"fc_2.b_0"
,
"fc_2.w_0"
,
"fc_1.b_0_velocity_0"
,
"fc_2.b_0_velocity_0"
,
"fc_2.w_0_velocity_0"
,
"learning_rate_0"
,
"loss_scaling_0"
,
"num_bad_steps_0"
,
"num_good_steps_0"
]))
self
.
assertEqual
(
ops
,
[
'cast'
,
'cast'
,
'cast'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'c_sync_calc_stream'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_sync_comm_stream'
,
'cast'
,
'cast'
,
'mul'
,
'cast'
,
'elementwise_add'
,
'cast'
,
'tanh'
,
'cast'
,
'mul'
,
'elementwise_add'
,
'cast'
,
'tanh'
,
'cast'
,
'mul'
,
'elementwise_add'
,
'softmax'
,
'cast'
,
'cross_entropy2'
,
'mean'
,
'elementwise_mul'
,
'fill_constant'
,
'scale'
,
'elementwise_mul_grad'
,
'mean_grad'
,
'cross_entropy_grad2'
,
'cast'
,
'softmax_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'cast'
,
'cast'
,
'mul'
,
'cast'
,
'elementwise_add'
,
'cast'
,
'tanh_grad'
,
'cast'
,
'elementwise_add_grad'
,
'mul_grad'
,
'cast'
,
'cast'
,
'mul'
,
'cast'
,
'elementwise_add'
,
'cast'
,
'tanh_grad'
,
'cast'
,
'elementwise_add_grad'
,
'mul_grad'
,
'c_sync_calc_stream'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_sync_comm_stream'
,
'cast'
,
'cast'
,
'cast'
,
'check_finite_and_unscale'
,
'cast'
,
'c_sync_calc_stream'
,
'c_allreduce_max'
,
'c_sync_comm_stream'
,
'cast'
,
'update_loss_scaling'
,
'momentum'
,
'momentum'
,
'momentum'
])
def
test_sharding_weight_decay
(
self
):
train_prog
,
startup_prog
=
paddle
.
fluid
.
Program
(),
paddle
.
fluid
.
Program
(
)
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
self
.
set_strategy
(
strategy
,
'sharding'
)
regularization
=
paddle
.
fluid
.
regularizer
.
L2Decay
(
0.0001
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
,
regularization
=
regularization
)
parameters
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()
if
x
.
persistable
==
True
]
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
vars
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()]
self
.
assertIn
(
'@BroadCast'
,
''
.
join
(
vars
))
self
.
assertEqual
(
set
(
parameters
),
set
([
"fc_1.b_0"
,
"fc_2.b_0"
,
"fc_2.w_0"
,
"fc_1.b_0_velocity_0"
,
"fc_2.b_0_velocity_0"
,
"fc_2.w_0_velocity_0"
,
"learning_rate_0"
]))
self
.
assertEqual
(
ops
,
[
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'c_sync_calc_stream'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_sync_comm_stream'
,
'mul'
,
'elementwise_add'
,
'tanh'
,
'mul'
,
'elementwise_add'
,
'tanh'
,
'mul'
,
'elementwise_add'
,
'softmax'
,
'cross_entropy2'
,
'mean'
,
'fill_constant'
,
'scale'
,
'mean_grad'
,
'cross_entropy_grad2'
,
'softmax_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'tanh_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'tanh_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'c_sync_calc_stream'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_sync_comm_stream'
,
'scale'
,
'sum'
,
'scale'
,
'sum'
,
'scale'
,
'sum'
,
'momentum'
,
'momentum'
,
'momentum'
])
def
test_sharding_gradient_clip
(
self
):
train_prog
,
startup_prog
=
paddle
.
fluid
.
Program
(),
paddle
.
fluid
.
Program
(
)
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
self
.
set_strategy
(
strategy
,
'sharding'
)
clip
=
paddle
.
fluid
.
clip
.
GradientClipByGlobalNorm
(
clip_norm
=
1.0
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
,
grad_clip
=
clip
)
parameters
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()
if
x
.
persistable
==
True
]
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
vars
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()]
self
.
assertIn
(
'@BroadCast'
,
''
.
join
(
vars
))
self
.
assertEqual
(
set
(
parameters
),
set
([
"fc_1.b_0"
,
"fc_2.b_0"
,
"fc_2.w_0"
,
"fc_1.b_0_velocity_0"
,
"fc_2.b_0_velocity_0"
,
"fc_2.w_0_velocity_0"
,
"learning_rate_0"
]))
self
.
assertEqual
(
ops
,
[
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'c_sync_calc_stream'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_sync_comm_stream'
,
'mul'
,
'elementwise_add'
,
'tanh'
,
'mul'
,
'elementwise_add'
,
'tanh'
,
'mul'
,
'elementwise_add'
,
'softmax'
,
'cross_entropy2'
,
'mean'
,
'fill_constant'
,
'scale'
,
'mean_grad'
,
'cross_entropy_grad2'
,
'softmax_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'tanh_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'tanh_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'c_sync_calc_stream'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_sync_comm_stream'
,
'square'
,
'reduce_sum'
,
'square'
,
'reduce_sum'
,
'square'
,
'reduce_sum'
,
'sum'
,
'c_sync_calc_stream'
,
'c_allreduce_sum'
,
'c_sync_comm_stream'
,
'sqrt'
,
'fill_constant'
,
'elementwise_max'
,
'elementwise_div'
,
'elementwise_mul'
,
'elementwise_mul'
,
'elementwise_mul'
,
'momentum'
,
'momentum'
,
'momentum'
])
if
__name__
==
"__main__"
:
unittest
.
main
()
python/setup.py.in
浏览文件 @
81244fbf
...
...
@@ -148,6 +148,7 @@ packages=['paddle',
'paddle.distributed.fleet',
'paddle.distributed.fleet.base',
'paddle.distributed.fleet.meta_optimizers',
'paddle.distributed.fleet.meta_optimizers.sharding',
'paddle.distributed.fleet.runtime',
'paddle.distributed.fleet.dataset',
'paddle.distributed.fleet.data_generator',
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录