Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
c47853f6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c47853f6
编写于
4月 19, 2023
作者:
Z
zhaoyingli
提交者:
GitHub
4月 19, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[AutoParallel] add gradient_merge master_grad & 1F1B pass (#52647)
上级
6f3c9643
变更
18
隐藏空白更改
内联
并排
Showing
18 changed file
with
1532 addition
and
648 deletion
+1532
-648
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
+10
-32
python/paddle/distributed/auto_parallel/constants.py
python/paddle/distributed/auto_parallel/constants.py
+11
-2
python/paddle/distributed/auto_parallel/dist_loader.py
python/paddle/distributed/auto_parallel/dist_loader.py
+97
-63
python/paddle/distributed/auto_parallel/dist_saver.py
python/paddle/distributed/auto_parallel/dist_saver.py
+33
-18
python/paddle/distributed/auto_parallel/engine.py
python/paddle/distributed/auto_parallel/engine.py
+36
-27
python/paddle/distributed/auto_parallel/parallelizer_v2.py
python/paddle/distributed/auto_parallel/parallelizer_v2.py
+9
-7
python/paddle/distributed/auto_parallel/partitioner.py
python/paddle/distributed/auto_parallel/partitioner.py
+234
-139
python/paddle/distributed/auto_parallel/reshard.py
python/paddle/distributed/auto_parallel/reshard.py
+68
-14
python/paddle/distributed/auto_parallel/strategy.py
python/paddle/distributed/auto_parallel/strategy.py
+9
-0
python/paddle/distributed/auto_parallel/utils.py
python/paddle/distributed/auto_parallel/utils.py
+7
-0
python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py
...ibuted/passes/auto_parallel_data_parallel_optimization.py
+175
-104
python/paddle/distributed/passes/auto_parallel_gradient_merge.py
...paddle/distributed/passes/auto_parallel_gradient_merge.py
+309
-153
python/paddle/distributed/passes/auto_parallel_pipeline.py
python/paddle/distributed/passes/auto_parallel_pipeline.py
+325
-88
python/paddle/fluid/tests/unittests/auto_parallel/1F1B_pass_unittest.py
...fluid/tests/unittests/auto_parallel/1F1B_pass_unittest.py
+126
-0
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
...paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
+3
-0
python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py
...ddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py
+7
-0
python/paddle/fluid/tests/unittests/auto_parallel/gradient_merge_pass_unittest.py
...s/unittests/auto_parallel/gradient_merge_pass_unittest.py
+16
-1
python/paddle/fluid/tests/unittests/auto_parallel/test_pass_1F1B.py
...dle/fluid/tests/unittests/auto_parallel/test_pass_1F1B.py
+57
-0
未找到文件。
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
浏览文件 @
c47853f6
...
...
@@ -110,12 +110,15 @@ void PreventVarsDelete(
std
::
vector
<
std
::
string
>
GetUnusedVarsAfterWhile
(
const
framework
::
ProgramDesc
&
program_desc
,
TaskNode
*
cond_task
,
const
std
::
vector
<
std
::
string
>&
vars_not_gc
)
{
// NOTE: Since while op won't appear in task node, in order to analyze
// the vars which should be free after calling while op, we rebuild the
// whole program and get the unused vars after calling while op.
// vars in parent block should not be free until the while op is finished.
// The local vars will be free while running op in sub block.
// The vars in while block should not be free until the while op is finished.
// In a word, the vars need to be free after while op is:
// 1. Vars in parent block and being used in while block.
// 2. Local vars only defined in while block.
// The unused vars above will be free in cond interceptor.
std
::
vector
<
std
::
string
>
while_block_vars
;
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
ops
;
...
...
@@ -129,29 +132,14 @@ std::vector<std::string> GetUnusedVarsAfterWhile(
for
(
const
auto
&
var_name
:
pair
.
second
)
{
while_block_vars
.
emplace_back
(
var_name
);
}
for
(
auto
&
var
:
program_desc
.
Block
(
1
).
AllVars
())
{
while_block_vars
.
emplace_back
(
var
->
Name
());
}
}
}
return
while_block_vars
;
}
std
::
unordered_map
<
const
framework
::
OperatorBase
*
,
std
::
vector
<
std
::
string
>>
GetSubUnusedVars
(
const
framework
::
ProgramDesc
&
program_desc
,
const
std
::
set
<
TaskNode
*>&
sub_block_tasks
,
const
std
::
vector
<
std
::
string
>&
vars_not_gc
)
{
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
ops
;
for
(
auto
*
task_node
:
sub_block_tasks
)
{
for
(
const
auto
&
op
:
task_node
->
ops
())
{
ops
.
emplace_back
(
std
::
unique_ptr
<
framework
::
OperatorBase
>
(
op
));
}
}
auto
unused_vars
=
framework
::
GetUnusedVars
(
program_desc
.
Block
(
1
),
ops
,
{});
for
(
auto
&
unique_op
:
ops
)
{
unique_op
.
release
();
}
PreventVarsDelete
(
&
unused_vars
,
vars_not_gc
);
return
unused_vars
;
}
}
// namespace
void
FleetExecutor
::
Init
(
...
...
@@ -174,13 +162,8 @@ void FleetExecutor::Init(
for
(
const
auto
&
task_node
:
task_nodes
)
{
if
(
task_node
->
type
()
==
"Cond"
)
{
GetSubBlockTask
(
task_nodes
,
task_node
,
&
sub_block_tasks
);
while_block_vars
=
GetUnusedVarsAfterWhile
(
program_desc
,
inference_root_scope_vars
);
for
(
auto
*
task_node
:
sub_block_tasks
)
{
for
(
auto
iter
:
task_node
->
vars_to_dtype
())
{
while_block_vars
.
emplace_back
(
iter
.
first
);
}
}
while_block_vars
=
GetUnusedVarsAfterWhile
(
program_desc
,
task_node
,
inference_root_scope_vars
);
VLOG
(
3
)
<<
"Vars will be gced after while op"
;
for
(
auto
var
:
while_block_vars
)
{
VLOG
(
3
)
<<
var
;
...
...
@@ -210,9 +193,6 @@ void FleetExecutor::Init(
unique_op
.
release
();
}
auto
sub_unused_vars
=
GetSubUnusedVars
(
program_desc
,
sub_block_tasks
,
while_block_vars
);
// NOTE: For inference, the vars in inference_root_scope_vars
// shouldn't be deleted during inf, for that they may be the result of the
// inf. If they are GCed, it will cause error during ZeroCopy the result.
...
...
@@ -223,8 +203,6 @@ void FleetExecutor::Init(
for
(
auto
task_node
:
task_nodes
)
{
if
(
sub_block_tasks
.
find
(
task_node
)
==
sub_block_tasks
.
end
())
{
task_node
->
SetUnusedVars
(
global_unused_vars
);
}
else
{
task_node
->
SetUnusedVars
(
sub_unused_vars
);
}
int64_t
interceptor_id
=
task_node
->
task_id
();
interceptor_id_to_task
.
emplace
(
interceptor_id
,
task_node
);
...
...
python/paddle/distributed/auto_parallel/constants.py
浏览文件 @
c47853f6
...
...
@@ -117,9 +117,9 @@ set_field_default_config(QAT, "activation_bits", 8)
set_field_default_config
(
QAT
,
"not_quant_pattern"
,
[
'skip_quant'
])
set_field_default_config
(
QAT
,
"algo"
,
None
)
#
#
########################################
#########################################
# auto tuning configuration
#
#
########################################
#########################################
TUNING
=
"tuning"
set_field_default_config
(
TUNING
,
"enable"
,
False
)
set_field_default_config
(
TUNING
,
"batch_size"
,
1
)
...
...
@@ -135,3 +135,12 @@ set_field_default_config(TUNING, "verbose", True)
DATASET
=
"dataset"
set_field_default_config
(
DATASET
,
"enable"
,
False
)
set_field_default_config
(
DATASET
,
"num_shards"
,
1
)
#########################################
# data parallel configuration
#########################################
DP_OPTIMIZATION
=
"dp_optimization"
set_field_default_config
(
DP_OPTIMIZATION
,
"enable"
,
False
)
set_field_default_config
(
DP_OPTIMIZATION
,
"fuse_all_reduce_ops"
,
True
)
set_field_default_config
(
DP_OPTIMIZATION
,
"fuse_grad_size_in_MB"
,
32
)
set_field_default_config
(
DP_OPTIMIZATION
,
"overlap_comm_cacl"
,
True
)
python/paddle/distributed/auto_parallel/dist_loader.py
浏览文件 @
c47853f6
...
...
@@ -17,12 +17,18 @@ import numpy as np
import
paddle
from
paddle.io
import
BatchSampler
,
IterableDataset
from
paddle.fluid.dataloader.batch_sampler
import
_InfiniteIterableSampler
,
DistributedBatchSampler
from
paddle.fluid.dataloader.dataloader_iter
import
_DatasetKind
,
default_collate_fn
,
default_convert_fn
from
paddle.fluid.dataloader.batch_sampler
import
(
_InfiniteIterableSampler
,
DistributedBatchSampler
,
)
from
paddle.fluid.dataloader.dataloader_iter
import
(
_DatasetKind
,
default_collate_fn
,
default_convert_fn
,
)
class
DistributedDataLoaderBase
(
metaclass
=
abc
.
ABCMeta
):
@
abc
.
abstractmethod
def
__iter__
(
self
):
raise
NotImplementedError
...
...
@@ -43,24 +49,26 @@ class DistributedDataLoaderBase(metaclass=abc.ABCMeta):
class
DistributedDataLoaderFromGenerator
(
DistributedDataLoaderBase
):
def
__init__
(
self
,
dataset
,
feed_list
=
None
,
capacity
=
None
,
use_double_buffer
=
True
,
iterable
=
True
,
return_list
=
False
,
use_multiprocess
=
False
,
drop_last
=
True
,
places
=
None
,
batch_size
=
1
,
epochs
=
1
,
steps_per_epoch
=
None
,
collate_fn
=
None
,
split_data
=
True
,
data_parallel_world_size
=
[],
data_parallel_rank
=
[]):
def
__init__
(
self
,
dataset
,
feed_list
=
None
,
capacity
=
None
,
use_double_buffer
=
True
,
iterable
=
True
,
return_list
=
False
,
use_multiprocess
=
False
,
drop_last
=
True
,
places
=
None
,
batch_size
=
1
,
epochs
=
1
,
steps_per_epoch
=
None
,
collate_fn
=
None
,
split_data
=
True
,
data_parallel_world_size
=
[],
data_parallel_rank
=
[],
acc_steps
=
1
,
):
self
.
dataset
=
dataset
self
.
feed_list
=
feed_list
self
.
capacity
=
capacity
...
...
@@ -79,6 +87,7 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
assert
len
(
data_parallel_rank
)
==
len
(
feed_list
)
self
.
dp_world_sizes
=
data_parallel_world_size
self
.
dp_ranks
=
data_parallel_rank
self
.
acc_steps
=
acc_steps
if
isinstance
(
dataset
,
IterableDataset
):
self
.
dataset_kind
=
_DatasetKind
.
ITER
...
...
@@ -90,12 +99,15 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
else
:
if
isinstance
(
dataset
,
IterableDataset
):
self
.
batch_sampler
=
_InfiniteIterableSampler
(
dataset
,
batch_size
)
dataset
,
batch_size
)
else
:
self
.
batch_sampler
=
BatchSampler
(
dataset
,
batch_size
=
batch_size
,
shuffle
=
False
,
drop_last
=
drop_last
)
self
.
batch_sampler
=
BatchSampler
(
dataset
,
batch_size
=
batch_size
,
shuffle
=
False
,
drop_last
=
drop_last
,
)
self
.
auto_collate_batch
=
self
.
batch_sampler
is
not
None
self
.
sampler_iter
=
iter
(
self
.
index_sampler
)
...
...
@@ -106,8 +118,12 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
self
.
collate_fn
=
collate_fn
or
default_convert_fn
self
.
dataset_fetcher
=
_DatasetKind
.
create_fetcher
(
self
.
dataset_kind
,
self
.
dataset
,
self
.
auto_collate_batch
,
self
.
collate_fn
,
self
.
drop_last
)
self
.
dataset_kind
,
self
.
dataset
,
self
.
auto_collate_batch
,
self
.
collate_fn
,
self
.
drop_last
,
)
self
.
_steps
=
self
.
_infer_steps
()
self
.
_inner_dataloader
=
self
.
_create_inner_dataloader
()
...
...
@@ -136,9 +152,11 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
if
isinstance
(
self
.
dataset
,
IterableDataset
):
steps_per_epoch
=
None
elif
self
.
batch_size
is
None
:
steps_per_epoch
=
len
(
self
.
dataset
)
steps_per_epoch
=
len
(
self
.
dataset
)
//
self
.
acc_steps
else
:
steps_per_epoch
=
len
(
self
.
dataset
)
//
self
.
batch_size
steps_per_epoch
=
(
len
(
self
.
dataset
)
//
self
.
batch_size
//
self
.
acc_steps
)
except
:
raise
ValueError
(
"Pleace set `steps_per_epoch` or implement `__len__` methond in dataset class."
...
...
@@ -156,18 +174,21 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
return
_InfiniteIterableSampler
(
self
.
dataset
,
1
)
def
_create_inner_dataloader
(
self
):
def
data_generator
():
while
True
:
try
:
indices
=
next
(
self
.
sampler_iter
)
batch
=
self
.
dataset_fetcher
.
fetch
(
indices
)
if
batch
is
None
:
break
if
batch
is
None
:
break
except
StopIteration
:
self
.
dataset_fetcher
=
_DatasetKind
.
create_fetcher
(
self
.
dataset_kind
,
self
.
dataset
,
self
.
auto_collate_batch
,
self
.
collate_fn
,
self
.
drop_last
)
self
.
dataset_kind
,
self
.
dataset
,
self
.
auto_collate_batch
,
self
.
collate_fn
,
self
.
drop_last
,
)
break
partial_data
=
[]
...
...
@@ -178,11 +199,16 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
continue
batch_size
=
array
.
shape
[
0
]
assert
batch_size
%
self
.
dp_world_sizes
[
i
]
==
0
,
\
"batch_size [{}] is not divisible by dp_world_size [{}]"
.
format
(
str
(
batch_size
),
str
(
self
.
dp_world_sizes
[
i
]))
assert
(
batch_size
%
self
.
dp_world_sizes
[
i
]
==
0
),
"batch_size [{}] is not divisible by dp_world_size [{}]"
.
format
(
str
(
batch_size
),
str
(
self
.
dp_world_sizes
[
i
])
)
partial_data
.
append
(
np
.
split
(
array
,
self
.
dp_world_sizes
[
i
])[
self
.
dp_ranks
[
i
]])
np
.
split
(
array
,
self
.
dp_world_sizes
[
i
])[
self
.
dp_ranks
[
i
]
]
)
yield
partial_data
...
...
@@ -194,33 +220,35 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
iterable
=
False
,
return_list
=
self
.
return_list
,
use_multiprocess
=
self
.
use_multiprocess
,
drop_last
=
self
.
drop_last
)
drop_last
=
self
.
drop_last
,
)
dataloader
.
set_batch_generator
(
data_generator
,
self
.
places
)
return
dataloader
class
DistributedDataLoader
(
DistributedDataLoaderBase
):
def
__init__
(
self
,
dataset
,
feed_list
=
None
,
places
=
None
,
return_list
=
True
,
batch_size
=
1
,
shuffle
=
False
,
drop_last
=
False
,
collate_fn
=
None
,
num_workers
=
0
,
use_buffer_reader
=
True
,
use_shared_memory
=
True
,
timeout
=
0
,
worker_init_fn
=
None
,
epochs
=
1
,
steps_per_epoch
=
None
,
split_data
=
True
,
data_parallel_world_size
=
[],
data_parallel_rank
=
[]):
def
__init__
(
self
,
dataset
,
feed_list
=
None
,
places
=
None
,
return_list
=
True
,
batch_size
=
1
,
shuffle
=
False
,
drop_last
=
False
,
collate_fn
=
None
,
num_workers
=
0
,
use_buffer_reader
=
True
,
use_shared_memory
=
True
,
timeout
=
0
,
worker_init_fn
=
None
,
epochs
=
1
,
steps_per_epoch
=
None
,
split_data
=
True
,
data_parallel_world_size
=
[],
data_parallel_rank
=
[],
):
self
.
dataset
=
dataset
self
.
feed_list
=
feed_list
self
.
return_list
=
return_list
...
...
@@ -241,8 +269,13 @@ class DistributedDataLoader(DistributedDataLoaderBase):
self
.
split_data
=
split_data
# TODO: rank info
self
.
batch_sampler
=
DistributedBatchSampler
(
self
.
dataset
,
self
.
batch_size
,
self
.
dp_world_sizes
[
0
],
self
.
dp_ranks
[
0
],
self
.
shuffle
,
self
.
drop_last
)
self
.
dataset
,
self
.
batch_size
,
self
.
dp_world_sizes
[
0
],
self
.
dp_ranks
[
0
],
self
.
shuffle
,
self
.
drop_last
,
)
self
.
_inner_dataloader
=
self
.
_create_inner_dataloader
()
def
__iter__
(
self
):
...
...
@@ -263,7 +296,8 @@ class DistributedDataLoader(DistributedDataLoaderBase):
use_buffer_reader
=
self
.
use_buffer_reader
,
use_shared_memory
=
self
.
use_shared_memory
,
timeout
=
self
.
timeout
,
worker_init_fn
=
self
.
worker_init_fn
)
worker_init_fn
=
self
.
worker_init_fn
,
)
self
.
data
=
(
x
for
x
in
dataloader
)
return
dataloader
python/paddle/distributed/auto_parallel/dist_saver.py
浏览文件 @
c47853f6
...
...
@@ -18,6 +18,7 @@ import errno
import
pickle
import
warnings
import
logging
import
collections
import
numpy
as
np
import
paddle
...
...
@@ -53,16 +54,13 @@ def _process_path(path):
class
DistributedSaver
:
def
__init__
(
self
):
self
.
_logger
=
get_logger
(
logging
.
INFO
)
def
save
(
self
,
path
,
serial_program
,
dist_main_program
,
dist_context
):
def
_save_state
(
program
,
path
,
mode
=
"param"
):
state
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
program
.
state_dict
(
mode
).
items
()
k
:
np
.
array
(
v
)
for
k
,
v
in
program
.
state_dict
(
mode
).
items
()
}
with
open
(
path
,
"wb"
)
as
f
:
pickle
.
dump
(
state
,
f
)
...
...
@@ -108,8 +106,9 @@ class DistributedSaver:
def
_load_file
(
filename
,
dirname
,
suffix
=
"pdparams"
):
file_list
=
[]
for
file
in
os
.
listdir
(
dirname
):
if
check_filename
(
'{}(.*)_dist(.*).{}'
.
format
(
filename
,
suffix
),
file
):
if
check_filename
(
'{}(.*)_dist(.*).{}'
.
format
(
filename
,
suffix
),
file
):
file_list
.
append
(
os
.
path
.
join
(
dirname
,
file
))
file_list
.
sort
()
return
file_list
...
...
@@ -137,14 +136,16 @@ class DistributedSaver:
# load path.pdparam and path.pdopt
param_state_dict
=
_load_state
(
filename
,
dirname
)
opt_state_dict
=
_load_state
(
filename
,
dirname
,
"pdopt"
)
if
load_optimizer
else
{}
opt_state_dict
=
(
_load_state
(
filename
,
dirname
,
"pdopt"
)
if
load_optimizer
else
{}
)
state_dict
=
dict
(
param_state_dict
,
**
opt_state_dict
)
# load path.pdattr
dist_attr_file_list
=
_load_file
(
filename
,
dirname
,
"pdattr"
)
self
.
_logger
.
info
(
"Load distributed attribute file: {}"
.
format
(
dist_attr_file_list
))
"Load distributed attribute file: {}"
.
format
(
dist_attr_file_list
)
)
dist_attr
=
{}
for
dist_attr_file
in
dist_attr_file_list
:
with
open
(
dist_attr_file
,
'rb'
)
as
f
:
...
...
@@ -196,12 +197,24 @@ class DistributedSaver:
used_inputs
+=
op
.
input_arg_names
used_outputs
+=
op
.
output_arg_names
dist_feed_vars_names
=
list
(
set
(
feed_vars_names
)
&
set
(
used_inputs
))
dist_fetch_vars_names
=
list
(
set
(
fetch_vars_names
)
&
set
(
used_outputs
))
# delete duplicated elements and keep order
feed_vars_names
=
list
({}.
fromkeys
(
feed_vars_names
).
keys
())
used_inputs
=
list
({}.
fromkeys
(
used_inputs
).
keys
())
fetch_vars_names
=
list
({}.
fromkeys
(
fetch_vars_names
).
keys
())
used_outputs
=
list
({}.
fromkeys
(
used_outputs
).
keys
())
dist_feed_vars
=
[
global_block
.
vars
[
name
]
for
name
in
dist_feed_vars_name
s
dist_feed_vars
_names
=
[
var_name
for
var_name
in
feed_vars_names
if
var_name
in
used_input
s
]
dist_fetch_vars_names
=
[
var_name
for
var_name
in
fetch_vars_names
if
var_name
in
used_outputs
]
dist_feed_vars
=
list
(
reversed
([
global_block
.
vars
[
name
]
for
name
in
dist_feed_vars_names
])
)
dist_fetch_vars
=
[
global_block
.
vars
[
name
]
for
name
in
dist_fetch_vars_names
]
...
...
@@ -209,11 +222,13 @@ class DistributedSaver:
# NOTE: `paddle.static.save_inference_model` does not support subblock.
dist_filename
=
filename
+
"_dist"
+
str
(
rank_id
)
dist_path
=
os
.
path
.
join
(
dirname
,
dist_filename
)
paddle
.
static
.
save_inference_model
(
dist_path
,
dist_feed_vars
,
dist_fetch_vars
,
exe
,
program
=
dist_main_prog
)
paddle
.
static
.
save_inference_model
(
dist_path
,
dist_feed_vars
,
dist_fetch_vars
,
exe
,
program
=
dist_main_prog
,
)
def
_save_rank_mapping
(
self
,
dirname
):
path
=
os
.
path
.
join
(
dirname
,
'rank_mapping.csv'
)
...
...
python/paddle/distributed/auto_parallel/engine.py
浏览文件 @
c47853f6
...
...
@@ -225,6 +225,11 @@ class Engine:
self
.
_planned_mode
=
None
self
.
_dygraph_mode
=
False
self
.
_tuning
=
self
.
_strategy
.
tuning
self
.
_acc_steps
=
1
if
self
.
_strategy
.
gradient_merge
.
enable
:
self
.
_acc_steps
=
self
.
_strategy
.
gradient_merge
.
k_steps
elif
self
.
_strategy
.
pipeline
.
enable
:
self
.
_acc_steps
=
self
.
_strategy
.
pipeline
.
accumulate_steps
self
.
history
=
None
...
...
@@ -388,7 +393,12 @@ class Engine:
if
self
.
main_program
.
_pipeline_opt
:
assert
"tasks"
in
self
.
main_program
.
_pipeline_opt
[
"fleet_opt"
]
fleet_opt
=
self
.
main_program
.
_pipeline_opt
[
"fleet_opt"
]
fwd_task
=
fleet_opt
[
"tasks"
][
0
]
fwd_task
=
None
if
self
.
_strategy
.
pipeline
.
schedule_mode
==
"1F1B"
:
fwd_task
=
fleet_opt
[
"tasks"
][
1
]
elif
self
.
_strategy
.
pipeline
.
schedule_mode
==
"stream"
:
fwd_task
=
fleet_opt
[
"tasks"
][
0
]
assert
fwd_task
is
not
None
fwd_prog
=
fwd_task
.
get_program
()
fwd_block
=
fwd_prog
.
global_block
()
...
...
@@ -438,8 +448,6 @@ class Engine:
),
"user_fetches must be a list, but receive {}"
.
format
(
type
(
user_fetches
).
__name__
)
else
:
user_fetches
=
[]
fetch_names
=
[]
fetch_indices
=
[]
...
...
@@ -466,7 +474,7 @@ class Engine:
_process_fetch_group
(
"metrics_"
+
str
(
i
),
var_list
)
if
mode
==
"predict"
:
_process_fetch_group
(
"outputs"
,
fetch_vars
[
"outputs"
])
for
usr_fetch
in
user_fetches
:
for
usr_fetch
in
user_fetches
or
[]
:
var_name
=
_to_name_str
(
usr_fetch
)
fetch
(
var_name
)
user_fetches_collection
=
[
...
...
@@ -903,6 +911,7 @@ class Engine:
self
.
_inputs_spec
,
self
.
_labels_spec
=
self
.
_prepare_data_spec
(
train_data
,
train_sample_split
,
batch_size
)
batch_size
=
self
.
_validate_batch_size
(
batch_size
)
if
not
self
.
_has_prepared
[
self
.
_mode
]:
self
.
_prepare_program
(
self
.
_mode
)
else
:
...
...
@@ -931,7 +940,7 @@ class Engine:
save_dir
=
save_dir
,
verbose
=
verbose
,
metrics
=
self
.
_metrics_name
(),
acc_step
=
self
.
_
k
_steps
,
acc_step
=
self
.
_
acc
_steps
,
)
cbks
.
on_begin
(
'train'
)
...
...
@@ -965,7 +974,7 @@ class Engine:
val_logs
=
self
.
evaluate
(
valid_data
,
valid_sample_split
,
batch_size
,
batch_size
*
self
.
_acc_steps
,
valid_steps
,
log_freq
,
collate_fn
,
...
...
@@ -1046,6 +1055,7 @@ class Engine:
self
.
_inputs_spec
,
self
.
_labels_spec
=
self
.
_prepare_data_spec
(
valid_data
,
valid_sample_split
,
batch_size
)
batch_size
=
self
.
_validate_batch_size
(
batch_size
)
if
not
self
.
_has_prepared
[
self
.
_mode
]:
self
.
_prepare_program
(
self
.
_mode
)
else
:
...
...
@@ -1152,6 +1162,7 @@ class Engine:
self
.
_inputs_spec
,
self
.
_labels_spec
=
self
.
_prepare_data_spec
(
test_data
,
test_sample_split
,
batch_size
)
batch_size
=
self
.
_validate_batch_size
(
batch_size
)
if
not
self
.
_has_prepared
[
self
.
_mode
]:
self
.
_prepare_program
(
self
.
_mode
)
else
:
...
...
@@ -1214,6 +1225,7 @@ class Engine:
self
.
_inputs_spec
,
self
.
_labels_spec
=
self
.
_prepare_data_spec
(
dataset
,
sample_split
,
batch_size
)
batch_size
=
self
.
_validate_batch_size
(
batch_size
)
if
not
self
.
_has_prepared
[
self
.
_mode
]:
self
.
_prepare_program
(
self
.
_mode
)
else
:
...
...
@@ -1256,6 +1268,7 @@ class Engine:
self
.
_inputs_spec
,
self
.
_labels_spec
=
self
.
_prepare_data_spec
(
dataset
,
sample_split
,
batch_size
)
batch_size
=
self
.
_validate_batch_size
(
batch_size
)
if
not
self
.
_has_prepared
[
self
.
_mode
]:
self
.
_prepare_program
(
self
.
_mode
)
else
:
...
...
@@ -1371,14 +1384,6 @@ class Engine:
steps_per_epoch
=
None
,
):
if
self
.
_strategy
.
gradient_merge
and
batch_size
is
not
None
:
assert
(
batch_size
%
self
.
_k_steps
==
0
),
"Requires batch_size:[{}] to be divisible by k_steps:[{}]."
.
format
(
batch_size
,
self
.
_k_steps
)
batch_size
//=
self
.
_k_steps
dist_context
=
self
.
_dist_contexts
[
self
.
_mode
]
dist_main_prog
=
dist_context
.
dist_main_programs
[
self
.
_cur_rank
]
dist_startup_prog
=
dist_context
.
dist_startup_programs
[
self
.
_cur_rank
]
...
...
@@ -1440,14 +1445,6 @@ class Engine:
collate_fn
=
None
,
):
if
self
.
_strategy
.
gradient_merge
and
batch_size
is
not
None
:
assert
(
batch_size
%
self
.
_k_steps
==
0
),
"Requires batch_size:[{}] to be divisible by k_steps:[{}]."
.
format
(
batch_size
,
self
.
_k_steps
)
batch_size
//=
self
.
_k_steps
dist_context
=
self
.
_dist_contexts
[
self
.
_mode
]
dist_main_prog
=
dist_context
.
dist_main_programs
[
self
.
_cur_rank
]
dist_startup_prog
=
dist_context
.
dist_startup_programs
[
self
.
_cur_rank
]
...
...
@@ -1487,6 +1484,9 @@ class Engine:
split_data
=
self
.
_strategy
.
split_data
,
data_parallel_world_size
=
self
.
_dp_world_sizes
,
data_parallel_rank
=
self
.
_dp_ranks
,
acc_steps
=
1
if
not
self
.
_strategy
.
pipeline
.
enable
else
self
.
_acc_steps
,
)
self
.
_prepare_reader
(
feed_list
)
return
dataloader
...
...
@@ -1498,9 +1498,18 @@ class Engine:
)
self
.
_optimization_tuning
(
self
.
_mode
,
tune_data
,
batch_size
)
def
_validate_batch_size
(
self
,
batch_size
):
if
batch_size
is
None
:
return
None
assert
(
batch_size
%
self
.
_acc_steps
==
0
),
"Requires batch_size:[{}] to be divisible by acc_steps:[{}]."
.
format
(
batch_size
,
self
.
_acc_steps
)
return
batch_size
//
self
.
_acc_steps
def
_validate_spec
(
self
,
specs
):
specs
=
to_list
(
specs
)
self
.
_k_steps
=
self
.
_strategy
.
gradient_merge
.
k_steps
if
specs
is
not
None
:
for
i
,
spec
in
enumerate
(
specs
):
if
not
isinstance
(
spec
,
InputSpec
):
...
...
@@ -1513,14 +1522,14 @@ class Engine:
i
,
spec
)
)
if
self
.
_
k
_steps
>
1
:
if
self
.
_
acc
_steps
>
1
:
shape
=
list
(
spec
.
shape
)
assert
(
shape
[
0
]
%
self
.
_
k
_steps
==
0
shape
[
0
]
%
self
.
_
acc
_steps
==
0
),
"Requires batch_size[{}] to be divisible by k_steps[{}]."
.
format
(
spec
.
shape
[
0
],
self
.
_
k
_steps
spec
.
shape
[
0
],
self
.
_
acc
_steps
)
shape
[
0
]
//=
self
.
_
k
_steps
shape
[
0
]
//=
self
.
_
acc
_steps
spec
.
shape
=
shape
return
specs
or
[]
...
...
python/paddle/distributed/auto_parallel/parallelizer_v2.py
浏览文件 @
c47853f6
...
...
@@ -297,13 +297,15 @@ class Parallelizer:
if
self
.
_strategy
is
None
:
return
# data parallel optimization
config
=
{}
config
[
"dist_context"
]
=
self
.
_dist_context
config
[
"global_rank"
]
=
rank
config
[
"use_sharding"
]
=
self
.
_strategy
.
sharding
.
enable
dp_pass
=
new_pass
(
"auto_parallel_data_parallel_optimization"
,
config
)
dp_pass
.
apply
([
main_program
],
[
startup_program
],
self
.
_pass_context
)
if
self
.
_strategy
.
dp_optimization
.
enable
:
config
=
copy
.
deepcopy
(
self
.
_strategy
.
dp_optimization
.
to_dict
())
config
[
"dist_context"
]
=
self
.
_dist_context
config
[
"global_rank"
]
=
rank
config
[
"use_sharding"
]
=
self
.
_strategy
.
sharding
.
enable
dp_pass
=
new_pass
(
"auto_parallel_data_parallel_optimization"
,
config
)
dp_pass
.
apply
([
main_program
],
[
startup_program
],
self
.
_pass_context
)
if
self
.
_strategy
.
sharding
.
enable
:
config
=
copy
.
deepcopy
(
self
.
_strategy
.
sharding
.
to_dict
())
...
...
python/paddle/distributed/auto_parallel/partitioner.py
浏览文件 @
c47853f6
...
...
@@ -13,24 +13,25 @@
# limitations under the License
import
copy
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid
import
core
from
paddle.fluid
import
framework
as
framework
from
paddle.fluid
import
core
,
unique_name
from
paddle.fluid.framework
import
Program
,
Parameter
,
Variable
,
program_guard
from
paddle.distributed.auto_parallel.operators.common
import
get_distributed_operator_impl_container
from
paddle.distributed.auto_parallel.dist_context
import
DistributedContext
,
DistributedOperatorContext
from
paddle.fluid.framework
import
Program
,
Parameter
,
core
from
paddle.distributed.auto_parallel.operators.common
import
(
get_distributed_operator_impl_container
,
)
from
paddle.distributed.auto_parallel.dist_context
import
DistributedContext
from
.dist_attribute
import
OperatorDistributedAttribute
from
.process_group
import
new_process_group
from
.utils
import
set_dist_op_desc_original_id
from
.utils
import
print_program_with_dist_attr
,
is_forward_op
,
is_backward_op
,
is_loss_op
,
is_optimize_op
from
.utils
import
(
is_forward_op
,
is_backward_op
,
is_loss_op
,
is_optimize_op
,
is_fillconst_op_for_micro_batch
,
)
from
.operators.common
import
BACKWARD_ONLY_DIST_OPS
__varname_not_in_block__
=
[
"lod_tensor_blocking_queue"
]
__not_shape_var_type__
=
[
core
.
VarDesc
.
VarType
.
READER
,
core
.
VarDesc
.
VarType
.
STEP_SCOPES
core
.
VarDesc
.
VarType
.
READER
,
core
.
VarDesc
.
VarType
.
STEP_SCOPES
,
]
...
...
@@ -39,7 +40,7 @@ class Partitioner(object):
warning:: Partitioner is experimental and subject to change.
Partitioner convert a program into another program.
Given a serial program which has been auto completed with shard annotation, the Partitioner
Given a serial program which has been auto completed with shard annotation, the Partitioner
convert the serial program into a "distributed" program. The Partitioner will modify the serial
program in following two ways, which is also the major difference between serial and distributed program:
1. partition op: replace a serial op into its corresponding dist op infered from the shard annotation
...
...
@@ -56,25 +57,29 @@ class Partitioner(object):
"""
if
not
isinstance
(
dist_context
,
DistributedContext
):
raise
TypeError
(
"dist_context be paddle.fluid.DistributedContext, got %s here"
%
type
(
dist_context
))
"dist_context be paddle.fluid.DistributedContext, got %s here"
%
type
(
dist_context
)
)
self
.
_dist_context
=
dist_context
self
.
_rank_id
=
rank_id
self
.
_serial2dist_varname_mapping
=
{}
self
.
_dist_varname_suffix
=
""
def
partition
(
self
,
serial_main_program
,
serial_startup_program
,
params_grads
):
def
partition
(
self
,
serial_main_program
,
serial_startup_program
,
params_grads
):
if
not
isinstance
(
serial_main_program
,
(
Program
)):
raise
TypeError
(
"main_program be paddle.fluid.framework.program, got %s here"
%
type
(
serial_main_program
))
"main_program be paddle.fluid.framework.program, got %s here"
%
type
(
serial_main_program
)
)
# check if shard annotated serial program valid
if
not
self
.
_is_valid_annotated_program
(
serial_main_program
):
raise
RuntimeError
(
"Not all vars or ops are annotated in main program !"
)
"Not all vars or ops are annotated in main program !"
)
# init distop helper
dist_op_context
=
self
.
_dist_context
.
dist_op_context
...
...
@@ -86,24 +91,33 @@ class Partitioner(object):
partitioned_startup_prog
=
None
else
:
partitioned_startup_prog
=
self
.
partition_startup_program
(
serial_main_program
,
serial_startup_program
)
serial_main_program
,
serial_startup_program
)
dist_op_context
.
dst_startup_program
=
partitioned_startup_prog
# partition main program
partitioned_main_prog
,
partitioned_params_grads
=
self
.
partition_main_program
(
serial_main_program
,
params_grads
)
(
partitioned_main_prog
,
partitioned_params_grads
,
)
=
self
.
partition_main_program
(
serial_main_program
,
params_grads
)
return
partitioned_main_prog
,
partitioned_startup_prog
,
partitioned_params_grads
return
(
partitioned_main_prog
,
partitioned_startup_prog
,
partitioned_params_grads
,
)
def
partition_startup_program
(
self
,
serial_main_program
,
serial_startup_program
):
def
partition_startup_program
(
self
,
serial_main_program
,
serial_startup_program
):
if
not
isinstance
(
serial_startup_program
,
(
Program
)):
raise
TypeError
(
"dist_context be paddle.fluid.framework.program, got %s here"
%
type
(
serial_startup_program
))
"dist_context be paddle.fluid.framework.program, got %s here"
%
type
(
serial_startup_program
)
)
partitioned_startup_prog
=
fluid
.
Program
()
partitioned_startup_prog
=
Program
()
ref_block
=
serial_main_program
.
global_block
()
target_block
=
partitioned_startup_prog
.
global_block
()
var2shape
=
{}
...
...
@@ -114,27 +128,33 @@ class Partitioner(object):
assert
var
.
persistable
new_name
=
var
.
name
+
self
.
_dist_varname_suffix
temp_varname_map
[
var
.
name
]
=
new_name
target_shape
=
_partition_var
(
self
.
_dist_context
,
ref_block
,
target_block
,
var
.
name
,
new_name
)
target_shape
=
_partition_var
(
self
.
_dist_context
,
ref_block
,
target_block
,
var
.
name
,
new_name
)
var2shape
[
new_name
]
=
target_shape
# ops
for
op
in
serial_startup_program
.
global_block
().
ops
:
# TODO if var not belong to this rank, should be filtered
output_vars
=
op
.
desc
.
output_arg_names
()
assert
len
(
output_vars
)
==
1
,
"initializer should output only ONE variable, but got [{}]"
.
format
(
str
(
op
.
desc
))
assert
temp_varname_map
[
output_vars
[
0
]]
in
var2shape
,
"try to initialize [{}] which is not a persistable var"
.
format
(
output_vars
[
0
])
assert
(
len
(
output_vars
)
==
1
),
"initializer should output only ONE variable, but got [{}]"
.
format
(
str
(
op
.
desc
)
)
assert
(
temp_varname_map
[
output_vars
[
0
]]
in
var2shape
),
"try to initialize [{}] which is not a persistable var"
.
format
(
output_vars
[
0
]
)
new_op_desc
=
target_block
.
desc
.
append_op
()
new_op_desc
.
copy_from
(
op
.
desc
)
new_op_desc
.
_rename_output
(
output_vars
[
0
],
temp_varname_map
[
output_vars
[
0
]])
new_op_desc
.
_set_attr
(
"shape"
,
var2shape
[
temp_varname_map
[
output_vars
[
0
]]])
new_op_desc
.
_rename_output
(
output_vars
[
0
],
temp_varname_map
[
output_vars
[
0
]]
)
new_op_desc
.
_set_attr
(
"shape"
,
var2shape
[
temp_varname_map
[
output_vars
[
0
]]]
)
target_block
.
_sync_with_cpp
()
# set distribute atrribute
...
...
@@ -142,14 +162,17 @@ class Partitioner(object):
assert
new_op
.
type
==
new_op_desc
.
type
()
assert
new_op
.
desc
==
new_op_desc
output_var
=
target_block
.
var
(
output_vars
[
0
])
output_var_attr
=
self
.
_dist_context
.
get_tensor_dist_attr_for_program
(
output_var
)
output_var_attr
=
(
self
.
_dist_context
.
get_tensor_dist_attr_for_program
(
output_var
)
)
op_attr
=
OperatorDistributedAttribute
()
op_attr
.
process_mesh
=
output_var_attr
.
process_mesh
op_attr
.
set_output_dims_mapping
(
output_var
.
name
,
output_var_attr
.
dims_mapping
)
op_attr
.
set_input_dims_mapping
(
output_var
.
name
,
output_var_attr
.
dims_mapping
)
op_attr
.
set_output_dims_mapping
(
output_var
.
name
,
output_var_attr
.
dims_mapping
)
op_attr
.
set_input_dims_mapping
(
output_var
.
name
,
output_var_attr
.
dims_mapping
)
self
.
_dist_context
.
set_op_dist_attr_for_program
(
new_op
,
op_attr
)
return
partitioned_startup_prog
...
...
@@ -160,7 +183,7 @@ class Partitioner(object):
2. replace local op with corresponding dist op
"""
partitioned_main_prog
=
fluid
.
Program
()
partitioned_main_prog
=
Program
()
dist_op_context
=
self
.
_dist_context
.
dist_op_context
dist_op_context
.
dst_main_program
=
partitioned_main_prog
...
...
@@ -171,7 +194,8 @@ class Partitioner(object):
target_block
=
partitioned_main_prog
.
blocks
[
0
]
else
:
target_block
=
partitioned_main_prog
.
_create_block
(
parent_idx
=
ref_block
.
parent_idx
)
parent_idx
=
ref_block
.
parent_idx
)
assert
ref_block
.
idx
==
target_block
.
idx
target_block
.
_set_forward_block_idx
(
ref_block
.
forward_block_idx
)
dist_op_context
.
work_block
=
target_block
...
...
@@ -186,8 +210,9 @@ class Partitioner(object):
for
attr_name
in
op
.
all_attrs
():
if
op
.
attr_type
(
attr_name
)
==
core
.
AttrType
.
BLOCK
:
relative_id
=
op
.
_block_attr_id
(
attr_name
)
op
.
_set_attr
(
attr_name
,
partitioned_main_prog
.
block
(
relative_id
))
op
.
_set_attr
(
attr_name
,
partitioned_main_prog
.
block
(
relative_id
)
)
partitioned_params_and_grads
=
[]
for
p
,
g
in
params_and_grads
:
...
...
@@ -198,7 +223,8 @@ class Partitioner(object):
else
:
assert
g
.
name
in
self
.
_serial2dist_varname_mapping
dist_g
=
self
.
_get_dist_var_by_serial_var
(
g
,
partitioned_main_prog
)
g
,
partitioned_main_prog
)
partitioned_params_and_grads
.
append
((
dist_p
,
dist_g
))
return
partitioned_main_prog
,
partitioned_params_and_grads
...
...
@@ -222,71 +248,116 @@ class Partitioner(object):
for
idx
in
range
(
len
(
serial_ops
)):
if
idx
<=
last_fwd_op_idx
:
forward_op_id2forward_op
[
serial_ops
[
idx
].
desc
.
original_id
()]
=
serial_ops
[
idx
]
serial_ops
[
idx
].
desc
.
original_id
()
]
=
serial_ops
[
idx
]
# partiiton
appended_grad_times
=
0
for
idx
,
op
in
enumerate
(
serial_ops
):
op_dist_attr
=
self
.
_dist_context
.
get_op_dist_attr_for_program
(
op
)
if
is_backward_op
(
op
)
and
(
is_forward_op
(
serial_ops
[
idx
-
1
])
or
is_loss_op
(
serial_ops
[
idx
-
1
])):
if
is_backward_op
(
op
)
and
(
is_forward_op
(
serial_ops
[
idx
-
1
])
or
is_loss_op
(
serial_ops
[
idx
-
1
])
):
if
not
op_dist_attr
.
is_recompute
:
appended_grad_times
+=
1
# partititon input variables
for
serial_input_varname
in
op
.
desc
.
input_arg_names
():
if
serial_input_varname
not
in
self
.
_serial2dist_varname_mapping
:
new_varname
=
serial_input_varname
+
self
.
_dist_varname_suffix
if
(
serial_input_varname
not
in
self
.
_serial2dist_varname_mapping
):
new_varname
=
(
serial_input_varname
+
self
.
_dist_varname_suffix
)
if
ref_block
.
has_var
(
serial_input_varname
):
_partition_var
(
self
.
_dist_context
,
ref_block
,
target_block
,
serial_input_varname
,
new_varname
)
_partition_var
(
self
.
_dist_context
,
ref_block
,
target_block
,
serial_input_varname
,
new_varname
,
)
else
:
for
varname_not_in_block
in
__varname_not_in_block__
:
assert
varname_not_in_block
in
serial_input_varname
,
\
"{} is not found"
.
format
(
serial_input_varname
)
assert
(
varname_not_in_block
in
serial_input_varname
),
"{} is not found"
.
format
(
serial_input_varname
)
self
.
_serial2dist_varname_mapping
[
serial_input_varname
]
=
new_varname
serial_input_varname
]
=
new_varname
# partition output vars
for
serial_output_varname
in
op
.
desc
.
output_arg_names
():
if
serial_output_varname
not
in
self
.
_serial2dist_varname_mapping
:
new_varname
=
serial_output_varname
+
self
.
_dist_varname_suffix
_partition_var
(
self
.
_dist_context
,
ref_block
,
target_block
,
serial_output_varname
,
new_varname
)
if
(
serial_output_varname
not
in
self
.
_serial2dist_varname_mapping
):
new_varname
=
(
serial_output_varname
+
self
.
_dist_varname_suffix
)
_partition_var
(
self
.
_dist_context
,
ref_block
,
target_block
,
serial_output_varname
,
new_varname
,
)
self
.
_serial2dist_varname_mapping
[
serial_output_varname
]
=
new_varname
serial_output_varname
]
=
new_varname
# partition op
if
is_forward_op
(
op
)
or
op_dist_attr
.
is_recompute
:
if
(
is_forward_op
(
op
)
or
op_dist_attr
.
is_recompute
or
is_fillconst_op_for_micro_batch
(
op
)
):
kinputs
,
koutputs
=
dist_op_context
.
prepare_context
(
op
)
dist_op_forward_impl
=
_get_dist_op_forward_implement
(
op
,
self
.
_dist_context
)
dist_op_forward_impl
.
forward
(
self
.
_dist_context
,
**
kinputs
,
**
koutputs
)
op
,
self
.
_dist_context
)
dist_op_forward_impl
.
forward
(
self
.
_dist_context
,
**
kinputs
,
**
koutputs
)
elif
is_backward_op
(
op
):
kinputs
,
koutputs
=
dist_op_context
.
prepare_context
(
op
)
dist_op_backward_impl
=
_get_dist_op_backward_implement
(
op
,
self
.
_dist_context
,
forward_op_id2forward_op
)
grad_var_to_var
=
self
.
_dist_context
.
dist_op_context
.
grad_var_to_var
[
appended_grad_times
]
op
,
self
.
_dist_context
,
forward_op_id2forward_op
)
grad_var_to_var
=
(
self
.
_dist_context
.
dist_op_context
.
grad_var_to_var
[
appended_grad_times
]
)
dist_op_backward_impl
.
backward
(
self
.
_dist_context
,
**
kinputs
,
**
koutputs
,
**
{
"grad_var_to_var"
:
grad_var_to_var
})
self
.
_dist_context
,
**
kinputs
,
**
koutputs
,
**
{
"grad_var_to_var"
:
grad_var_to_var
}
)
elif
is_optimize_op
(
op
):
# NOTE: BACKWARD_ONLY_DIST_OPS's op_role must 2 because of 1F1B PASS
# NOTE: BACKWARD_ONLY_DIST_OPS's op_role must
be
2 because of 1F1B PASS
kinputs
,
koutputs
=
dist_op_context
.
prepare_context
(
op
)
dist_op_opt_impl
=
_get_dist_op_backward_implement
(
op
,
self
.
_dist_context
,
forward_op_id2forward_op
)
dist_op_opt_impl
.
backward
(
self
.
_dist_context
,
**
kinputs
,
**
koutputs
,
**
{
"grad_var_to_var"
:
{}})
op
,
self
.
_dist_context
,
forward_op_id2forward_op
)
dist_op_opt_impl
.
backward
(
self
.
_dist_context
,
**
kinputs
,
**
koutputs
,
**
{
"grad_var_to_var"
:
{}}
)
else
:
raise
NotImplementedError
(
"partitioner only support forward and backward, optimize ops, but got {}"
.
format
(
str
(
op
)))
"partitioner only support forward and backward, optimize ops, but got {}"
.
format
(
str
(
op
)
)
)
def
_is_valid_annotated_program
(
self
,
program
):
...
...
@@ -298,13 +369,16 @@ class Partitioner(object):
]
var_dist_attrs
=
[
self
.
_dist_context
.
get_tensor_dist_attr_for_program
(
var
)
for
var
in
vars_
if
(
var
.
type
not
in
__not_shape_var_type__
)
for
var
in
vars_
if
(
var
.
type
not
in
__not_shape_var_type__
)
]
all_ops_annotated
=
all
(
dist_attr
is
not
None
for
dist_attr
in
op_dist_attrs
)
all_vars_annotated
=
all
(
dist_attr
is
not
None
for
dist_attr
in
var_dist_attrs
)
all_ops_annotated
=
all
(
dist_attr
is
not
None
for
dist_attr
in
op_dist_attrs
)
all_vars_annotated
=
all
(
dist_attr
is
not
None
for
dist_attr
in
var_dist_attrs
)
return
all_ops_annotated
and
all_vars_annotated
...
...
@@ -328,22 +402,26 @@ def _get_dist_shape(var, dist_attr):
assert
len
(
var_shape
)
==
len
(
mapping
),
"variable shape [{}] and dim_mapping [{}] is NOT match !"
.
format
(
var_shape
,
mapping
)
var_shape
,
mapping
)
new_shape
=
[]
for
idx
in
range
(
len
(
var_shape
)):
if
var_shape
[
idx
]
==
-
1
or
mapping
[
idx
]
==
-
1
:
new_shape
.
append
(
var_shape
[
idx
])
else
:
assert
var_shape
[
idx
]
%
mesh
[
mapping
[
idx
]]
==
0
,
"un-event partition: var_shape[idx]=[{}], mesh[{}]"
.
format
(
var_shape
[
idx
],
mesh
[
mapping
[
idx
]])
assert
(
var_shape
[
idx
]
%
mesh
[
mapping
[
idx
]]
==
0
),
"un-event partition: var_shape[idx]=[{}], mesh[{}]"
.
format
(
var_shape
[
idx
],
mesh
[
mapping
[
idx
]]
)
new_shape
.
append
(
var_shape
[
idx
]
//
mesh
[
mapping
[
idx
]])
return
new_shape
def
_partition_parameter
(
dist_context
,
src_var
,
dst_block
,
dst_varname
,
dst_shape
):
def
_partition_parameter
(
dist_context
,
src_var
,
dst_block
,
dst_varname
,
dst_shape
):
# NOTE hack to copied Parameter
# not initialized parameter, need to initialize it
copied_kwargs
=
{}
...
...
@@ -353,39 +431,45 @@ def _partition_parameter(dist_context, src_var, dst_block, dst_varname,
copied_kwargs
[
'do_model_average'
]
=
src_var
.
do_model_average
copied_kwargs
[
'need_clip'
]
=
src_var
.
need_clip
param
=
Parameter
(
block
=
dst_block
,
type
=
src_var
.
type
,
name
=
dst_varname
,
shape
=
dst_shape
,
dtype
=
src_var
.
dtype
,
lod_level
=
src_var
.
lod_level
,
error_clip
=
src_var
.
error_clip
,
stop_gradient
=
src_var
.
stop_gradient
,
is_data
=
src_var
.
is_data
,
belong_to_optimizer
=
src_var
.
belong_to_optimizer
,
**
copied_kwargs
)
param
=
Parameter
(
block
=
dst_block
,
type
=
src_var
.
type
,
name
=
dst_varname
,
shape
=
dst_shape
,
dtype
=
src_var
.
dtype
,
lod_level
=
src_var
.
lod_level
,
error_clip
=
src_var
.
error_clip
,
stop_gradient
=
src_var
.
stop_gradient
,
is_data
=
src_var
.
is_data
,
belong_to_optimizer
=
src_var
.
belong_to_optimizer
,
**
copied_kwargs
)
return
param
def
_partition_intermediate_var
(
dist_context
,
src_var
,
dst_block
,
dst_varname
,
dst_shape
):
var
=
dst_block
.
create_var
(
type
=
src_var
.
type
,
name
=
dst_varname
,
shape
=
dst_shape
,
dtype
=
src_var
.
dtype
,
lod_level
=
src_var
.
lod_level
,
persistable
=
src_var
.
persistable
,
error_clip
=
src_var
.
error_clip
,
stop_gradient
=
src_var
.
stop_gradient
,
is_data
=
src_var
.
is_data
,
belong_to_optimizer
=
src_var
.
belong_to_optimizer
)
def
_partition_intermediate_var
(
dist_context
,
src_var
,
dst_block
,
dst_varname
,
dst_shape
):
var
=
dst_block
.
create_var
(
type
=
src_var
.
type
,
name
=
dst_varname
,
shape
=
dst_shape
,
dtype
=
src_var
.
dtype
,
lod_level
=
src_var
.
lod_level
,
persistable
=
src_var
.
persistable
,
error_clip
=
src_var
.
error_clip
,
stop_gradient
=
src_var
.
stop_gradient
,
is_data
=
src_var
.
is_data
,
belong_to_optimizer
=
src_var
.
belong_to_optimizer
,
)
return
var
def
_partition_var
(
dist_context
,
src_block
,
dst_block
,
src_varname
,
dst_varname
):
def
_partition_var
(
dist_context
,
src_block
,
dst_block
,
src_varname
,
dst_varname
):
"""
partition include: split + replicate
"""
...
...
@@ -393,44 +477,53 @@ def _partition_var(dist_context, src_block, dst_block, src_varname,
if
src_var
.
type
in
__not_shape_var_type__
:
persist
=
getattr
(
src_var
,
'persistable'
,
False
)
new_var
=
dst_block
.
create_var
(
type
=
src_var
.
type
,
name
=
dst_varname
,
persistable
=
persist
,
stop_gradient
=
True
)
new_var
=
dst_block
.
create_var
(
type
=
src_var
.
type
,
name
=
dst_varname
,
persistable
=
persist
,
stop_gradient
=
True
,
)
target_shape
=
None
else
:
dist_attr
=
dist_context
.
get_tensor_dist_attr_for_program
(
src_var
)
target_shape
=
_get_dist_shape
(
src_var
,
dist_attr
)
if
isinstance
(
src_var
,
Parameter
):
new_var
=
_partition_parameter
(
dist_context
,
src_var
,
dst_block
,
dst_varname
,
target_shape
)
new_var
=
_partition_parameter
(
dist_context
,
src_var
,
dst_block
,
dst_varname
,
target_shape
)
else
:
new_var
=
_partition_intermediate_var
(
dist_context
,
src_var
,
dst_block
,
dst_varname
,
target_shape
)
new_var
=
_partition_intermediate_var
(
dist_context
,
src_var
,
dst_block
,
dst_varname
,
target_shape
)
dist_attr
=
copy
.
deepcopy
(
dist_context
.
get_tensor_dist_attr_for_program
(
src_var
))
dist_context
.
get_tensor_dist_attr_for_program
(
src_var
)
)
assert
dist_attr
is
not
None
dist_context
.
set_tensor_dist_attr_for_program
(
new_var
,
dist_attr
)
return
target_shape
def
_get_dist_op_backward_implement
(
backward_op
,
dist_context
,
forward_op_id2forward_op
):
def
_get_dist_op_backward_implement
(
backward_op
,
dist_context
,
forward_op_id2forward_op
):
dist_op_context
=
dist_context
.
dist_op_context
if
backward_op
.
desc
.
original_id
()
in
dist_op_context
.
grad_op_id_to_op_id
:
forward_op_id
=
dist_op_context
.
grad_op_id_to_op_id
[
backward_op
.
desc
.
original_id
()]
backward_op
.
desc
.
original_id
()
]
forward_op
=
forward_op_id2forward_op
[
forward_op_id
]
forward_op_dist_attr
=
dist_context
.
get_op_dist_attr_for_program
(
forward_op
)
forward_op
)
dist_op_impl_container
=
get_distributed_operator_impl_container
(
forward_op_dist_attr
.
impl_type
)
forward_op_dist_attr
.
impl_type
)
dist_op_impl
=
dist_op_impl_container
.
get_impl
(
forward_op_dist_attr
.
impl_idx
)
forward_op_dist_attr
.
impl_idx
)
return
dist_op_impl
# # NOTE trick for dist ops that only have backward implement
...
...
@@ -438,7 +531,8 @@ def _get_dist_op_backward_implement(backward_op, dist_context,
op_dist_attr
=
dist_context
.
get_op_dist_attr_for_program
(
backward_op
)
assert
op_dist_attr
.
impl_idx
>=
0
dist_op_impl
=
get_distributed_operator_impl_container
(
op_dist_attr
.
impl_type
).
get_impl
(
op_dist_attr
.
impl_idx
)
op_dist_attr
.
impl_type
).
get_impl
(
op_dist_attr
.
impl_idx
)
return
dist_op_impl
dist_op
=
get_distributed_operator_impl_container
(
"default"
)
...
...
@@ -448,6 +542,7 @@ def _get_dist_op_backward_implement(backward_op, dist_context,
def
_get_dist_op_forward_implement
(
forward_op
,
dist_context
):
dist_attr
=
dist_context
.
get_op_dist_attr_for_program
(
forward_op
)
dist_op_impl_container
=
get_distributed_operator_impl_container
(
dist_attr
.
impl_type
)
dist_attr
.
impl_type
)
dist_op_impl
=
dist_op_impl_container
.
get_impl
(
dist_attr
.
impl_idx
)
return
dist_op_impl
python/paddle/distributed/auto_parallel/reshard.py
浏览文件 @
c47853f6
...
...
@@ -422,11 +422,11 @@ class Inserter:
)
inputs
=
{
'X'
:
[
tensor
]}
outputs
=
{
"Out"
:
[
out
]}
attrs
=
{
"in_place"
:
False
}
slice
_op
=
block
.
_insert_op
(
attrs
=
{
"in_place"
:
False
,
"op_role"
:
op_role
}
assign
_op
=
block
.
_insert_op
(
idx
,
type
=
"assign"
,
inputs
=
inputs
,
outputs
=
outputs
,
attrs
=
attrs
)
slice
_op
.
_set_attr
(
'op_namescope'
,
"/auto_parallel/reshard"
)
assign
_op
.
_set_attr
(
'op_namescope'
,
"/auto_parallel/reshard"
)
return
out
# use split once
...
...
@@ -1217,6 +1217,8 @@ class Resharder:
shape_x
[
0
]
<=
shape_y
[
0
]
<
shape_x
[
1
]
):
overlapped
=
True
if
shape_x
==
[
0
,
0
]
and
shape_y
==
[
0
,
0
]:
overlapped
=
True
return
overlapped
def
is_unshard
(
self
,
dims_mapping
):
...
...
@@ -1304,6 +1306,14 @@ class Resharder:
# judge whether need reshard by process_mesh
if
tensor_process_mesh
!=
op_process_mesh
:
is_reshard
=
True
# not reshard data in send/recv scene
if
(
tensor_process_mesh
!=
op_process_mesh
and
len
(
tensor_process_mesh
.
process_ids
)
==
len
(
op_process_mesh
.
process_ids
)
and
dist_tensor
.
serial_tensor
.
is_data
):
is_reshard
=
False
else
:
op_output_dims_mapping
=
dist_attr
[
1
]
if
all
(
...
...
@@ -1585,10 +1595,10 @@ class Resharder:
if
i
==
0
:
all_partition_index_list
.
append
(
process_index
[
j
][
1
])
for
process
in
group
:
# append slice op desc
slice_starts
=
[]
slice_ends
=
[]
slices_axes
=
[]
min_comm_group
=
copy
.
deepcopy
(
group
)
all_partition_index_list_copied
=
copy
.
deepcopy
(
all_partition_index_list
)
target_partition_index
=
Resharder
.
compute_partition_index
(
process
,
complete_shape
,
...
...
@@ -1596,12 +1606,56 @@ class Resharder:
target_process_shape
,
target_process_group
,
)
for
idx
,
item
in
enumerate
(
target_partition_index
):
slice_starts
.
append
(
item
[
0
])
slice_ends
.
append
(
item
[
1
])
for
_process
in
group
:
source_partition_index
=
(
Resharder
.
compute_partition_index
(
_process
,
complete_shape
,
source_dims_mapping
,
source_process_shape
,
source_process_group
,
)
)
if
not
all
(
_
for
_
in
list
(
map
(
self
.
is_overlapped
,
source_partition_index
,
target_partition_index
,
)
)
):
min_comm_group
.
remove
(
_process
)
all_partition_index_list_copied
.
remove
(
source_partition_index
)
concatenated_partition_index_list
=
[]
for
partition_index
in
all_partition_index_list_copied
:
Resharder
.
concat_partitions
(
concatenated_partition_index_list
,
partition_index
)
concatenated_partition_index
=
(
concatenated_partition_index_list
[
0
]
)
slice_starts
=
[]
slice_ends
=
[]
slices_axes
=
[]
to_slice_tensor_shape
=
[]
for
idx
,
item
in
enumerate
(
concatenated_partition_index
):
slice_starts
.
append
(
target_partition_index
[
idx
][
0
]
-
item
[
0
]
)
slice_ends
.
append
(
target_partition_index
[
idx
][
1
]
-
item
[
0
]
)
slices_axes
.
append
(
idx
)
to_slice_tensor_shape
.
append
(
item
[
1
]
-
item
[
0
])
to_slice_tensor_shape
=
dist_tensor
.
global_sizes
()
slice_op_desc
=
SliceOpDesc
(
starts
=
slice_starts
,
ends
=
slice_ends
,
...
...
@@ -1616,16 +1670,16 @@ class Resharder:
op_desc_seq
[
process
]
=
(
[
AllGatherOpDesc
(
group
=
group
,
group
=
min_comm_
group
,
shape
=
allgather_shape
,
is_bool
=
(
source_tensor
.
dtype
==
paddle
.
bool
),
),
ConcatOpDesc
(
partition_index_list
=
all_partition_index_list
partition_index_list
=
all_partition_index_list
_copied
),
slice_op_desc
,
]
if
len
(
group
)
>
1
if
len
(
min_comm_
group
)
>
1
else
[
slice_op_desc
]
)
...
...
python/paddle/distributed/auto_parallel/strategy.py
浏览文件 @
c47853f6
...
...
@@ -123,6 +123,12 @@ class DatasetConfig(BaseConfig):
super
(
DatasetConfig
,
self
).
__init__
(
category
,
config_dict
)
class
DPOptimizationConfig
(
BaseConfig
):
def
__init__
(
self
,
config_dict
=
None
):
category
=
constants
.
DP_OPTIMIZATION
super
(
DPOptimizationConfig
,
self
).
__init__
(
category
,
config_dict
)
class
Strategy
(
BaseConfig
):
"""
The `Strategy` object is used to configure the paralleization and optimization beheviors.
...
...
@@ -194,3 +200,6 @@ class Strategy(BaseConfig):
config_dict
=
self
.
_config_dict
.
get
(
constants
.
DATASET
,
None
)
self
.
dataset
=
DatasetConfig
(
config_dict
)
config_dict
=
self
.
_config_dict
.
get
(
constants
.
DP_OPTIMIZATION
,
None
)
self
.
dp_optimization
=
DPOptimizationConfig
(
config_dict
)
python/paddle/distributed/auto_parallel/utils.py
浏览文件 @
c47853f6
...
...
@@ -1252,6 +1252,7 @@ def set_grad_var_shape(program, dist_context):
"fused_softmax_mask_upper_triangle_grad"
,
"flatten_contiguous_range_grad"
,
"relu_grad"
,
"exp_grad"
,
]
forward_list
=
[
"reshape2"
,
...
...
@@ -1270,6 +1271,7 @@ def set_grad_var_shape(program, dist_context):
"fused_softmax_mask_upper_triangle"
,
"flatten_contiguous_range"
,
"relu"
,
"exp"
,
]
if
op
.
type
in
need_set_shape_list
:
for
forward_op
in
block
.
ops
:
...
...
@@ -1320,6 +1322,11 @@ def is_forward_op(op):
)
def
is_fillconst_op_for_micro_batch
(
op
):
op_role
=
int
(
op
.
attr
(
'op_role'
))
return
OP_ROLE_KEY
in
op
.
attr_names
and
(
op_role
==
int
(
OpRole
.
LRSched
))
def
is_backward_op
(
op
):
return
OP_ROLE_KEY
in
op
.
attr_names
and
int
(
op
.
all_attrs
()[
OP_ROLE_KEY
]
...
...
python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py
浏览文件 @
c47853f6
...
...
@@ -18,15 +18,31 @@ import numpy as np
import
paddle
from
paddle.fluid
import
core
,
unique_name
from
paddle.fluid.framework
import
default_main_program
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
,
OP_ROLE_KEY
,
OP_ROLE_VAR_KEY
from
paddle.distributed.auto_parallel.operators.common
import
is_data_parallel_scale_op
,
is_data_parallel_reduce_op
from
paddle.distributed.auto_parallel.utils
import
is_loss_grad_op
,
is_optimize_op
,
is_backward_op
,
ring_id_to_process_group
,
find_higher_order_backward_op
from
paddle.distributed.fleet.meta_optimizers.common
import
(
OpRole
,
OP_ROLE_KEY
,
OP_ROLE_VAR_KEY
,
)
from
paddle.distributed.auto_parallel.operators.common
import
(
is_data_parallel_scale_op
,
is_data_parallel_reduce_op
,
)
from
paddle.distributed.auto_parallel.utils
import
(
is_loss_grad_op
,
is_optimize_op
,
is_backward_op
,
ring_id_to_process_group
,
find_higher_order_backward_op
,
)
from
.pass_base
import
PassBase
,
PassType
,
register_pass
# add new optimizers supporting rescale_grad here
__rescale_grad_supported_opts__
=
[
'lars_momentum'
,
'sparse_momentum'
,
'dgc_momentum'
,
'momentum'
,
'merge_momentum'
'lars_momentum'
,
'sparse_momentum'
,
'dgc_momentum'
,
'momentum'
,
'merge_momentum'
,
]
# a heuristic number
...
...
@@ -41,7 +57,7 @@ def numel(var):
class
DataParallelOptimizationPass
(
PassBase
):
"""
Apply Optimizations that specialized for data parallelism in Auto Parallel.
1. prune grad scaling
1. prune grad scaling
2. overlap comm and calc
3. fuse allreduce
"""
...
...
@@ -52,6 +68,9 @@ class DataParallelOptimizationPass(PassBase):
self
.
set_attr
(
"dist_context"
,
None
)
self
.
set_attr
(
"global_rank"
,
-
1
)
self
.
set_attr
(
"use_sharding"
,
False
)
self
.
set_attr
(
"fuse_all_reduce_ops"
,
False
)
self
.
set_attr
(
"fuse_grad_size_in_MB"
,
32
)
self
.
set_attr
(
"overlap_comm_cacl"
,
False
)
# {grad1: group1, grad2: group1, grad3: group2}
# record the order for fuse grad data memory
self
.
_grad_name_to_group_map
=
OrderedDict
()
...
...
@@ -62,8 +81,9 @@ class DataParallelOptimizationPass(PassBase):
def
_check_self
(
self
):
if
self
.
get_attr
(
"dist_context"
)
is
None
:
return
False
if
(
not
isinstance
(
self
.
get_attr
(
"global_rank"
),
int
))
or
self
.
get_attr
(
"global_rank"
)
<
0
:
if
(
not
isinstance
(
self
.
get_attr
(
"global_rank"
),
int
))
or
self
.
get_attr
(
"global_rank"
)
<
0
:
return
False
return
True
...
...
@@ -80,13 +100,18 @@ class DataParallelOptimizationPass(PassBase):
self
.
global_rank
=
int
(
self
.
get_attr
(
"global_rank"
))
self
.
use_sharding
=
self
.
get_attr
(
"use_sharding"
)
overlap_comm_cacl
=
self
.
get_attr
(
"overlap_comm_cacl"
)
fuse_all_reduce_ops
=
self
.
get_attr
(
"fuse_all_reduce_ops"
)
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
self
.
_analyze_program
()
if
self
.
is_data_parallel_applied
():
self
.
_prune_grad_scaling
()
self
.
_calc_comm_overlap
()
grad_group
=
self
.
_fuse_allreduce
()
if
overlap_comm_cacl
:
self
.
_prune_grad_scaling
()
self
.
_calc_comm_overlap
()
if
fuse_all_reduce_ops
:
grad_group
=
self
.
_fuse_allreduce
()
# self.summary(grad_group)
...
...
@@ -140,8 +165,11 @@ class DataParallelOptimizationPass(PassBase):
),
"Unexception: comm op [{}] has NOT ring id."
.
format
(
str
(
op
))
group
=
ring_id_to_process_group
(
op
.
attr
(
"ring_id"
))
assert
group
is
not
None
,
"Unexception: data parallel group of [{}] from op [{}] is None"
.
format
(
grad_name
,
str
(
op
))
assert
(
group
is
not
None
),
"Unexception: data parallel group of [{}] from op [{}] is None"
.
format
(
grad_name
,
str
(
op
)
)
self
.
_grad_name_to_group_map
[
grad_name
]
=
group
...
...
@@ -156,18 +184,21 @@ class DataParallelOptimizationPass(PassBase):
# TODO support multiple optimizers in on network in future.
# here we assume that the optimizer is unique in network.
elif
is_optimize_op
(
op
)
and
op
.
type
in
__rescale_grad_supported_opts__
:
elif
(
is_optimize_op
(
op
)
and
op
.
type
in
__rescale_grad_supported_opts__
):
self
.
_support_rescale_grad
=
True
not_synchronized_grads
=
[]
for
grad_name
in
scaled_grads
:
if
grad_name
not
in
self
.
_grad_name_to_group_map
:
not_synchronized_grads
.
append
(
grad_name
)
assert
len
(
assert
(
len
(
not_synchronized_grads
)
==
0
),
"Unexception: gradients [{}] is scaled BUT NOT synchronized."
.
format
(
not_synchronized_grads
)
==
0
,
"Unexception: gradients [{}] is scaled BUT NOT synchronized."
.
format
(
not_synchronized_grads
)
)
def
is_data_parallel_applied
(
self
):
return
len
(
self
.
_group_to_grad_name_map
)
>
0
...
...
@@ -175,14 +206,21 @@ class DataParallelOptimizationPass(PassBase):
def
_could_be_prune
(
self
):
return
self
.
dist_context
.
gradient_scale
and
(
self
.
_support_rescale_grad
or
self
.
_all_dp_groups_same_degree
())
self
.
_support_rescale_grad
or
self
.
_all_dp_groups_same_degree
()
)
def
_all_dp_groups_same_degree
(
self
):
return
len
(
set
([
len
(
group
.
ranks
)
for
group
in
self
.
_group_to_grad_name_map
.
keys
()
]))
==
1
return
(
len
(
set
(
[
len
(
group
.
ranks
)
for
group
in
self
.
_group_to_grad_name_map
.
keys
()
]
)
)
==
1
)
def
_scale_backward_initial_grad
(
self
):
...
...
@@ -191,9 +229,10 @@ class DataParallelOptimizationPass(PassBase):
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
is_loss_grad_op
(
op
):
assert
op
.
type
==
'fill_constant'
,
\
"loss_grad_op must be fill_constant op, "
\
assert
op
.
type
==
'fill_constant'
,
(
"loss_grad_op must be fill_constant op, "
"but this op is {}"
.
format
(
op
.
type
)
)
assert
op
.
has_attr
(
'value'
)
loss_scale
=
float
(
op
.
attr
(
'value'
))
loss_scale
=
loss_scale
/
dp_degree
...
...
@@ -215,28 +254,35 @@ class DataParallelOptimizationPass(PassBase):
scaled_grads
=
set
()
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
is_optimize_op
(
op
)
and
op
.
type
in
__rescale_grad_supported_opts__
:
if
(
is_optimize_op
(
op
)
and
op
.
type
in
__rescale_grad_supported_opts__
):
assert
op
.
has_attr
(
'rescale_grad'
),
"Unexception: op [{}] is supported to have [rescale_grad] attribute."
.
format
(
str
(
op
))
assert
len
(
op
.
input
(
"Grad"
)
)
==
1
,
"Unexception: op [{}] is supported to have only one input grad var."
.
format
(
str
(
op
))
str
(
op
)
)
assert
(
len
(
op
.
input
(
"Grad"
))
==
1
),
"Unexception: op [{}] is supported to have only one input grad var."
.
format
(
str
(
op
)
)
grad_name
=
op
.
input
(
"Grad"
)[
0
]
dp_degree
=
len
(
list
(
self
.
_grad_name_to_group_map
[
grad_name
].
ranks
))
list
(
self
.
_grad_name_to_group_map
[
grad_name
].
ranks
)
)
scaled_grads
.
add
(
grad_name
)
rescale_grad
=
float
(
op
.
attr
(
'rescale_grad'
))
/
dp_degree
op
.
_set_attr
(
'rescale_grad'
,
rescale_grad
)
assert
scaled_grads
==
set
(
self
.
_grad_name_to_group_map
.
keys
(
)),
"Unexception: gradients [{}] are unscaled."
.
format
(
set
(
self
.
_grad_name_to_group_map
.
keys
())
-
scaled_grads
)
assert
scaled_grads
==
set
(
self
.
_grad_name_to_group_map
.
keys
()
),
"Unexception: gradients [{}] are unscaled."
.
format
(
set
(
self
.
_grad_name_to_group_map
.
keys
())
-
scaled_grads
)
def
_could_be_overlap
(
self
):
# NOTE current different nccl comm will use different cuda stream
...
...
@@ -266,14 +312,13 @@ class DataParallelOptimizationPass(PassBase):
op
.
_set_attr
(
'use_calc_stream'
,
False
)
ring_id
=
op
.
attr
(
"ring_id"
)
block
.
_insert_op_without_sync
(
idx
,
type
=
'c_wait_compute'
,
inputs
=
{
'X'
:
[]},
outputs
=
{
'Out'
:
[]},
attrs
=
{
'op_role'
:
OpRole
.
Backward
,
'ring_id'
:
ring_id
})
block
.
_insert_op_without_sync
(
idx
,
type
=
'c_wait_compute'
,
inputs
=
{
'X'
:
[]},
outputs
=
{
'Out'
:
[]},
attrs
=
{
'op_role'
:
OpRole
.
Backward
,
'ring_id'
:
ring_id
},
)
block
.
_sync_with_cpp
()
...
...
@@ -307,8 +352,10 @@ class DataParallelOptimizationPass(PassBase):
# other ops that might use communicating grad
else
:
for
input_var_name
in
op
.
input_arg_names
:
for
ring_id
,
unsync_grad_names
in
ring_id_to_un_sync_grad_map
.
items
(
):
for
(
ring_id
,
unsync_grad_names
,
)
in
ring_id_to_un_sync_grad_map
.
items
():
if
input_var_name
in
unsync_grad_names
:
# need to sync before op_i
if
i
in
op_idx_to_sync_ring_id_map
:
...
...
@@ -328,14 +375,13 @@ class DataParallelOptimizationPass(PassBase):
for
i
in
sorted
(
indices
,
reverse
=
True
):
for
ring_id
in
op_idx_to_sync_ring_id_map
[
i
]:
block
.
_insert_op_without_sync
(
i
,
type
=
'c_wait_comm'
,
inputs
=
{
'X'
:
[]},
outputs
=
{
'Out'
:
[]},
attrs
=
{
'op_role'
:
OpRole
.
Backward
,
'ring_id'
:
ring_id
})
block
.
_insert_op_without_sync
(
i
,
type
=
'c_wait_comm'
,
inputs
=
{
'X'
:
[]},
outputs
=
{
'Out'
:
[]},
attrs
=
{
'op_role'
:
OpRole
.
Backward
,
'ring_id'
:
ring_id
},
)
def
_could_be_fuse
(
self
):
# TODO support gradient fuse higher order gradient.
...
...
@@ -350,9 +396,9 @@ class DataParallelOptimizationPass(PassBase):
"""
conditions for gradients to be grouped:
1. group size < max_fuse_numel
2. same dp group
2. same dp group
3. same dtype
4. dependency: grad would NOT be used by other ops within group segment
4. dependency: grad would NOT be used by other ops within group segment
gradients inside same group would be fuse into one coalesce tensor
"""
...
...
@@ -423,36 +469,51 @@ class DataParallelOptimizationPass(PassBase):
for
i
,
group
in
enumerate
(
grad_groups
[::
-
1
]):
# create coalecse tensor
group
.
coalesce_var
=
block
.
create_var
(
name
=
unique_name
.
generate
(
'coalecse_grad_{}'
.
format
(
i
)),
dtype
=
group
.
dtype
,
persistable
=
False
,
stop_gradient
=
True
)
group
.
coalesce_var
=
block
.
create_var
(
name
=
unique_name
.
generate
(
'coalecse_grad_{}'
.
format
(
i
)),
dtype
=
group
.
dtype
,
persistable
=
False
,
stop_gradient
=
True
,
)
# update allreduce & scale op
if
group
.
scale_op_idx
!=
-
1
:
scale_op
=
block
.
ops
[
group
.
scale_op_idx
]
assert
scale_op
.
type
==
'scale'
,
"should found scale op but found {}"
.
format
(
str
(
scale_op
))
scale_op
.
_rename_input
(
scale_op
.
input_arg_names
[
0
],
group
.
coalesce_var
.
name
)
scale_op
.
_rename_output
(
scale_op
.
output_arg_names
[
0
],
group
.
coalesce_var
.
name
)
assert
(
scale_op
.
type
==
'scale'
),
"should found scale op but found {}"
.
format
(
str
(
scale_op
))
scale_op
.
_rename_input
(
scale_op
.
input_arg_names
[
0
],
group
.
coalesce_var
.
name
)
scale_op
.
_rename_output
(
scale_op
.
output_arg_names
[
0
],
group
.
coalesce_var
.
name
)
allreduce_op
=
block
.
ops
[
group
.
allreduce_op_idx
]
assert
allreduce_op
.
type
==
'c_allreduce_sum'
,
"should found c_allreduce_sum op but found {}"
.
format
(
str
(
allreduce_op
))
allreduce_op
.
_rename_input
(
allreduce_op
.
input_arg_names
[
0
],
group
.
coalesce_var
.
name
)
allreduce_op
.
_rename_output
(
allreduce_op
.
output_arg_names
[
0
],
group
.
coalesce_var
.
name
)
assert
(
allreduce_op
.
type
==
'c_allreduce_sum'
),
"should found c_allreduce_sum op but found {}"
.
format
(
str
(
allreduce_op
)
)
allreduce_op
.
_rename_input
(
allreduce_op
.
input_arg_names
[
0
],
group
.
coalesce_var
.
name
)
allreduce_op
.
_rename_output
(
allreduce_op
.
output_arg_names
[
0
],
group
.
coalesce_var
.
name
)
# remvoe un-used op
remove_op_indices
=
group
.
remove_wait_op_indices
+
group
.
remove_allreduce_op_indices
+
group
.
remove_scale_op_indices
remove_op_indices
=
(
group
.
remove_wait_op_indices
+
group
.
remove_allreduce_op_indices
+
group
.
remove_scale_op_indices
)
for
idx
in
sorted
(
remove_op_indices
,
reverse
=
True
):
assert
block
.
ops
[
idx
].
type
in
remove_op_types
,
"Unexception: try to remove op {}"
.
format
(
str
(
op
))
assert
(
block
.
ops
[
idx
].
type
in
remove_op_types
),
"Unexception: try to remove op {}"
.
format
(
str
(
block
.
ops
[
idx
].
type
())
)
block
.
_remove_op
(
idx
)
# insert coalecse op
...
...
@@ -464,22 +525,23 @@ class DataParallelOptimizationPass(PassBase):
concated_ranks
.
append
(
len
(
shape
))
grad_names
=
[
grad
.
name
for
grad
in
group
.
gradients
]
block
.
_insert_op_without_sync
(
group
.
coalesce_op_idx
,
type
=
"coalesce_tensor"
,
inputs
=
{
"Input"
:
grad_names
},
outputs
=
{
"Output"
:
grad_names
,
"FusedOutput"
:
group
.
coalesce_var
},
attrs
=
{
"copy_data"
:
False
,
"use_align"
:
True
,
"dtype"
:
group
.
dtype
,
"concated_shapes"
:
concated_shapes
,
"concated_ranks"
:
concated_ranks
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
block
.
_insert_op_without_sync
(
group
.
coalesce_op_idx
,
type
=
"coalesce_tensor"
,
inputs
=
{
"Input"
:
grad_names
},
outputs
=
{
"Output"
:
grad_names
,
"FusedOutput"
:
group
.
coalesce_var
,
},
attrs
=
{
"copy_data"
:
False
,
"use_align"
:
True
,
"dtype"
:
group
.
dtype
,
"concated_shapes"
:
concated_shapes
,
"concated_ranks"
:
concated_ranks
,
OP_ROLE_KEY
:
OpRole
.
Backward
,
},
)
block
.
_sync_with_cpp
()
# TODO update dist attr
...
...
@@ -487,6 +549,7 @@ class DataParallelOptimizationPass(PassBase):
def
summary
(
self
,
grad_groups
=
[]):
# TODO: add logger module
import
logging
self
.
_logger
=
logging
.
getLogger
()
self
.
_logger
.
propagate
=
False
if
not
self
.
_logger
.
handlers
:
...
...
@@ -500,26 +563,31 @@ class DataParallelOptimizationPass(PassBase):
if
len
(
grad_groups
)
>
0
:
self
.
_logger
.
info
(
"origin {} allreduce ops are fused into {} coalecse allreduce ops."
.
format
(
len
(
self
.
_grad_name_to_group_map
.
keys
()),
len
(
grad_groups
)))
"origin {} allreduce ops are fused into {} coalecse allreduce ops."
.
format
(
len
(
self
.
_grad_name_to_group_map
.
keys
()),
len
(
grad_groups
)
)
)
self
.
_logger
.
info
(
"gradient fusing group are following: "
)
fused_grads
=
set
()
for
i
,
group
in
enumerate
(
grad_groups
):
self
.
_logger
.
info
(
"coalecse gradient [{}] is composed by: {}"
.
format
(
i
,
[
grad
.
name
for
grad
in
group
.
gradients
]))
i
,
[
grad
.
name
for
grad
in
group
.
gradients
]
)
)
fused_grads
.
update
([
grad
.
name
for
grad
in
group
.
gradients
])
individual_grads
=
set
(
self
.
_grad_name_to_group_map
.
keys
())
-
set
(
fused_grads
)
individual_grads
=
set
(
self
.
_grad_name_to_group_map
.
keys
())
-
set
(
fused_grads
)
self
.
_logger
.
info
(
"the following [{}] gradients are not fused: "
.
format
(
len
(
individual_grads
)))
len
(
individual_grads
)
)
)
self
.
_logger
.
info
(
"individual gradient {}"
.
format
(
individual_grads
))
class
GradientsGroup
(
object
):
def
__init__
(
self
,
ops
,
max_group_size
):
self
.
max_group_size
=
max_group_size
self
.
ops
=
ops
...
...
@@ -575,8 +643,11 @@ class GradientsGroup(object):
grad_op_idx
-=
1
grad_op
=
self
.
ops
[
grad_op_idx
]
assert
grad_var
.
name
in
grad_op
.
output_arg_names
,
"grad [{}] should be output of {}"
.
format
(
grad_var
.
name
,
str
(
grad_op
))
assert
(
grad_var
.
name
in
grad_op
.
output_arg_names
),
"grad [{}] should be output of {}"
.
format
(
grad_var
.
name
,
str
(
grad_op
)
)
self
.
coalesce_op_idx
=
grad_op_idx
def
finalize
(
self
):
...
...
python/paddle/distributed/passes/auto_parallel_gradient_merge.py
浏览文件 @
c47853f6
...
...
@@ -12,23 +12,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
from
collections
import
OrderedDict
from
typing
import
List
,
Tuple
,
Dict
,
Any
import
paddle
from
paddle.framework
import
core
from
paddle.fluid
import
layers
from
paddle.fluid.framework
import
program_guard
,
device_guard
from
paddle.distributed.fleet.meta_optimizers.common
import
(
OpRole
,
OP_ROLE_KEY
,
OP_ROLE_VAR_KEY
,
)
from
paddle.distributed.auto_parallel.utils
import
(
set_var_dist_attr
,
is_optimize_op
,
is_backward_op
,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
,
)
from
paddle.distributed.auto_parallel.process_group
import
(
get_world_process_group
,
)
from
paddle.distributed.auto_parallel.operators.common
import
(
is_data_parallel_reduce_op
,
is_data_parallel_scale_op
,
)
from
.pass_base
import
PassBase
,
PassType
,
register_pass
from
paddle.distributed.auto_parallel.utils
import
set_var_dist_attr
,
is_optimize_op
,
OpRole
,
OP_ROLE_KEY
from
paddle.distributed.auto_parallel.utils
import
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
from
paddle.distributed.auto_parallel.process_group
import
get_world_process_group
world_process_group
=
get_world_process_group
()
def
_remove_and_get_optimizer_op
(
main_program
,
dist_context
):
def
is_gradient_clip_op
(
op_desc
):
return
op_desc
.
has_attr
(
"op_namescope"
)
and
op_desc
.
attr
(
"op_namescope"
).
startswith
(
"/gradient_clip"
)
def
_remove_and_get_ops
(
main_program
,
dist_context
):
# 1 create tmp block
# 2 mv optimizer op from global program to tmp block
# 3 del the op from dist_context
...
...
@@ -36,101 +53,119 @@ def _remove_and_get_optimizer_op(main_program, dist_context):
temp_block
=
main_program
.
_create_block
()
removed_op_idx
=
[]
optimize_ops_desc
=
[]
allreduce_sum_desc
=
[]
for
idx
,
op
in
enumerate
(
main_block
.
ops
):
# append optimizer op to tmp block
if
is_optimize_op
(
op
):
# append optimizer op to tmp block
new_op_desc
=
temp_block
.
desc
.
append_op
()
new_op_desc
.
copy_from
(
op
.
desc
)
optimize_ops_desc
.
append
(
new_op_desc
)
removed_op_idx
.
append
(
idx
)
# del op from dist_context
if
dist_context
:
dist_context
.
del_dist_op_for_program
(
op
)
# append allreduce_op and scale_op to tmp block
if
is_backward_op
(
op
):
if
is_data_parallel_reduce_op
(
op
)
or
is_data_parallel_scale_op
(
op
):
assert
len
(
op
.
desc
.
output_arg_names
())
==
1
new_op_desc
=
temp_block
.
desc
.
append_op
()
new_op_desc
.
copy_from
(
op
.
desc
)
allreduce_sum_desc
.
append
(
new_op_desc
)
removed_op_idx
.
append
(
idx
)
dist_context
.
del_dist_op_for_program
(
op
)
for
idx
in
removed_op_idx
[::
-
1
]:
main_block
.
_remove_op
(
idx
,
sync
=
False
)
main_block
.
_sync_with_cpp
()
return
optimize_ops_desc
return
optimize_ops_desc
,
allreduce_sum_desc
def
_
get
_gm_cond_var
(
main_program
,
k_steps
,
dist_context
):
def
_
create
_gm_cond_var
(
main_program
,
k_steps
,
dist_context
):
main_block
=
main_program
.
global_block
()
# Add const var
k_step_var
=
layers
.
create_global_var
(
name
=
"gradient_merge_k"
,
shape
=
[
1
],
value
=
int
(
k_steps
),
dtype
=
'int32'
,
persistable
=
True
,
force_cpu
=
True
)
k_step_var
=
layers
.
create_global_var
(
name
=
"gradient_merge_k"
,
shape
=
[
1
],
value
=
int
(
k_steps
),
dtype
=
'int32'
,
persistable
=
True
,
force_cpu
=
True
,
)
set_var_dist_attr
(
dist_context
,
k_step_var
,
[
-
1
],
world_process_group
.
ranks
)
zero_var
=
layers
.
create_global_var
(
name
=
"gradient_merge_zero"
,
shape
=
[
1
],
value
=
int
(
0
),
dtype
=
'int32'
,
persistable
=
True
,
force_cpu
=
True
)
zero_var
=
layers
.
create_global_var
(
name
=
"gradient_merge_zero"
,
shape
=
[
1
],
value
=
int
(
0
),
dtype
=
'int32'
,
persistable
=
True
,
force_cpu
=
True
,
)
set_var_dist_attr
(
dist_context
,
zero_var
,
[
-
1
],
world_process_group
.
ranks
)
# Add step var & cond var
step_var
=
layers
.
create_global_var
(
name
=
"gradient_merge_step"
,
shape
=
[
1
],
value
=
int
(
0
),
dtype
=
'int32'
,
persistable
=
True
,
force_cpu
=
True
)
step_var
=
layers
.
create_global_var
(
name
=
"gradient_merge_step"
,
shape
=
[
1
],
value
=
int
(
0
),
dtype
=
'int32'
,
persistable
=
True
,
force_cpu
=
True
,
)
set_var_dist_attr
(
dist_context
,
step_var
,
[
-
1
],
world_process_group
.
ranks
)
cond_var
=
main_block
.
create_var
(
name
=
"gradient_merge_cond"
,
shape
=
[
1
],
dtype
=
'bool'
)
cond_var
=
main_block
.
create_var
(
name
=
"gradient_merge_cond"
,
shape
=
[
1
],
dtype
=
'bool'
)
set_var_dist_attr
(
dist_context
,
cond_var
,
[
-
1
],
world_process_group
.
ranks
)
with
device_guard
(
"cpu"
):
with
paddle
.
static
.
device_guard
(
"cpu"
):
# step_var += 1
increment_op
=
main_block
.
append_op
(
type
=
'increment'
,
inputs
=
{
'X'
:
[
step_var
]},
outputs
=
{
'Out'
:
[
step_var
]},
attrs
=
{
'step'
:
float
(
1.0
),
OP_ROLE_KEY
:
OpRole
.
Backward
})
increment_op
=
main_block
.
append_op
(
type
=
'increment'
,
inputs
=
{
'X'
:
[
step_var
]},
outputs
=
{
'Out'
:
[
step_var
]},
attrs
=
{
'step'
:
float
(
1.0
),
OP_ROLE_KEY
:
OpRole
.
Backward
},
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
increment_op
,
world_process_group
.
ranks
,
[
-
1
],
dist_context
)
increment_op
,
world_process_group
.
ranks
,
[
-
1
],
dist_context
)
# step_var %= k_step
elementwise_mod_op
=
main_block
.
append_op
(
type
=
'elementwise_mod'
,
inputs
=
{
'X'
:
step_var
,
'Y'
:
k_step_var
},
outputs
=
{
'Out'
:
step_var
},
attrs
=
{
'axis'
:
-
1
,
'use_mkldnn'
:
False
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
elementwise_mod_op
=
main_block
.
append_op
(
type
=
'elementwise_mod'
,
inputs
=
{
'X'
:
step_var
,
'Y'
:
k_step_var
},
outputs
=
{
'Out'
:
step_var
},
attrs
=
{
'axis'
:
-
1
,
'use_mkldnn'
:
False
,
OP_ROLE_KEY
:
OpRole
.
Backward
,
},
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
elementwise_mod_op
,
world_process_group
.
ranks
,
[
-
1
],
dist_context
)
elementwise_mod_op
,
world_process_group
.
ranks
,
[
-
1
],
dist_context
)
# cond_var = (step_var == 0)
equal_op
=
main_block
.
append_op
(
type
=
'equal'
,
inputs
=
{
'X'
:
step_var
,
'Y'
:
zero_var
},
outputs
=
{
'Out'
:
cond_var
},
attrs
=
{
OP_ROLE_KEY
:
OpRole
.
Backward
})
equal_op
=
main_block
.
append_op
(
type
=
'equal'
,
inputs
=
{
'X'
:
step_var
,
'Y'
:
zero_var
},
outputs
=
{
'Out'
:
cond_var
},
attrs
=
{
OP_ROLE_KEY
:
OpRole
.
Backward
},
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
equal_op
,
world_process_group
.
ranks
,
[
-
1
],
dist_context
)
equal_op
,
world_process_group
.
ranks
,
[
-
1
],
dist_context
)
return
cond_var
def
_append_gradient_merge_backward_op
(
main_program
,
startup_program
,
params_grads
:
List
[
Tuple
[
Any
,
Any
]],
dist_context
)
->
Tuple
[
List
[
Tuple
[
Any
,
Any
]],
Dict
[
str
,
Any
]]:
main_program
,
startup_program
,
params_grads
,
master_grad
,
dist_context
,
):
main_block
=
main_program
.
global_block
()
startup_block
=
startup_program
.
global_block
()
...
...
@@ -148,149 +183,260 @@ def _append_gradient_merge_backward_op(
for
param
,
grad
in
params_grads
:
param_name
=
param
.
name
param_var
=
main_block
.
var
(
param_name
)
assert
(
param_var
is
not
None
)
ref_dist_attr
=
dist_context
.
get_tensor_dist_attr_for_program
(
param_var
)
assert
ref_dist_attr
is
not
None
gradient_merge_var
=
main_block
.
create_var
(
name
=
param_name
+
"@GRAD@GradientMerge"
,
shape
=
param_var
.
shape
,
dtype
=
param_var
.
dtype
,
persistable
=
True
)
ref_process_mesh
=
ref_dist_attr
.
process_mesh
ref_dims_mapping
=
ref_dist_attr
.
dims_mapping
assert
param_var
is
not
None
set_var_dist_attr
(
dist_context
,
gradient_merge_var
,
ref_dims_mapping
,
ref_process_mesh
)
dst_dtype
=
(
core
.
VarDesc
.
VarType
.
FP32
if
master_grad
else
param_var
.
dtype
)
# 2.1 crate param@GRAD@MERGE var in startup_block
startup_gradient_merge_var
=
startup_block
.
create_var
(
name
=
param_name
+
"@GRAD@GradientMerge"
,
name
=
param_name
+
"@GRAD@MERGED"
,
shape
=
param_var
.
shape
,
dtype
=
dst_dtype
,
persistable
=
True
,
)
startup_block
.
append_op
(
type
=
"fill_constant"
,
outputs
=
{
"Out"
:
startup_gradient_merge_var
},
attrs
=
{
"shape"
:
param_var
.
shape
,
"dtype"
:
dst_dtype
,
"value"
:
float
(
0
),
},
)
# 2.2 crate param@GRAD@MERGE var in main_block
ref_dist_attr
=
dist_context
.
get_tensor_dist_attr_for_program
(
param_var
)
assert
ref_dist_attr
is
not
None
gradient_merge_var
=
main_block
.
create_var
(
name
=
param_name
+
"@GRAD@MERGED"
,
shape
=
param_var
.
shape
,
dtype
=
param_var
.
dtype
,
persistable
=
True
)
startup_block
.
append_op
(
type
=
"fill_constant"
,
outputs
=
{
"Out"
:
startup_gradient_merge_var
},
attrs
=
{
"shape"
:
param_var
.
shape
,
"dtype"
:
param_var
.
dtype
,
"value"
:
float
(
0
),
})
# grad_merge += grad
new_grad_op
=
main_block
.
append_op
(
type
=
"elementwise_add"
,
inputs
=
{
'X'
:
grad
,
'Y'
:
gradient_merge_var
},
outputs
=
{
'Out'
:
gradient_merge_var
},
attrs
=
{
'axis'
:
-
1
,
'use_mkldnn'
:
False
,
OP_ROLE_KEY
:
OpRole
.
Backward
})
dtype
=
dst_dtype
,
persistable
=
True
,
)
ref_process_mesh
=
ref_dist_attr
.
process_mesh
ref_dims_mapping
=
ref_dist_attr
.
dims_mapping
set_var_dist_attr
(
dist_context
,
gradient_merge_var
,
ref_dims_mapping
,
ref_process_mesh
)
# 2.3 grad_merge += grad
grad_name
=
grad
.
name
if
grad
.
dtype
!=
dst_dtype
:
cast_grad_name
=
grad_name
+
"@TMP"
cast_grad_var
=
main_block
.
create_var
(
name
=
cast_grad_name
,
shape
=
grad
.
shape
,
dtype
=
dst_dtype
,
persistable
=
False
,
stop_gradient
=
grad
.
stop_gradient
,
)
set_var_dist_attr
(
dist_context
,
cast_grad_var
,
ref_dims_mapping
,
ref_process_mesh
)
cast_op
=
main_block
.
append_op
(
type
=
"cast"
,
inputs
=
{
"X"
:
grad
},
outputs
=
{
"Out"
:
cast_grad_var
},
attrs
=
{
"in_dtype"
:
grad
.
dtype
,
"out_dtype"
:
dst_dtype
,
OP_ROLE_KEY
:
OpRole
.
Backward
,
},
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
cast_op
,
ref_process_mesh
,
ref_dims_mapping
,
dist_context
)
grad
=
cast_grad_var
new_grad_op
=
main_block
.
append_op
(
type
=
"elementwise_add"
,
inputs
=
{
'X'
:
grad
,
'Y'
:
gradient_merge_var
},
outputs
=
{
'Out'
:
gradient_merge_var
},
attrs
=
{
'axis'
:
-
1
,
'use_mkldnn'
:
False
,
OP_ROLE_KEY
:
OpRole
.
Backward
,
},
)
new_params_to_grads
.
append
([
param
,
gradient_merge_var
])
grad_to_gradient_merge
[
grad
.
name
]
=
gradient_merge_var
.
name
grad_to_gradient_merge
[
grad
_
name
]
=
gradient_merge_var
.
name
naive_set_dist_op_attr_for_program_by_mesh_and_mapping
(
new_grad_op
,
ref_process_mesh
,
ref_dims_mapping
,
dist_context
)
new_grad_op
,
ref_process_mesh
,
ref_dims_mapping
,
dist_context
)
return
new_params_to_grads
,
grad_to_gradient_merge
def
_
create_cond_block_and_update_optimizer
(
main_program
,
cond_var
,
new_params_to_grads
:
List
[
Tuple
[
Any
,
Any
]],
grad_to_gradient_merge
:
Dict
[
str
,
str
],
optimize_ops_desc
:
List
[
Any
],
k_steps
,
avg
):
def
_
rename_arg_names
(
op_desc
,
var_name_dict
):
for
input_name
in
op_desc
.
input_arg_names
():
if
input_name
in
var_name_dict
:
op_desc
.
_rename_input
(
input_name
,
var_name_dict
[
input_name
])
for
output_name
in
op_desc
.
output_arg_names
():
if
output_name
in
var_name_dict
:
op_desc
.
_rename_output
(
output_name
,
var_name_dict
[
output_name
])
def
_create_cond_block_and_update_optimizer
(
main_program
,
cond_var
,
params_grads
,
new_params_to_grads
,
grad_to_gradient_merge
,
optimize_ops_desc
,
allreduce_sum_desc
,
k_steps
,
avg
,
master_grad
,
):
def
true_apply_gradient
():
cur_block_idx
=
main_program
.
current_block_idx
cur_block
=
main_program
.
current_block
()
# cur_block's forward_block & backward_block is itself
cur_block
.
_set_forward_block_idx
(
cur_block_idx
)
op_maker
=
core
.
op_proto_and_checker_maker
# record grads_name to insert c_allreduce_sum op
grads_name
=
[
grad
.
name
for
_
,
grad
in
params_grads
]
# append c_allreduce_sum ops and scale ops
for
op_desc
in
allreduce_sum_desc
:
outputs_name
=
op_desc
.
output_arg_names
()
assert
len
(
outputs_name
)
==
1
if
outputs_name
[
0
]
in
grads_name
:
new_op_desc
=
cur_block
.
desc
.
append_op
()
new_op_desc
.
copy_from
(
op_desc
)
_rename_arg_names
(
new_op_desc
,
grad_to_gradient_merge
)
new_op_desc
.
_set_attr
(
OP_ROLE_KEY
,
OpRole
.
Optimize
)
cur_block
.
_sync_with_cpp
()
if
avg
:
for
param
,
new_grad
in
new_params_to_grads
:
for
_
,
new_grad
in
new_params_to_grads
:
# grad /= k_steps
cur_block
.
append_op
(
type
=
'scale'
,
inputs
=
{
'X'
:
new_grad
},
outputs
=
{
'Out'
:
new_grad
},
attrs
=
{
'scale'
:
1.0
/
k_steps
,
'bias'
:
0.0
,
'bias_after_scale'
:
False
})
cur_block
.
append_op
(
type
=
'scale'
,
inputs
=
{
'X'
:
new_grad
},
outputs
=
{
'Out'
:
new_grad
},
attrs
=
{
'scale'
:
1.0
/
k_steps
,
'bias'
:
0.0
,
'bias_after_scale'
:
False
,
},
)
new_grad
.
op
.
_set_attr
(
OP_ROLE_KEY
,
OpRole
.
Optimize
)
cast_name_dict
=
{}
# append optimizer ops
for
op_desc
in
optimize_ops_desc
:
if
master_grad
and
is_gradient_clip_op
(
op_desc
):
if
op_desc
.
type
()
==
"cast"
:
if
(
op_desc
.
attr
(
'out_dtype'
)
in
[
4
,
22
]
and
op_desc
.
attr
(
'in_dtype'
)
==
5
):
cast_name_dict
[
op_desc
.
output_arg_names
()[
0
]
]
=
op_desc
.
input_arg_names
()[
0
]
elif
(
op_desc
.
attr
(
'in_dtype'
)
in
[
4
,
22
]
and
op_desc
.
attr
(
'out_dtype'
)
==
5
):
cast_name_dict
[
op_desc
.
output_arg_names
()[
0
]
]
=
op_desc
.
input_arg_names
()[
0
]
continue
for
out_name
in
op_desc
.
output_arg_names
():
out_var
=
cur_block
.
_var_recursive
(
out_name
)
out_var
.
desc
.
set_dtype
(
core
.
VarDesc
.
VarType
.
FP32
)
_rename_arg_names
(
op_desc
,
cast_name_dict
)
new_op_desc
=
cur_block
.
desc
.
append_op
()
new_op_desc
.
copy_from
(
op_desc
)
#update input/output
for
input_name
in
new_op_desc
.
input_arg_names
():
if
input_name
in
grad_to_gradient_merge
:
new_op_desc
.
_rename_input
(
input_name
,
grad_to_gradient_merge
[
input_name
])
for
output_name
in
new_op_desc
.
output_arg_names
():
if
output_name
in
grad_to_gradient_merge
:
new_op_desc
.
_rename_output
(
output_name
,
grad_to_gradient_merge
[
output_name
])
# update input/output
_rename_arg_names
(
new_op_desc
,
grad_to_gradient_merge
)
# remove op_role_var
if
new_op_desc
.
has_attr
(
op_maker
.
kOpRoleVarAttrName
()
):
new_op_desc
.
remove_attr
(
op_maker
.
kOpRoleVarAttrName
()
)
if
new_op_desc
.
has_attr
(
OP_ROLE_VAR_KEY
):
new_op_desc
.
remove_attr
(
OP_ROLE_VAR_KEY
)
# op's update Grad
if
core
.
grad_var_suffix
()
in
new_op_desc
.
input_arg_names
():
grad_value
=
new_op_desc
.
input
(
"Grad"
)[
0
]
# TODO FIXME(xym) support fp16
grad_merge_value
=
grad_value
+
'@
GradientMerge
'
grad_merge_value
=
grad_value
+
'@
MERGED
'
new_op_desc
.
set_input
(
"Grad"
,
[
grad_merge_value
])
main_program
.
global_block
().
_sync_with_cpp
()
cur_block
.
_sync_with_cpp
()
# clear gradient_merge_vars
for
param
,
new_grad
in
new_params_to_grads
:
layers
.
fill_constant
(
shape
=
new_grad
.
shape
,
dtype
=
new_grad
.
dtype
,
value
=
0.0
,
out
=
new_grad
)
new_grad
.
op
.
_set_attr
(
OP_ROLE_KEY
,
op_maker
.
OpRole
.
Optimize
)
for
_
,
new_grad
in
new_params_to_grads
:
layers
.
fill_constant
(
shape
=
new_grad
.
shape
,
dtype
=
new_grad
.
dtype
,
value
=
0.0
,
out
=
new_grad
,
)
new_grad
.
op
.
_set_attr
(
OP_ROLE_KEY
,
OpRole
.
Optimize
)
layers
.
cond
(
cond_var
,
true_fn
=
true_apply_gradient
,
false_fn
=
None
)
cond_op
=
main_program
.
global_block
().
ops
[
-
1
]
cond_op
.
_set_attr
(
OP_ROLE_KEY
,
OpRole
.
Optimize
)
def
parse_program
(
main_program
,
startup_program
,
params_grads
,
k_steps
,
avg
,
dist_context
):
# 1 remove optimizer_op from main_program
optimize_ops_desc
=
_remove_and_get_optimizer_op
(
main_program
,
dist_context
)
def
parse_program
(
main_program
,
startup_program
,
params_grads
,
k_steps
,
avg
,
master_grad
,
dist_context
,
):
# 1 remove optimizer_op, allreduce_sum_op and scale_op from main_program
optimize_ops_desc
,
allreduce_sum_desc
=
_remove_and_get_ops
(
main_program
,
dist_context
)
# back to block 0
main_program
.
_rollback
()
# 2 append gradient merge backward op to main_program
new_params_to_grads
,
grad_to_gradient_merge
=
_append_gradient_merge_backward_op
(
main_program
,
startup_program
,
params_grads
,
dist_context
)
(
new_params_to_grads
,
grad_to_gradient_merge
,
)
=
_append_gradient_merge_backward_op
(
main_program
,
startup_program
,
params_grads
,
master_grad
,
dist_context
)
# 3 create gradient_merge_cond
cond_var
=
_
get
_gm_cond_var
(
main_program
,
k_steps
,
dist_context
)
cond_var
=
_
create
_gm_cond_var
(
main_program
,
k_steps
,
dist_context
)
# 4 create ConditionalBlock and append gradient merge optimizer ops
_create_cond_block_and_update_optimizer
(
main_program
,
cond_var
,
new_params_to_grads
,
grad_to_gradient_merge
,
optimize_ops_desc
,
k_steps
,
avg
)
_create_cond_block_and_update_optimizer
(
main_program
,
cond_var
,
params_grads
,
new_params_to_grads
,
grad_to_gradient_merge
,
optimize_ops_desc
,
allreduce_sum_desc
,
k_steps
,
avg
,
master_grad
,
)
@
register_pass
(
"auto_parallel_gradient_merge_pass"
)
class
GradientMergePass
(
PassBase
):
def
__init__
(
self
):
super
(
GradientMergePass
,
self
).
__init__
()
self
.
set_attr
(
"k_steps"
,
-
1
)
self
.
set_attr
(
"avg"
,
True
)
self
.
set_attr
(
"master_grad"
,
False
)
def
_check_self
(
self
):
if
self
.
get_attr
(
"k_steps"
)
<
1
:
...
...
@@ -306,10 +452,20 @@ class GradientMergePass(PassBase):
def
_apply_single_impl
(
self
,
main_program
,
startup_program
,
context
):
k_steps
=
self
.
get_attr
(
"k_steps"
,
-
1
)
avg
=
self
.
get_attr
(
"avg"
,
False
)
master_grad
=
self
.
get_attr
(
"master_grad"
,
False
)
dist_context
=
self
.
get_attr
(
"dist_context"
)
params_grads
=
self
.
get_attr
(
"params_grads"
)
# TODO(zyl): make master_grad configurable
master_grad
=
True
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
parse_program
(
main_program
,
startup_program
,
params_grads
,
k_steps
,
avg
,
dist_context
)
parse_program
(
main_program
,
startup_program
,
params_grads
,
k_steps
,
avg
,
master_grad
,
dist_context
,
)
main_program
.
_sync_with_cpp
()
python/paddle/distributed/passes/auto_parallel_pipeline.py
浏览文件 @
c47853f6
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
logging
import
exception
import
os
from
paddle.fluid
import
core
...
...
@@ -26,6 +25,7 @@ from paddle.distributed.auto_parallel.utils import (
is_backward_op
,
is_optimize_op
,
is_lr_sched_op
,
is_fillconst_op_for_micro_batch
,
)
...
...
@@ -38,6 +38,12 @@ __not_shape_var_type__ = [
]
def
is_reshard_op
(
op
):
return
op
.
has_attr
(
'op_namescope'
)
and
"/auto_parallel/reshard"
in
op
.
attr
(
'op_namescope'
)
@
register_pass
(
"auto_parallel_pipeline"
)
class
PipelinePass
(
PassBase
):
def
__init__
(
self
):
...
...
@@ -59,8 +65,17 @@ class PipelinePass(PassBase):
self
.
_gen_bsz
=
self
.
get_attr
(
"generation_batch_size"
)
self
.
_program
=
main_program
self
.
_cur_rank
=
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
,
0
))
trainer_endpoints
=
os
.
getenv
(
"PADDLE_TRAINER_ENDPOINTS"
,
""
).
split
(
','
)
self
.
_nrank
=
len
(
trainer_endpoints
)
# compute current pp stage
self
.
_pp_stages
=
len
(
self
.
_dist_context
.
process_meshes
)
self
.
_cur_pp_stage
=
self
.
_get_pp_stage
(
self
.
_cur_rank
)
if
self
.
_mode
==
"1F1B"
:
raise
NotImplementedError
(
"1F1B has not been implemented"
)
self
.
_insert_sync_ops_for_1f1b
()
self
.
_task_1f1b
()
elif
self
.
_mode
==
"F-Then-B"
:
raise
NotImplementedError
(
"F-Then-B has not been implemented"
)
elif
self
.
_mode
==
"stream"
:
...
...
@@ -103,6 +118,93 @@ class PipelinePass(PassBase):
block
.
_sync_with_cpp
()
def
_insert_sync_ops_for_1f1b
(
self
):
"""
This implementation refers to lots of Paddle/python/paddle/fluid/optimizer.py.
The difference between this function with 'PipelineOptimizer' is that
'send_v2' op and 'recv_v2' op have been inserted in program by 'reshard'.
"""
for
block
in
self
.
_program
.
blocks
:
offset
=
0
first_optimize_index
=
None
for
index
,
op
in
enumerate
(
list
(
block
.
ops
)):
if
is_optimize_op
(
op
):
first_optimize_index
=
index
break
# insert sync ops
for
index
,
op
in
enumerate
(
list
(
block
.
ops
)):
if
op
.
type
==
'send_v2'
:
# step1: set 'use_calc_stream' False
op
.
_set_attr
(
"use_calc_stream"
,
False
)
op_role
=
op
.
attr
(
'op_role'
)
ring_id
=
op
.
attr
(
'ring_id'
)
# step2: insert 'c_sync_calc_stream' op before 'send_v2' op
var_name
=
op
.
input_arg_names
[
0
]
var
=
block
.
var
(
var_name
)
block
.
_insert_op_without_sync
(
index
=
index
+
offset
,
type
=
"c_sync_calc_stream"
,
inputs
=
{
'X'
:
[
var
]},
outputs
=
{
'Out'
:
[
var
]},
attrs
=
{
'op_role'
:
op_role
},
)
offset
+=
1
# step3: insert 'c_sync_comm_stream' op after 'send_v2' op or
# before the first optimize op
if
int
(
op_role
)
==
int
(
OpRole
.
Backward
):
index
=
first_optimize_index
+
offset
new_op_role
=
OpRole
.
Optimize
else
:
index
=
index
+
offset
+
1
new_op_role
=
OpRole
.
Backward
sync_comm_op
=
block
.
_insert_op_without_sync
(
index
=
index
,
type
=
"c_sync_comm_stream"
,
inputs
=
{
'X'
:
[
var
]},
outputs
=
{
'Out'
:
[
var
]},
attrs
=
{
'op_role'
:
new_op_role
,
'ring_id'
:
ring_id
,
},
)
# step4: If 'send_v2' op in forward parse, set 'pipeline_flag' to distinguish
# whether the 'c_sync_comm_stream' op is inserted for pipeline.
if
int
(
op_role
)
==
int
(
OpRole
.
Forward
):
sync_comm_op
.
_set_attr
(
'pipeline_flag'
,
''
)
offset
+=
1
block
.
_sync_with_cpp
()
offset
=
0
backward_recv_index
=
None
for
index
,
op
in
enumerate
(
block
.
ops
):
if
op
.
type
==
"recv_v2"
and
is_backward_op
(
op
):
backward_recv_index
=
index
break
if
backward_recv_index
is
None
:
continue
# replace 'c_sync_comm_stream' op with 'nop' op
for
index
,
op
in
enumerate
(
list
(
block
.
ops
)):
if
index
>=
backward_recv_index
:
break
if
op
.
type
==
'c_sync_comm_stream'
and
op
.
has_attr
(
'pipeline_flag'
):
var_name
=
op
.
output_arg_names
[
0
]
var
=
block
.
var
(
var_name
)
block
.
_remove_op
(
index
+
offset
,
sync
=
False
)
offset
-=
1
block
.
_insert_op_without_sync
(
index
=
backward_recv_index
,
type
=
"nop"
,
inputs
=
{
'X'
:
[
var
]},
outputs
=
{
'Out'
:
[
var
]},
attrs
=
{
'op_role'
:
OpRole
.
Backward
},
)
block
.
_sync_with_cpp
()
def
_create_param
(
self
,
dst_block
,
src_var
):
copied_kwargs
=
{}
copied_kwargs
[
'trainable'
]
=
src_var
.
trainable
...
...
@@ -190,16 +292,185 @@ class PipelinePass(PassBase):
break
return
pp_idx
def
_task_stream
(
self
):
cur_rank
=
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
,
0
))
trainer_endpoints
=
os
.
getenv
(
"PADDLE_TRAINER_ENDPOINTS"
,
""
).
split
(
','
)
nrank
=
len
(
trainer_endpoints
)
num_of_functionality
=
5
def
_task_1f1b
(
self
):
# create fwd, bwd, opt program with op_role
num_of_functionality
=
4
lr_prog
=
Program
()
fwd_prog
=
Program
()
bwd_prog
=
Program
()
opt_prog
=
Program
()
for
idx
,
src_block
in
enumerate
(
self
.
_program
.
blocks
):
if
idx
==
0
:
lr_block
=
lr_prog
.
block
(
0
)
fwd_block
=
fwd_prog
.
block
(
0
)
bwd_block
=
bwd_prog
.
block
(
0
)
opt_block
=
opt_prog
.
block
(
0
)
else
:
lr_block
=
lr_prog
.
_create_block
(
parent_idx
=
src_block
.
parent_idx
)
fwd_block
=
fwd_prog
.
_create_block
(
parent_idx
=
src_block
.
parent_idx
)
bwd_block
=
bwd_prog
.
_create_block
(
parent_idx
=
src_block
.
parent_idx
)
opt_block
=
opt_prog
.
_create_block
(
parent_idx
=
src_block
.
parent_idx
)
lr_block
.
_set_forward_block_idx
(
src_block
.
forward_block_idx
)
fwd_block
.
_set_forward_block_idx
(
src_block
.
forward_block_idx
)
bwd_block
.
_set_forward_block_idx
(
src_block
.
forward_block_idx
)
opt_block
.
_set_forward_block_idx
(
src_block
.
forward_block_idx
)
# split the program based on the op_role
for
op
in
src_block
.
ops
:
if
is_lr_sched_op
(
op
):
self
.
_create_program
(
src_block
,
lr_block
,
op
)
if
is_forward_op
(
op
)
or
is_fillconst_op_for_micro_batch
(
op
):
self
.
_create_program
(
src_block
,
fwd_block
,
op
)
elif
is_backward_op
(
op
):
self
.
_create_program
(
src_block
,
bwd_block
,
op
)
elif
is_optimize_op
(
op
):
self
.
_create_program
(
src_block
,
opt_block
,
op
)
else
:
raise
ValueError
(
"The op role: "
+
str
(
op
.
attr
(
'op_role'
))
+
" isn't one of LRSched, Forward, Backward or Optimizer."
)
# compute current pp stage
pp_stages
=
len
(
self
.
_dist_context
.
process_meshes
)
cur_pp_stage
=
self
.
_get_pp_stage
(
cur_rank
)
lr_prog
.
_sync_with_cpp
()
fwd_prog
.
_sync_with_cpp
()
bwd_prog
.
_sync_with_cpp
()
opt_prog
.
_sync_with_cpp
()
lr_prog
.
_rollback
()
fwd_prog
.
_rollback
()
bwd_prog
.
_rollback
()
opt_prog
.
_rollback
()
# Create task nodes.
lr_task_node
=
TaskNode
(
rank
=
self
.
_cur_rank
,
max_run_times
=
self
.
_acc_steps
,
program
=
lr_prog
,
task_id
=
int
(
self
.
_cur_rank
*
num_of_functionality
+
0
),
node_type
=
"Amplifier"
,
lazy_initialize
=
True
,
)
lr_task_node
.
set_run_pre_steps
(
self
.
_acc_steps
)
fwd_task_node
=
TaskNode
(
rank
=
self
.
_cur_rank
,
max_run_times
=
self
.
_acc_steps
,
program
=
fwd_prog
,
task_id
=
int
(
self
.
_cur_rank
*
num_of_functionality
+
1
),
node_type
=
"Compute"
,
lazy_initialize
=
True
,
)
bwd_task_node
=
TaskNode
(
rank
=
self
.
_cur_rank
,
max_run_times
=
self
.
_acc_steps
,
program
=
bwd_prog
,
task_id
=
int
(
self
.
_cur_rank
*
num_of_functionality
+
2
),
node_type
=
"Compute"
,
lazy_initialize
=
True
,
)
opt_task_node
=
TaskNode
(
rank
=
self
.
_cur_rank
,
max_run_times
=
self
.
_acc_steps
,
program
=
opt_prog
,
task_id
=
int
(
self
.
_cur_rank
*
num_of_functionality
+
3
),
node_type
=
"Amplifier"
,
lazy_initialize
=
True
,
)
opt_task_node
.
set_run_pre_steps
(
self
.
_acc_steps
)
opt_task_node
.
set_run_at_offset
(
self
.
_acc_steps
-
1
)
task_nodes
=
{
"lr"
:
lr_task_node
,
"fwd"
:
fwd_task_node
,
"bwd"
:
bwd_task_node
,
"opt"
:
opt_task_node
,
}
# get upstream ranks and downstream ranks of cur_rank
up_down_streams
=
self
.
_dist_context
.
up_down_streams
pp_upstream
=
up_down_streams
.
ups
(
self
.
_cur_rank
)
pp_downstream
=
up_down_streams
.
downs
(
self
.
_cur_rank
)
# set upstream/downstream for task_nodes of cur_rank
for
i
,
(
task_role
,
task_node
)
in
enumerate
(
task_nodes
.
items
()):
cur_id
=
int
(
self
.
_cur_rank
*
num_of_functionality
+
i
)
ups
=
[]
downs
=
[]
# set upstream/downstream and buffersize in pipeline stage
pp_buff_size
=
int
(
self
.
_pp_stages
-
self
.
_cur_pp_stage
)
prev_id
=
cur_id
-
1
next_id
=
cur_id
+
1
if
task_role
!=
"lr"
:
buf_size
=
pp_buff_size
if
task_role
==
"bwd"
else
2
ups
.
append
((
prev_id
,
buf_size
))
if
task_role
!=
"opt"
:
buf_size
=
pp_buff_size
if
task_role
==
"fwd"
else
2
downs
.
append
((
next_id
,
buf_size
))
# set upstream/downstream and buffersize cross pipeline stage
for
upstream
in
pp_upstream
:
upstream_id
=
int
(
upstream
*
num_of_functionality
+
i
)
if
task_role
==
"fwd"
:
if
upstream
!=
-
1
:
ups
.
append
((
upstream_id
,
2
))
elif
task_role
==
"bwd"
:
if
upstream
!=
-
1
:
downs
.
append
((
upstream_id
,
2
))
for
downstream
in
pp_downstream
:
downstream_id
=
int
(
downstream
*
num_of_functionality
+
i
)
if
task_role
==
"fwd"
:
if
downstream
!=
-
1
:
downs
.
append
((
downstream_id
,
2
))
elif
task_role
==
"bwd"
:
if
downstream
!=
-
1
:
ups
.
append
((
downstream_id
,
2
))
for
up
in
ups
:
print
(
"Task:"
,
cur_id
,
"'s upstream includes:"
,
up
[
0
],
", buffer size is:"
,
up
[
1
],
)
task_node
.
add_upstream_task
(
up
[
0
],
up
[
1
])
for
down
in
downs
:
print
(
"Task:"
,
cur_id
,
"'s downstream includes:"
,
down
[
0
],
", buffer size is:"
,
down
[
1
],
)
task_node
.
add_downstream_task
(
down
[
0
],
down
[
1
])
# record global message: task_id_to_rank
task_id_to_rank
=
{}
for
i
in
range
(
self
.
_nrank
):
for
j
in
range
(
num_of_functionality
):
task_id_to_rank
[
int
(
i
*
num_of_functionality
+
j
)]
=
i
self
.
_program
.
_pipeline_opt
=
{}
self
.
_program
.
_pipeline_opt
[
'fleet_opt'
]
=
{
"tasks"
:
list
(
task_nodes
.
values
()),
"task_id_to_rank"
:
task_id_to_rank
,
"num_micro_batches"
:
self
.
_acc_steps
,
}
def
_task_stream
(
self
):
num_of_functionality
=
5
start_prog
=
Program
()
cond_prog
=
Program
()
end_prog
=
Program
()
...
...
@@ -207,6 +478,7 @@ class PipelinePass(PassBase):
recv_prog
=
Program
()
cond_var_name
=
None
# record the varnames related to the while cond vars and communicate by nccl
send_vars_name
=
set
()
recv_vars_name
=
dict
()
for
ib
,
src_block
in
enumerate
(
self
.
_program
.
blocks
):
...
...
@@ -231,38 +503,23 @@ class PipelinePass(PassBase):
src_block
,
end_block
,
op
,
force_create
=
True
)
elif
ib
==
1
:
# NOTE: The while block will be split to two separate blocks.
# The send_block:
# include all ops about tansformer generation
# execlude the nccl op about the while cond var
# The recv_block:
# include all ops about the while cond var
# execlude the nccl op about the while cond var
# the nccl op about cond var:
# put these varnames in the task node and do communication by brpc
send_block
=
send_prog
.
block
(
0
)
recv_block
=
recv_prog
.
block
(
0
)
is_after_send_op
=
False
is_after_recv_op
=
False
for
op
in
src_block
.
ops
:
for
i
,
op
in
enumerate
(
src_block
.
ops
)
:
if
op
.
type
==
"send_v2"
and
not
is_after_send_op
:
is_after_send_op
=
True
if
cur_pp_stage
==
pp_stages
-
1
:
if
op
.
type
in
[
"c_sync_calc_stream"
,
"nop"
]:
continue
if
(
op
.
type
not
in
[
"recv_2"
,
"assign"
]
and
op
.
has_attr
(
'op_namescope'
)
and
"/auto_parallel/reshard"
in
op
.
attr
(
'op_namescope'
)
):
if
(
len
(
op
.
desc
.
input_arg_names
())
>
0
and
"@RESHARD"
not
in
op
.
desc
.
input_arg_names
()[
0
]
):
send_vars_name
.
add
(
op
.
desc
.
input_arg_names
()[
0
]
)
continue
if
op
.
type
==
"send_v2"
:
continue
self
.
_create_program
(
src_block
,
send_block
,
op
,
force_create
=
True
)
continue
if
(
is_after_send_op
...
...
@@ -270,45 +527,21 @@ class PipelinePass(PassBase):
and
op
.
type
==
"recv_v2"
):
is_after_recv_op
=
True
if
op
.
has_attr
(
'op_namescope'
)
and
"/auto_parallel/reshard"
in
op
.
attr
(
'op_namescope'
):
var_name
=
op
.
desc
.
output_arg_names
()[
0
]
index
=
var_name
.
find
(
"@"
)
if
index
>
0
:
old_var_name
=
var_name
[:
index
]
else
:
old_var_name
=
var_name
recv_vars_name
[
var_name
]
=
old_var_name
if
not
src_block
.
_find_var_recursive
(
old_var_name
):
src_var
=
src_block
.
_var_recursive
(
var_name
)
recv_block
.
create_var
(
type
=
src_var
.
type
,
name
=
old_var_name
,
shape
=
src_var
.
shape
,
dtype
=
src_var
.
dtype
,
lod_level
=
src_var
.
lod_level
,
persistable
=
src_var
.
persistable
,
error_clip
=
src_var
.
error_clip
,
stop_gradient
=
src_var
.
stop_gradient
,
is_data
=
src_var
.
is_data
,
belong_to_optimizer
=
src_var
.
belong_to_optimizer
,
)
continue
self
.
_create_program
(
src_block
,
recv_block
,
op
,
force_create
=
True
)
continue
if
not
is_after_send_op
or
not
is_after_recv_op
:
if
cur_pp_stage
==
pp_stages
-
1
:
if
op
.
type
in
[
"c_sync_calc_stream"
,
"nop"
]:
if
self
.
_cur_pp_stage
==
self
.
_pp_stages
-
1
:
# the c_sync_calc_stream about c_allgather cannot be removed
if
(
op
.
type
==
"c_sync_calc_stream"
and
src_block
.
ops
[
i
+
1
].
type
==
"send_v2"
):
continue
if
op
.
type
==
"nop"
:
continue
# HACKCODE: the varname of send_v2 op, cast op should be recorded for brpc comm
if
(
op
.
type
not
in
[
"recv_2"
,
"assign"
]
op
.
type
not
in
[
"recv_2"
,
"assign"
,
"c_allgather"
]
and
op
.
has_attr
(
'op_namescope'
)
and
"/auto_parallel/reshard"
in
op
.
attr
(
'op_namescope'
)
...
...
@@ -327,13 +560,16 @@ class PipelinePass(PassBase):
self
.
_create_program
(
src_block
,
send_block
,
op
,
force_create
=
True
)
continue
if
is_after_send_op
and
is_after_recv_op
:
# HACKCODE: the varname of recv_v2 op, assign op should be recorded for brpc comm
if
op
.
has_attr
(
'op_namescope'
)
and
"/auto_parallel/reshard"
in
op
.
attr
(
'op_namescope'
):
# remove the suffix of "@RESHARD"
var_name
=
op
.
desc
.
output_arg_names
()[
0
]
index
=
var_name
.
find
(
"@"
)
if
index
>
0
:
...
...
@@ -365,6 +601,7 @@ class PipelinePass(PassBase):
self
.
_create_program
(
src_block
,
recv_block
,
op
,
force_create
=
True
)
continue
else
:
raise
Exception
(
"Only support generation condition."
)
...
...
@@ -406,52 +643,52 @@ class PipelinePass(PassBase):
vars_to_shape
=
recv_task_node_var_shape
start_task_node
=
TaskNode
(
rank
=
cur_rank
,
rank
=
self
.
_
cur_rank
,
max_run_times
=
self
.
_acc_steps
,
node_type
=
"Start"
,
task_id
=
int
(
cur_rank
*
num_of_functionality
+
0
),
task_id
=
int
(
self
.
_
cur_rank
*
num_of_functionality
+
0
),
program
=
start_prog
,
lazy_initialize
=
True
,
)
cond_task_node
=
TaskNode
(
rank
=
cur_rank
,
rank
=
self
.
_
cur_rank
,
max_run_times
=
self
.
_acc_steps
,
node_type
=
"Cond"
,
task_id
=
int
(
cur_rank
*
num_of_functionality
+
1
),
task_id
=
int
(
self
.
_
cur_rank
*
num_of_functionality
+
1
),
program
=
cond_prog
,
cond_var_name
=
cond_var_name
,
lazy_initialize
=
True
,
)
send_task_node
=
TaskNode
(
rank
=
cur_rank
,
rank
=
self
.
_
cur_rank
,
max_run_times
=
self
.
_acc_steps
,
node_type
=
"Compute"
,
task_id
=
int
(
cur_rank
*
num_of_functionality
+
2
),
task_id
=
int
(
self
.
_
cur_rank
*
num_of_functionality
+
2
),
program
=
send_prog
,
lazy_initialize
=
True
,
)
recv_task_node
=
TaskNode
(
rank
=
cur_rank
,
rank
=
self
.
_
cur_rank
,
max_run_times
=
self
.
_acc_steps
,
node_type
=
"Compute"
,
task_id
=
int
(
cur_rank
*
num_of_functionality
+
3
),
task_id
=
int
(
self
.
_
cur_rank
*
num_of_functionality
+
3
),
program
=
recv_prog
,
lazy_initialize
=
True
,
vars_to_dtype
=
vars_to_dtype
,
vars_to_shape
=
vars_to_shape
,
)
end_task_node
=
TaskNode
(
rank
=
cur_rank
,
rank
=
self
.
_
cur_rank
,
max_run_times
=
self
.
_acc_steps
,
node_type
=
"Compute"
,
task_id
=
int
(
cur_rank
*
num_of_functionality
+
4
),
task_id
=
int
(
self
.
_
cur_rank
*
num_of_functionality
+
4
),
program
=
end_prog
,
lazy_initialize
=
True
,
)
# add dependencies for task nodes intra stage
inf
=
-
1
pp_buff_size
=
int
(
pp_stages
-
cur_pp_stage
)
pp_buff_size
=
int
(
self
.
_pp_stages
-
self
.
_
cur_pp_stage
)
start_task_node
.
add_downstream_task
(
cond_task_node
.
task_id
(),
self
.
_gen_bsz
)
...
...
@@ -560,12 +797,12 @@ class PipelinePass(PassBase):
# add dependencies for task nodes inter stage
# get upstream ranks and downstream ranks of cur_rank
up_down_streams
=
self
.
_dist_context
.
up_down_streams
pp_upstream
_ranks
=
up_down_streams
.
ups
(
cur_rank
)
pp_downstream
_ranks
=
up_down_streams
.
downs
(
cur_rank
)
pp_upstream
=
up_down_streams
.
ups
(
self
.
_
cur_rank
)
pp_downstream
=
up_down_streams
.
downs
(
self
.
_
cur_rank
)
for
upstream_rank
in
pp_upstream
_ranks
:
for
upstream_rank
in
pp_upstream
:
upstream_pp_stage
=
self
.
_get_pp_stage
(
upstream_rank
)
if
upstream_pp_stage
<
pp_stages
-
1
:
if
upstream_pp_stage
<
self
.
_
pp_stages
-
1
:
upstream_task_id
=
int
(
upstream_rank
*
num_of_functionality
+
2
)
send_task_node
.
add_upstream_task
(
upstream_task_id
)
print
(
...
...
@@ -587,8 +824,8 @@ class PipelinePass(PassBase):
", buffer size is:"
,
2
,
)
for
downstream_rank
in
pp_downstream
_ranks
:
if
cur_pp_stage
<
pp_stages
-
1
:
for
downstream_rank
in
pp_downstream
:
if
self
.
_cur_pp_stage
<
self
.
_
pp_stages
-
1
:
downstream_task_id
=
int
(
downstream_rank
*
num_of_functionality
+
2
)
...
...
@@ -616,7 +853,7 @@ class PipelinePass(PassBase):
)
task_id_to_rank
=
{}
for
i
in
range
(
nrank
):
for
i
in
range
(
self
.
_
nrank
):
for
j
in
range
(
num_of_functionality
):
task_id_to_rank
[
int
(
i
*
num_of_functionality
+
j
)]
=
i
self
.
_program
.
_pipeline_opt
=
{
...
...
python/paddle/fluid/tests/unittests/auto_parallel/1F1B_pass_unittest.py
0 → 100644
浏览文件 @
c47853f6
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
random
import
numpy
as
np
import
paddle
from
paddle.distributed.fleet
import
auto
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
get_gpt_model
import
generate_model
,
FakeDataset
paddle
.
enable_static
()
def
apply_pass
(
use_1f1b
=
False
):
strategy
=
auto
.
Strategy
()
strategy
.
auto_mode
=
"semi"
strategy
.
reinit
=
True
if
use_1f1b
:
pipeline
=
strategy
.
pipeline
pipeline
.
enable
=
True
pipeline
.
schedule_mode
=
"1F1B"
pipeline
.
accumulate_steps
=
2
else
:
gradient_merge
=
strategy
.
gradient_merge
gradient_merge
.
enable
=
True
gradient_merge
.
k_steps
=
2
gradient_merge
.
avg
=
True
amp
=
strategy
.
amp
amp
.
enable
=
True
amp
.
custom_white_list
=
[
'softmax'
,
'layer_norm'
,
'gelu'
]
amp
.
custom_black_list
=
[
'c_softmax_with_cross_entropy'
,
'elementwise_div'
,
'reduce_sum'
,
]
amp
.
init_loss_scaling
=
32768
amp
.
use_fp16_guard
=
False
amp
.
use_pure_fp16
=
True
return
strategy
def
reset_prog
():
paddle
.
fluid
.
framework
.
switch_main_program
(
paddle
.
static
.
Program
())
paddle
.
fluid
.
framework
.
switch_startup_program
(
paddle
.
static
.
Program
())
class
Test1F1BPass
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
rtol
=
1e-5
self
.
atol
=
1e-8
self
.
batch_size
=
2
self
.
batch_num
=
10
self
.
clip_norm
=
0.2
self
.
dataset
=
FakeDataset
(
self
.
batch_size
*
self
.
batch_num
)
def
init
(
self
,
engine
):
paddle
.
seed
(
2021
)
np
.
random
.
seed
(
2021
)
random
.
seed
(
2021
)
paddle
.
distributed
.
fleet
.
init
(
is_collective
=
True
)
place
=
paddle
.
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
engine
.
_executor
=
paddle
.
static
.
Executor
(
place
)
def
get_engine
(
self
,
use_1f1b
=
False
):
reset_prog
()
strategy
=
apply_pass
(
use_1f1b
)
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
self
.
clip_norm
)
opt
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
0.00001
,
grad_clip
=
clip
)
model
,
loss
=
generate_model
(
"pp"
)
engine
=
auto
.
Engine
(
model
,
loss
,
opt
,
strategy
=
strategy
)
self
.
init
(
engine
)
return
engine
def
check_results
(
self
,
ref_losses
,
check_losses
):
np
.
testing
.
assert_allclose
(
ref_losses
,
check_losses
,
rtol
=
self
.
rtol
,
atol
=
self
.
atol
,
err_msg
=
'pass {} has wrong results!,
\n
u={}
\n
v={}
\n
diff={}'
.
format
(
__class__
,
ref_losses
,
check_losses
,
ref_losses
-
check_losses
),
)
def
test_1f1b_pass
(
self
):
# navie_pp+gradient_merge training
engine_pp
=
self
.
get_engine
()
history
=
engine_pp
.
fit
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
,
log_freq
=
1
)
assert
engine_pp
.
_strategy
.
pipeline
.
enable
==
False
# pp2 1f1b merge training
engine_1f1b
=
self
.
get_engine
(
True
)
history
=
engine_1f1b
.
fit
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
,
log_freq
=
1
)
assert
engine_1f1b
.
_strategy
.
pipeline
.
enable
==
True
# NOTE: every sample data from dataset is all the same
if
paddle
.
distributed
.
get_rank
()
==
1
:
losses_pp
=
np
.
array
(
history
.
history
[
"loss"
])
losses_1f1b
=
np
.
array
(
history
.
history
[
"loss"
])
self
.
check_results
(
losses_pp
,
losses_1f1b
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
浏览文件 @
c47853f6
...
...
@@ -69,6 +69,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules
(
test_engine_callbacks MODULES test_engine_callbacks
)
set_tests_properties
(
test_engine_callbacks
PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
TIMEOUT 50
)
py_test_modules
(
test_pass_1F1B MODULES test_pass_1F1B
)
set_tests_properties
(
test_pass_1F1B PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
TIMEOUT 50
)
py_test_modules
(
test_parallel_tuner MODULES test_parallel_tuner ENVS
${
dist_ENVS
}
)
...
...
python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py
浏览文件 @
c47853f6
...
...
@@ -89,6 +89,12 @@ def generate_model(strategy, dropout_prob=0.0):
modeling
.
_global_parallel_strategy
=
"mp"
elif
strategy
==
"dp"
:
modeling
.
_global_parallel_strategy
=
"dp"
elif
strategy
==
"pp"
:
modeling
.
_global_parallel_strategy
=
"pp"
modeling
.
PP_MESH_LIST
=
[
auto
.
ProcessMesh
(
mesh
=
[
0
]),
auto
.
ProcessMesh
(
mesh
=
[
1
]),
]
else
:
raise
ValueError
(
"Only support serial, mp2 and dp2."
)
...
...
@@ -108,6 +114,7 @@ def generate_model(strategy, dropout_prob=0.0):
eos_token_id
=
7
,
bos_token_id
=
0
,
eol_token_id
=
3
,
pp_degree
=
2
if
strategy
==
"pp"
else
None
,
)
model
=
GPTForPretraining
(
gpt
,
vocab_size
=
1000
,
hidden_size
=
64
,
initializer_range
=
0.02
...
...
python/paddle/fluid/tests/unittests/auto_parallel/gradient_merge_pass_unittest.py
浏览文件 @
c47853f6
...
...
@@ -19,7 +19,7 @@ import paddle
from
paddle.distributed.fleet
import
auto
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
get_gpt_model
import
generate_model
,
create_data_holder
,
FakeDataset
from
get_gpt_model
import
generate_model
,
FakeDataset
paddle
.
enable_static
()
...
...
@@ -28,12 +28,25 @@ def apply_pass(use_gradient_merge=False):
strategy
=
auto
.
Strategy
()
strategy
.
auto_mode
=
"semi"
strategy
.
reinit
=
True
if
use_gradient_merge
:
gradient_merge
=
strategy
.
gradient_merge
gradient_merge
.
enable
=
True
gradient_merge
.
k_steps
=
4
gradient_merge
.
avg
=
True
amp
=
strategy
.
amp
amp
.
enable
=
True
amp
.
custom_white_list
=
[
'softmax'
,
'layer_norm'
,
'gelu'
]
amp
.
custom_black_list
=
[
'c_softmax_with_cross_entropy'
,
'elementwise_div'
,
'reduce_sum'
,
]
amp
.
init_loss_scaling
=
32768
amp
.
use_fp16_guard
=
False
amp
.
use_pure_fp16
=
True
return
strategy
...
...
@@ -88,6 +101,7 @@ class TestGradientMergePass(unittest.TestCase):
history
=
dp_engine
.
fit
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
,
log_freq
=
1
)
assert
dp_engine
.
_strategy
.
gradient_merge
.
enable
==
False
dp_losses
=
np
.
array
(
history
.
history
[
"loss"
])
# dp2 gradient merge training
...
...
@@ -95,6 +109,7 @@ class TestGradientMergePass(unittest.TestCase):
history
=
gm_engine
.
fit
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
,
log_freq
=
1
)
assert
gm_engine
.
_strategy
.
gradient_merge
.
enable
==
True
gm_losses
=
np
.
array
(
history
.
history
[
"loss"
])
# avg_loss = 0
...
...
python/paddle/fluid/tests/unittests/auto_parallel/test_pass_1F1B.py
0 → 100644
浏览文件 @
c47853f6
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
tempfile
import
unittest
import
os
import
sys
import
shutil
import
subprocess
from
paddle.distributed.fleet.launch_utils
import
run_with_coverage
class
Test1F1BPass
(
unittest
.
TestCase
):
def
test_pp2
(
self
):
file_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
launch_model_path
=
os
.
path
.
join
(
file_dir
,
"1F1B_pass_unittest.py"
)
if
os
.
environ
.
get
(
"WITH_COVERAGE"
,
"OFF"
)
==
"ON"
:
coverage_args
=
[
"-m"
,
"coverage"
,
"run"
,
"--branch"
,
"-p"
]
else
:
coverage_args
=
[]
tmp_dir
=
tempfile
.
TemporaryDirectory
()
cmd
=
(
[
sys
.
executable
,
"-u"
]
+
coverage_args
+
[
"-m"
,
"paddle.distributed.launch"
,
"--devices"
,
"0,1"
,
"--log_dir"
,
tmp_dir
.
name
,
launch_model_path
,
]
)
process
=
subprocess
.
Popen
(
cmd
)
process
.
wait
()
self
.
assertEqual
(
process
.
returncode
,
0
)
tmp_dir
.
cleanup
()
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录