Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
1e60a0c4
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1e60a0c4
编写于
4月 07, 2021
作者:
J
JZ-LIANG
提交者:
GitHub
4月 07, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[3D-parallelism] Hybrid Model Parallelism (#32074)
上级
363b25aa
变更
11
展开全部
隐藏空白更改
内联
并排
Showing
11 changed file
with
1023 addition
and
152 deletion
+1023
-152
paddle/fluid/framework/distributed_strategy.proto
paddle/fluid/framework/distributed_strategy.proto
+11
-7
python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py
...e/distributed/fleet/meta_optimizers/pipeline_optimizer.py
+5
-0
python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py
...distributed/fleet/meta_optimizers/sharding/fp16_helper.py
+64
-2
python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
...ed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
+65
-16
python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py
...tributed/fleet/meta_optimizers/sharding/offload_helper.py
+281
-0
python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py
...addle/distributed/fleet/meta_optimizers/sharding/prune.py
+4
-0
python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py
...addle/distributed/fleet/meta_optimizers/sharding/utils.py
+60
-6
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
...e/distributed/fleet/meta_optimizers/sharding_optimizer.py
+418
-99
python/paddle/fluid/backward.py
python/paddle/fluid/backward.py
+15
-0
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+48
-8
python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py
...uid/tests/unittests/test_fleet_sharding_meta_optimizer.py
+52
-14
未找到文件。
paddle/fluid/framework/distributed_strategy.proto
浏览文件 @
1e60a0c4
...
@@ -29,14 +29,18 @@ message RecomputeConfig {
...
@@ -29,14 +29,18 @@ message RecomputeConfig {
}
}
message
ShardingConfig
{
message
ShardingConfig
{
optional
float
segment_broadcast_MB
=
1
[
default
=
32.0
];
optional
string
sharding_segment_strategy
=
1
optional
bool
hybrid_dp
=
2
[
default
=
false
];
optional
int32
sharding_degree
=
3
[
default
=
8
];
optional
int32
mp_degree
=
4
[
default
=
1
];
optional
string
sharding_segment_strategy
=
5
[
default
=
'segment_broadcast_MB'
];
[
default
=
'segment_broadcast_MB'
];
repeated
string
segment_anchors
=
6
;
optional
float
segment_broadcast_MB
=
2
[
default
=
32.0
];
optional
int32
gradient_merge_acc_step
=
7
[
default
=
1
];
repeated
string
segment_anchors
=
3
;
optional
int32
sharding_degree
=
4
[
default
=
8
];
optional
int32
mp_degree
=
5
[
default
=
1
];
optional
int32
dp_degree
=
6
[
default
=
1
];
optional
bool
hybrid_dp
=
7
[
default
=
false
];
optional
int32
gradient_merge_acc_step
=
8
[
default
=
1
];
optional
bool
optimize_offload
=
9
[
default
=
false
];
optional
bool
pp_allreduce_in_optimize
=
10
[
default
=
false
];
optional
int32
pp_degree
=
11
[
default
=
1
];
}
}
message
AMPConfig
{
message
AMPConfig
{
...
...
python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py
100644 → 100755
浏览文件 @
1e60a0c4
...
@@ -45,11 +45,16 @@ class PipelineOptimizer(MetaOptimizerBase):
...
@@ -45,11 +45,16 @@ class PipelineOptimizer(MetaOptimizerBase):
'accumulate_steps'
]
'accumulate_steps'
]
self
.
schedule_mode
=
user_defined_strategy
.
pipeline_configs
[
self
.
schedule_mode
=
user_defined_strategy
.
pipeline_configs
[
'schedule_mode'
]
'schedule_mode'
]
self
.
use_sharding
=
user_defined_strategy
.
sharding
def
_can_apply
(
self
):
def
_can_apply
(
self
):
if
not
self
.
role_maker
.
_is_collective
:
if
not
self
.
role_maker
.
_is_collective
:
return
False
return
False
# FIXME revise for hybrid parallelism
if
self
.
use_sharding
:
return
False
if
self
.
user_defined_strategy
.
pipeline
==
True
:
if
self
.
user_defined_strategy
.
pipeline
==
True
:
return
True
return
True
return
False
return
False
...
...
python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py
浏览文件 @
1e60a0c4
...
@@ -81,7 +81,10 @@ class FP16Utils(object):
...
@@ -81,7 +81,10 @@ class FP16Utils(object):
if
not
FP16Utils
.
is_fp32_cast_op
(
block
,
op
):
if
not
FP16Utils
.
is_fp32_cast_op
(
block
,
op
):
continue
continue
output_name
=
op
.
desc
.
output_arg_names
()[
0
]
output_name
=
op
.
desc
.
output_arg_names
()[
0
]
param_name
=
output_name
.
strip
(
"@GRAD"
)
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
param_name
=
output_name
.
strip
(
"@GRAD@MERGED"
)
if
"@MERGED"
in
output_name
else
output_name
.
strip
(
"@GRAD"
)
if
param_name
not
in
shard
.
global_params
:
if
param_name
not
in
shard
.
global_params
:
raise
ValueError
(
"Output 'X' of cast_op must be a grad of"
raise
ValueError
(
"Output 'X' of cast_op must be a grad of"
"model param, but {} is not a grad"
.
format
(
"model param, but {} is not a grad"
.
format
(
...
@@ -105,7 +108,11 @@ class FP16Utils(object):
...
@@ -105,7 +108,11 @@ class FP16Utils(object):
reversed_x
=
[]
reversed_x
=
[]
reversed_x_paramname
=
[]
reversed_x_paramname
=
[]
for
input_name
in
op
.
desc
.
input
(
'X'
):
for
input_name
in
op
.
desc
.
input
(
'X'
):
param_name
=
input_name
.
strip
(
"@GRAD"
)
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
if
"@MERGED"
in
input_name
:
param_name
=
input_name
.
strip
(
"@GRAD@MERGED"
)
else
:
param_name
=
input_name
.
strip
(
"@GRAD"
)
if
param_name
not
in
shard
.
global_params
:
if
param_name
not
in
shard
.
global_params
:
raise
ValueError
(
raise
ValueError
(
"Input 'X' of check_finite_and_unscale must"
"Input 'X' of check_finite_and_unscale must"
...
@@ -169,3 +176,58 @@ class FP16Utils(object):
...
@@ -169,3 +176,58 @@ class FP16Utils(object):
OP_ROLE_KEY
:
OpRole
.
Optimize
OP_ROLE_KEY
:
OpRole
.
Optimize
})
})
block
.
_sync_with_cpp
()
block
.
_sync_with_cpp
()
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
@
staticmethod
def
sync_amp_check_nan_inf
(
block
,
ring_id
):
update_loss_scaling_op_idx
=
-
1
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
op
.
type
==
"update_loss_scaling"
:
update_loss_scaling_op_idx
=
idx
inf_var_name
=
op
.
desc
.
input
(
'FoundInfinite'
)[
0
]
op
.
_rename_input
(
inf_var_name
,
inf_var_name
+
"@GLOBAL_WORLD"
)
# not use amp
if
update_loss_scaling_op_idx
==
-
1
:
return
inf_var
=
block
.
var
(
inf_var_name
)
inf_var_int32
=
block
.
create_var
(
name
=
inf_var_name
+
"@cast_int32"
,
shape
=
inf_var
.
shape
,
dtype
=
core
.
VarDesc
.
VarType
.
INT32
)
inf_var_global
=
block
.
create_var
(
name
=
inf_var_name
+
"@GLOBAL_WORLD"
,
shape
=
inf_var
.
shape
,
dtype
=
inf_var
.
dtype
)
block
.
_insert_op_without_sync
(
update_loss_scaling_op_idx
,
type
=
'cast'
,
inputs
=
{
'X'
:
inf_var
},
outputs
=
{
'Out'
:
inf_var_int32
},
attrs
=
{
"in_dtype"
:
inf_var
.
dtype
,
"out_dtype"
:
inf_var_int32
.
dtype
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
block
.
_insert_op_without_sync
(
update_loss_scaling_op_idx
+
1
,
type
=
'c_allreduce_max'
,
inputs
=
{
'X'
:
inf_var_int32
},
outputs
=
{
'Out'
:
inf_var_int32
},
attrs
=
{
'ring_id'
:
ring_id
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
block
.
_insert_op_without_sync
(
update_loss_scaling_op_idx
+
2
,
type
=
'cast'
,
inputs
=
{
'X'
:
inf_var_int32
},
outputs
=
{
'Out'
:
inf_var_global
},
attrs
=
{
"in_dtype"
:
inf_var_int32
.
dtype
,
"out_dtype"
:
inf_var_global
.
dtype
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
block
.
_sync_with_cpp
()
python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
浏览文件 @
1e60a0c4
...
@@ -32,6 +32,7 @@ class GradientClipHelper(object):
...
@@ -32,6 +32,7 @@ class GradientClipHelper(object):
deperated_vars
=
set
()
deperated_vars
=
set
()
deperate_op_idx
=
set
()
deperate_op_idx
=
set
()
reversed_x_paramname
=
[]
reversed_x_paramname
=
[]
global_norm_sum_op_idx
=
-
1
for
idx
,
op
in
enumerate
(
block
.
ops
):
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
not
self
.
_is_gradient_clip_op
(
op
):
if
not
self
.
_is_gradient_clip_op
(
op
):
continue
continue
...
@@ -41,7 +42,11 @@ class GradientClipHelper(object):
...
@@ -41,7 +42,11 @@ class GradientClipHelper(object):
for
input_name
in
op
.
desc
.
input_arg_names
():
for
input_name
in
op
.
desc
.
input_arg_names
():
if
input_name
in
deperated_vars
:
if
input_name
in
deperated_vars
:
deperate_op
=
True
deperate_op
=
True
param_name
=
input_name
.
strip
(
"@GRAD"
)
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
if
"@MERGED"
in
input_name
:
param_name
=
input_name
.
strip
(
"@GRAD@MERGED"
)
else
:
param_name
=
input_name
.
strip
(
"@GRAD"
)
if
shard
.
is_param
(
param_name
)
and
\
if
shard
.
is_param
(
param_name
)
and
\
not
shard
.
has_param
(
param_name
):
not
shard
.
has_param
(
param_name
):
deperate_op
=
True
deperate_op
=
True
...
@@ -51,7 +56,8 @@ class GradientClipHelper(object):
...
@@ -51,7 +56,8 @@ class GradientClipHelper(object):
if
deperate_op
:
if
deperate_op
:
deperate_op_idx
.
add
(
idx
)
deperate_op_idx
.
add
(
idx
)
for
output_name
in
op
.
desc
.
output_arg_names
():
for
output_name
in
op
.
desc
.
output_arg_names
():
deperated_vars
.
add
(
output_name
)
if
output_name
not
in
op
.
desc
.
input_arg_names
():
deperated_vars
.
add
(
output_name
)
if
not
deperated_vars
:
if
not
deperated_vars
:
# got no gradient_clip op
# got no gradient_clip op
...
@@ -65,6 +71,7 @@ class GradientClipHelper(object):
...
@@ -65,6 +71,7 @@ class GradientClipHelper(object):
continue
continue
reversed_inputs
=
[]
reversed_inputs
=
[]
if
op
.
type
==
"sum"
:
if
op
.
type
==
"sum"
:
global_norm_sum_op_idx
=
idx
for
input_name
in
op
.
desc
.
input_arg_names
():
for
input_name
in
op
.
desc
.
input_arg_names
():
if
input_name
not
in
deperated_vars
:
if
input_name
not
in
deperated_vars
:
reversed_inputs
.
append
(
input_name
)
reversed_inputs
.
append
(
input_name
)
...
@@ -86,20 +93,20 @@ class GradientClipHelper(object):
...
@@ -86,20 +93,20 @@ class GradientClipHelper(object):
OP_ROLE_KEY
:
OpRole
.
Optimize
,
OP_ROLE_KEY
:
OpRole
.
Optimize
,
})
})
# global norm should only be sum within each model parallelism word size when use global group
# global norm should only be sum within each model parallelism word size when use global group
if
pure_dp_degree
>
1
:
if
pure_dp_degree
>
1
:
block
.
_insert_op_without_sync
(
block
.
_insert_op_without_sync
(
idx
+
2
,
idx
+
2
,
type
=
'scale'
,
type
=
'scale'
,
inputs
=
{
'X'
:
sum_res
},
inputs
=
{
'X'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
attrs
=
{
attrs
=
{
'scale'
:
1.0
/
float
(
pure_dp_degree
),
'scale'
:
1.0
/
float
(
pure_dp_degree
),
'op_namescope'
:
"/gradient_clip_model_parallelism"
,
'op_namescope'
:
"/gradient_clip_model_parallelism"
,
'bias'
:
0.0
,
'bias'
:
0.0
,
'bias_after_scale'
:
False
,
'bias_after_scale'
:
False
,
OP_ROLE_KEY
:
OpRole
.
Optimize
OP_ROLE_KEY
:
OpRole
.
Optimize
})
})
# the grad sum here should take the all and only param in the current shard
# the grad sum here should take the all and only param in the current shard
to_check_param
=
set
(
reversed_x_paramname
)
to_check_param
=
set
(
reversed_x_paramname
)
...
@@ -115,3 +122,45 @@ class GradientClipHelper(object):
...
@@ -115,3 +122,45 @@ class GradientClipHelper(object):
block
.
_remove_var
(
var_name
,
sync
=
False
)
block
.
_remove_var
(
var_name
,
sync
=
False
)
block
.
_sync_with_cpp
()
block
.
_sync_with_cpp
()
return
return
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
def
sync_global_norm
(
self
,
block
,
ring_id
,
pure_dp_degree
=
1
):
"""
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
keep: sum, sqrt, elementwise_max, elementwise_div
"""
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
not
self
.
_is_gradient_clip_op
(
op
):
continue
if
op
.
type
==
"sum"
:
sum_res
=
op
.
desc
.
output_arg_names
()[
0
]
block
.
_insert_op_without_sync
(
idx
+
1
,
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
attrs
=
{
'ring_id'
:
ring_id
,
'op_namescope'
:
"/gradient_clip_model_parallelism"
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Optimize
,
})
# global norm should only be sum within each model parallelism word size
if
pure_dp_degree
>
1
:
block
.
_insert_op_without_sync
(
idx
+
2
,
type
=
'scale'
,
inputs
=
{
'X'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
attrs
=
{
'scale'
:
1.0
/
float
(
pure_dp_degree
),
'op_namescope'
:
"/gradient_clip_model_parallelism"
,
'bias'
:
0.0
,
'bias_after_scale'
:
False
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
return
python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py
0 → 100755
浏览文件 @
1e60a0c4
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
..common
import
is_optimizer_op
,
OP_ROLE_KEY
,
OpRole
from
paddle.fluid
import
core
,
unique_name
class
OffloadHelper
(
object
):
cpu_place_type
=
0
cuda_place_type
=
1
cuda_pinned_place_type
=
2
def
__init__
(
self
):
pass
"0: dst is on CPUPlace. "
"1: dst is on CUDAPlace. "
"2: dst is on CUDAPinnedPlace. "
def
_insert_cast_op
(
self
,
block
,
idx
,
src_name
,
dst_name
):
src_var
=
block
.
var
(
src_name
)
if
not
block
.
has_var
(
dst_name
):
block
.
create_var
(
name
=
dst_name
,
shape
=
src_var
.
shape
,
dtype
=
core
.
VarDesc
.
VarType
.
FP16
,
persistable
=
True
)
dst_var
=
block
.
var
(
dst_name
)
assert
dst_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
block
.
_insert_op_without_sync
(
idx
,
type
=
'cast'
,
inputs
=
{
'X'
:
src_var
},
outputs
=
{
'Out'
:
dst_var
},
attrs
=
{
'in_dtype'
:
src_var
.
dtype
,
'out_dtype'
:
dst_var
.
dtype
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
def
_insert_memcpy_op
(
self
,
block
,
idx
,
src_name
,
dst_name
,
dst_place_type
):
src_var
=
block
.
var
(
src_name
)
dst_var
=
block
.
var
(
dst_name
)
block
.
_insert_op_without_sync
(
idx
,
type
=
'memcpy'
,
inputs
=
{
'X'
:
src_var
},
outputs
=
{
'Out'
:
dst_var
},
attrs
=
{
'dst_place_type'
:
dst_place_type
,
OP_ROLE_KEY
:
OpRole
.
Optimize
,
})
def
_insert_fetch_op
(
self
,
block
,
idx
,
src_name
,
dst_name
):
self
.
_insert_memcpy_op
(
block
,
idx
,
src_name
,
dst_name
,
OffloadHelper
.
cuda_place_type
)
def
_insert_offload_op
(
self
,
block
,
idx
,
src_name
,
dst_name
):
self
.
_insert_memcpy_op
(
block
,
idx
,
src_name
,
dst_name
,
OffloadHelper
.
cuda_pinned_place_type
)
def
_get_offload_var_name
(
self
,
name
):
return
unique_name
.
generate
(
name
+
'@offload'
)
def
_create_offload_var
(
self
,
var_name
,
offload_var_name
,
blocks
):
for
block
in
blocks
:
var
=
block
.
var
(
var_name
)
var
.
persistable
=
False
offload_var
=
block
.
create_var
(
name
=
offload_var_name
,
shape
=
var
.
shape
,
dtype
=
var
.
dtype
,
persistable
=
True
)
def
offload_fp32param
(
self
,
block
,
startup_block
):
"""
(p_fp16) = cast(p)
(p_fp16_recompute) = cast(p)
(pout,) = adam(p)
===========================>
rename(p_fp16_recompute, p_fp16)
(p,) = prefetch(p@offload)
(pout,) = adam(p)
(p_fp16) = cast(p)
(p@offload) = memcpy(p)
"""
param_to_idx
=
dict
()
param_to_fp16
=
dict
()
# recompute_var which need rename to fp16_param
fp16_param_to_recompute
=
dict
()
recompute_to_fp16
=
dict
()
def
remove_param
(
input_name
):
param_to_idx
.
pop
(
input_name
)
if
input_name
in
param_to_fp16
:
fp16_param
=
param_to_fp16
.
pop
(
input_name
)
if
fp16_param
in
fp16_param_to_recompute
:
recompute
=
fp16_param_to_recompute
.
pop
(
fp16_param
)
recompute_to_fp16
.
pop
(
recompute
)
# step1: record param
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
op
.
type
in
(
'adam'
,
'momentum'
,
'lars'
,
'lamb'
):
param
=
op
.
desc
.
input
(
"Param"
)[
0
]
param_to_idx
[
param
]
=
idx
# step2: remove param which can't offload
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
is_optimizer_op
(
op
):
break
for
input_name
in
op
.
desc
.
input_arg_names
():
if
input_name
not
in
param_to_idx
:
continue
# param is real used by fp32 op
if
op
.
type
!=
'cast'
:
remove_param
(
input_name
)
continue
# param is only used by cast op,
# which to cast fp32_param to fp16_param
output_name
=
op
.
output_arg_names
[
0
]
if
'cast_fp16'
not
in
output_name
:
remove_param
(
input_name
)
continue
if
'subprog'
not
in
output_name
:
assert
output_name
==
input_name
+
'.cast_fp16'
assert
input_name
not
in
param_to_fp16
,
\
"There must be only one cast op from fp32 param to fp16 param."
param_to_fp16
[
input_name
]
=
output_name
else
:
# fp16-->recompute_var
assert
input_name
in
param_to_fp16
,
\
"param must first be cast to fp16"
fp16_param
=
param_to_fp16
[
input_name
]
fp16_param_to_recompute
[
fp16_param
]
=
output_name
recompute_to_fp16
[
output_name
]
=
fp16_param
param_name_to_offload_name
=
dict
()
# step3: main_block add offload, cast op
# change recompute to fp16, remove cast(param) to fp16
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
op
.
type
in
(
'adam'
,
'momentum'
,
'lars'
,
'lamb'
):
param
=
op
.
desc
.
input
(
"Param"
)[
0
]
if
param
not
in
param_to_idx
:
continue
# step3.1: create offload_var
offload_var_name
=
self
.
_get_offload_var_name
(
param
)
param_name_to_offload_name
[
param
]
=
offload_var_name
self
.
_create_offload_var
(
param
,
offload_var_name
,
[
block
,
startup_block
])
# step3.2: insert cast op and offload op
self
.
_insert_offload_op
(
block
,
idx
+
1
,
param
,
offload_var_name
)
assert
param
in
param_to_fp16
fp16_param_name
=
param_to_fp16
[
param
]
fp16_param_var
=
block
.
var
(
fp16_param_name
)
fp16_param_var
.
persistable
=
True
self
.
_insert_cast_op
(
block
,
idx
+
1
,
param
,
param_to_fp16
[
param
])
# step3.3: insert fetch op
self
.
_insert_fetch_op
(
block
,
idx
,
offload_var_name
,
param
)
continue
# step3.4: remove cast op
if
op
.
type
==
'cast'
:
input_name
=
op
.
desc
.
input_arg_names
()[
0
]
if
input_name
in
param_to_idx
:
block
.
_remove_op
(
idx
,
sync
=
False
)
continue
# step3.5: change recompute_param to fp16_param
for
input_name
in
op
.
desc
.
input_arg_names
():
if
input_name
in
recompute_to_fp16
:
op
.
_rename_input
(
input_name
,
recompute_to_fp16
[
input_name
])
for
output_name
in
op
.
desc
.
output_arg_names
():
if
output_name
in
recompute_to_fp16
:
op
.
_rename_output
(
output_name
,
recompute_to_fp16
[
output_name
])
# step4: remove recompute_param
for
name
in
recompute_to_fp16
.
keys
():
block
.
_remove_var
(
name
,
sync
=
False
)
# step5: startup_block add offload
visited_vars
=
set
()
for
idx
,
op
in
reversed
(
list
(
enumerate
(
startup_block
.
ops
))):
for
out_name
in
op
.
output_arg_names
:
if
out_name
in
visited_vars
:
continue
if
out_name
in
param_name_to_offload_name
:
var_name
=
out_name
offload_var_name
=
param_name_to_offload_name
[
var_name
]
self
.
_insert_offload_op
(
startup_block
,
idx
+
1
,
var_name
,
offload_var_name
)
self
.
_insert_cast_op
(
startup_block
,
idx
+
1
,
var_name
,
param_to_fp16
[
var_name
])
visited_vars
.
add
(
out_name
)
block
.
_sync_with_cpp
()
startup_block
.
_sync_with_cpp
()
def
offload
(
self
,
block
,
startup_block
):
"""
(m1, m2) = prefetch(m1@offload, m2@offload)
(m1out, m2out, pout) = adam(m1, m2, p)
(m1@offload, m2@offload) = memcpy(m1, m2)
"""
vars_name_to_offload_name
=
dict
()
# main_block add offload
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
not
is_optimizer_op
(
op
):
break
vars_name
=
[]
if
op
.
type
==
"adam"
:
# {Moment1Out = [''], Moment2Out = [''], ParamOut = ['']} =
# adam(inputs={Moment1 = [''], Moment2 = [''], Param = ['']})
vars_name
.
append
(
op
.
desc
.
input
(
"Moment1"
)[
0
])
vars_name
.
append
(
op
.
desc
.
input
(
"Moment2"
)[
0
])
elif
op
.
type
==
'momentum'
:
pass
elif
op
.
type
==
'lars'
:
pass
elif
op
.
type
==
'lamb'
:
pass
# step1: create and init offload_var
for
var_name
in
vars_name
:
assert
var_name
not
in
vars_name_to_offload_name
offload_var_name
=
self
.
_get_offload_var_name
(
var_name
)
vars_name_to_offload_name
[
var_name
]
=
offload_var_name
self
.
_create_offload_var
(
var_name
,
offload_var_name
,
[
block
,
startup_block
])
# step2: insert offload op
for
var_name
in
vars_name
:
offload_var_name
=
vars_name_to_offload_name
[
var_name
]
self
.
_insert_offload_op
(
block
,
idx
+
1
,
var_name
,
offload_var_name
)
# step3: insert fetch op
for
var_name
in
vars_name
:
offload_var_name
=
vars_name_to_offload_name
[
var_name
]
self
.
_insert_fetch_op
(
block
,
idx
,
offload_var_name
,
var_name
)
# startup_block add offload
visited_vars
=
set
()
for
idx
,
op
in
reversed
(
list
(
enumerate
(
startup_block
.
ops
))):
for
out_name
in
op
.
output_arg_names
:
if
out_name
in
visited_vars
:
continue
if
out_name
in
vars_name_to_offload_name
:
var_name
=
out_name
offload_var_name
=
vars_name_to_offload_name
[
var_name
]
# insert offload op after var is generated
self
.
_insert_offload_op
(
startup_block
,
idx
+
1
,
var_name
,
offload_var_name
)
visited_vars
.
add
(
out_name
)
block
.
_sync_with_cpp
()
startup_block
.
_sync_with_cpp
()
python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py
100644 → 100755
浏览文件 @
1e60a0c4
...
@@ -126,6 +126,10 @@ class ProgramDeps(object):
...
@@ -126,6 +126,10 @@ class ProgramDeps(object):
def
should_remove_op
(
self
,
op_idx
):
def
should_remove_op
(
self
,
op_idx
):
op
=
self
.
_block
.
ops
[
op_idx
]
op
=
self
.
_block
.
ops
[
op_idx
]
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
# remove check_finite_and_unscale op if its input 'X' is empty
if
op
.
type
==
'check_finite_and_unscale'
and
len
(
op
.
input
(
'X'
))
==
0
:
return
True
for
output_name
in
op
.
desc
.
output_arg_names
():
for
output_name
in
op
.
desc
.
output_arg_names
():
if
output_name
not
in
self
.
_should_removed_var
:
if
output_name
not
in
self
.
_should_removed_var
:
return
False
return
False
...
...
python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py
浏览文件 @
1e60a0c4
...
@@ -274,6 +274,10 @@ def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars):
...
@@ -274,6 +274,10 @@ def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars):
"""
"""
insert sync_comm_op for vars
insert sync_comm_op for vars
"""
"""
# NOTE (JZ-LIANG) to be check, may result undefined case
if
len
(
comm_dep_vars
)
==
0
:
return
0
op_role
=
get_valid_op_role
(
block
,
insert_idx
)
op_role
=
get_valid_op_role
(
block
,
insert_idx
)
block
.
_insert_op_without_sync
(
block
.
_insert_op_without_sync
(
insert_idx
,
insert_idx
,
...
@@ -324,27 +328,45 @@ def insert_cast_ops(block, insert_idx, cast_ops):
...
@@ -324,27 +328,45 @@ def insert_cast_ops(block, insert_idx, cast_ops):
return
return
def
insert_allreduce_ops
(
block
,
insert_idx
,
ring_id
,
allreduce_vars
):
def
insert_allreduce_ops
(
block
,
insert_idx
,
ring_id
,
allreduce_vars
,
op_role
=
OpRole
.
Backward
,
use_calc_stream
=
False
):
"""
"""
_add_allreduce_ops
_add_allreduce_ops
"""
"""
if
len
(
allreduce_vars
)
==
0
:
return
for
var
in
allreduce_vars
:
for
var
in
allreduce_vars
:
block
.
_insert_op_without_sync
(
block
.
_insert_op_without_sync
(
insert_idx
,
insert_idx
,
type
=
'c_allreduce_sum'
,
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
var
},
inputs
=
{
'X'
:
var
},
outputs
=
{
'Out'
:
var
},
outputs
=
{
'Out'
:
var
},
attrs
=
{
'ring_id'
:
ring_id
,
attrs
=
{
OP_ROLE_KEY
:
OpRole
.
Backward
})
'ring_id'
:
ring_id
,
'use_calc_stream'
:
use_calc_stream
,
OP_ROLE_KEY
:
op_role
})
return
return
def
insert_reduce_ops
(
block
,
insert_idx
,
ring_id
,
reduce_vars
,
shard
):
def
insert_reduce_ops
(
block
,
insert_idx
,
ring_id
,
reduce_vars
,
shard
,
op_role
=
OpRole
.
Backward
,
use_calc_stream
=
False
):
"""
"""
_add_allreduce_ops
_add_allreduce_ops
"""
"""
for
var
in
reduce_vars
:
for
var
in
reduce_vars
:
root_id
=
get_grad_device
(
var
,
shard
)
root_id
=
get_grad_device
(
var
,
shard
)
assert
root_id
>=
0
,
"root id should be a positive int"
.
format
(
var
)
assert
root_id
>=
0
,
"root id should be a positive int"
.
format
(
var
)
block
.
_insert_op_without_sync
(
block
.
_insert_op_without_sync
(
...
@@ -355,12 +377,40 @@ def insert_reduce_ops(block, insert_idx, ring_id, reduce_vars, shard):
...
@@ -355,12 +377,40 @@ def insert_reduce_ops(block, insert_idx, ring_id, reduce_vars, shard):
attrs
=
{
attrs
=
{
'ring_id'
:
ring_id
,
'ring_id'
:
ring_id
,
'root_id'
:
root_id
,
'root_id'
:
root_id
,
OP_ROLE_KEY
:
OpRole
.
Backward
'use_calc_stream'
:
use_calc_stream
,
OP_ROLE_KEY
:
op_role
})
})
return
return
def
get_grad_device
(
grad_name
,
shard
):
assert
"@GRAD"
in
grad_name
,
"[{}] should be a grad variable."
.
format
(
grad_name
)
base_name
=
None
# mind the traversal order
possible_suffixes
=
[
'.cast_fp16@GRAD@MERGED'
,
'.cast_fp16@GRAD'
,
'@GRAD@MERGED'
,
'@GRAD'
]
for
suffix
in
possible_suffixes
:
if
suffix
in
grad_name
:
base_name
=
re
.
sub
(
suffix
,
''
,
grad_name
)
break
assert
base_name
in
shard
.
global_param2device
,
"[{}] should be a param variable."
.
format
(
base_name
)
return
shard
.
global_param2device
[
base_name
]
def
get_first_check_finite_and_unscale_op_idx
(
block
):
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
op
.
type
==
"check_finite_and_unscale"
:
return
idx
raise
ValueError
(
"check_finite_and_unscale does not exist in block"
)
def
insert_broadcast_ops
(
block
,
insert_idx
,
ring_id
,
broadcast2root
):
def
insert_broadcast_ops
(
block
,
insert_idx
,
ring_id
,
broadcast2root
):
"""
"""
_add_broadcast_ops
_add_broadcast_ops
...
@@ -420,6 +470,7 @@ def insert_scale_loss_grad_ops(block, scale=1.0):
...
@@ -420,6 +470,7 @@ def insert_scale_loss_grad_ops(block, scale=1.0):
outputs
=
{
'Out'
:
loss_grad_var
},
outputs
=
{
'Out'
:
loss_grad_var
},
attrs
=
{
'scale'
:
scale
,
attrs
=
{
'scale'
:
scale
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
OP_ROLE_KEY
:
OpRole
.
Backward
})
break
def
comm_analyse
(
main_program
):
def
comm_analyse
(
main_program
):
...
@@ -502,6 +553,9 @@ def save_persistables(exe, dirname, main_program, filename=None):
...
@@ -502,6 +553,9 @@ def save_persistables(exe, dirname, main_program, filename=None):
and part of persistable vars are duplicated and exist in all the ranks with different values.
and part of persistable vars are duplicated and exist in all the ranks with different values.
This function handles the model saving for sharding training.
This function handles the model saving for sharding training.
"""
"""
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
if
main_program
.
_pipeline_opt
:
main_program
=
main_program
.
_pipeline_opt
[
'section_program'
][
'program'
]
def
is_opt_vars
(
var
):
def
is_opt_vars
(
var
):
# NOTE(JZ-LIANG): The checks should be updated when add new compatible optimizer
# NOTE(JZ-LIANG): The checks should be updated when add new compatible optimizer
...
...
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
浏览文件 @
1e60a0c4
此差异已折叠。
点击以展开。
python/paddle/fluid/backward.py
浏览文件 @
1e60a0c4
...
@@ -233,6 +233,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
...
@@ -233,6 +233,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
new_op_desc
=
block
.
desc
.
append_op
()
new_op_desc
=
block
.
desc
.
append_op
()
new_op_desc
.
copy_from
(
desc
)
new_op_desc
.
copy_from
(
desc
)
new_op_desc
.
_set_attr
(
op_role_attr_name
,
backward
)
new_op_desc
.
_set_attr
(
op_role_attr_name
,
backward
)
if
desc
.
has_attr
(
'op_device'
):
new_op_desc
.
_set_attr
(
'op_device'
,
desc
.
attr
(
'op_device'
))
result_descs
.
append
(
new_op_desc
)
result_descs
.
append
(
new_op_desc
)
return
result_descs
return
result_descs
...
@@ -252,6 +254,8 @@ def _add_descs_to_block(descs, block):
...
@@ -252,6 +254,8 @@ def _add_descs_to_block(descs, block):
new_op_desc
=
block
.
desc
.
append_op
()
new_op_desc
=
block
.
desc
.
append_op
()
new_op_desc
.
copy_from
(
desc
)
new_op_desc
.
copy_from
(
desc
)
new_op_desc
.
_set_attr
(
op_role_attr_name
,
backward
)
new_op_desc
.
_set_attr
(
op_role_attr_name
,
backward
)
if
desc
.
has_attr
(
'op_device'
):
new_op_desc
.
_set_attr
(
'op_device'
,
desc
.
attr
(
'op_device'
))
result_descs
.
append
(
new_op_desc
)
result_descs
.
append
(
new_op_desc
)
return
result_descs
return
result_descs
...
@@ -843,6 +847,7 @@ def _append_backward_ops_with_checkpoints_(
...
@@ -843,6 +847,7 @@ def _append_backward_ops_with_checkpoints_(
vars_in_memory
=
vars_should_be_hold
+
checkpoints_name
vars_in_memory
=
vars_should_be_hold
+
checkpoints_name
max_calculated_op_position
=
len
(
ops
)
max_calculated_op_position
=
len
(
ops
)
device_attr_name
=
core
.
op_proto_and_checker_maker
.
kOpDeviceAttrName
()
if
recompute_segments
==
[]:
if
recompute_segments
==
[]:
gap_ops
=
ops
[
0
:
max_calculated_op_position
]
gap_ops
=
ops
[
0
:
max_calculated_op_position
]
for
op
in
reversed
(
gap_ops
):
for
op
in
reversed
(
gap_ops
):
...
@@ -852,6 +857,11 @@ def _append_backward_ops_with_checkpoints_(
...
@@ -852,6 +857,11 @@ def _append_backward_ops_with_checkpoints_(
_pretty_op_desc_
(
op
.
desc
,
"with_sub_block"
))
_pretty_op_desc_
(
op
.
desc
,
"with_sub_block"
))
grad_op_desc
,
op_grad_to_var
=
core
.
get_grad_op_desc
(
grad_op_desc
,
op_grad_to_var
=
core
.
get_grad_op_desc
(
op
.
desc
,
cpt
.
to_text
(
no_grad_dict
[
block
.
idx
]),
[])
op
.
desc
,
cpt
.
to_text
(
no_grad_dict
[
block
.
idx
]),
[])
# Set device for grad_op according to forward Op
if
op
.
desc
.
has_attr
(
device_attr_name
):
op_device
=
op
.
desc
.
attr
(
device_attr_name
)
for
op_desc
in
grad_op_desc
:
op_desc
.
_set_attr
(
device_attr_name
,
op_device
)
added_descs
=
_add_descs_to_block
(
grad_op_desc
,
local_block
)
added_descs
=
_add_descs_to_block
(
grad_op_desc
,
local_block
)
grad_op_descs
.
extend
(
added_descs
)
grad_op_descs
.
extend
(
added_descs
)
grad_to_var
.
update
(
op_grad_to_var
)
grad_to_var
.
update
(
op_grad_to_var
)
...
@@ -866,6 +876,11 @@ def _append_backward_ops_with_checkpoints_(
...
@@ -866,6 +876,11 @@ def _append_backward_ops_with_checkpoints_(
_pretty_op_desc_
(
op
.
desc
,
"with_sub_block"
))
_pretty_op_desc_
(
op
.
desc
,
"with_sub_block"
))
grad_op_desc
,
op_grad_to_var
=
core
.
get_grad_op_desc
(
grad_op_desc
,
op_grad_to_var
=
core
.
get_grad_op_desc
(
op
.
desc
,
cpt
.
to_text
(
no_grad_dict
[
block
.
idx
]),
[])
op
.
desc
,
cpt
.
to_text
(
no_grad_dict
[
block
.
idx
]),
[])
# Set device for grad_op according to forward Op
if
op
.
desc
.
has_attr
(
device_attr_name
):
op_device
=
op
.
desc
.
attr
(
device_attr_name
)
for
op_desc
in
grad_op_desc
:
op_desc
.
_set_attr
(
device_attr_name
,
op_device
)
added_descs
=
_add_descs_to_block
(
grad_op_desc
,
local_block
)
added_descs
=
_add_descs_to_block
(
grad_op_desc
,
local_block
)
grad_op_descs
.
extend
(
added_descs
)
grad_op_descs
.
extend
(
added_descs
)
grad_to_var
.
update
(
op_grad_to_var
)
grad_to_var
.
update
(
op_grad_to_var
)
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
1e60a0c4
...
@@ -4033,6 +4033,12 @@ class PipelineOptimizer(object):
...
@@ -4033,6 +4033,12 @@ class PipelineOptimizer(object):
"""
"""
Find the post op that has variable named var_name as input.
Find the post op that has variable named var_name as input.
"""
"""
# bugfix for uniform hybrid parallelism
if
'.cast_fp32'
in
var_name
:
var_name
=
var_name
.
replace
(
'.cast_fp32'
,
''
)
if
'.cast_fp16'
in
var_name
:
var_name
=
var_name
.
replace
(
'.cast_fp16'
,
''
)
post_ops
=
self
.
input_var_to_op
[
var_name
]
post_ops
=
self
.
input_var_to_op
[
var_name
]
if
post_ops
==
None
:
return
None
if
post_ops
==
None
:
return
None
result_op
=
None
result_op
=
None
...
@@ -4114,7 +4120,23 @@ class PipelineOptimizer(object):
...
@@ -4114,7 +4120,23 @@ class PipelineOptimizer(object):
# For LRSched ops, we should put them on all sub-programs to
# For LRSched ops, we should put them on all sub-programs to
# make sure each sub-program update the lr correctly
# make sure each sub-program update the lr correctly
op
.
_set_attr
(
self
.
_op_device_key
,
"gpu:all"
)
op
.
_set_attr
(
self
.
_op_device_key
,
"gpu:all"
)
elif
op
.
type
==
"scale"
and
self
.
_is_backward_op
(
op
):
# bugfix in hybrid parallelism
elif
op
.
type
==
"sum"
and
self
.
_is_backward_op
(
op
):
# For sum ops that compute the sum of @RENAMED@ vars
for
name
in
op
.
desc
.
input_arg_names
():
assert
'@RENAME@'
in
name
,
\
"The op must be sum used to accumulate renamed vars."
assert
len
(
op
.
desc
.
output_arg_names
())
==
1
out_name
=
op
.
desc
.
output_arg_names
()[
0
]
post_op
=
self
.
_find_post_op
(
idx
,
out_name
)
assert
post_op
.
has_attr
(
'op_device'
),
"{} has no op_device attr for var {}"
.
format
(
post_op
.
type
,
out_name
)
device
=
post_op
.
attr
(
self
.
_op_device_key
)
assert
device
,
"The post op must have op_device set."
op
.
_set_attr
(
self
.
_op_device_key
,
device
)
elif
(
op
.
type
==
"cast"
or
op
.
type
==
"scale"
)
and
self
.
_is_backward_op
(
op
):
prev_op
=
self
.
_find_prev_op
(
idx
,
op
.
desc
.
input
(
"X"
)[
0
])
prev_op
=
self
.
_find_prev_op
(
idx
,
op
.
desc
.
input
(
"X"
)[
0
])
op
.
_set_attr
(
self
.
_op_device_key
,
prev_op
.
attr
(
self
.
_op_device_key
))
op
.
_set_attr
(
self
.
_op_device_key
,
prev_op
.
attr
(
self
.
_op_device_key
))
elif
op
.
type
==
"memcpy"
and
not
self
.
_is_optimize_op
(
op
):
elif
op
.
type
==
"memcpy"
and
not
self
.
_is_optimize_op
(
op
):
...
@@ -4249,11 +4271,19 @@ class PipelineOptimizer(object):
...
@@ -4249,11 +4271,19 @@ class PipelineOptimizer(object):
Insert a pair of send and recv ops for every two
Insert a pair of send and recv ops for every two
consecutive ops on different devices.
consecutive ops on different devices.
"""
"""
extra_index_info
=
{
'index'
:
0
}
# A map from var to device where op takes it as input,
# A map from var to device where op takes it as input,
# avoiding multiple send and recv ops.
# avoiding multiple send and recv ops.
input_var_to_device
=
dict
()
input_var_to_device
=
dict
()
# bugfix hybrid parallelism
first_optimize_index
=
None
for
index
,
op
in
enumerate
(
list
(
block
.
ops
)):
if
self
.
_is_optimize_op
(
op
):
first_optimize_index
=
index
break
extra_index_info
=
{
'index'
:
0
,
'first_optimize_index'
:
first_optimize_index
}
for
index
,
op
in
enumerate
(
list
(
block
.
ops
)):
for
index
,
op
in
enumerate
(
list
(
block
.
ops
)):
cur_device
=
op
.
attr
(
self
.
_op_device_key
)
cur_device
=
op
.
attr
(
self
.
_op_device_key
)
...
@@ -4371,17 +4401,26 @@ class PipelineOptimizer(object):
...
@@ -4371,17 +4401,26 @@ class PipelineOptimizer(object):
'peer'
:
1
,
'peer'
:
1
,
})
})
extra_index_info
[
'index'
]
+=
1
extra_index_info
[
'index'
]
+=
1
insert_index
=
None
if
int
(
op_role
)
==
int
(
self
.
_op_role
.
Backward
):
insert_index
=
extra_index_info
[
'first_optimize_index'
]
new_op_role
=
self
.
_op_role
.
Optimize
else
:
insert_index
=
index
new_op_role
=
self
.
_op_role
.
Backward
block
.
_insert_op
(
block
.
_insert_op
(
index
=
index
+
extra_index_info
[
'index'
],
index
=
in
sert_in
dex
+
extra_index_info
[
'index'
],
type
=
'c_sync_comm_stream'
,
type
=
'c_sync_comm_stream'
,
inputs
=
{
'X'
:
[
var
]},
inputs
=
{
'X'
:
[
var
]},
outputs
=
{
'Out'
:
[
var
]},
outputs
=
{
'Out'
:
[
var
]},
attrs
=
{
attrs
=
{
self
.
_op_device_key
:
prev_dev
,
self
.
_op_device_key
:
prev_dev
,
self
.
_op_role_key
:
self
.
_op_role
.
Backward
,
self
.
_op_role_key
:
new_op_role
,
'ring_id'
:
ring_id
,
'ring_id'
:
ring_id
,
})
})
extra_index_info
[
'index'
]
+=
1
if
int
(
op_role
)
==
int
(
self
.
_op_role
.
Forward
):
extra_index_info
[
'index'
]
+=
1
var_shape
=
list
(
var
.
shape
)
var_shape
=
list
(
var
.
shape
)
var_shape
[
0
]
=
self
.
micro_batch_size
if
var_shape
[
var_shape
[
0
]
=
self
.
micro_batch_size
if
var_shape
[
0
]
<
0
else
var_shape
[
0
]
0
]
<
0
else
var_shape
[
0
]
...
@@ -4768,8 +4807,9 @@ class PipelineOptimizer(object):
...
@@ -4768,8 +4807,9 @@ class PipelineOptimizer(object):
# Step4: Special Case: process persistable vars that exist in
# Step4: Special Case: process persistable vars that exist in
# multiple sections
# multiple sections
self
.
_process_persistable_vars_in_multi_sections
(
# FIXME
main_program
,
startup_program
,
program_list
)
# self._process_persistable_vars_in_multi_sections(
# main_program, startup_program, program_list)
# Step5: Add sub blocks for section programs
# Step5: Add sub blocks for section programs
self
.
_add_sub_blocks
(
main_block
,
program_list
)
self
.
_add_sub_blocks
(
main_block
,
program_list
)
...
...
python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py
浏览文件 @
1e60a0c4
...
@@ -354,6 +354,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
...
@@ -354,6 +354,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
"segment_broadcast_MB"
:
0.2
,
"segment_broadcast_MB"
:
0.2
,
"segment_anchors"
:
None
,
"segment_anchors"
:
None
,
"sharding_degree"
:
2
,
"sharding_degree"
:
2
,
"dp_degree"
:
2
,
"hybrid_dp"
:
True
,
"hybrid_dp"
:
True
,
"gradient_merge_acc_step"
:
1
,
"gradient_merge_acc_step"
:
1
,
"mp_degree"
:
1
"mp_degree"
:
1
...
@@ -422,6 +423,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
...
@@ -422,6 +423,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
"segment_broadcast_MB"
:
0.2
,
"segment_broadcast_MB"
:
0.2
,
"segment_anchors"
:
None
,
"segment_anchors"
:
None
,
"sharding_degree"
:
2
,
"sharding_degree"
:
2
,
"dp_degree"
:
2
,
"hybrid_dp"
:
True
,
"hybrid_dp"
:
True
,
"gradient_merge_acc_step"
:
4
,
"gradient_merge_acc_step"
:
4
,
"mp_degree"
:
1
"mp_degree"
:
1
...
@@ -458,20 +460,56 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
...
@@ -458,20 +460,56 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
fw_bw_ops
=
[
op
.
type
for
op
in
train_prog
.
blocks
[
0
].
ops
]
fw_bw_ops
=
[
op
.
type
for
op
in
train_prog
.
blocks
[
0
].
ops
]
opt_ops
=
[
op
.
type
for
op
in
train_prog
.
blocks
[
2
].
ops
]
opt_ops
=
[
op
.
type
for
op
in
train_prog
.
blocks
[
2
].
ops
]
self
.
assertEqual
(
fw_bw_ops
,
[
self
.
assertEqual
(
fw_bw_ops
,
[
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'c_sync_calc_stream'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'fill_constant'
,
'c_broadcast'
,
'c_broadcast'
,
'c_broadcast'
,
'c_sync_comm_stream'
,
'fill_constant'
,
'c_sync_comm_stream'
,
'mul'
,
'elementwise_add'
,
'tanh'
,
'mul'
,
'c_sync_calc_stream'
,
'elementwise_add'
,
'tanh'
,
'mul'
,
'elementwise_add'
,
'softmax'
,
'c_broadcast'
,
'cross_entropy2'
,
'mean'
,
'fill_constant'
,
'scale'
,
'mean_grad'
,
'c_broadcast'
,
'cross_entropy_grad2'
,
'softmax_grad'
,
'elementwise_add_grad'
,
'c_broadcast'
,
'mul_grad'
,
'tanh_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'c_broadcast'
,
'tanh_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'c_broadcast'
,
'c_sync_calc_stream'
,
'c_reduce_sum'
,
'c_reduce_sum'
,
'c_broadcast'
,
'c_reduce_sum'
,
'c_reduce_sum'
,
'c_reduce_sum'
,
'c_reduce_sum'
,
'c_sync_comm_stream'
,
'c_sync_comm_stream'
,
'elementwise_add'
,
'elementwise_add'
,
'mul'
,
'elementwise_add'
,
'increment'
,
'elementwise_mod'
,
'equal'
,
'elementwise_add'
,
'conditional_block'
'tanh'
,
'mul'
,
'elementwise_add'
,
'tanh'
,
'mul'
,
'elementwise_add'
,
'softmax'
,
'cross_entropy2'
,
'mean'
,
'fill_constant'
,
'scale'
,
'mean_grad'
,
'cross_entropy_grad2'
,
'softmax_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'tanh_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'tanh_grad'
,
'elementwise_add_grad'
,
'mul_grad'
,
'c_sync_calc_stream'
,
'c_reduce_sum'
,
'c_reduce_sum'
,
'c_reduce_sum'
,
'c_reduce_sum'
,
'c_reduce_sum'
,
'c_reduce_sum'
,
'c_sync_comm_stream'
,
'elementwise_add'
,
'elementwise_add'
,
'elementwise_add'
,
'increment'
,
'elementwise_mod'
,
'equal'
,
'conditional_block'
,
])
])
self
.
assertEqual
(
opt_ops
,
[
self
.
assertEqual
(
opt_ops
,
[
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'scale'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'c_allreduce_sum'
,
'scale'
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录