Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
81244fbf
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
81244fbf
编写于
10月 26, 2020
作者:
M
mapingshuo
提交者:
GitHub
10月 26, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add sharding strategy in fleet(#27900)
* add sharding
上级
4877bd59
变更
20
隐藏空白更改
内联
并排
Showing
20 changed file
with
1648 addition
and
14 deletion
+1648
-14
paddle/fluid/framework/distributed_strategy.proto
paddle/fluid/framework/distributed_strategy.proto
+6
-0
python/paddle/distributed/fleet/base/distributed_strategy.py
python/paddle/distributed/fleet/base/distributed_strategy.py
+49
-0
python/paddle/distributed/fleet/meta_optimizers/__init__.py
python/paddle/distributed/fleet/meta_optimizers/__init__.py
+1
-0
python/paddle/distributed/fleet/meta_optimizers/common.py
python/paddle/distributed/fleet/meta_optimizers/common.py
+6
-0
python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py
...paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py
+8
-0
python/paddle/distributed/fleet/meta_optimizers/sharding/__init__.py
...le/distributed/fleet/meta_optimizers/sharding/__init__.py
+13
-0
python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py
...distributed/fleet/meta_optimizers/sharding/fp16_helper.py
+154
-0
python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
...ed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
+90
-0
python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py
...addle/distributed/fleet/meta_optimizers/sharding/prune.py
+131
-0
python/paddle/distributed/fleet/meta_optimizers/sharding/shard.py
...addle/distributed/fleet/meta_optimizers/sharding/shard.py
+144
-0
python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py
...addle/distributed/fleet/meta_optimizers/sharding/utils.py
+274
-0
python/paddle/distributed/fleet/meta_optimizers/sharding/weight_decay_helper.py
...ted/fleet/meta_optimizers/sharding/weight_decay_helper.py
+37
-0
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
...e/distributed/fleet/meta_optimizers/sharding_optimizer.py
+411
-0
python/paddle/fluid/clip.py
python/paddle/fluid/clip.py
+2
-2
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+30
-6
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+2
-0
python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py
...paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py
+14
-3
python/paddle/fluid/tests/unittests/test_fleet_gradient_merge_meta_optimizer.py
...sts/unittests/test_fleet_gradient_merge_meta_optimizer.py
+0
-3
python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py
...uid/tests/unittests/test_fleet_sharding_meta_optimizer.py
+275
-0
python/setup.py.in
python/setup.py.in
+1
-0
未找到文件。
paddle/fluid/framework/distributed_strategy.proto
浏览文件 @
81244fbf
...
...
@@ -24,6 +24,10 @@ enum Mode {
message
RecomputeConfig
{
repeated
string
checkpoints
=
1
;
}
message
ShardingConfig
{
optional
float
fuse_broadcast_MB
=
1
[
default
=
32.0
];
}
message
AMPConfig
{
optional
float
init_loss_scaling
=
1
[
default
=
32768.0
];
optional
int32
incr_every_n_steps
=
2
[
default
=
1000
];
...
...
@@ -130,6 +134,7 @@ message DistributedStrategy {
optional
bool
cudnn_batchnorm_spatial_persistent
=
23
[
default
=
true
];
optional
bool
adaptive_localsgd
=
24
[
default
=
false
];
optional
bool
fp16_allreduce
=
25
[
default
=
false
];
optional
bool
sharding
=
26
[
default
=
false
];
optional
RecomputeConfig
recompute_configs
=
101
;
optional
AMPConfig
amp_configs
=
102
;
...
...
@@ -141,6 +146,7 @@ message DistributedStrategy {
optional
LarsConfig
lars_configs
=
108
;
optional
LambConfig
lamb_configs
=
109
;
optional
AdaptiveLocalSGDConfig
adaptive_localsgd_configs
=
110
;
optional
ShardingConfig
sharding_configs
=
111
;
optional
BuildStrategy
build_strategy
=
201
;
optional
ExecutionStrategy
execution_strategy
=
202
;
}
...
...
python/paddle/distributed/fleet/base/distributed_strategy.py
浏览文件 @
81244fbf
...
...
@@ -611,6 +611,55 @@ class DistributedStrategy(object):
"checkpoint_configs"
)
assign_configs_value
(
self
.
strategy
.
recompute_configs
,
configs
)
@
property
def
sharding
(
self
):
"""
Indicating whether we are using sharding Optimizer for memory
optimization
Default value: False
Examples:
.. code-block:: python
import paddle.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.sharding = True
"""
return
self
.
strategy
.
sharding
@
sharding
.
setter
@
is_strict_auto
def
sharding
(
self
,
flag
):
if
isinstance
(
flag
,
bool
):
self
.
strategy
.
sharding
=
flag
else
:
print
(
"WARNING: sharding should have value of bool type"
)
@
property
def
sharding_configs
(
self
):
"""
Set sharding configurations.
**Note**:
fuse_broadcast_MB(float): size of a fused group of broadcasted parameters.
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.sharding = True
strategy.sharding_configs = {"fuse_broadcast_MB": 32}
"""
return
get_msg_dict
(
self
.
strategy
.
sharding_configs
)
@
sharding_configs
.
setter
@
is_strict_auto
def
sharding_configs
(
self
,
configs
):
check_configs_key
(
self
.
strategy
.
sharding_configs
,
configs
,
"sharding_configs"
)
assign_configs_value
(
self
.
strategy
.
sharding_configs
,
configs
)
@
property
def
pipeline
(
self
):
"""
...
...
python/paddle/distributed/fleet/meta_optimizers/__init__.py
浏览文件 @
81244fbf
...
...
@@ -24,3 +24,4 @@ from .parameter_server_graph_optimizer import ParameterServerGraphOptimizer
from
.dgc_optimizer
import
DGCOptimizer
from
.lamb_optimizer
import
LambOptimizer
from
.fp16_allreduce_optimizer
import
FP16AllReduceOptimizer
from
.sharding_optimizer
import
ShardingOptimizer
python/paddle/distributed/fleet/meta_optimizers/common.py
浏览文件 @
81244fbf
...
...
@@ -99,6 +99,12 @@ class CollectiveHelper(object):
OP_ROLE_KEY
:
OpRole
.
Forward
})
def
_wait
(
self
,
current_endpoint
,
endpoints
):
assert
(
self
.
wait_port
)
other_endpoints
=
endpoints
[:]
other_endpoints
.
remove
(
current_endpoint
)
wait_server_ready
(
other_endpoints
)
def
_broadcast_params
(
self
):
block
=
self
.
startup_program
.
global_block
()
ring_id
=
-
1
...
...
python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py
浏览文件 @
81244fbf
...
...
@@ -30,6 +30,10 @@ class DGCOptimizer(MetaOptimizerBase):
super
(
DGCOptimizer
,
self
).
_set_basic_info
(
loss
,
role_maker
,
user_defined_optimizer
,
user_defined_strategy
)
def
_init_dgc_opt
(
self
):
if
self
.
dgc_opt
is
not
None
:
return
opt
=
self
.
inner_opt
if
not
self
.
role_maker
.
_is_collective
:
...
...
@@ -86,13 +90,16 @@ class DGCOptimizer(MetaOptimizerBase):
parameter_list
=
None
,
no_grad_set
=
None
,
callbacks
=
None
):
self
.
_init_dgc_opt
()
return
self
.
dgc_opt
.
backward
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
,
callbacks
)
def
apply_gradients
(
self
,
params_grads
):
self
.
_init_dgc_opt
()
return
self
.
dgc_opt
.
apply_gradients
(
params_grads
=
params_grads
)
def
apply_optimize
(
self
,
loss
,
startup_program
,
params_grads
):
self
.
_init_dgc_opt
()
return
self
.
dgc_opt
.
apply_optimize
(
loss
,
startup_program
=
startup_program
,
params_grads
=
params_grads
)
...
...
@@ -101,6 +108,7 @@ class DGCOptimizer(MetaOptimizerBase):
startup_program
=
None
,
parameter_list
=
None
,
no_grad_set
=
None
):
self
.
_init_dgc_opt
()
optimize_ops
,
params_grads
=
\
self
.
dgc_opt
.
minimize
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
)
...
...
python/paddle/distributed/fleet/meta_optimizers/sharding/__init__.py
0 → 100644
浏览文件 @
81244fbf
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py
0 → 100644
浏览文件 @
81244fbf
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.distributed.fleet.meta_optimizers.common
import
is_optimizer_op
,
OP_ROLE_KEY
,
OpRole
from
paddle.distributed.fleet.meta_optimizers.sharding.utils
import
*
from
paddle.fluid
import
core
class
FP16Utils
(
object
):
def
__init__
(
self
):
pass
@
staticmethod
def
is_fp16_cast_op
(
block
,
op
,
params
):
if
op
.
type
!=
"cast"
:
return
False
if
is_optimizer_op
(
op
):
return
False
assert
(
len
(
op
.
desc
.
input_arg_names
())
==
1
)
assert
(
len
(
op
.
desc
.
output_arg_names
())
==
1
)
input_name
,
output_name
=
op
.
desc
.
input_arg_names
()[
0
],
op
.
desc
.
output_arg_names
()[
0
]
if
input_name
not
in
params
:
return
False
input_var
=
block
.
var
(
input_name
)
output_var
=
block
.
var
(
output_name
)
if
input_var
.
dtype
!=
core
.
VarDesc
.
VarType
.
FP32
or
\
output_var
.
dtype
!=
core
.
VarDesc
.
VarType
.
FP16
:
return
False
return
True
@
staticmethod
def
is_fp32_cast_op
(
block
,
op
):
if
op
.
type
!=
"cast"
:
return
False
if
not
is_optimizer_op
(
op
):
return
False
assert
(
len
(
op
.
desc
.
input_arg_names
())
==
1
)
assert
(
len
(
op
.
desc
.
output_arg_names
())
==
1
)
input_name
,
output_name
=
op
.
desc
.
input_arg_names
()[
0
],
op
.
desc
.
output_arg_names
()[
0
]
input_var
=
block
.
var
(
input_name
)
output_var
=
block
.
var
(
output_name
)
if
input_var
.
dtype
!=
core
.
VarDesc
.
VarType
.
FP16
or
\
output_var
.
dtype
!=
core
.
VarDesc
.
VarType
.
FP32
:
return
False
return
True
@
staticmethod
def
remove_cast_op
(
block
,
params
,
segment
,
offset
):
inserted_op_num
=
0
for
op_idx
in
reversed
(
range
(
offset
+
segment
.
_start_idx
,
offset
+
segment
.
_end_idx
)):
op
=
block
.
ops
[
op_idx
]
if
FP16Utils
.
is_fp16_cast_op
(
block
,
op
,
params
):
block
.
_remove_op
(
op_idx
,
sync
=
False
)
inserted_op_num
-=
1
block
.
_sync_with_cpp
()
return
inserted_op_num
@
staticmethod
def
prune_fp16
(
block
,
shard
,
reduced_grads_to_param
,
nrings
):
# remove cast
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
not
FP16Utils
.
is_fp32_cast_op
(
block
,
op
):
continue
output_name
=
op
.
desc
.
output_arg_names
()[
0
]
param_name
=
output_name
.
strip
(
"@GRAD"
)
if
param_name
not
in
shard
.
global_params
:
raise
ValueError
(
"Input 'X' of check_finite_and_unscale must"
"be grads, but {} is not a grad"
.
format
(
input_name
))
if
output_name
in
reduced_grads_to_param
:
continue
if
shard
.
has_param
(
param_name
):
continue
block
.
_remove_op
(
idx
,
sync
=
False
)
block
.
_remove_var
(
output_name
,
sync
=
False
)
block
.
_sync_with_cpp
()
update_loss_scaling_op_idx
=
-
1
inf_var_name
=
''
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
op
.
type
==
"update_loss_scaling"
:
update_loss_scaling_op_idx
=
idx
inf_var_name
=
op
.
desc
.
input
(
'FoundInfinite'
)[
0
]
op
.
_rename_input
(
inf_var_name
,
inf_var_name
+
"@sharding"
)
if
op
.
type
in
[
"check_finite_and_unscale"
,
"update_loss_scaling"
]:
reversed_x
=
[]
for
input_name
in
op
.
desc
.
input
(
'X'
):
param_name
=
input_name
.
strip
(
"@GRAD"
)
if
param_name
not
in
shard
.
global_params
:
raise
ValueError
(
"Input 'X' of check_finite_and_unscale must"
"be grads, but {} is not a grad"
.
format
(
input_name
))
if
shard
.
has_param
(
param_name
):
reversed_x
.
append
(
input_name
)
op
.
desc
.
set_input
(
'X'
,
reversed_x
)
op
.
desc
.
set_output
(
'Out'
,
reversed_x
)
if
update_loss_scaling_op_idx
==
-
1
:
return
inf_var
=
block
.
var
(
inf_var_name
)
inf_var_fp32
=
block
.
create_var
(
name
=
inf_var_name
+
"@cast_int32"
,
shape
=
inf_var
.
shape
,
dtype
=
core
.
VarDesc
.
VarType
.
INT32
)
inf_var_sharding
=
block
.
create_var
(
name
=
inf_var_name
+
"@sharding"
,
shape
=
inf_var
.
shape
,
dtype
=
inf_var
.
dtype
)
block
.
_insert_op_without_sync
(
update_loss_scaling_op_idx
,
type
=
'cast'
,
inputs
=
{
'X'
:
inf_var
},
outputs
=
{
'Out'
:
inf_var_fp32
},
attrs
=
{
"in_dtype"
:
inf_var
.
dtype
,
"out_dtype"
:
inf_var_fp32
.
dtype
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
insert_sync_calc_op
(
block
,
update_loss_scaling_op_idx
+
1
,
[
inf_var_fp32
])
block
.
_insert_op_without_sync
(
update_loss_scaling_op_idx
+
2
,
type
=
'c_allreduce_max'
,
inputs
=
{
'X'
:
inf_var_fp32
},
outputs
=
{
'Out'
:
inf_var_fp32
},
attrs
=
{
'ring_id'
:
0
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
comm_op_num
=
insert_sync_comm_ops
(
block
,
update_loss_scaling_op_idx
+
3
,
nrings
,
[
inf_var_fp32
])
block
.
_insert_op_without_sync
(
update_loss_scaling_op_idx
+
3
+
comm_op_num
,
type
=
'cast'
,
inputs
=
{
'X'
:
inf_var_fp32
},
outputs
=
{
'Out'
:
inf_var_sharding
},
attrs
=
{
"in_dtype"
:
inf_var_fp32
.
dtype
,
"out_dtype"
:
inf_var_sharding
.
dtype
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
block
.
_sync_with_cpp
()
python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
0 → 100644
浏览文件 @
81244fbf
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.distributed.fleet.meta_optimizers.common
import
OP_ROLE_KEY
,
OpRole
class
GradientClipHelper
(
object
):
def
__init__
(
self
):
pass
def
_is_gradient_clip_op
(
self
,
op
):
return
op
.
desc
.
has_attr
(
"op_namescope"
)
\
and
op
.
desc
.
attr
(
"op_namescope"
).
startswith
(
"/gradient_clip"
)
def
prune_gradient_clip
(
self
,
block
,
shard
):
deperated_vars
=
set
()
deperate_op_idx
=
set
()
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
not
self
.
_is_gradient_clip_op
(
op
):
continue
if
op
.
type
==
"sum"
:
continue
deperate_op
=
False
for
input_name
in
op
.
desc
.
input_arg_names
():
if
input_name
in
deperated_vars
:
deperate_op
=
True
param_name
=
input_name
.
strip
(
"@GRAD"
)
if
shard
.
is_param
(
param_name
)
and
\
not
shard
.
has_param
(
param_name
):
deperate_op
=
True
if
deperate_op
:
deperate_op_idx
.
add
(
idx
)
for
output_name
in
op
.
desc
.
output_arg_names
():
deperated_vars
.
add
(
output_name
)
if
not
deperated_vars
:
# got no gradient_clip op
return
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
not
self
.
_is_gradient_clip_op
(
op
):
continue
if
idx
in
deperate_op_idx
:
block
.
_remove_op
(
idx
,
sync
=
False
)
continue
reversed_inputs
=
[]
if
op
.
type
==
"sum"
:
for
input_name
in
op
.
desc
.
input_arg_names
():
if
input_name
not
in
deperated_vars
:
reversed_inputs
.
append
(
input_name
)
op
.
desc
.
set_input
(
"X"
,
reversed_inputs
)
assert
(
len
(
op
.
desc
.
output_arg_names
())
==
1
)
sum_res
=
op
.
desc
.
output_arg_names
()[
0
]
block
.
_insert_op_without_sync
(
idx
+
1
,
type
=
'c_sync_comm_stream'
,
inputs
=
{
'X'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
attrs
=
{
'ring_id'
:
0
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
block
.
_insert_op_without_sync
(
idx
+
1
,
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
attrs
=
{
'ring_id'
:
0
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
block
.
_insert_op_without_sync
(
idx
+
1
,
type
=
'c_sync_calc_stream'
,
inputs
=
{
'X'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
attrs
=
{
OP_ROLE_KEY
:
OpRole
.
Optimize
})
for
var_name
in
deperated_vars
:
block
.
_remove_var
(
var_name
,
sync
=
False
)
block
.
_sync_with_cpp
()
return
python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py
0 → 100644
浏览文件 @
81244fbf
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
ProgramDeps
(
object
):
def
__init__
(
self
,
block
,
start_vars
,
end_vars
):
self
.
_block
=
block
# vars where to start to build the deps
self
.
_start_vars
=
start_vars
# vars where to stop to build the deps
self
.
_end_vars
=
end_vars
# var name -> op idxs which depends on this var
self
.
_var_to_use_op
=
{}
# sub block deps which is a subset of this topo
self
.
_sub_block_deps
=
{}
# var name -> op idxs which generate var
self
.
_var_to_generate_op
=
{}
self
.
_should_removed_var
=
set
()
self
.
_father_block_deps
=
None
self
.
_build_deps
()
def
get_sub_block_deps
(
self
,
idx
):
if
idx
in
self
.
_sub_block_deps
:
return
self
.
_sub_block_deps
[
idx
]
else
:
return
None
def
get_var_deps
(
self
,
var_name
):
if
var_name
in
self
.
_var_to_use_op
:
return
self
.
_var_to_use_op
[
var_name
]
else
:
return
None
def
_build_deps
(
self
,
):
for
var_name
in
self
.
_start_vars
:
self
.
_var_to_use_op
[
var_name
]
=
[]
self
.
_var_to_generate_op
[
var_name
]
=
[]
for
idx
,
op
in
enumerate
(
self
.
_block
.
ops
):
if
op
.
type
in
[
"c_allreduce_sum"
,
"c_sync_comm_stream"
,
"c_calc_comm_stream"
]:
continue
input_vars
=
op
.
desc
.
input_arg_names
()
output_vars
=
op
.
desc
.
output_arg_names
()
deps_reduce
=
False
for
input_name
in
input_vars
:
if
input_name
in
self
.
_var_to_use_op
:
deps_reduce
=
True
if
not
deps_reduce
:
continue
for
input_name
in
input_vars
:
if
input_name
in
self
.
_var_to_use_op
:
self
.
_var_to_use_op
[
input_name
].
append
(
idx
)
for
output_name
in
output_vars
:
if
output_name
not
in
self
.
_var_to_use_op
:
self
.
_var_to_use_op
[
output_name
]
=
[]
if
output_name
not
in
self
.
_var_to_generate_op
:
self
.
_var_to_generate_op
[
output_name
]
=
[
idx
]
else
:
self
.
_var_to_generate_op
[
output_name
].
append
(
idx
)
if
op
.
type
==
"conditional_block"
:
# subblock
assert
(
op
.
desc
.
has_attr
(
"sub_block"
))
subblock_idx
=
op
.
desc
.
attr
(
"sub_block"
).
id
subblock_deps
=
ProgramDeps
(
self
.
_block
.
program
.
block
(
subblock_idx
),
op
.
desc
.
input_arg_names
(),
op
.
desc
.
output_arg_names
())
self
.
_sub_block_deps
[
subblock_idx
]
=
subblock_deps
subblock_deps
.
_father_block_deps
=
self
def
crop_input_var_from_op
(
self
,
op_idx
,
var_name
):
if
var_name
in
self
.
_var_to_use_op
:
# update var -> dep_var_op
if
self
.
_var_to_use_op
[
var_name
]
!=
[]:
if
op_idx
not
in
self
.
_var_to_use_op
[
var_name
]:
raise
ValueError
(
"op_idx: {} is not in self._var_to_use_op[{}], "
"self._var_to_use_op[{}] is {}"
.
format
(
op_idx
,
var_name
,
var_name
,
self
.
_var_to_use_op
[
var_name
]))
self
.
_var_to_use_op
[
var_name
].
remove
(
op_idx
)
# update _should_removed_var
if
var_name
in
self
.
_start_vars
:
self
.
_should_removed_var
.
discard
(
var_name
)
elif
self
.
_var_to_use_op
[
var_name
]
==
[]:
# no more deps of this var
self
.
_should_removed_var
.
add
(
var_name
)
elif
self
.
_var_to_generate_op
[
var_name
][
-
1
]
>=
self
.
_var_to_use_op
[
var_name
][
-
1
]:
# there are circle in the graph
self
.
_should_removed_var
.
add
(
var_name
)
else
:
# input_name should not be deleted
self
.
_should_removed_var
.
discard
(
var_name
)
def
crop_output_var_from_op
(
self
,
op_idx
,
var_name
):
if
var_name
in
self
.
_var_to_generate_op
:
assert
(
op_idx
in
self
.
_var_to_generate_op
[
var_name
])
self
.
_var_to_generate_op
[
var_name
].
remove
(
op_idx
)
if
self
.
_block
.
has_var
(
var_name
):
if
var_name
not
in
self
.
_var_to_generate_op
or
self
.
_var_to_generate_op
[
var_name
]
==
[]:
self
.
_block
.
_remove_var
(
var_name
,
sync
=
False
)
def
remove_op
(
self
,
op_idx
):
# update deps
op
=
self
.
_block
.
ops
[
op_idx
]
for
input_name
in
op
.
desc
.
input_arg_names
():
self
.
crop_input_var_from_op
(
op_idx
,
input_name
)
for
output_name
in
op
.
desc
.
output_arg_names
():
self
.
crop_output_var_from_op
(
op_idx
,
output_name
)
self
.
_block
.
_remove_op
(
op_idx
,
sync
=
False
)
def
should_remove_op
(
self
,
op_idx
):
op
=
self
.
_block
.
ops
[
op_idx
]
for
output_name
in
op
.
desc
.
output_arg_names
():
if
output_name
not
in
self
.
_should_removed_var
:
return
False
return
True
python/paddle/distributed/fleet/meta_optimizers/sharding/shard.py
0 → 100644
浏览文件 @
81244fbf
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.distributed.fleet.meta_optimizers.common
import
is_optimizer_op
from
paddle.distributed.fleet.meta_optimizers.sharding.utils
import
*
from
paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper
import
FP16Utils
class
Shard
(
object
):
def
__init__
(
self
,
):
self
.
global_params
=
set
([])
self
.
worker_idx
=
-
1
self
.
worker_num
=
-
1
self
.
global_param2device
=
{}
def
setup
(
self
,
params_grads
,
worker_idx
,
worker_num
):
# param names of all devices
self
.
global_params
=
set
([
x
[
0
].
name
for
x
in
params_grads
])
# _param(str) -> device_id(int)
self
.
worker_idx
=
worker_idx
self
.
worker_num
=
worker_num
# global_param2device contains fp32 params and fp16 params
self
.
global_param2device
=
self
.
_split_params
(
params_grads
,
worker_idx
,
worker_num
)
def
has_param
(
self
,
var_name
):
return
var_name
in
self
.
global_param2device
and
\
self
.
_var_device_id
(
var_name
)
==
self
.
worker_idx
def
has_opt_var
(
self
,
var_name
):
return
self
.
_var_device_id
(
var_name
)
==
self
.
worker_idx
def
has_var
(
self
,
var_name
):
return
self
.
_var_device_id
(
var_name
)
==
-
1
or
\
self
.
_var_device_id
(
var_name
)
==
self
.
worker_idx
def
_split_params
(
self
,
params_grads
,
worker_idx
,
worker_num
):
param2device
=
{}
total_param_mem
=
0.0
param2mem
=
[]
for
param
in
[
x
[
0
]
for
x
in
params_grads
]:
mem
=
get_var_size
(
param
)
total_param_mem
+=
mem
param2mem
.
append
((
param
.
name
,
mem
))
device2params
=
{
x
:
[]
for
x
in
range
(
worker_num
)}
device_idx
=
0
mem_accu
=
0.0
for
param_name
,
mem
in
param2mem
:
if
mem_accu
>
total_param_mem
*
1.0
*
(
device_idx
+
1
)
/
worker_num
:
device_idx
+=
1
device2params
[
device_idx
].
append
(
param_name
)
param2device
[
param_name
]
=
device_idx
mem_accu
+=
mem
return
param2device
def
_var_device_id
(
self
,
var_name
):
if
var_name
in
self
.
global_param2device
:
return
self
.
global_param2device
[
var_name
]
for
suffix
in
[
"_moment1_0"
,
"_moment2_0"
,
"_beta1_pow_acc_0"
,
"_beta2_pow_acc_0"
,
"_velocity_0"
]:
base_name
=
re
.
sub
(
suffix
,
''
,
var_name
)
if
base_name
in
self
.
global_param2device
:
return
self
.
global_param2device
[
base_name
]
return
-
1
def
find_broadcast_params
(
self
,
block
):
broadcast_vars
=
set
([])
fp16_params
=
set
([])
fp16_to_fp32
=
{}
param_usage
=
{
x
:
0
for
x
in
self
.
global_params
}
for
op
in
block
.
ops
:
if
is_optimizer_op
(
op
):
continue
for
input_name
in
op
.
desc
.
input_arg_names
():
if
input_name
in
self
.
global_params
:
param_usage
[
input_name
]
+=
1
for
op
in
block
.
ops
:
if
not
FP16Utils
.
is_fp16_cast_op
(
block
,
op
,
self
.
global_params
):
continue
input_name
=
op
.
input_arg_names
[
0
]
output_name
=
op
.
output_arg_names
[
0
]
broadcast_vars
.
add
(
output_name
)
fp16_params
.
add
(
output_name
)
fp16_to_fp32
[
output_name
]
=
input_name
param_usage
[
input_name
]
-=
1
self
.
global_param2device
[
output_name
]
=
self
.
global_param2device
[
input_name
]
for
param
,
usage
in
param_usage
.
items
():
if
usage
>
0
:
broadcast_vars
.
add
(
param
)
return
broadcast_vars
def
device
(
self
,
var_name
):
return
self
.
_var_device_id
(
var_name
)
def
is_param
(
self
,
var_name
):
return
var_name
in
self
.
global_params
def
is_opti_var
(
self
,
var_name
):
if
var_name
in
self
.
global_params
:
return
True
for
suffix
in
[
"_moment1_0"
,
"_moment2_0"
,
"_beta1_pow_acc_0"
,
"_beta2_pow_acc_0"
,
"_velocity_0"
]:
base_name
=
re
.
sub
(
suffix
,
''
,
var_name
)
if
base_name
in
self
.
global_params
:
return
True
return
False
class
ProgramSegment
(
object
):
def
__init__
(
self
,
block
):
self
.
_block
=
block
self
.
_allreduce_vars
=
[]
# sub program start idx
self
.
_start_idx
=
-
1
# sub program end idx
self
.
_end_idx
=
-
1
# param name to broadcast name
self
.
_param2broadcast
=
{}
self
.
_broadcast_vars
=
[]
# cast op pairs, fp16 name (str) -> fp32 name (str)
self
.
_cast_ops
=
{}
# fill constant vars
self
.
_fill_constant_vars
=
[]
# parameter mems
self
.
_param_mem
=
0.0
python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py
0 → 100644
浏览文件 @
81244fbf
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.fluid
import
core
from
functools
import
reduce
from
paddle.distributed.fleet.meta_optimizers.common
import
is_loss_grad_op
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
,
OP_ROLE_KEY
,
OP_ROLE_VAR_KEY
import
re
def
check_broadcast
(
block
):
"""
if a var is broadcasted, it should have a sync_comm before
this var is used, if not, raise error.
if the broadcasted var has a fill_constant op, the fill_constant
op should stay forward before the broadcast op, and before a
sync_calc op. Otherwise, raise error.
"""
broadcast_vars
=
{}
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
op
.
type
==
"c_broadcast"
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
if
"@BroadCast"
in
var_name
:
if
var_name
in
broadcast_vars
:
raise
ValueError
(
"var_name areadly exist: {}"
"the old pos is {}, the new pos is {}"
.
format
(
var_name
,
broadcast_vars
[
var_name
][
"broadcast_pos"
],
idx
))
broadcast_vars
[
var_name
]
=
{
"fill_constant_pos"
:
-
1
,
"broadcast_pos"
:
idx
,
}
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
op
.
type
==
"fill_constant"
:
var_name
=
op
.
desc
.
output_arg_names
()[
0
]
if
var_name
in
broadcast_vars
:
broadcast_vars
[
var_name
][
"fill_constant_pos"
]
=
idx
continue
last_sync_comm_op_idx
=
-
1
last_sync_calc_op_idx
=
-
1
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
op
.
type
==
"c_sync_comm_stream"
:
last_sync_comm_op_idx
=
idx
continue
if
op
.
type
==
"c_sync_calc_stream"
:
last_sync_calc_op_idx
=
idx
continue
if
op
.
type
==
"c_broadcast"
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
if
"@BroadCast"
in
var_name
:
if
broadcast_vars
[
var_name
][
"fill_constant_pos"
]
!=
-
1
:
assert
(
last_sync_calc_op_idx
!=
-
1
)
assert
(
broadcast_vars
[
var_name
][
"fill_constant_pos"
]
<
last_sync_calc_op_idx
)
assert
(
last_sync_calc_op_idx
<
idx
)
continue
for
input_name
in
op
.
desc
.
input_arg_names
():
if
input_name
in
broadcast_vars
:
assert
(
broadcast_vars
[
input_name
][
"broadcast_pos"
]
!=
-
1
)
assert
(
broadcast_vars
[
input_name
][
"broadcast_pos"
]
<
last_sync_comm_op_idx
)
assert
(
last_sync_comm_op_idx
<
idx
)
return
def
check_allreduce_sum
(
block
):
"""
if a Var is allreduced, the op order should be:
- 0: op that generate Var
- 1: sync_calc
- 2: allreduce_sum op
- 3: sync_comm
- 4: op that use Var
"""
var_status
=
{}
for
op
in
block
.
ops
:
if
op
.
type
==
"c_allreduce_sum"
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
var_status
[
var_name
]
=
-
1
for
op
in
block
.
ops
:
if
op
.
type
==
"c_sync_calc_stream"
:
for
var_name
in
var_status
:
if
var_name
in
var_status
and
var_status
[
var_name
]
==
0
:
var_status
[
var_name
]
=
1
elif
op
.
type
==
"c_allreduce_sum"
:
var_name
=
op
.
desc
.
input_arg_names
()[
0
]
if
var_status
[
var_name
]
==
-
1
:
raise
ValueError
(
"{} is not generated, but you are"
"trying to all-reduce it"
.
format
(
var_name
))
if
var_status
[
var_name
]
==
0
:
raise
ValueError
(
"There should be a sync_calc op "
"after generate Var: {} and before the"
"c_allreduce_sum op"
.
format
(
var_name
))
assert
(
var_status
[
var_name
]
==
1
)
var_status
[
var_name
]
=
2
elif
op
.
type
==
"c_sync_comm_stream"
:
for
var_name
in
op
.
desc
.
input_arg_names
():
if
var_name
in
var_status
and
var_status
[
var_name
]
==
2
:
var_status
[
var_name
]
=
3
else
:
for
input_name
in
op
.
desc
.
input_arg_names
():
if
input_name
in
var_status
:
if
var_status
[
input_name
]
!=
3
:
raise
ValueError
(
"There should be a sync_comm op "
"after allreduce the Var: {}"
.
format
(
var_name
))
for
output_name
in
op
.
desc
.
output_arg_names
():
if
output_name
in
var_status
and
\
var_status
[
output_name
]
==
-
1
:
var_status
[
output_name
]
=
0
return
def
insert_sync_calc_op
(
block
,
insert_idx
,
calc_dep_vars
):
"""
_insert_sync_calc_op
"""
op_role
=
block
.
ops
[
insert_idx
].
attr
(
'op_role'
)
block
.
_insert_op_without_sync
(
insert_idx
,
type
=
'c_sync_calc_stream'
,
inputs
=
{
'X'
:
calc_dep_vars
},
outputs
=
{
'Out'
:
calc_dep_vars
},
attrs
=
{
OP_ROLE_KEY
:
op_role
})
return
def
insert_sync_comm_ops
(
block
,
insert_idx
,
nrings
,
comm_dep_vars
):
"""
_insert_sync_comm_ops
"""
op_role
=
block
.
ops
[
insert_idx
].
attr
(
'op_role'
)
for
i
in
range
(
nrings
):
block
.
_insert_op_without_sync
(
insert_idx
,
type
=
'c_sync_comm_stream'
,
inputs
=
{
'X'
:
comm_dep_vars
},
outputs
=
{
'Out'
:
comm_dep_vars
},
attrs
=
{
'ring_id'
:
i
,
OP_ROLE_KEY
:
op_role
})
return
nrings
def
insert_fill_constant_ops
(
block
,
insert_idx
,
fill_constant_vars
):
"""
_add_fill_constant_ops
"""
op_role
=
block
.
ops
[
insert_idx
].
attr
(
'op_role'
)
for
broadcast_name
in
fill_constant_vars
:
broadcast_var
=
block
.
var
(
broadcast_name
)
block
.
_insert_op_without_sync
(
insert_idx
,
type
=
"fill_constant"
,
outputs
=
{
"Out"
:
broadcast_var
.
name
},
attrs
=
{
"shape"
:
broadcast_var
.
shape
,
"dtype"
:
broadcast_var
.
dtype
,
"value"
:
0.0
,
OP_ROLE_KEY
:
op_role
})
return
def
insert_cast_ops
(
block
,
insert_idx
,
cast_ops
):
"""
_add_cast_ops
"""
op_role
=
block
.
ops
[
insert_idx
].
attr
(
'op_role'
)
for
fp16_name
,
fp32_name
in
cast_ops
.
items
():
block
.
_insert_op_without_sync
(
insert_idx
,
type
=
"cast"
,
inputs
=
{
"X"
:
fp32_name
},
outputs
=
{
"Out"
:
fp16_name
},
attrs
=
{
"in_dtype"
:
core
.
VarDesc
.
VarType
.
FP32
,
"out_dtype"
:
core
.
VarDesc
.
VarType
.
FP16
,
OP_ROLE_KEY
:
op_role
})
return
def
insert_allreduce_ops
(
block
,
insert_idx
,
nrings
,
allreduce_vars
):
"""
_add_allreduce_ops
"""
ring_id
=
-
1
for
var
in
allreduce_vars
:
ring_id
=
(
ring_id
+
1
)
%
nrings
block
.
_insert_op_without_sync
(
insert_idx
,
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
var
},
outputs
=
{
'Out'
:
var
},
attrs
=
{
'ring_id'
:
ring_id
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
return
def
insert_broadcast_ops
(
block
,
insert_idx
,
nrings
,
broadcast2root
):
"""
_add_broadcast_ops
"""
ring_id
=
-
1
op_role
=
block
.
ops
[
insert_idx
].
attr
(
'op_role'
)
for
broadcast_name
,
root_device
in
broadcast2root
:
ring_id
=
(
ring_id
+
1
)
%
nrings
block
.
_insert_op_without_sync
(
insert_idx
,
type
=
'c_broadcast'
,
inputs
=
{
'X'
:
broadcast_name
},
outputs
=
{
'Out'
:
broadcast_name
},
attrs
=
{
'ring_id'
:
ring_id
,
'root'
:
root_device
,
OP_ROLE_KEY
:
op_role
})
return
DtypeToSize
=
{
core
.
VarDesc
.
VarType
.
FP16
:
2
,
core
.
VarDesc
.
VarType
.
FP32
:
4
,
core
.
VarDesc
.
VarType
.
FP64
:
8
,
core
.
VarDesc
.
VarType
.
INT16
:
2
,
core
.
VarDesc
.
VarType
.
INT32
:
4
,
core
.
VarDesc
.
VarType
.
INT64
:
8
,
core
.
VarDesc
.
VarType
.
BOOL
:
1
,
core
.
VarDesc
.
VarType
.
UINT8
:
1
,
}
def
get_var_size
(
param
):
"""
input:
- param: var
return:
var size in Bytes
"""
assert
-
1
not
in
param
.
shape
return
reduce
(
lambda
x
,
y
:
x
*
y
,
param
.
shape
)
*
DtypeToSize
[
param
.
dtype
]
/
1024.0
/
1024.0
def
insert_scale_loss_grad_ops
(
block
,
scale
=
1.0
):
'''
In order to keep the learning rate consistent in different numbers of
training workers, we scale the loss grad by the number of workers
'''
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
is_loss_grad_op
(
op
):
loss_grad_var
=
block
.
vars
[
op
.
output_arg_names
[
0
]]
block
.
_insert_op_without_sync
(
idx
+
1
,
type
=
'scale'
,
inputs
=
{
'X'
:
loss_grad_var
},
outputs
=
{
'Out'
:
loss_grad_var
},
attrs
=
{
'scale'
:
scale
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
python/paddle/distributed/fleet/meta_optimizers/sharding/weight_decay_helper.py
0 → 100644
浏览文件 @
81244fbf
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.distributed.fleet.meta_optimizers.common
import
OP_ROLE_VAR_KEY
class
WeightDecayHelper
(
object
):
def
__init__
(
self
):
pass
def
_is_weight_decay_op
(
self
,
op
):
return
op
.
desc
.
has_attr
(
"op_namescope"
)
\
and
op
.
desc
.
attr
(
"op_namescope"
).
startswith
(
"/regularization"
)
def
prune_weight_decay
(
self
,
block
,
shard
):
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
not
self
.
_is_weight_decay_op
(
op
):
continue
if
OP_ROLE_VAR_KEY
not
in
op
.
attr_names
:
raise
ValueError
(
"The Weight Dacay op should hold op_role_var attribute"
"but the {} op does not hold op_role_var"
.
format
(
op
.
type
))
op_role_var
=
op
.
all_attrs
()[
OP_ROLE_VAR_KEY
]
if
not
shard
.
has_param
(
op_role_var
[
0
]):
block
.
_remove_op
(
idx
,
sync
=
False
)
block
.
_sync_with_cpp
()
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
0 → 100644
浏览文件 @
81244fbf
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.fluid
import
unique_name
,
core
import
paddle.fluid
as
fluid
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
,
OP_ROLE_VAR_KEY
,
CollectiveHelper
from
paddle.distributed.fleet.meta_optimizers.common
import
is_backward_op
from
paddle.distributed.fleet.meta_optimizers.meta_optimizer_base
import
MetaOptimizerBase
from
paddle.distributed.fleet.meta_optimizers.sharding.shard
import
Shard
,
ProgramSegment
from
paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper
import
FP16Utils
from
paddle.distributed.fleet.meta_optimizers.sharding.weight_decay_helper
import
WeightDecayHelper
from
paddle.distributed.fleet.meta_optimizers.sharding.gradient_clip_helper
import
GradientClipHelper
from
paddle.distributed.fleet.meta_optimizers.sharding.prune
import
ProgramDeps
from
paddle.distributed.fleet.meta_optimizers.sharding.utils
import
*
from
functools
import
reduce
__all__
=
[
"ShardingOptimizer"
]
class
ShardingOptimizer
(
MetaOptimizerBase
):
def
__init__
(
self
,
optimizer
):
super
(
ShardingOptimizer
,
self
).
__init__
(
optimizer
)
self
.
inner_opt
=
optimizer
self
.
meta_optimizers_white_list
=
[
"RecomputeOptimizer"
,
"AMPOptimizer"
,
]
self
.
meta_optimizers_black_list
=
[
"GraphExecutionOptimizer"
,
]
self
.
_main_program
=
None
self
.
_startup_program
=
None
self
.
_segments
=
[]
# params and fp16 params is for broadcast
self
.
_params
=
set
([])
self
.
_broadcast_vars
=
set
([])
# reduced grads to param name
self
.
_reduced_grads_to_param
=
{}
self
.
_shard
=
Shard
()
def
_can_apply
(
self
):
if
not
self
.
role_maker
.
_is_collective
:
return
False
if
self
.
role_maker
.
_worker_num
()
<=
1
:
return
False
return
self
.
user_defined_strategy
.
sharding
def
_disable_strategy
(
self
,
dist_strategy
):
dist_strategy
.
sharding
=
False
dist_strategy
.
sharding_configs
=
{}
def
_enable_strategy
(
self
,
dist_strategy
,
context
):
dist_strategy
.
sharding
=
True
dist_strategy
.
sharding_configs
=
{
"fuse_broadcast_MB"
:
32
}
def
minimize_impl
(
self
,
loss
,
startup_program
=
None
,
parameter_list
=
None
,
no_grad_set
=
None
):
self
.
_nrings
=
self
.
user_defined_strategy
.
nccl_comm_num
self
.
_fuse_broadcast_MB
=
self
.
user_defined_strategy
.
sharding_configs
[
"fuse_broadcast_MB"
]
if
self
.
inner_opt
is
None
:
raise
ValueError
(
"self.inner_opt of ShardingOptimizer should not be None."
)
optimize_ops
,
params_grads
=
self
.
inner_opt
.
minimize
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
)
if
startup_program
is
None
:
startup_program
=
default_startup_program
()
main_block
=
loss
.
block
startup_block
=
startup_program
.
global_block
()
self
.
_main_program
=
main_block
.
program
self
.
_startup_program
=
startup_program
# step1: set_up
self
.
_set_up
(
params_grads
)
# step2: split_program
self
.
_split_program
(
main_block
)
# step3: add broadcast and reduce ops
self
.
_add_broadcast_allreduce
(
main_block
)
main_block
.
_sync_with_cpp
()
startup_block
.
_sync_with_cpp
()
# step4: insert reduce_sum for grad
insert_scale_loss_grad_ops
(
main_block
,
scale
=
1.0
/
self
.
role_maker
.
_worker_num
())
main_block
.
_sync_with_cpp
()
# step5: remove unneeded ops and vars from block
self
.
_prune_main_program
(
main_block
)
self
.
_prune_startup_program
(
startup_block
)
# check op dependecy
check_broadcast
(
main_block
)
check_allreduce_sum
(
main_block
)
self
.
_wait
()
return
optimize_ops
,
params_grads
def
_set_up
(
self
,
params_grads
):
# step 1: initialize nccl
worker_idx
=
self
.
role_maker
.
_worker_index
()
endpoints
=
self
.
role_maker
.
_get_trainer_endpoints
()
current_endpoint
=
endpoints
[
worker_idx
]
self
.
_collective_helper
=
CollectiveHelper
(
self
.
role_maker
,
self
.
_nrings
)
for
ring_id
in
range
(
self
.
_nrings
):
self
.
_collective_helper
.
_init_communicator
(
self
.
_startup_program
,
current_endpoint
,
endpoints
,
worker_idx
,
ring_id
,
None
)
startup_block
=
self
.
_startup_program
.
global_block
()
startup_block
.
_sync_with_cpp
()
# step 2: split params
self
.
_params
=
set
([
x
[
0
].
name
for
x
in
params_grads
])
self
.
_shard
.
setup
(
params_grads
,
worker_idx
,
self
.
role_maker
.
_worker_num
())
# step 3: get broadcast vars
self
.
_broadcast_vars
=
self
.
_shard
.
find_broadcast_params
(
self
.
_main_program
.
global_block
())
def
_wait
(
self
,
):
endpoints
=
self
.
role_maker
.
_get_trainer_endpoints
()
current_endpoint
=
endpoints
[
self
.
role_maker
.
_worker_index
()]
if
self
.
role_maker
.
_worker_index
()
==
0
:
self
.
_collective_helper
.
_wait
(
current_endpoint
,
endpoints
)
def
_split_program
(
self
,
block
):
for
op_idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
int
(
op
.
attr
(
'op_role'
))
!=
int
(
OpRole
.
Optimize
):
last_backward_op_idx
=
op_idx
+
1
break
segment
=
ProgramSegment
(
block
)
segment
.
_end_idx
=
last_backward_op_idx
for
op_idx
in
reversed
(
range
(
last_backward_op_idx
)):
op
=
block
.
ops
[
op_idx
]
assert
(
int
(
op
.
attr
(
'op_role'
))
!=
int
(
OpRole
.
Optimize
))
if
segment
.
_param_mem
>=
self
.
_fuse_broadcast_MB
:
segment
.
_start_idx
=
op_idx
+
1
self
.
_segments
.
insert
(
0
,
segment
)
segment
=
ProgramSegment
(
block
)
segment
.
_end_idx
=
op_idx
+
1
# find broadcast vars
for
input_name
in
op
.
desc
.
input_arg_names
():
if
input_name
not
in
self
.
_broadcast_vars
:
continue
if
input_name
in
segment
.
_param2broadcast
:
# skip broadcast because it reuse the old broadcast var
broadcast_name
=
segment
.
_param2broadcast
[
input_name
]
if
input_name
!=
broadcast_name
:
op
.
_rename_input
(
input_name
,
broadcast_name
)
continue
if
self
.
_shard
.
has_param
(
input_name
):
broadcast_var_name
=
input_name
else
:
broadcast_var_name
=
unique_name
.
generate
(
input_name
+
"@BroadCast"
)
segment
.
_fill_constant_vars
.
append
(
broadcast_var_name
)
segment
.
_param2broadcast
[
input_name
]
=
broadcast_var_name
segment
.
_broadcast_vars
.
append
((
broadcast_var_name
,
self
.
_shard
.
device
(
input_name
)))
segment
.
_param_mem
+=
get_var_size
(
self
.
_main_program
.
global_block
().
var
(
input_name
))
# find reduce vars
if
is_backward_op
(
op
)
and
\
OP_ROLE_VAR_KEY
in
op
.
attr_names
:
op_role_var
=
op
.
all_attrs
()[
OP_ROLE_VAR_KEY
]
if
len
(
op_role_var
)
!=
0
:
assert
len
(
op_role_var
)
%
2
==
0
for
i
in
range
(
0
,
len
(
op_role_var
),
2
):
param
,
reduced_grad
=
op_role_var
[
i
],
op_role_var
[
i
+
1
]
segment
.
_allreduce_vars
.
append
(
reduced_grad
)
assert
(
reduced_grad
not
in
self
.
_reduced_grads_to_param
)
self
.
_reduced_grads_to_param
[
reduced_grad
]
=
param
# find cast op
if
FP16Utils
.
is_fp16_cast_op
(
block
,
op
,
self
.
_params
):
fp32_param
=
op
.
desc
.
input_arg_names
()[
0
]
fp16_param
=
op
.
desc
.
output_arg_names
()[
0
]
if
self
.
_shard
.
has_param
(
fp32_param
):
segment
.
_cast_ops
[
fp16_param
]
=
fp32_param
if
segment
.
_param_mem
>
0
:
segment
.
_start_idx
=
0
self
.
_segments
.
insert
(
0
,
segment
)
return
def
_prune_main_program
(
self
,
block
):
"""
calculate deps from allredce op to optimize op,
remove ops and vars not needed in this worker
"""
weightdecay_helper
=
WeightDecayHelper
()
weightdecay_helper
.
prune_weight_decay
(
block
,
self
.
_shard
)
FP16Utils
.
prune_fp16
(
block
,
self
.
_shard
,
self
.
_reduced_grads_to_param
,
self
.
_nrings
)
gradientclip_helper
=
GradientClipHelper
()
gradientclip_helper
.
prune_gradient_clip
(
block
,
self
.
_shard
)
# build prog deps
reduced_grads
=
[]
for
idx
,
op
in
enumerate
(
block
.
ops
):
input_names
=
op
.
desc
.
input_arg_names
()
output_names
=
op
.
desc
.
output_arg_names
()
if
op
.
type
==
"c_allreduce_sum"
:
assert
(
len
(
output_names
)
==
1
)
output_name
=
output_names
[
0
]
reduced_grads
.
append
(
output_name
)
pruned_opti_vars
=
[]
for
var_name
in
list
(
block
.
vars
.
keys
()):
if
self
.
_shard
.
is_opti_var
(
var_name
)
and
\
not
self
.
_shard
.
has_opt_var
(
var_name
):
pruned_opti_vars
.
append
(
var_name
)
program_deps
=
ProgramDeps
(
block
,
reduced_grads
,
pruned_opti_vars
)
# Init
for
var_name
in
program_deps
.
_end_vars
:
program_deps
.
_should_removed_var
.
add
(
var_name
)
# Prune
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
op
.
type
in
[
"c_allreduce_sum"
,
"c_sync_comm_stream"
,
"c_calc_comm_stream"
,
"c_gen_nccl_id"
,
"c_comm_init"
]:
pass
elif
op
.
type
==
"conditional_block"
:
assert
(
op
.
desc
.
has_attr
(
"sub_block"
))
subblock_idx
=
op
.
desc
.
attr
(
"sub_block"
).
id
subblock_deps
=
program_deps
.
get_sub_block_deps
(
subblock_idx
)
# only prune amp subblock
if
subblock_deps
is
None
or
not
self
.
_is_amp_subblock
(
op
):
continue
# init
reversed_output_vars
=
[]
for
output_name
in
op
.
desc
.
output
(
"Out"
):
if
output_name
in
program_deps
.
_should_removed_var
:
subblock_deps
.
_should_removed_var
.
add
(
output_name
)
program_deps
.
crop_output_var_from_op
(
idx
,
output_name
)
else
:
reversed_output_vars
.
append
(
output_name
)
# prune
for
sub_op_idx
,
_
in
reversed
(
list
(
enumerate
(
subblock_deps
.
_block
.
ops
))):
if
subblock_deps
.
should_remove_op
(
sub_op_idx
):
subblock_deps
.
remove_op
(
sub_op_idx
)
reversed_input_vars
=
[]
for
input_name
in
op
.
desc
.
input
(
'Input'
):
if
input_name
not
in
subblock_deps
.
_should_removed_var
:
reversed_input_vars
.
append
(
input_name
)
else
:
program_deps
.
crop_input_var_from_op
(
idx
,
input_name
)
op
.
desc
.
set_input
(
'Input'
,
reversed_input_vars
)
op
.
desc
.
set_output
(
'Out'
,
reversed_output_vars
)
else
:
if
program_deps
.
should_remove_op
(
idx
):
program_deps
.
remove_op
(
idx
)
block
.
_sync_with_cpp
()
return
def
_add_broadcast_allreduce
(
self
,
block
):
"""
_add_broadcast_allreduce
"""
ring_id
=
-
1
if
len
(
self
.
_segments
)
<
1
:
return
if
self
.
_segments
[
-
1
].
_allreduce_vars
:
insert_sync_comm_ops
(
block
,
self
.
_segments
[
-
1
].
_end_idx
,
self
.
_nrings
,
self
.
_segments
[
-
1
].
_allreduce_vars
)
insert_allreduce_ops
(
block
,
self
.
_segments
[
-
1
].
_end_idx
,
self
.
_nrings
,
self
.
_segments
[
-
1
].
_allreduce_vars
)
for
idx
,
segment
in
reversed
(
list
(
enumerate
(
self
.
_segments
))):
allreduce_vars
=
self
.
_segments
[
idx
-
1
].
_allreduce_vars
if
idx
>
0
else
[]
broadcast_vars
=
self
.
_segments
[
idx
+
1
].
_broadcast_vars
if
idx
<
len
(
self
.
_segments
)
-
1
else
[]
fill_constant_vars
=
self
.
_segments
[
idx
+
2
].
_fill_constant_vars
if
idx
<
len
(
self
.
_segments
)
-
2
else
[]
cast_ops
=
self
.
_segments
[
idx
+
2
].
_cast_ops
if
idx
<
len
(
self
.
_segments
)
-
2
else
{}
for
op_idx
in
reversed
(
range
(
segment
.
_start_idx
,
segment
.
_end_idx
)):
op
=
block
.
ops
[
op_idx
]
for
input_name
in
op
.
desc
.
input_arg_names
():
if
input_name
in
segment
.
_param2broadcast
and
\
input_name
!=
segment
.
_param2broadcast
[
input_name
]:
op
.
_rename_input
(
input_name
,
segment
.
_param2broadcast
[
input_name
])
for
param_name
,
broadcast_name
in
segment
.
_param2broadcast
.
items
():
if
param_name
!=
broadcast_name
:
block
.
create_var
(
name
=
broadcast_name
,
shape
=
self
.
_main_program
.
global_block
().
var
(
param_name
).
shape
,
dtype
=
self
.
_main_program
.
global_block
().
var
(
param_name
)
.
dtype
,
persistable
=
False
)
# step1: remove cast ops
block
.
_sync_with_cpp
()
segment
.
_end_idx
+=
FP16Utils
.
remove_cast_op
(
block
,
self
.
_params
,
segment
,
0
)
# step2: add Sync ops
comm_dep_vars
=
allreduce_vars
+
[
x
[
0
]
for
x
in
broadcast_vars
]
if
len
(
comm_dep_vars
)
>
0
:
insert_sync_comm_ops
(
block
,
segment
.
_end_idx
,
self
.
_nrings
,
comm_dep_vars
,
)
calc_dep_vars
=
fill_constant_vars
+
[
k
for
k
,
v
in
cast_ops
.
items
()
]
+
self
.
_segments
[
idx
].
_allreduce_vars
if
len
(
calc_dep_vars
)
>
0
:
insert_sync_calc_op
(
block
,
segment
.
_end_idx
,
[
calc_dep_vars
[
-
1
]])
# step3: insert `fill_constant` ops
insert_fill_constant_ops
(
block
,
segment
.
_end_idx
,
fill_constant_vars
)
# step4: add `cast` ops
insert_cast_ops
(
block
,
segment
.
_end_idx
,
cast_ops
)
# step5: add broadcast ops
insert_broadcast_ops
(
block
,
segment
.
_start_idx
,
self
.
_nrings
,
broadcast_vars
)
# step6: add all_reduce ops
insert_allreduce_ops
(
block
,
segment
.
_start_idx
,
self
.
_nrings
,
allreduce_vars
)
block
.
_sync_with_cpp
()
if
self
.
_segments
[
0
].
_broadcast_vars
:
insert_sync_comm_ops
(
block
,
self
.
_segments
[
0
].
_start_idx
,
self
.
_nrings
,
[
x
[
0
]
for
x
in
self
.
_segments
[
0
].
_broadcast_vars
])
insert_broadcast_ops
(
block
,
self
.
_segments
[
0
].
_start_idx
,
self
.
_nrings
,
self
.
_segments
[
0
].
_broadcast_vars
)
fill_constant_vars
=
[]
for
x
in
self
.
_segments
[:
2
]:
fill_constant_vars
+=
x
.
_fill_constant_vars
# Join
cast_ops
=
{}
for
x
in
self
.
_segments
[:
2
]:
for
k
,
v
in
x
.
_cast_ops
.
items
():
cast_ops
[
k
]
=
v
calc_deps_vars
=
fill_constant_vars
+
[
k
for
k
,
v
in
cast_ops
.
items
()]
if
fill_constant_vars
or
cast_ops
:
insert_sync_calc_op
(
block
,
self
.
_segments
[
0
].
_start_idx
,
[
calc_deps_vars
[
-
1
]])
if
fill_constant_vars
:
insert_fill_constant_ops
(
block
,
self
.
_segments
[
0
].
_start_idx
,
fill_constant_vars
)
if
cast_ops
:
insert_cast_ops
(
block
,
self
.
_segments
[
0
].
_start_idx
,
cast_ops
)
return
def
_prune_startup_program
(
self
,
block
):
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
for
output_name
in
op
.
desc
.
output_arg_names
():
if
self
.
_shard
.
has_var
(
output_name
):
continue
#TODO why do we remove op, when only one var is removed
block
.
_remove_op
(
idx
,
sync
=
False
)
break
for
var_name
in
list
(
block
.
vars
.
keys
()):
if
self
.
_shard
.
has_var
(
var_name
):
continue
block
.
_remove_var
(
var_name
,
sync
=
False
)
block
.
_sync_with_cpp
()
python/paddle/fluid/clip.py
浏览文件 @
81244fbf
...
...
@@ -669,7 +669,7 @@ def append_gradient_clip_ops(param_grads):
if
g
is
None
:
continue
with
p
.
block
.
program
.
_optimized_guard
(
[
p
,
g
]),
framework
.
name_scope
(
'gradient_clip
_@CLIP
'
):
[
p
,
g
]),
framework
.
name_scope
(
'gradient_clip'
):
clip_attr
=
getattr
(
p
,
'gradient_clip_attr'
,
None
)
if
clip_attr
is
None
:
return
param_grads
...
...
@@ -685,7 +685,7 @@ def append_gradient_clip_ops(param_grads):
if
g
is
None
:
continue
with
p
.
block
.
program
.
_optimized_guard
(
[
p
,
g
]),
framework
.
name_scope
(
'gra
ident_clip_@CLIP
'
):
[
p
,
g
]),
framework
.
name_scope
(
'gra
dient_clip
'
):
param
,
new_grad
=
clip_attr
.
_create_operators
(
param
=
p
,
grad
=
g
)
param_new_grad_name_dict
[
param
.
name
]
=
new_grad
.
name
res
.
append
([
param
,
new_grad
])
...
...
python/paddle/fluid/framework.py
浏览文件 @
81244fbf
...
...
@@ -2100,10 +2100,16 @@ class Operator(object):
%
(
out_proto
.
name
,
len
(
out_args
)))
out_arg_names
=
[]
for
arg
in
out_args
:
out_arg_names
.
append
(
cpt
.
to_text
(
arg
.
name
))
if
isinstance
(
arg
,
six
.
string_types
):
out_arg_names
.
append
(
arg
)
else
:
out_arg_names
.
append
(
cpt
.
to_text
(
arg
.
name
))
# TODO(minqiyang): could we remove variable's op in static mode?
if
not
in_dygraph_mode
():
arg
.
op
=
self
if
isinstance
(
arg
,
six
.
string_types
):
block
.
var
(
arg
).
op
=
self
else
:
arg
.
op
=
self
self
.
desc
.
set_output
(
out_proto
.
name
,
out_arg_names
)
if
op_attrs
is
not
None
:
...
...
@@ -2837,8 +2843,9 @@ class Block(object):
self
.
_sync_with_cpp
()
return
var
def
_remove_var
(
self
,
name
):
self
.
_sync_with_cpp
()
def
_remove_var
(
self
,
name
,
sync
=
True
):
if
sync
==
True
:
self
.
_sync_with_cpp
()
self
.
desc
.
_remove_var
(
cpt
.
to_bytes
(
name
))
del
self
.
vars
[
name
]
...
...
@@ -2936,7 +2943,23 @@ class Block(object):
self
.
ops
.
insert
(
index
,
op
)
return
op
def
_remove_op
(
self
,
index
):
def
_insert_op_without_sync
(
self
,
index
,
*
args
,
**
kwargs
):
"""
Insert an Operator according to the giving arguments,
without sync_with_cpp to meke the compilation faster.
Args:
index(int): the place that the operator to insert.
Returns:
Operator: the insert Operator.
"""
op_desc
=
self
.
desc
.
_insert_op
(
index
)
op
=
Operator
(
block
=
self
,
desc
=
op_desc
,
*
args
,
**
kwargs
)
self
.
ops
.
insert
(
index
,
op
)
return
op
def
_remove_op
(
self
,
index
,
sync
=
True
):
"""
Remove the specific position operator.
...
...
@@ -2946,7 +2969,8 @@ class Block(object):
Returns:
None
"""
self
.
_sync_with_cpp
()
if
sync
==
True
:
self
.
_sync_with_cpp
()
self
.
desc
.
_remove_op
(
index
,
index
+
1
)
del
self
.
ops
[
index
]
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
81244fbf
...
...
@@ -41,6 +41,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_recompute_meta_optimizer)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_pipeline_meta_optimizer
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_amp_meta_optimizer
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_gradient_merge_meta_optimizer
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_sharding_meta_optimizer
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_localsgd_meta_optimizer
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_lars_meta_optimizer
)
list
(
APPEND MIXED_DIST_TEST_OPS test_fleet_lamb_meta_optimizer
)
...
...
@@ -461,6 +462,7 @@ if(WITH_DISTRIBUTE)
py_test_modules
(
test_fleet_recompute_meta_optimizer MODULES test_fleet_recompute_meta_optimizer ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_fleet_graph_executor MODULES test_fleet_graph_executor ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_fleet_gradient_merge_meta_optimizer MODULES test_fleet_gradient_merge_meta_optimizer ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_fleet_sharding_meta_optimizer MODULES test_fleet_sharding_meta_optimizer ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_fleet_amp_meta_optimizer MODULES test_fleet_amp_meta_optimizer ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_fleet_fp16_allreduce_meta_optimizer MODULES test_fleet_fp16_allreduce_meta_optimizer ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_fleet_pipeline_meta_optimizer MODULES test_fleet_pipeline_meta_optimizer ENVS
${
dist_ENVS
}
)
...
...
python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py
浏览文件 @
81244fbf
...
...
@@ -55,14 +55,22 @@ class TestFleetMetaOptimizer(unittest.TestCase):
strategy
,
train_prog
,
startup_prog
,
name
=
'momentum'
):
name
=
'momentum'
,
regularization
=
None
,
grad_clip
=
None
):
with
fluid
.
program_guard
(
train_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
if
name
==
'momentum'
:
optimizer
=
paddle
.
fluid
.
optimizer
.
Momentum
(
learning_rate
=
0.01
,
momentum
=
0.9
)
learning_rate
=
0.01
,
momentum
=
0.9
,
regularization
=
regularization
,
grad_clip
=
grad_clip
)
elif
name
==
'adam'
:
optimizer
=
paddle
.
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.01
)
optimizer
=
paddle
.
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.01
,
regularization
=
regularization
,
grad_clip
=
grad_clip
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
=
strategy
)
optimizer
.
minimize
(
loss
)
...
...
@@ -121,5 +129,8 @@ class TestFleetMetaOptimizer(unittest.TestCase):
elif
name
==
"gradient_merge"
:
strategy
.
gradient_merge
=
True
strategy
.
gradient_merge_configs
=
{
"k_steps"
:
2
,
"avg"
:
True
}
elif
name
==
"sharding"
:
strategy
.
sharding
=
True
strategy
.
sharding_configs
=
{
"fuse_broadcast_MB"
:
0.2
}
else
:
raise
NotImplementedError
()
python/paddle/fluid/tests/unittests/test_fleet_gradient_merge_meta_optimizer.py
浏览文件 @
81244fbf
...
...
@@ -32,9 +32,6 @@ class TestFleetGradientMergeMetaOptimizer(TestFleetMetaOptimizer):
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
)
vars
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()]
with
open
(
"main_program"
,
'w'
)
as
f
:
f
.
write
(
str
(
train_prog
))
self
.
assertIn
(
'@GradientMerge'
,
''
.
join
(
vars
))
def
test_recom_gm_optimizer
(
self
):
...
...
python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py
0 → 100644
浏览文件 @
81244fbf
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
import
os
import
paddle.distributed.fleet
as
fleet
import
paddle.distributed.fleet.base.role_maker
as
role_maker
from
fleet_meta_optimizer_base
import
TestFleetMetaOptimizer
paddle
.
enable_static
()
class
TestFleetShardingMetaOptimizer
(
TestFleetMetaOptimizer
):
def
test_sharding_optimizer
(
self
):
train_prog
,
startup_prog
=
paddle
.
fluid
.
Program
(),
paddle
.
fluid
.
Program
(
)
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
self
.
set_strategy
(
strategy
,
'sharding'
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
)
parameters
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()
if
x
.
persistable
==
True
]
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
vars
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()]
self
.
assertIn
(
'@BroadCast'
,
''
.
join
(
vars
))
self
.
assertEqual
(
set
(
parameters
),
set
([
"fc_1.b_0"
,
"fc_2.b_0"
,
"fc_2.w_0"
,
"fc_1.b_0_velocity_0"
,
"fc_2.b_0_velocity_0"
,
"fc_2.w_0_velocity_0"
,
"learning_rate_0"
]))
self
.
assertEqual
(
ops
,
[
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'c_sync_calc_stream'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_sync_comm_stream'
,
'mul'
,
'elementwise_add'
,
'tanh'
,
'mul'
,
'elementwise_add'
,
'tanh'
,
'mul'
,
'elementwise_add'
,
'softmax'
,
'cross_entropy2'
,
'mean'
,
'fill_constant'
,
'scale'
,
'mean_grad'
,
'cross_entropy_grad2'
,
'softmax_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'tanh_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'tanh_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'c_sync_calc_stream'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_sync_comm_stream'
,
'momentum'
,
'momentum'
,
'momentum'
])
def
test_sharding_amp_optimizer
(
self
):
train_prog
,
startup_prog
=
paddle
.
fluid
.
Program
(),
paddle
.
fluid
.
Program
(
)
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
self
.
set_strategy
(
strategy
,
'sharding'
)
self
.
set_strategy
(
strategy
,
'amp'
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
)
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
vars
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()]
parameters
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()
if
x
.
persistable
==
True
]
self
.
assertIn
(
'@BroadCast'
,
''
.
join
(
vars
))
self
.
assertIn
(
'cast'
,
ops
)
self
.
assertIn
(
'check_finite_and_unscale'
,
ops
)
self
.
assertEqual
(
set
(
parameters
),
set
([
"fc_1.b_0"
,
"fc_2.b_0"
,
"fc_2.w_0"
,
"fc_1.b_0_velocity_0"
,
"fc_2.b_0_velocity_0"
,
"fc_2.w_0_velocity_0"
,
"learning_rate_0"
,
"loss_scaling_0"
,
"num_bad_steps_0"
,
"num_good_steps_0"
]))
self
.
assertEqual
(
ops
,
[
'cast'
,
'cast'
,
'cast'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'c_sync_calc_stream'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_sync_comm_stream'
,
'cast'
,
'mul'
,
'elementwise_add'
,
'cast'
,
'tanh'
,
'cast'
,
'mul'
,
'elementwise_add'
,
'cast'
,
'tanh'
,
'cast'
,
'mul'
,
'elementwise_add'
,
'softmax'
,
'cast'
,
'cross_entropy2'
,
'mean'
,
'elementwise_mul'
,
'fill_constant'
,
'scale'
,
'elementwise_mul_grad'
,
'mean_grad'
,
'cross_entropy_grad2'
,
'cast'
,
'softmax_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'cast'
,
'tanh_grad'
,
'cast'
,
'elementwise_add_grad'
,
'mul_grad'
,
'cast'
,
'tanh_grad'
,
'cast'
,
'elementwise_add_grad'
,
'mul_grad'
,
'c_sync_calc_stream'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_sync_comm_stream'
,
'cast'
,
'cast'
,
'cast'
,
'check_finite_and_unscale'
,
'cast'
,
'c_sync_calc_stream'
,
'c_allreduce_max'
,
'c_sync_comm_stream'
,
'cast'
,
'update_loss_scaling'
,
'momentum'
,
'momentum'
,
'momentum'
])
def
test_sharding_recompute_optimizer
(
self
):
train_prog
,
startup_prog
=
paddle
.
fluid
.
Program
(),
paddle
.
fluid
.
Program
(
)
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
self
.
set_strategy
(
strategy
,
'sharding'
)
self
.
set_strategy
(
strategy
,
'recompute'
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
)
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
vars
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()]
parameters
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()
if
x
.
persistable
==
True
]
self
.
assertIn
(
'@BroadCast'
,
''
.
join
(
vars
))
self
.
assertIn
(
'subprog'
,
''
.
join
(
vars
))
self
.
assertEqual
(
set
(
parameters
),
set
([
"fc_1.b_0"
,
"fc_2.b_0"
,
"fc_2.w_0"
,
"fc_1.b_0_velocity_0"
,
"fc_2.b_0_velocity_0"
,
"fc_2.w_0_velocity_0"
,
"learning_rate_0"
]))
self
.
assertEqual
(
ops
,
[
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'c_sync_calc_stream'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_sync_comm_stream'
,
'mul'
,
'elementwise_add'
,
'tanh'
,
'mul'
,
'elementwise_add'
,
'tanh'
,
'mul'
,
'elementwise_add'
,
'softmax'
,
'cross_entropy2'
,
'mean'
,
'fill_constant'
,
'scale'
,
'mean_grad'
,
'cross_entropy_grad2'
,
'softmax_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'mul'
,
'elementwise_add'
,
'tanh_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'mul'
,
'elementwise_add'
,
'tanh_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'c_sync_calc_stream'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_sync_comm_stream'
,
'momentum'
,
'momentum'
,
'momentum'
])
def
test_sharding_amp_recompute_optimizer
(
self
):
train_prog
,
startup_prog
=
paddle
.
fluid
.
Program
(),
paddle
.
fluid
.
Program
(
)
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
self
.
set_strategy
(
strategy
,
'sharding'
)
self
.
set_strategy
(
strategy
,
'recompute'
)
self
.
set_strategy
(
strategy
,
'amp'
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
)
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
vars
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()]
parameters
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()
if
x
.
persistable
==
True
]
self
.
assertIn
(
'@BroadCast'
,
''
.
join
(
vars
))
self
.
assertIn
(
'subprog'
,
''
.
join
(
vars
))
self
.
assertIn
(
'cast'
,
ops
)
self
.
assertIn
(
'check_finite_and_unscale'
,
ops
)
self
.
assertEqual
(
set
(
parameters
),
set
([
"fc_1.b_0"
,
"fc_2.b_0"
,
"fc_2.w_0"
,
"fc_1.b_0_velocity_0"
,
"fc_2.b_0_velocity_0"
,
"fc_2.w_0_velocity_0"
,
"learning_rate_0"
,
"loss_scaling_0"
,
"num_bad_steps_0"
,
"num_good_steps_0"
]))
self
.
assertEqual
(
ops
,
[
'cast'
,
'cast'
,
'cast'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'c_sync_calc_stream'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_sync_comm_stream'
,
'cast'
,
'cast'
,
'mul'
,
'cast'
,
'elementwise_add'
,
'cast'
,
'tanh'
,
'cast'
,
'mul'
,
'elementwise_add'
,
'cast'
,
'tanh'
,
'cast'
,
'mul'
,
'elementwise_add'
,
'softmax'
,
'cast'
,
'cross_entropy2'
,
'mean'
,
'elementwise_mul'
,
'fill_constant'
,
'scale'
,
'elementwise_mul_grad'
,
'mean_grad'
,
'cross_entropy_grad2'
,
'cast'
,
'softmax_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'cast'
,
'cast'
,
'mul'
,
'cast'
,
'elementwise_add'
,
'cast'
,
'tanh_grad'
,
'cast'
,
'elementwise_add_grad'
,
'mul_grad'
,
'cast'
,
'cast'
,
'mul'
,
'cast'
,
'elementwise_add'
,
'cast'
,
'tanh_grad'
,
'cast'
,
'elementwise_add_grad'
,
'mul_grad'
,
'c_sync_calc_stream'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_sync_comm_stream'
,
'cast'
,
'cast'
,
'cast'
,
'check_finite_and_unscale'
,
'cast'
,
'c_sync_calc_stream'
,
'c_allreduce_max'
,
'c_sync_comm_stream'
,
'cast'
,
'update_loss_scaling'
,
'momentum'
,
'momentum'
,
'momentum'
])
def
test_sharding_weight_decay
(
self
):
train_prog
,
startup_prog
=
paddle
.
fluid
.
Program
(),
paddle
.
fluid
.
Program
(
)
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
self
.
set_strategy
(
strategy
,
'sharding'
)
regularization
=
paddle
.
fluid
.
regularizer
.
L2Decay
(
0.0001
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
,
regularization
=
regularization
)
parameters
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()
if
x
.
persistable
==
True
]
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
vars
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()]
self
.
assertIn
(
'@BroadCast'
,
''
.
join
(
vars
))
self
.
assertEqual
(
set
(
parameters
),
set
([
"fc_1.b_0"
,
"fc_2.b_0"
,
"fc_2.w_0"
,
"fc_1.b_0_velocity_0"
,
"fc_2.b_0_velocity_0"
,
"fc_2.w_0_velocity_0"
,
"learning_rate_0"
]))
self
.
assertEqual
(
ops
,
[
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'c_sync_calc_stream'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_sync_comm_stream'
,
'mul'
,
'elementwise_add'
,
'tanh'
,
'mul'
,
'elementwise_add'
,
'tanh'
,
'mul'
,
'elementwise_add'
,
'softmax'
,
'cross_entropy2'
,
'mean'
,
'fill_constant'
,
'scale'
,
'mean_grad'
,
'cross_entropy_grad2'
,
'softmax_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'tanh_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'tanh_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'c_sync_calc_stream'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_sync_comm_stream'
,
'scale'
,
'sum'
,
'scale'
,
'sum'
,
'scale'
,
'sum'
,
'momentum'
,
'momentum'
,
'momentum'
])
def
test_sharding_gradient_clip
(
self
):
train_prog
,
startup_prog
=
paddle
.
fluid
.
Program
(),
paddle
.
fluid
.
Program
(
)
avg_cost
,
strategy
=
self
.
net
(
train_prog
,
startup_prog
)
self
.
set_strategy
(
strategy
,
'sharding'
)
clip
=
paddle
.
fluid
.
clip
.
GradientClipByGlobalNorm
(
clip_norm
=
1.0
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
,
grad_clip
=
clip
)
parameters
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()
if
x
.
persistable
==
True
]
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
vars
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()]
self
.
assertIn
(
'@BroadCast'
,
''
.
join
(
vars
))
self
.
assertEqual
(
set
(
parameters
),
set
([
"fc_1.b_0"
,
"fc_2.b_0"
,
"fc_2.w_0"
,
"fc_1.b_0_velocity_0"
,
"fc_2.b_0_velocity_0"
,
"fc_2.w_0_velocity_0"
,
"learning_rate_0"
]))
self
.
assertEqual
(
ops
,
[
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'c_sync_calc_stream'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_sync_comm_stream'
,
'mul'
,
'elementwise_add'
,
'tanh'
,
'mul'
,
'elementwise_add'
,
'tanh'
,
'mul'
,
'elementwise_add'
,
'softmax'
,
'cross_entropy2'
,
'mean'
,
'fill_constant'
,
'scale'
,
'mean_grad'
,
'cross_entropy_grad2'
,
'softmax_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'tanh_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'tanh_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'c_sync_calc_stream'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_sync_comm_stream'
,
'square'
,
'reduce_sum'
,
'square'
,
'reduce_sum'
,
'square'
,
'reduce_sum'
,
'sum'
,
'c_sync_calc_stream'
,
'c_allreduce_sum'
,
'c_sync_comm_stream'
,
'sqrt'
,
'fill_constant'
,
'elementwise_max'
,
'elementwise_div'
,
'elementwise_mul'
,
'elementwise_mul'
,
'elementwise_mul'
,
'momentum'
,
'momentum'
,
'momentum'
])
if
__name__
==
"__main__"
:
unittest
.
main
()
python/setup.py.in
浏览文件 @
81244fbf
...
...
@@ -148,6 +148,7 @@ packages=['paddle',
'paddle.distributed.fleet',
'paddle.distributed.fleet.base',
'paddle.distributed.fleet.meta_optimizers',
'paddle.distributed.fleet.meta_optimizers.sharding',
'paddle.distributed.fleet.runtime',
'paddle.distributed.fleet.dataset',
'paddle.distributed.fleet.data_generator',
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录