Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
c47853f6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c47853f6
编写于
4月 19, 2023
作者:
Z
zhaoyingli
提交者:
GitHub
4月 19, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[AutoParallel] add gradient_merge master_grad & 1F1B pass (#52647)
上级
6f3c9643
变更
18
展开全部
显示空白变更内容
内联
并排
Showing
18 changed file
with
1532 addition
and
648 deletion
+1532
-648
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
+10
-32
python/paddle/distributed/auto_parallel/constants.py
python/paddle/distributed/auto_parallel/constants.py
+11
-2
python/paddle/distributed/auto_parallel/dist_loader.py
python/paddle/distributed/auto_parallel/dist_loader.py
+97
-63
python/paddle/distributed/auto_parallel/dist_saver.py
python/paddle/distributed/auto_parallel/dist_saver.py
+33
-18
python/paddle/distributed/auto_parallel/engine.py
python/paddle/distributed/auto_parallel/engine.py
+36
-27
python/paddle/distributed/auto_parallel/parallelizer_v2.py
python/paddle/distributed/auto_parallel/parallelizer_v2.py
+9
-7
python/paddle/distributed/auto_parallel/partitioner.py
python/paddle/distributed/auto_parallel/partitioner.py
+234
-139
python/paddle/distributed/auto_parallel/reshard.py
python/paddle/distributed/auto_parallel/reshard.py
+68
-14
python/paddle/distributed/auto_parallel/strategy.py
python/paddle/distributed/auto_parallel/strategy.py
+9
-0
python/paddle/distributed/auto_parallel/utils.py
python/paddle/distributed/auto_parallel/utils.py
+7
-0
python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py
...ibuted/passes/auto_parallel_data_parallel_optimization.py
+175
-104
python/paddle/distributed/passes/auto_parallel_gradient_merge.py
...paddle/distributed/passes/auto_parallel_gradient_merge.py
+309
-153
python/paddle/distributed/passes/auto_parallel_pipeline.py
python/paddle/distributed/passes/auto_parallel_pipeline.py
+325
-88
python/paddle/fluid/tests/unittests/auto_parallel/1F1B_pass_unittest.py
...fluid/tests/unittests/auto_parallel/1F1B_pass_unittest.py
+126
-0
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
...paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
+3
-0
python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py
...ddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py
+7
-0
python/paddle/fluid/tests/unittests/auto_parallel/gradient_merge_pass_unittest.py
...s/unittests/auto_parallel/gradient_merge_pass_unittest.py
+16
-1
python/paddle/fluid/tests/unittests/auto_parallel/test_pass_1F1B.py
...dle/fluid/tests/unittests/auto_parallel/test_pass_1F1B.py
+57
-0
未找到文件。
paddle/fluid/distributed/fleet_executor/fleet_executor.cc
浏览文件 @
c47853f6
...
@@ -110,12 +110,15 @@ void PreventVarsDelete(
...
@@ -110,12 +110,15 @@ void PreventVarsDelete(
std
::
vector
<
std
::
string
>
GetUnusedVarsAfterWhile
(
std
::
vector
<
std
::
string
>
GetUnusedVarsAfterWhile
(
const
framework
::
ProgramDesc
&
program_desc
,
const
framework
::
ProgramDesc
&
program_desc
,
TaskNode
*
cond_task
,
const
std
::
vector
<
std
::
string
>&
vars_not_gc
)
{
const
std
::
vector
<
std
::
string
>&
vars_not_gc
)
{
// NOTE: Since while op won't appear in task node, in order to analyze
// NOTE: Since while op won't appear in task node, in order to analyze
// the vars which should be free after calling while op, we rebuild the
// the vars which should be free after calling while op, we rebuild the
// whole program and get the unused vars after calling while op.
// whole program and get the unused vars after calling while op.
// vars in parent block should not be free until the while op is finished.
// The vars in while block should not be free until the while op is finished.
// The local vars will be free while running op in sub block.
// In a word, the vars need to be free after while op is:
// 1. Vars in parent block and being used in while block.
// 2. Local vars only defined in while block.
// The unused vars above will be free in cond interceptor.
// The unused vars above will be free in cond interceptor.
std
::
vector
<
std
::
string
>
while_block_vars
;
std
::
vector
<
std
::
string
>
while_block_vars
;
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
ops
;
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
ops
;
...
@@ -129,27 +132,12 @@ std::vector<std::string> GetUnusedVarsAfterWhile(
...
@@ -129,27 +132,12 @@ std::vector<std::string> GetUnusedVarsAfterWhile(
for
(
const
auto
&
var_name
:
pair
.
second
)
{
for
(
const
auto
&
var_name
:
pair
.
second
)
{
while_block_vars
.
emplace_back
(
var_name
);
while_block_vars
.
emplace_back
(
var_name
);
}
}
for
(
auto
&
var
:
program_desc
.
Block
(
1
).
AllVars
())
{
while_block_vars
.
emplace_back
(
var
->
Name
());
}
}
}
}
return
while_block_vars
;
}
std
::
unordered_map
<
const
framework
::
OperatorBase
*
,
std
::
vector
<
std
::
string
>>
GetSubUnusedVars
(
const
framework
::
ProgramDesc
&
program_desc
,
const
std
::
set
<
TaskNode
*>&
sub_block_tasks
,
const
std
::
vector
<
std
::
string
>&
vars_not_gc
)
{
std
::
vector
<
std
::
unique_ptr
<
framework
::
OperatorBase
>>
ops
;
for
(
auto
*
task_node
:
sub_block_tasks
)
{
for
(
const
auto
&
op
:
task_node
->
ops
())
{
ops
.
emplace_back
(
std
::
unique_ptr
<
framework
::
OperatorBase
>
(
op
));
}
}
}
auto
unused_vars
=
framework
::
GetUnusedVars
(
program_desc
.
Block
(
1
),
ops
,
{});
return
while_block_vars
;
for
(
auto
&
unique_op
:
ops
)
{
unique_op
.
release
();
}
PreventVarsDelete
(
&
unused_vars
,
vars_not_gc
);
return
unused_vars
;
}
}
}
// namespace
}
// namespace
...
@@ -174,13 +162,8 @@ void FleetExecutor::Init(
...
@@ -174,13 +162,8 @@ void FleetExecutor::Init(
for
(
const
auto
&
task_node
:
task_nodes
)
{
for
(
const
auto
&
task_node
:
task_nodes
)
{
if
(
task_node
->
type
()
==
"Cond"
)
{
if
(
task_node
->
type
()
==
"Cond"
)
{
GetSubBlockTask
(
task_nodes
,
task_node
,
&
sub_block_tasks
);
GetSubBlockTask
(
task_nodes
,
task_node
,
&
sub_block_tasks
);
while_block_vars
=
while_block_vars
=
GetUnusedVarsAfterWhile
(
GetUnusedVarsAfterWhile
(
program_desc
,
inference_root_scope_vars
);
program_desc
,
task_node
,
inference_root_scope_vars
);
for
(
auto
*
task_node
:
sub_block_tasks
)
{
for
(
auto
iter
:
task_node
->
vars_to_dtype
())
{
while_block_vars
.
emplace_back
(
iter
.
first
);
}
}
VLOG
(
3
)
<<
"Vars will be gced after while op"
;
VLOG
(
3
)
<<
"Vars will be gced after while op"
;
for
(
auto
var
:
while_block_vars
)
{
for
(
auto
var
:
while_block_vars
)
{
VLOG
(
3
)
<<
var
;
VLOG
(
3
)
<<
var
;
...
@@ -210,9 +193,6 @@ void FleetExecutor::Init(
...
@@ -210,9 +193,6 @@ void FleetExecutor::Init(
unique_op
.
release
();
unique_op
.
release
();
}
}
auto
sub_unused_vars
=
GetSubUnusedVars
(
program_desc
,
sub_block_tasks
,
while_block_vars
);
// NOTE: For inference, the vars in inference_root_scope_vars
// NOTE: For inference, the vars in inference_root_scope_vars
// shouldn't be deleted during inf, for that they may be the result of the
// shouldn't be deleted during inf, for that they may be the result of the
// inf. If they are GCed, it will cause error during ZeroCopy the result.
// inf. If they are GCed, it will cause error during ZeroCopy the result.
...
@@ -223,8 +203,6 @@ void FleetExecutor::Init(
...
@@ -223,8 +203,6 @@ void FleetExecutor::Init(
for
(
auto
task_node
:
task_nodes
)
{
for
(
auto
task_node
:
task_nodes
)
{
if
(
sub_block_tasks
.
find
(
task_node
)
==
sub_block_tasks
.
end
())
{
if
(
sub_block_tasks
.
find
(
task_node
)
==
sub_block_tasks
.
end
())
{
task_node
->
SetUnusedVars
(
global_unused_vars
);
task_node
->
SetUnusedVars
(
global_unused_vars
);
}
else
{
task_node
->
SetUnusedVars
(
sub_unused_vars
);
}
}
int64_t
interceptor_id
=
task_node
->
task_id
();
int64_t
interceptor_id
=
task_node
->
task_id
();
interceptor_id_to_task
.
emplace
(
interceptor_id
,
task_node
);
interceptor_id_to_task
.
emplace
(
interceptor_id
,
task_node
);
...
...
python/paddle/distributed/auto_parallel/constants.py
浏览文件 @
c47853f6
...
@@ -117,9 +117,9 @@ set_field_default_config(QAT, "activation_bits", 8)
...
@@ -117,9 +117,9 @@ set_field_default_config(QAT, "activation_bits", 8)
set_field_default_config
(
QAT
,
"not_quant_pattern"
,
[
'skip_quant'
])
set_field_default_config
(
QAT
,
"not_quant_pattern"
,
[
'skip_quant'
])
set_field_default_config
(
QAT
,
"algo"
,
None
)
set_field_default_config
(
QAT
,
"algo"
,
None
)
#
#
########################################
#########################################
# auto tuning configuration
# auto tuning configuration
#
#
########################################
#########################################
TUNING
=
"tuning"
TUNING
=
"tuning"
set_field_default_config
(
TUNING
,
"enable"
,
False
)
set_field_default_config
(
TUNING
,
"enable"
,
False
)
set_field_default_config
(
TUNING
,
"batch_size"
,
1
)
set_field_default_config
(
TUNING
,
"batch_size"
,
1
)
...
@@ -135,3 +135,12 @@ set_field_default_config(TUNING, "verbose", True)
...
@@ -135,3 +135,12 @@ set_field_default_config(TUNING, "verbose", True)
DATASET
=
"dataset"
DATASET
=
"dataset"
set_field_default_config
(
DATASET
,
"enable"
,
False
)
set_field_default_config
(
DATASET
,
"enable"
,
False
)
set_field_default_config
(
DATASET
,
"num_shards"
,
1
)
set_field_default_config
(
DATASET
,
"num_shards"
,
1
)
#########################################
# data parallel configuration
#########################################
DP_OPTIMIZATION
=
"dp_optimization"
set_field_default_config
(
DP_OPTIMIZATION
,
"enable"
,
False
)
set_field_default_config
(
DP_OPTIMIZATION
,
"fuse_all_reduce_ops"
,
True
)
set_field_default_config
(
DP_OPTIMIZATION
,
"fuse_grad_size_in_MB"
,
32
)
set_field_default_config
(
DP_OPTIMIZATION
,
"overlap_comm_cacl"
,
True
)
python/paddle/distributed/auto_parallel/dist_loader.py
浏览文件 @
c47853f6
...
@@ -17,12 +17,18 @@ import numpy as np
...
@@ -17,12 +17,18 @@ import numpy as np
import
paddle
import
paddle
from
paddle.io
import
BatchSampler
,
IterableDataset
from
paddle.io
import
BatchSampler
,
IterableDataset
from
paddle.fluid.dataloader.batch_sampler
import
_InfiniteIterableSampler
,
DistributedBatchSampler
from
paddle.fluid.dataloader.batch_sampler
import
(
from
paddle.fluid.dataloader.dataloader_iter
import
_DatasetKind
,
default_collate_fn
,
default_convert_fn
_InfiniteIterableSampler
,
DistributedBatchSampler
,
)
from
paddle.fluid.dataloader.dataloader_iter
import
(
_DatasetKind
,
default_collate_fn
,
default_convert_fn
,
)
class
DistributedDataLoaderBase
(
metaclass
=
abc
.
ABCMeta
):
class
DistributedDataLoaderBase
(
metaclass
=
abc
.
ABCMeta
):
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
__iter__
(
self
):
def
__iter__
(
self
):
raise
NotImplementedError
raise
NotImplementedError
...
@@ -43,8 +49,8 @@ class DistributedDataLoaderBase(metaclass=abc.ABCMeta):
...
@@ -43,8 +49,8 @@ class DistributedDataLoaderBase(metaclass=abc.ABCMeta):
class
DistributedDataLoaderFromGenerator
(
DistributedDataLoaderBase
):
class
DistributedDataLoaderFromGenerator
(
DistributedDataLoaderBase
):
def
__init__
(
def
__init__
(
self
,
self
,
dataset
,
dataset
,
feed_list
=
None
,
feed_list
=
None
,
capacity
=
None
,
capacity
=
None
,
...
@@ -60,7 +66,9 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
...
@@ -60,7 +66,9 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
collate_fn
=
None
,
collate_fn
=
None
,
split_data
=
True
,
split_data
=
True
,
data_parallel_world_size
=
[],
data_parallel_world_size
=
[],
data_parallel_rank
=
[]):
data_parallel_rank
=
[],
acc_steps
=
1
,
):
self
.
dataset
=
dataset
self
.
dataset
=
dataset
self
.
feed_list
=
feed_list
self
.
feed_list
=
feed_list
self
.
capacity
=
capacity
self
.
capacity
=
capacity
...
@@ -79,6 +87,7 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
...
@@ -79,6 +87,7 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
assert
len
(
data_parallel_rank
)
==
len
(
feed_list
)
assert
len
(
data_parallel_rank
)
==
len
(
feed_list
)
self
.
dp_world_sizes
=
data_parallel_world_size
self
.
dp_world_sizes
=
data_parallel_world_size
self
.
dp_ranks
=
data_parallel_rank
self
.
dp_ranks
=
data_parallel_rank
self
.
acc_steps
=
acc_steps
if
isinstance
(
dataset
,
IterableDataset
):
if
isinstance
(
dataset
,
IterableDataset
):
self
.
dataset_kind
=
_DatasetKind
.
ITER
self
.
dataset_kind
=
_DatasetKind
.
ITER
...
@@ -90,12 +99,15 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
...
@@ -90,12 +99,15 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
else
:
else
:
if
isinstance
(
dataset
,
IterableDataset
):
if
isinstance
(
dataset
,
IterableDataset
):
self
.
batch_sampler
=
_InfiniteIterableSampler
(
self
.
batch_sampler
=
_InfiniteIterableSampler
(
dataset
,
batch_size
)
dataset
,
batch_size
)
else
:
else
:
self
.
batch_sampler
=
BatchSampler
(
dataset
,
self
.
batch_sampler
=
BatchSampler
(
dataset
,
batch_size
=
batch_size
,
batch_size
=
batch_size
,
shuffle
=
False
,
shuffle
=
False
,
drop_last
=
drop_last
)
drop_last
=
drop_last
,
)
self
.
auto_collate_batch
=
self
.
batch_sampler
is
not
None
self
.
auto_collate_batch
=
self
.
batch_sampler
is
not
None
self
.
sampler_iter
=
iter
(
self
.
index_sampler
)
self
.
sampler_iter
=
iter
(
self
.
index_sampler
)
...
@@ -106,8 +118,12 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
...
@@ -106,8 +118,12 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
self
.
collate_fn
=
collate_fn
or
default_convert_fn
self
.
collate_fn
=
collate_fn
or
default_convert_fn
self
.
dataset_fetcher
=
_DatasetKind
.
create_fetcher
(
self
.
dataset_fetcher
=
_DatasetKind
.
create_fetcher
(
self
.
dataset_kind
,
self
.
dataset
,
self
.
auto_collate_batch
,
self
.
dataset_kind
,
self
.
collate_fn
,
self
.
drop_last
)
self
.
dataset
,
self
.
auto_collate_batch
,
self
.
collate_fn
,
self
.
drop_last
,
)
self
.
_steps
=
self
.
_infer_steps
()
self
.
_steps
=
self
.
_infer_steps
()
self
.
_inner_dataloader
=
self
.
_create_inner_dataloader
()
self
.
_inner_dataloader
=
self
.
_create_inner_dataloader
()
...
@@ -136,9 +152,11 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
...
@@ -136,9 +152,11 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
if
isinstance
(
self
.
dataset
,
IterableDataset
):
if
isinstance
(
self
.
dataset
,
IterableDataset
):
steps_per_epoch
=
None
steps_per_epoch
=
None
elif
self
.
batch_size
is
None
:
elif
self
.
batch_size
is
None
:
steps_per_epoch
=
len
(
self
.
dataset
)
steps_per_epoch
=
len
(
self
.
dataset
)
//
self
.
acc_steps
else
:
else
:
steps_per_epoch
=
len
(
self
.
dataset
)
//
self
.
batch_size
steps_per_epoch
=
(
len
(
self
.
dataset
)
//
self
.
batch_size
//
self
.
acc_steps
)
except
:
except
:
raise
ValueError
(
raise
ValueError
(
"Pleace set `steps_per_epoch` or implement `__len__` methond in dataset class."
"Pleace set `steps_per_epoch` or implement `__len__` methond in dataset class."
...
@@ -156,18 +174,21 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
...
@@ -156,18 +174,21 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
return
_InfiniteIterableSampler
(
self
.
dataset
,
1
)
return
_InfiniteIterableSampler
(
self
.
dataset
,
1
)
def
_create_inner_dataloader
(
self
):
def
_create_inner_dataloader
(
self
):
def
data_generator
():
def
data_generator
():
while
True
:
while
True
:
try
:
try
:
indices
=
next
(
self
.
sampler_iter
)
indices
=
next
(
self
.
sampler_iter
)
batch
=
self
.
dataset_fetcher
.
fetch
(
indices
)
batch
=
self
.
dataset_fetcher
.
fetch
(
indices
)
if
batch
is
None
:
break
if
batch
is
None
:
break
except
StopIteration
:
except
StopIteration
:
self
.
dataset_fetcher
=
_DatasetKind
.
create_fetcher
(
self
.
dataset_fetcher
=
_DatasetKind
.
create_fetcher
(
self
.
dataset_kind
,
self
.
dataset
,
self
.
dataset_kind
,
self
.
auto_collate_batch
,
self
.
collate_fn
,
self
.
dataset
,
self
.
drop_last
)
self
.
auto_collate_batch
,
self
.
collate_fn
,
self
.
drop_last
,
)
break
break
partial_data
=
[]
partial_data
=
[]
...
@@ -178,11 +199,16 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
...
@@ -178,11 +199,16 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
continue
continue
batch_size
=
array
.
shape
[
0
]
batch_size
=
array
.
shape
[
0
]
assert
batch_size
%
self
.
dp_world_sizes
[
i
]
==
0
,
\
assert
(
"batch_size [{}] is not divisible by dp_world_size [{}]"
.
format
(
str
(
batch_size
),
str
(
self
.
dp_world_sizes
[
i
]))
batch_size
%
self
.
dp_world_sizes
[
i
]
==
0
),
"batch_size [{}] is not divisible by dp_world_size [{}]"
.
format
(
str
(
batch_size
),
str
(
self
.
dp_world_sizes
[
i
])
)
partial_data
.
append
(
partial_data
.
append
(
np
.
split
(
array
,
np
.
split
(
array
,
self
.
dp_world_sizes
[
i
])[
self
.
dp_world_sizes
[
i
])[
self
.
dp_ranks
[
i
]])
self
.
dp_ranks
[
i
]
]
)
yield
partial_data
yield
partial_data
...
@@ -194,15 +220,16 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
...
@@ -194,15 +220,16 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
iterable
=
False
,
iterable
=
False
,
return_list
=
self
.
return_list
,
return_list
=
self
.
return_list
,
use_multiprocess
=
self
.
use_multiprocess
,
use_multiprocess
=
self
.
use_multiprocess
,
drop_last
=
self
.
drop_last
)
drop_last
=
self
.
drop_last
,
)
dataloader
.
set_batch_generator
(
data_generator
,
self
.
places
)
dataloader
.
set_batch_generator
(
data_generator
,
self
.
places
)
return
dataloader
return
dataloader
class
DistributedDataLoader
(
DistributedDataLoaderBase
):
class
DistributedDataLoader
(
DistributedDataLoaderBase
):
def
__init__
(
def
__init__
(
self
,
self
,
dataset
,
dataset
,
feed_list
=
None
,
feed_list
=
None
,
places
=
None
,
places
=
None
,
...
@@ -220,7 +247,8 @@ class DistributedDataLoader(DistributedDataLoaderBase):
...
@@ -220,7 +247,8 @@ class DistributedDataLoader(DistributedDataLoaderBase):
steps_per_epoch
=
None
,
steps_per_epoch
=
None
,
split_data
=
True
,
split_data
=
True
,
data_parallel_world_size
=
[],
data_parallel_world_size
=
[],
data_parallel_rank
=
[]):
data_parallel_rank
=
[],
):
self
.
dataset
=
dataset
self
.
dataset
=
dataset
self
.
feed_list
=
feed_list
self
.
feed_list
=
feed_list
self
.
return_list
=
return_list
self
.
return_list
=
return_list
...
@@ -241,8 +269,13 @@ class DistributedDataLoader(DistributedDataLoaderBase):
...
@@ -241,8 +269,13 @@ class DistributedDataLoader(DistributedDataLoaderBase):
self
.
split_data
=
split_data
self
.
split_data
=
split_data
# TODO: rank info
# TODO: rank info
self
.
batch_sampler
=
DistributedBatchSampler
(
self
.
batch_sampler
=
DistributedBatchSampler
(
self
.
dataset
,
self
.
batch_size
,
self
.
dp_world_sizes
[
0
],
self
.
dataset
,
self
.
dp_ranks
[
0
],
self
.
shuffle
,
self
.
drop_last
)
self
.
batch_size
,
self
.
dp_world_sizes
[
0
],
self
.
dp_ranks
[
0
],
self
.
shuffle
,
self
.
drop_last
,
)
self
.
_inner_dataloader
=
self
.
_create_inner_dataloader
()
self
.
_inner_dataloader
=
self
.
_create_inner_dataloader
()
def
__iter__
(
self
):
def
__iter__
(
self
):
...
@@ -263,7 +296,8 @@ class DistributedDataLoader(DistributedDataLoaderBase):
...
@@ -263,7 +296,8 @@ class DistributedDataLoader(DistributedDataLoaderBase):
use_buffer_reader
=
self
.
use_buffer_reader
,
use_buffer_reader
=
self
.
use_buffer_reader
,
use_shared_memory
=
self
.
use_shared_memory
,
use_shared_memory
=
self
.
use_shared_memory
,
timeout
=
self
.
timeout
,
timeout
=
self
.
timeout
,
worker_init_fn
=
self
.
worker_init_fn
)
worker_init_fn
=
self
.
worker_init_fn
,
)
self
.
data
=
(
x
for
x
in
dataloader
)
self
.
data
=
(
x
for
x
in
dataloader
)
return
dataloader
return
dataloader
python/paddle/distributed/auto_parallel/dist_saver.py
浏览文件 @
c47853f6
...
@@ -18,6 +18,7 @@ import errno
...
@@ -18,6 +18,7 @@ import errno
import
pickle
import
pickle
import
warnings
import
warnings
import
logging
import
logging
import
collections
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
...
@@ -53,16 +54,13 @@ def _process_path(path):
...
@@ -53,16 +54,13 @@ def _process_path(path):
class
DistributedSaver
:
class
DistributedSaver
:
def
__init__
(
self
):
def
__init__
(
self
):
self
.
_logger
=
get_logger
(
logging
.
INFO
)
self
.
_logger
=
get_logger
(
logging
.
INFO
)
def
save
(
self
,
path
,
serial_program
,
dist_main_program
,
dist_context
):
def
save
(
self
,
path
,
serial_program
,
dist_main_program
,
dist_context
):
def
_save_state
(
program
,
path
,
mode
=
"param"
):
def
_save_state
(
program
,
path
,
mode
=
"param"
):
state
=
{
state
=
{
k
:
np
.
array
(
v
)
k
:
np
.
array
(
v
)
for
k
,
v
in
program
.
state_dict
(
mode
).
items
()
for
k
,
v
in
program
.
state_dict
(
mode
).
items
()
}
}
with
open
(
path
,
"wb"
)
as
f
:
with
open
(
path
,
"wb"
)
as
f
:
pickle
.
dump
(
state
,
f
)
pickle
.
dump
(
state
,
f
)
...
@@ -108,8 +106,9 @@ class DistributedSaver:
...
@@ -108,8 +106,9 @@ class DistributedSaver:
def
_load_file
(
filename
,
dirname
,
suffix
=
"pdparams"
):
def
_load_file
(
filename
,
dirname
,
suffix
=
"pdparams"
):
file_list
=
[]
file_list
=
[]
for
file
in
os
.
listdir
(
dirname
):
for
file
in
os
.
listdir
(
dirname
):
if
check_filename
(
'{}(.*)_dist(.*).{}'
.
format
(
filename
,
suffix
),
if
check_filename
(
file
):
'{}(.*)_dist(.*).{}'
.
format
(
filename
,
suffix
),
file
):
file_list
.
append
(
os
.
path
.
join
(
dirname
,
file
))
file_list
.
append
(
os
.
path
.
join
(
dirname
,
file
))
file_list
.
sort
()
file_list
.
sort
()
return
file_list
return
file_list
...
@@ -137,14 +136,16 @@ class DistributedSaver:
...
@@ -137,14 +136,16 @@ class DistributedSaver:
# load path.pdparam and path.pdopt
# load path.pdparam and path.pdopt
param_state_dict
=
_load_state
(
filename
,
dirname
)
param_state_dict
=
_load_state
(
filename
,
dirname
)
opt_state_dict
=
_load_state
(
filename
,
dirname
,
opt_state_dict
=
(
"pdopt"
)
if
load_optimizer
else
{}
_load_state
(
filename
,
dirname
,
"pdopt"
)
if
load_optimizer
else
{}
)
state_dict
=
dict
(
param_state_dict
,
**
opt_state_dict
)
state_dict
=
dict
(
param_state_dict
,
**
opt_state_dict
)
# load path.pdattr
# load path.pdattr
dist_attr_file_list
=
_load_file
(
filename
,
dirname
,
"pdattr"
)
dist_attr_file_list
=
_load_file
(
filename
,
dirname
,
"pdattr"
)
self
.
_logger
.
info
(
self
.
_logger
.
info
(
"Load distributed attribute file: {}"
.
format
(
dist_attr_file_list
))
"Load distributed attribute file: {}"
.
format
(
dist_attr_file_list
)
)
dist_attr
=
{}
dist_attr
=
{}
for
dist_attr_file
in
dist_attr_file_list
:
for
dist_attr_file
in
dist_attr_file_list
:
with
open
(
dist_attr_file
,
'rb'
)
as
f
:
with
open
(
dist_attr_file
,
'rb'
)
as
f
:
...
@@ -196,12 +197,24 @@ class DistributedSaver:
...
@@ -196,12 +197,24 @@ class DistributedSaver:
used_inputs
+=
op
.
input_arg_names
used_inputs
+=
op
.
input_arg_names
used_outputs
+=
op
.
output_arg_names
used_outputs
+=
op
.
output_arg_names
dist_feed_vars_names
=
list
(
set
(
feed_vars_names
)
&
set
(
used_inputs
))
# delete duplicated elements and keep order
dist_fetch_vars_names
=
list
(
set
(
fetch_vars_names
)
&
set
(
used_outputs
))
feed_vars_names
=
list
({}.
fromkeys
(
feed_vars_names
).
keys
())
used_inputs
=
list
({}.
fromkeys
(
used_inputs
).
keys
())
fetch_vars_names
=
list
({}.
fromkeys
(
fetch_vars_names
).
keys
())
used_outputs
=
list
({}.
fromkeys
(
used_outputs
).
keys
())
dist_feed_vars
=
[
dist_feed_vars
_names
=
[
global_block
.
vars
[
name
]
for
name
in
dist_feed_vars_name
s
var_name
for
var_name
in
feed_vars_names
if
var_name
in
used_input
s
]
]
dist_fetch_vars_names
=
[
var_name
for
var_name
in
fetch_vars_names
if
var_name
in
used_outputs
]
dist_feed_vars
=
list
(
reversed
([
global_block
.
vars
[
name
]
for
name
in
dist_feed_vars_names
])
)
dist_fetch_vars
=
[
dist_fetch_vars
=
[
global_block
.
vars
[
name
]
for
name
in
dist_fetch_vars_names
global_block
.
vars
[
name
]
for
name
in
dist_fetch_vars_names
]
]
...
@@ -209,11 +222,13 @@ class DistributedSaver:
...
@@ -209,11 +222,13 @@ class DistributedSaver:
# NOTE: `paddle.static.save_inference_model` does not support subblock.
# NOTE: `paddle.static.save_inference_model` does not support subblock.
dist_filename
=
filename
+
"_dist"
+
str
(
rank_id
)
dist_filename
=
filename
+
"_dist"
+
str
(
rank_id
)
dist_path
=
os
.
path
.
join
(
dirname
,
dist_filename
)
dist_path
=
os
.
path
.
join
(
dirname
,
dist_filename
)
paddle
.
static
.
save_inference_model
(
dist_path
,
paddle
.
static
.
save_inference_model
(
dist_path
,
dist_feed_vars
,
dist_feed_vars
,
dist_fetch_vars
,
dist_fetch_vars
,
exe
,
exe
,
program
=
dist_main_prog
)
program
=
dist_main_prog
,
)
def
_save_rank_mapping
(
self
,
dirname
):
def
_save_rank_mapping
(
self
,
dirname
):
path
=
os
.
path
.
join
(
dirname
,
'rank_mapping.csv'
)
path
=
os
.
path
.
join
(
dirname
,
'rank_mapping.csv'
)
...
...
python/paddle/distributed/auto_parallel/engine.py
浏览文件 @
c47853f6
...
@@ -225,6 +225,11 @@ class Engine:
...
@@ -225,6 +225,11 @@ class Engine:
self
.
_planned_mode
=
None
self
.
_planned_mode
=
None
self
.
_dygraph_mode
=
False
self
.
_dygraph_mode
=
False
self
.
_tuning
=
self
.
_strategy
.
tuning
self
.
_tuning
=
self
.
_strategy
.
tuning
self
.
_acc_steps
=
1
if
self
.
_strategy
.
gradient_merge
.
enable
:
self
.
_acc_steps
=
self
.
_strategy
.
gradient_merge
.
k_steps
elif
self
.
_strategy
.
pipeline
.
enable
:
self
.
_acc_steps
=
self
.
_strategy
.
pipeline
.
accumulate_steps
self
.
history
=
None
self
.
history
=
None
...
@@ -388,7 +393,12 @@ class Engine:
...
@@ -388,7 +393,12 @@ class Engine:
if
self
.
main_program
.
_pipeline_opt
:
if
self
.
main_program
.
_pipeline_opt
:
assert
"tasks"
in
self
.
main_program
.
_pipeline_opt
[
"fleet_opt"
]
assert
"tasks"
in
self
.
main_program
.
_pipeline_opt
[
"fleet_opt"
]
fleet_opt
=
self
.
main_program
.
_pipeline_opt
[
"fleet_opt"
]
fleet_opt
=
self
.
main_program
.
_pipeline_opt
[
"fleet_opt"
]
fwd_task
=
None
if
self
.
_strategy
.
pipeline
.
schedule_mode
==
"1F1B"
:
fwd_task
=
fleet_opt
[
"tasks"
][
1
]
elif
self
.
_strategy
.
pipeline
.
schedule_mode
==
"stream"
:
fwd_task
=
fleet_opt
[
"tasks"
][
0
]
fwd_task
=
fleet_opt
[
"tasks"
][
0
]
assert
fwd_task
is
not
None
fwd_prog
=
fwd_task
.
get_program
()
fwd_prog
=
fwd_task
.
get_program
()
fwd_block
=
fwd_prog
.
global_block
()
fwd_block
=
fwd_prog
.
global_block
()
...
@@ -438,8 +448,6 @@ class Engine:
...
@@ -438,8 +448,6 @@ class Engine:
),
"user_fetches must be a list, but receive {}"
.
format
(
),
"user_fetches must be a list, but receive {}"
.
format
(
type
(
user_fetches
).
__name__
type
(
user_fetches
).
__name__
)
)
else
:
user_fetches
=
[]
fetch_names
=
[]
fetch_names
=
[]
fetch_indices
=
[]
fetch_indices
=
[]
...
@@ -466,7 +474,7 @@ class Engine:
...
@@ -466,7 +474,7 @@ class Engine:
_process_fetch_group
(
"metrics_"
+
str
(
i
),
var_list
)
_process_fetch_group
(
"metrics_"
+
str
(
i
),
var_list
)
if
mode
==
"predict"
:
if
mode
==
"predict"
:
_process_fetch_group
(
"outputs"
,
fetch_vars
[
"outputs"
])
_process_fetch_group
(
"outputs"
,
fetch_vars
[
"outputs"
])
for
usr_fetch
in
user_fetches
:
for
usr_fetch
in
user_fetches
or
[]
:
var_name
=
_to_name_str
(
usr_fetch
)
var_name
=
_to_name_str
(
usr_fetch
)
fetch
(
var_name
)
fetch
(
var_name
)
user_fetches_collection
=
[
user_fetches_collection
=
[
...
@@ -903,6 +911,7 @@ class Engine:
...
@@ -903,6 +911,7 @@ class Engine:
self
.
_inputs_spec
,
self
.
_labels_spec
=
self
.
_prepare_data_spec
(
self
.
_inputs_spec
,
self
.
_labels_spec
=
self
.
_prepare_data_spec
(
train_data
,
train_sample_split
,
batch_size
train_data
,
train_sample_split
,
batch_size
)
)
batch_size
=
self
.
_validate_batch_size
(
batch_size
)
if
not
self
.
_has_prepared
[
self
.
_mode
]:
if
not
self
.
_has_prepared
[
self
.
_mode
]:
self
.
_prepare_program
(
self
.
_mode
)
self
.
_prepare_program
(
self
.
_mode
)
else
:
else
:
...
@@ -931,7 +940,7 @@ class Engine:
...
@@ -931,7 +940,7 @@ class Engine:
save_dir
=
save_dir
,
save_dir
=
save_dir
,
verbose
=
verbose
,
verbose
=
verbose
,
metrics
=
self
.
_metrics_name
(),
metrics
=
self
.
_metrics_name
(),
acc_step
=
self
.
_
k
_steps
,
acc_step
=
self
.
_
acc
_steps
,
)
)
cbks
.
on_begin
(
'train'
)
cbks
.
on_begin
(
'train'
)
...
@@ -965,7 +974,7 @@ class Engine:
...
@@ -965,7 +974,7 @@ class Engine:
val_logs
=
self
.
evaluate
(
val_logs
=
self
.
evaluate
(
valid_data
,
valid_data
,
valid_sample_split
,
valid_sample_split
,
batch_size
,
batch_size
*
self
.
_acc_steps
,
valid_steps
,
valid_steps
,
log_freq
,
log_freq
,
collate_fn
,
collate_fn
,
...
@@ -1046,6 +1055,7 @@ class Engine:
...
@@ -1046,6 +1055,7 @@ class Engine:
self
.
_inputs_spec
,
self
.
_labels_spec
=
self
.
_prepare_data_spec
(
self
.
_inputs_spec
,
self
.
_labels_spec
=
self
.
_prepare_data_spec
(
valid_data
,
valid_sample_split
,
batch_size
valid_data
,
valid_sample_split
,
batch_size
)
)
batch_size
=
self
.
_validate_batch_size
(
batch_size
)
if
not
self
.
_has_prepared
[
self
.
_mode
]:
if
not
self
.
_has_prepared
[
self
.
_mode
]:
self
.
_prepare_program
(
self
.
_mode
)
self
.
_prepare_program
(
self
.
_mode
)
else
:
else
:
...
@@ -1152,6 +1162,7 @@ class Engine:
...
@@ -1152,6 +1162,7 @@ class Engine:
self
.
_inputs_spec
,
self
.
_labels_spec
=
self
.
_prepare_data_spec
(
self
.
_inputs_spec
,
self
.
_labels_spec
=
self
.
_prepare_data_spec
(
test_data
,
test_sample_split
,
batch_size
test_data
,
test_sample_split
,
batch_size
)
)
batch_size
=
self
.
_validate_batch_size
(
batch_size
)
if
not
self
.
_has_prepared
[
self
.
_mode
]:
if
not
self
.
_has_prepared
[
self
.
_mode
]:
self
.
_prepare_program
(
self
.
_mode
)
self
.
_prepare_program
(
self
.
_mode
)
else
:
else
:
...
@@ -1214,6 +1225,7 @@ class Engine:
...
@@ -1214,6 +1225,7 @@ class Engine:
self
.
_inputs_spec
,
self
.
_labels_spec
=
self
.
_prepare_data_spec
(
self
.
_inputs_spec
,
self
.
_labels_spec
=
self
.
_prepare_data_spec
(
dataset
,
sample_split
,
batch_size
dataset
,
sample_split
,
batch_size
)
)
batch_size
=
self
.
_validate_batch_size
(
batch_size
)
if
not
self
.
_has_prepared
[
self
.
_mode
]:
if
not
self
.
_has_prepared
[
self
.
_mode
]:
self
.
_prepare_program
(
self
.
_mode
)
self
.
_prepare_program
(
self
.
_mode
)
else
:
else
:
...
@@ -1256,6 +1268,7 @@ class Engine:
...
@@ -1256,6 +1268,7 @@ class Engine:
self
.
_inputs_spec
,
self
.
_labels_spec
=
self
.
_prepare_data_spec
(
self
.
_inputs_spec
,
self
.
_labels_spec
=
self
.
_prepare_data_spec
(
dataset
,
sample_split
,
batch_size
dataset
,
sample_split
,
batch_size
)
)
batch_size
=
self
.
_validate_batch_size
(
batch_size
)
if
not
self
.
_has_prepared
[
self
.
_mode
]:
if
not
self
.
_has_prepared
[
self
.
_mode
]:
self
.
_prepare_program
(
self
.
_mode
)
self
.
_prepare_program
(
self
.
_mode
)
else
:
else
:
...
@@ -1371,14 +1384,6 @@ class Engine:
...
@@ -1371,14 +1384,6 @@ class Engine:
steps_per_epoch
=
None
,
steps_per_epoch
=
None
,
):
):
if
self
.
_strategy
.
gradient_merge
and
batch_size
is
not
None
:
assert
(
batch_size
%
self
.
_k_steps
==
0
),
"Requires batch_size:[{}] to be divisible by k_steps:[{}]."
.
format
(
batch_size
,
self
.
_k_steps
)
batch_size
//=
self
.
_k_steps
dist_context
=
self
.
_dist_contexts
[
self
.
_mode
]
dist_context
=
self
.
_dist_contexts
[
self
.
_mode
]
dist_main_prog
=
dist_context
.
dist_main_programs
[
self
.
_cur_rank
]
dist_main_prog
=
dist_context
.
dist_main_programs
[
self
.
_cur_rank
]
dist_startup_prog
=
dist_context
.
dist_startup_programs
[
self
.
_cur_rank
]
dist_startup_prog
=
dist_context
.
dist_startup_programs
[
self
.
_cur_rank
]
...
@@ -1440,14 +1445,6 @@ class Engine:
...
@@ -1440,14 +1445,6 @@ class Engine:
collate_fn
=
None
,
collate_fn
=
None
,
):
):
if
self
.
_strategy
.
gradient_merge
and
batch_size
is
not
None
:
assert
(
batch_size
%
self
.
_k_steps
==
0
),
"Requires batch_size:[{}] to be divisible by k_steps:[{}]."
.
format
(
batch_size
,
self
.
_k_steps
)
batch_size
//=
self
.
_k_steps
dist_context
=
self
.
_dist_contexts
[
self
.
_mode
]
dist_context
=
self
.
_dist_contexts
[
self
.
_mode
]
dist_main_prog
=
dist_context
.
dist_main_programs
[
self
.
_cur_rank
]
dist_main_prog
=
dist_context
.
dist_main_programs
[
self
.
_cur_rank
]
dist_startup_prog
=
dist_context
.
dist_startup_programs
[
self
.
_cur_rank
]
dist_startup_prog
=
dist_context
.
dist_startup_programs
[
self
.
_cur_rank
]
...
@@ -1487,6 +1484,9 @@ class Engine:
...
@@ -1487,6 +1484,9 @@ class Engine:
split_data
=
self
.
_strategy
.
split_data
,
split_data
=
self
.
_strategy
.
split_data
,
data_parallel_world_size
=
self
.
_dp_world_sizes
,
data_parallel_world_size
=
self
.
_dp_world_sizes
,
data_parallel_rank
=
self
.
_dp_ranks
,
data_parallel_rank
=
self
.
_dp_ranks
,
acc_steps
=
1
if
not
self
.
_strategy
.
pipeline
.
enable
else
self
.
_acc_steps
,
)
)
self
.
_prepare_reader
(
feed_list
)
self
.
_prepare_reader
(
feed_list
)
return
dataloader
return
dataloader
...
@@ -1498,9 +1498,18 @@ class Engine:
...
@@ -1498,9 +1498,18 @@ class Engine:
)
)
self
.
_optimization_tuning
(
self
.
_mode
,
tune_data
,
batch_size
)
self
.
_optimization_tuning
(
self
.
_mode
,
tune_data
,
batch_size
)
def
_validate_batch_size
(
self
,
batch_size
):
if
batch_size
is
None
:
return
None
assert
(
batch_size
%
self
.
_acc_steps
==
0
),
"Requires batch_size:[{}] to be divisible by acc_steps:[{}]."
.
format
(
batch_size
,
self
.
_acc_steps
)
return
batch_size
//
self
.
_acc_steps
def
_validate_spec
(
self
,
specs
):
def
_validate_spec
(
self
,
specs
):
specs
=
to_list
(
specs
)
specs
=
to_list
(
specs
)
self
.
_k_steps
=
self
.
_strategy
.
gradient_merge
.
k_steps
if
specs
is
not
None
:
if
specs
is
not
None
:
for
i
,
spec
in
enumerate
(
specs
):
for
i
,
spec
in
enumerate
(
specs
):
if
not
isinstance
(
spec
,
InputSpec
):
if
not
isinstance
(
spec
,
InputSpec
):
...
@@ -1513,14 +1522,14 @@ class Engine:
...
@@ -1513,14 +1522,14 @@ class Engine:
i
,
spec
i
,
spec
)
)
)
)
if
self
.
_
k
_steps
>
1
:
if
self
.
_
acc
_steps
>
1
:
shape
=
list
(
spec
.
shape
)
shape
=
list
(
spec
.
shape
)
assert
(
assert
(
shape
[
0
]
%
self
.
_
k
_steps
==
0
shape
[
0
]
%
self
.
_
acc
_steps
==
0
),
"Requires batch_size[{}] to be divisible by k_steps[{}]."
.
format
(
),
"Requires batch_size[{}] to be divisible by k_steps[{}]."
.
format
(
spec
.
shape
[
0
],
self
.
_
k
_steps
spec
.
shape
[
0
],
self
.
_
acc
_steps
)
)
shape
[
0
]
//=
self
.
_
k
_steps
shape
[
0
]
//=
self
.
_
acc
_steps
spec
.
shape
=
shape
spec
.
shape
=
shape
return
specs
or
[]
return
specs
or
[]
...
...
python/paddle/distributed/auto_parallel/parallelizer_v2.py
浏览文件 @
c47853f6
...
@@ -297,12 +297,14 @@ class Parallelizer:
...
@@ -297,12 +297,14 @@ class Parallelizer:
if
self
.
_strategy
is
None
:
if
self
.
_strategy
is
None
:
return
return
# data parallel optimization
if
self
.
_strategy
.
dp_optimization
.
enable
:
config
=
{}
config
=
copy
.
deepcopy
(
self
.
_strategy
.
dp_optimization
.
to_dict
())
config
[
"dist_context"
]
=
self
.
_dist_context
config
[
"dist_context"
]
=
self
.
_dist_context
config
[
"global_rank"
]
=
rank
config
[
"global_rank"
]
=
rank
config
[
"use_sharding"
]
=
self
.
_strategy
.
sharding
.
enable
config
[
"use_sharding"
]
=
self
.
_strategy
.
sharding
.
enable
dp_pass
=
new_pass
(
"auto_parallel_data_parallel_optimization"
,
config
)
dp_pass
=
new_pass
(
"auto_parallel_data_parallel_optimization"
,
config
)
dp_pass
.
apply
([
main_program
],
[
startup_program
],
self
.
_pass_context
)
dp_pass
.
apply
([
main_program
],
[
startup_program
],
self
.
_pass_context
)
if
self
.
_strategy
.
sharding
.
enable
:
if
self
.
_strategy
.
sharding
.
enable
:
...
...
python/paddle/distributed/auto_parallel/partitioner.py
浏览文件 @
c47853f6
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/reshard.py
浏览文件 @
c47853f6
...
@@ -422,11 +422,11 @@ class Inserter:
...
@@ -422,11 +422,11 @@ class Inserter:
)
)
inputs
=
{
'X'
:
[
tensor
]}
inputs
=
{
'X'
:
[
tensor
]}
outputs
=
{
"Out"
:
[
out
]}
outputs
=
{
"Out"
:
[
out
]}
attrs
=
{
"in_place"
:
False
}
attrs
=
{
"in_place"
:
False
,
"op_role"
:
op_role
}
slice
_op
=
block
.
_insert_op
(
assign
_op
=
block
.
_insert_op
(
idx
,
type
=
"assign"
,
inputs
=
inputs
,
outputs
=
outputs
,
attrs
=
attrs
idx
,
type
=
"assign"
,
inputs
=
inputs
,
outputs
=
outputs
,
attrs
=
attrs
)
)
slice
_op
.
_set_attr
(
'op_namescope'
,
"/auto_parallel/reshard"
)
assign
_op
.
_set_attr
(
'op_namescope'
,
"/auto_parallel/reshard"
)
return
out
return
out
# use split once
# use split once
...
@@ -1217,6 +1217,8 @@ class Resharder:
...
@@ -1217,6 +1217,8 @@ class Resharder:
shape_x
[
0
]
<=
shape_y
[
0
]
<
shape_x
[
1
]
shape_x
[
0
]
<=
shape_y
[
0
]
<
shape_x
[
1
]
):
):
overlapped
=
True
overlapped
=
True
if
shape_x
==
[
0
,
0
]
and
shape_y
==
[
0
,
0
]:
overlapped
=
True
return
overlapped
return
overlapped
def
is_unshard
(
self
,
dims_mapping
):
def
is_unshard
(
self
,
dims_mapping
):
...
@@ -1304,6 +1306,14 @@ class Resharder:
...
@@ -1304,6 +1306,14 @@ class Resharder:
# judge whether need reshard by process_mesh
# judge whether need reshard by process_mesh
if
tensor_process_mesh
!=
op_process_mesh
:
if
tensor_process_mesh
!=
op_process_mesh
:
is_reshard
=
True
is_reshard
=
True
# not reshard data in send/recv scene
if
(
tensor_process_mesh
!=
op_process_mesh
and
len
(
tensor_process_mesh
.
process_ids
)
==
len
(
op_process_mesh
.
process_ids
)
and
dist_tensor
.
serial_tensor
.
is_data
):
is_reshard
=
False
else
:
else
:
op_output_dims_mapping
=
dist_attr
[
1
]
op_output_dims_mapping
=
dist_attr
[
1
]
if
all
(
if
all
(
...
@@ -1585,10 +1595,10 @@ class Resharder:
...
@@ -1585,10 +1595,10 @@ class Resharder:
if
i
==
0
:
if
i
==
0
:
all_partition_index_list
.
append
(
process_index
[
j
][
1
])
all_partition_index_list
.
append
(
process_index
[
j
][
1
])
for
process
in
group
:
for
process
in
group
:
# append slice op desc
min_comm_group
=
copy
.
deepcopy
(
group
)
slice_starts
=
[]
all_partition_index_list_copied
=
copy
.
deepcopy
(
slice_ends
=
[]
all_partition_index_list
slices_axes
=
[]
)
target_partition_index
=
Resharder
.
compute_partition_index
(
target_partition_index
=
Resharder
.
compute_partition_index
(
process
,
process
,
complete_shape
,
complete_shape
,
...
@@ -1596,12 +1606,56 @@ class Resharder:
...
@@ -1596,12 +1606,56 @@ class Resharder:
target_process_shape
,
target_process_shape
,
target_process_group
,
target_process_group
,
)
)
for
idx
,
item
in
enumerate
(
target_partition_index
):
for
_process
in
group
:
slice_starts
.
append
(
item
[
0
])
source_partition_index
=
(
slice_ends
.
append
(
item
[
1
])
Resharder
.
compute_partition_index
(
_process
,
complete_shape
,
source_dims_mapping
,
source_process_shape
,
source_process_group
,
)
)
if
not
all
(
_
for
_
in
list
(
map
(
self
.
is_overlapped
,
source_partition_index
,
target_partition_index
,
)
)
):
min_comm_group
.
remove
(
_process
)
all_partition_index_list_copied
.
remove
(
source_partition_index
)
concatenated_partition_index_list
=
[]
for
partition_index
in
all_partition_index_list_copied
:
Resharder
.
concat_partitions
(
concatenated_partition_index_list
,
partition_index
)
concatenated_partition_index
=
(
concatenated_partition_index_list
[
0
]
)
slice_starts
=
[]
slice_ends
=
[]
slices_axes
=
[]
to_slice_tensor_shape
=
[]
for
idx
,
item
in
enumerate
(
concatenated_partition_index
):
slice_starts
.
append
(
target_partition_index
[
idx
][
0
]
-
item
[
0
]
)
slice_ends
.
append
(
target_partition_index
[
idx
][
1
]
-
item
[
0
]
)
slices_axes
.
append
(
idx
)
slices_axes
.
append
(
idx
)
to_slice_tensor_shape
.
append
(
item
[
1
]
-
item
[
0
])
to_slice_tensor_shape
=
dist_tensor
.
global_sizes
()
slice_op_desc
=
SliceOpDesc
(
slice_op_desc
=
SliceOpDesc
(
starts
=
slice_starts
,
starts
=
slice_starts
,
ends
=
slice_ends
,
ends
=
slice_ends
,
...
@@ -1616,16 +1670,16 @@ class Resharder:
...
@@ -1616,16 +1670,16 @@ class Resharder:
op_desc_seq
[
process
]
=
(
op_desc_seq
[
process
]
=
(
[
[
AllGatherOpDesc
(
AllGatherOpDesc
(
group
=
group
,
group
=
min_comm_
group
,
shape
=
allgather_shape
,
shape
=
allgather_shape
,
is_bool
=
(
source_tensor
.
dtype
==
paddle
.
bool
),
is_bool
=
(
source_tensor
.
dtype
==
paddle
.
bool
),
),
),
ConcatOpDesc
(
ConcatOpDesc
(
partition_index_list
=
all_partition_index_list
partition_index_list
=
all_partition_index_list
_copied
),
),
slice_op_desc
,
slice_op_desc
,
]
]
if
len
(
group
)
>
1
if
len
(
min_comm_
group
)
>
1
else
[
slice_op_desc
]
else
[
slice_op_desc
]
)
)
...
...
python/paddle/distributed/auto_parallel/strategy.py
浏览文件 @
c47853f6
...
@@ -123,6 +123,12 @@ class DatasetConfig(BaseConfig):
...
@@ -123,6 +123,12 @@ class DatasetConfig(BaseConfig):
super
(
DatasetConfig
,
self
).
__init__
(
category
,
config_dict
)
super
(
DatasetConfig
,
self
).
__init__
(
category
,
config_dict
)
class
DPOptimizationConfig
(
BaseConfig
):
def
__init__
(
self
,
config_dict
=
None
):
category
=
constants
.
DP_OPTIMIZATION
super
(
DPOptimizationConfig
,
self
).
__init__
(
category
,
config_dict
)
class
Strategy
(
BaseConfig
):
class
Strategy
(
BaseConfig
):
"""
"""
The `Strategy` object is used to configure the paralleization and optimization beheviors.
The `Strategy` object is used to configure the paralleization and optimization beheviors.
...
@@ -194,3 +200,6 @@ class Strategy(BaseConfig):
...
@@ -194,3 +200,6 @@ class Strategy(BaseConfig):
config_dict
=
self
.
_config_dict
.
get
(
constants
.
DATASET
,
None
)
config_dict
=
self
.
_config_dict
.
get
(
constants
.
DATASET
,
None
)
self
.
dataset
=
DatasetConfig
(
config_dict
)
self
.
dataset
=
DatasetConfig
(
config_dict
)
config_dict
=
self
.
_config_dict
.
get
(
constants
.
DP_OPTIMIZATION
,
None
)
self
.
dp_optimization
=
DPOptimizationConfig
(
config_dict
)
python/paddle/distributed/auto_parallel/utils.py
浏览文件 @
c47853f6
...
@@ -1252,6 +1252,7 @@ def set_grad_var_shape(program, dist_context):
...
@@ -1252,6 +1252,7 @@ def set_grad_var_shape(program, dist_context):
"fused_softmax_mask_upper_triangle_grad"
,
"fused_softmax_mask_upper_triangle_grad"
,
"flatten_contiguous_range_grad"
,
"flatten_contiguous_range_grad"
,
"relu_grad"
,
"relu_grad"
,
"exp_grad"
,
]
]
forward_list
=
[
forward_list
=
[
"reshape2"
,
"reshape2"
,
...
@@ -1270,6 +1271,7 @@ def set_grad_var_shape(program, dist_context):
...
@@ -1270,6 +1271,7 @@ def set_grad_var_shape(program, dist_context):
"fused_softmax_mask_upper_triangle"
,
"fused_softmax_mask_upper_triangle"
,
"flatten_contiguous_range"
,
"flatten_contiguous_range"
,
"relu"
,
"relu"
,
"exp"
,
]
]
if
op
.
type
in
need_set_shape_list
:
if
op
.
type
in
need_set_shape_list
:
for
forward_op
in
block
.
ops
:
for
forward_op
in
block
.
ops
:
...
@@ -1320,6 +1322,11 @@ def is_forward_op(op):
...
@@ -1320,6 +1322,11 @@ def is_forward_op(op):
)
)
def
is_fillconst_op_for_micro_batch
(
op
):
op_role
=
int
(
op
.
attr
(
'op_role'
))
return
OP_ROLE_KEY
in
op
.
attr_names
and
(
op_role
==
int
(
OpRole
.
LRSched
))
def
is_backward_op
(
op
):
def
is_backward_op
(
op
):
return
OP_ROLE_KEY
in
op
.
attr_names
and
int
(
return
OP_ROLE_KEY
in
op
.
attr_names
and
int
(
op
.
all_attrs
()[
OP_ROLE_KEY
]
op
.
all_attrs
()[
OP_ROLE_KEY
]
...
...
python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py
浏览文件 @
c47853f6
...
@@ -18,15 +18,31 @@ import numpy as np
...
@@ -18,15 +18,31 @@ import numpy as np
import
paddle
import
paddle
from
paddle.fluid
import
core
,
unique_name
from
paddle.fluid
import
core
,
unique_name
from
paddle.fluid.framework
import
default_main_program
from
paddle.fluid.framework
import
default_main_program
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
,
OP_ROLE_KEY
,
OP_ROLE_VAR_KEY
from
paddle.distributed.fleet.meta_optimizers.common
import
(
from
paddle.distributed.auto_parallel.operators.common
import
is_data_parallel_scale_op
,
is_data_parallel_reduce_op
OpRole
,
from
paddle.distributed.auto_parallel.utils
import
is_loss_grad_op
,
is_optimize_op
,
is_backward_op
,
ring_id_to_process_group
,
find_higher_order_backward_op
OP_ROLE_KEY
,
OP_ROLE_VAR_KEY
,
)
from
paddle.distributed.auto_parallel.operators.common
import
(
is_data_parallel_scale_op
,
is_data_parallel_reduce_op
,
)
from
paddle.distributed.auto_parallel.utils
import
(
is_loss_grad_op
,
is_optimize_op
,
is_backward_op
,
ring_id_to_process_group
,
find_higher_order_backward_op
,
)
from
.pass_base
import
PassBase
,
PassType
,
register_pass
from
.pass_base
import
PassBase
,
PassType
,
register_pass
# add new optimizers supporting rescale_grad here
# add new optimizers supporting rescale_grad here
__rescale_grad_supported_opts__
=
[
__rescale_grad_supported_opts__
=
[
'lars_momentum'
,
'sparse_momentum'
,
'dgc_momentum'
,
'momentum'
,
'lars_momentum'
,
'merge_momentum'
'sparse_momentum'
,
'dgc_momentum'
,
'momentum'
,
'merge_momentum'
,
]
]
# a heuristic number
# a heuristic number
...
@@ -52,6 +68,9 @@ class DataParallelOptimizationPass(PassBase):
...
@@ -52,6 +68,9 @@ class DataParallelOptimizationPass(PassBase):
self
.
set_attr
(
"dist_context"
,
None
)
self
.
set_attr
(
"dist_context"
,
None
)
self
.
set_attr
(
"global_rank"
,
-
1
)
self
.
set_attr
(
"global_rank"
,
-
1
)
self
.
set_attr
(
"use_sharding"
,
False
)
self
.
set_attr
(
"use_sharding"
,
False
)
self
.
set_attr
(
"fuse_all_reduce_ops"
,
False
)
self
.
set_attr
(
"fuse_grad_size_in_MB"
,
32
)
self
.
set_attr
(
"overlap_comm_cacl"
,
False
)
# {grad1: group1, grad2: group1, grad3: group2}
# {grad1: group1, grad2: group1, grad3: group2}
# record the order for fuse grad data memory
# record the order for fuse grad data memory
self
.
_grad_name_to_group_map
=
OrderedDict
()
self
.
_grad_name_to_group_map
=
OrderedDict
()
...
@@ -62,8 +81,9 @@ class DataParallelOptimizationPass(PassBase):
...
@@ -62,8 +81,9 @@ class DataParallelOptimizationPass(PassBase):
def
_check_self
(
self
):
def
_check_self
(
self
):
if
self
.
get_attr
(
"dist_context"
)
is
None
:
if
self
.
get_attr
(
"dist_context"
)
is
None
:
return
False
return
False
if
(
not
isinstance
(
self
.
get_attr
(
"global_rank"
),
if
(
not
isinstance
(
self
.
get_attr
(
"global_rank"
),
int
))
or
self
.
get_attr
(
int
))
or
self
.
get_attr
(
"global_rank"
)
<
0
:
"global_rank"
)
<
0
:
return
False
return
False
return
True
return
True
...
@@ -80,12 +100,17 @@ class DataParallelOptimizationPass(PassBase):
...
@@ -80,12 +100,17 @@ class DataParallelOptimizationPass(PassBase):
self
.
global_rank
=
int
(
self
.
get_attr
(
"global_rank"
))
self
.
global_rank
=
int
(
self
.
get_attr
(
"global_rank"
))
self
.
use_sharding
=
self
.
get_attr
(
"use_sharding"
)
self
.
use_sharding
=
self
.
get_attr
(
"use_sharding"
)
overlap_comm_cacl
=
self
.
get_attr
(
"overlap_comm_cacl"
)
fuse_all_reduce_ops
=
self
.
get_attr
(
"fuse_all_reduce_ops"
)
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
with
paddle
.
static
.
program_guard
(
main_program
,
startup_program
):
self
.
_analyze_program
()
self
.
_analyze_program
()
if
self
.
is_data_parallel_applied
():
if
self
.
is_data_parallel_applied
():
if
overlap_comm_cacl
:
self
.
_prune_grad_scaling
()
self
.
_prune_grad_scaling
()
self
.
_calc_comm_overlap
()
self
.
_calc_comm_overlap
()
if
fuse_all_reduce_ops
:
grad_group
=
self
.
_fuse_allreduce
()
grad_group
=
self
.
_fuse_allreduce
()
# self.summary(grad_group)
# self.summary(grad_group)
...
@@ -140,8 +165,11 @@ class DataParallelOptimizationPass(PassBase):
...
@@ -140,8 +165,11 @@ class DataParallelOptimizationPass(PassBase):
),
"Unexception: comm op [{}] has NOT ring id."
.
format
(
str
(
op
))
),
"Unexception: comm op [{}] has NOT ring id."
.
format
(
str
(
op
))
group
=
ring_id_to_process_group
(
op
.
attr
(
"ring_id"
))
group
=
ring_id_to_process_group
(
op
.
attr
(
"ring_id"
))
assert
group
is
not
None
,
"Unexception: data parallel group of [{}] from op [{}] is None"
.
format
(
assert
(
grad_name
,
str
(
op
))
group
is
not
None
),
"Unexception: data parallel group of [{}] from op [{}] is None"
.
format
(
grad_name
,
str
(
op
)
)
self
.
_grad_name_to_group_map
[
grad_name
]
=
group
self
.
_grad_name_to_group_map
[
grad_name
]
=
group
...
@@ -156,18 +184,21 @@ class DataParallelOptimizationPass(PassBase):
...
@@ -156,18 +184,21 @@ class DataParallelOptimizationPass(PassBase):
# TODO support multiple optimizers in on network in future.
# TODO support multiple optimizers in on network in future.
# here we assume that the optimizer is unique in network.
# here we assume that the optimizer is unique in network.
elif
is_optimize_op
(
elif
(
op
)
and
op
.
type
in
__rescale_grad_supported_opts__
:
is_optimize_op
(
op
)
and
op
.
type
in
__rescale_grad_supported_opts__
):
self
.
_support_rescale_grad
=
True
self
.
_support_rescale_grad
=
True
not_synchronized_grads
=
[]
not_synchronized_grads
=
[]
for
grad_name
in
scaled_grads
:
for
grad_name
in
scaled_grads
:
if
grad_name
not
in
self
.
_grad_name_to_group_map
:
if
grad_name
not
in
self
.
_grad_name_to_group_map
:
not_synchronized_grads
.
append
(
grad_name
)
not_synchronized_grads
.
append
(
grad_name
)
assert
len
(
assert
(
len
(
not_synchronized_grads
)
==
0
),
"Unexception: gradients [{}] is scaled BUT NOT synchronized."
.
format
(
not_synchronized_grads
not_synchronized_grads
)
==
0
,
"Unexception: gradients [{}] is scaled BUT NOT synchronized."
.
format
(
)
not_synchronized_grads
)
def
is_data_parallel_applied
(
self
):
def
is_data_parallel_applied
(
self
):
return
len
(
self
.
_group_to_grad_name_map
)
>
0
return
len
(
self
.
_group_to_grad_name_map
)
>
0
...
@@ -175,14 +206,21 @@ class DataParallelOptimizationPass(PassBase):
...
@@ -175,14 +206,21 @@ class DataParallelOptimizationPass(PassBase):
def
_could_be_prune
(
self
):
def
_could_be_prune
(
self
):
return
self
.
dist_context
.
gradient_scale
and
(
return
self
.
dist_context
.
gradient_scale
and
(
self
.
_support_rescale_grad
or
self
.
_all_dp_groups_same_degree
())
self
.
_support_rescale_grad
or
self
.
_all_dp_groups_same_degree
()
)
def
_all_dp_groups_same_degree
(
self
):
def
_all_dp_groups_same_degree
(
self
):
return
len
(
return
(
set
([
len
(
set
(
[
len
(
group
.
ranks
)
len
(
group
.
ranks
)
for
group
in
self
.
_group_to_grad_name_map
.
keys
()
for
group
in
self
.
_group_to_grad_name_map
.
keys
()
]))
==
1
]
)
)
==
1
)
def
_scale_backward_initial_grad
(
self
):
def
_scale_backward_initial_grad
(
self
):
...
@@ -191,9 +229,10 @@ class DataParallelOptimizationPass(PassBase):
...
@@ -191,9 +229,10 @@ class DataParallelOptimizationPass(PassBase):
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
is_loss_grad_op
(
op
):
if
is_loss_grad_op
(
op
):
assert
op
.
type
==
'fill_constant'
,
\
assert
op
.
type
==
'fill_constant'
,
(
"loss_grad_op must be fill_constant op, "
\
"loss_grad_op must be fill_constant op, "
"but this op is {}"
.
format
(
op
.
type
)
"but this op is {}"
.
format
(
op
.
type
)
)
assert
op
.
has_attr
(
'value'
)
assert
op
.
has_attr
(
'value'
)
loss_scale
=
float
(
op
.
attr
(
'value'
))
loss_scale
=
float
(
op
.
attr
(
'value'
))
loss_scale
=
loss_scale
/
dp_degree
loss_scale
=
loss_scale
/
dp_degree
...
@@ -215,28 +254,35 @@ class DataParallelOptimizationPass(PassBase):
...
@@ -215,28 +254,35 @@ class DataParallelOptimizationPass(PassBase):
scaled_grads
=
set
()
scaled_grads
=
set
()
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
is_optimize_op
(
if
(
op
)
and
op
.
type
in
__rescale_grad_supported_opts__
:
is_optimize_op
(
op
)
and
op
.
type
in
__rescale_grad_supported_opts__
):
assert
op
.
has_attr
(
assert
op
.
has_attr
(
'rescale_grad'
'rescale_grad'
),
"Unexception: op [{}] is supported to have [rescale_grad] attribute."
.
format
(
),
"Unexception: op [{}] is supported to have [rescale_grad] attribute."
.
format
(
str
(
op
))
str
(
op
)
assert
len
(
)
op
.
input
(
"Grad"
)
assert
(
)
==
1
,
"Unexception: op [{}] is supported to have only one input grad var."
.
format
(
len
(
op
.
input
(
"Grad"
))
==
1
str
(
op
))
),
"Unexception: op [{}] is supported to have only one input grad var."
.
format
(
str
(
op
)
)
grad_name
=
op
.
input
(
"Grad"
)[
0
]
grad_name
=
op
.
input
(
"Grad"
)[
0
]
dp_degree
=
len
(
dp_degree
=
len
(
list
(
self
.
_grad_name_to_group_map
[
grad_name
].
ranks
))
list
(
self
.
_grad_name_to_group_map
[
grad_name
].
ranks
)
)
scaled_grads
.
add
(
grad_name
)
scaled_grads
.
add
(
grad_name
)
rescale_grad
=
float
(
op
.
attr
(
'rescale_grad'
))
/
dp_degree
rescale_grad
=
float
(
op
.
attr
(
'rescale_grad'
))
/
dp_degree
op
.
_set_attr
(
'rescale_grad'
,
rescale_grad
)
op
.
_set_attr
(
'rescale_grad'
,
rescale_grad
)
assert
scaled_grads
==
set
(
self
.
_grad_name_to_group_map
.
keys
(
assert
scaled_grads
==
set
(
)),
"Unexception: gradients [{}] are unscaled."
.
format
(
self
.
_grad_name_to_group_map
.
keys
()
set
(
self
.
_grad_name_to_group_map
.
keys
())
-
scaled_grads
)
),
"Unexception: gradients [{}] are unscaled."
.
format
(
set
(
self
.
_grad_name_to_group_map
.
keys
())
-
scaled_grads
)
def
_could_be_overlap
(
self
):
def
_could_be_overlap
(
self
):
# NOTE current different nccl comm will use different cuda stream
# NOTE current different nccl comm will use different cuda stream
...
@@ -266,14 +312,13 @@ class DataParallelOptimizationPass(PassBase):
...
@@ -266,14 +312,13 @@ class DataParallelOptimizationPass(PassBase):
op
.
_set_attr
(
'use_calc_stream'
,
False
)
op
.
_set_attr
(
'use_calc_stream'
,
False
)
ring_id
=
op
.
attr
(
"ring_id"
)
ring_id
=
op
.
attr
(
"ring_id"
)
block
.
_insert_op_without_sync
(
idx
,
block
.
_insert_op_without_sync
(
idx
,
type
=
'c_wait_compute'
,
type
=
'c_wait_compute'
,
inputs
=
{
'X'
:
[]},
inputs
=
{
'X'
:
[]},
outputs
=
{
'Out'
:
[]},
outputs
=
{
'Out'
:
[]},
attrs
=
{
attrs
=
{
'op_role'
:
OpRole
.
Backward
,
'ring_id'
:
ring_id
},
'op_role'
:
OpRole
.
Backward
,
)
'ring_id'
:
ring_id
})
block
.
_sync_with_cpp
()
block
.
_sync_with_cpp
()
...
@@ -307,8 +352,10 @@ class DataParallelOptimizationPass(PassBase):
...
@@ -307,8 +352,10 @@ class DataParallelOptimizationPass(PassBase):
# other ops that might use communicating grad
# other ops that might use communicating grad
else
:
else
:
for
input_var_name
in
op
.
input_arg_names
:
for
input_var_name
in
op
.
input_arg_names
:
for
ring_id
,
unsync_grad_names
in
ring_id_to_un_sync_grad_map
.
items
(
for
(
):
ring_id
,
unsync_grad_names
,
)
in
ring_id_to_un_sync_grad_map
.
items
():
if
input_var_name
in
unsync_grad_names
:
if
input_var_name
in
unsync_grad_names
:
# need to sync before op_i
# need to sync before op_i
if
i
in
op_idx_to_sync_ring_id_map
:
if
i
in
op_idx_to_sync_ring_id_map
:
...
@@ -328,14 +375,13 @@ class DataParallelOptimizationPass(PassBase):
...
@@ -328,14 +375,13 @@ class DataParallelOptimizationPass(PassBase):
for
i
in
sorted
(
indices
,
reverse
=
True
):
for
i
in
sorted
(
indices
,
reverse
=
True
):
for
ring_id
in
op_idx_to_sync_ring_id_map
[
i
]:
for
ring_id
in
op_idx_to_sync_ring_id_map
[
i
]:
block
.
_insert_op_without_sync
(
i
,
block
.
_insert_op_without_sync
(
i
,
type
=
'c_wait_comm'
,
type
=
'c_wait_comm'
,
inputs
=
{
'X'
:
[]},
inputs
=
{
'X'
:
[]},
outputs
=
{
'Out'
:
[]},
outputs
=
{
'Out'
:
[]},
attrs
=
{
attrs
=
{
'op_role'
:
OpRole
.
Backward
,
'ring_id'
:
ring_id
},
'op_role'
:
OpRole
.
Backward
,
)
'ring_id'
:
ring_id
})
def
_could_be_fuse
(
self
):
def
_could_be_fuse
(
self
):
# TODO support gradient fuse higher order gradient.
# TODO support gradient fuse higher order gradient.
...
@@ -423,36 +469,51 @@ class DataParallelOptimizationPass(PassBase):
...
@@ -423,36 +469,51 @@ class DataParallelOptimizationPass(PassBase):
for
i
,
group
in
enumerate
(
grad_groups
[::
-
1
]):
for
i
,
group
in
enumerate
(
grad_groups
[::
-
1
]):
# create coalecse tensor
# create coalecse tensor
group
.
coalesce_var
=
block
.
create_var
(
name
=
unique_name
.
generate
(
group
.
coalesce_var
=
block
.
create_var
(
'coalecse_grad_{}'
.
format
(
i
)),
name
=
unique_name
.
generate
(
'coalecse_grad_{}'
.
format
(
i
)),
dtype
=
group
.
dtype
,
dtype
=
group
.
dtype
,
persistable
=
False
,
persistable
=
False
,
stop_gradient
=
True
)
stop_gradient
=
True
,
)
# update allreduce & scale op
# update allreduce & scale op
if
group
.
scale_op_idx
!=
-
1
:
if
group
.
scale_op_idx
!=
-
1
:
scale_op
=
block
.
ops
[
group
.
scale_op_idx
]
scale_op
=
block
.
ops
[
group
.
scale_op_idx
]
assert
scale_op
.
type
==
'scale'
,
"should found scale op but found {}"
.
format
(
assert
(
str
(
scale_op
))
scale_op
.
type
==
'scale'
scale_op
.
_rename_input
(
scale_op
.
input_arg_names
[
0
],
),
"should found scale op but found {}"
.
format
(
str
(
scale_op
))
group
.
coalesce_var
.
name
)
scale_op
.
_rename_input
(
scale_op
.
_rename_output
(
scale_op
.
output_arg_names
[
0
],
scale_op
.
input_arg_names
[
0
],
group
.
coalesce_var
.
name
group
.
coalesce_var
.
name
)
)
scale_op
.
_rename_output
(
scale_op
.
output_arg_names
[
0
],
group
.
coalesce_var
.
name
)
allreduce_op
=
block
.
ops
[
group
.
allreduce_op_idx
]
allreduce_op
=
block
.
ops
[
group
.
allreduce_op_idx
]
assert
allreduce_op
.
type
==
'c_allreduce_sum'
,
"should found c_allreduce_sum op but found {}"
.
format
(
assert
(
str
(
allreduce_op
))
allreduce_op
.
type
==
'c_allreduce_sum'
allreduce_op
.
_rename_input
(
allreduce_op
.
input_arg_names
[
0
],
),
"should found c_allreduce_sum op but found {}"
.
format
(
group
.
coalesce_var
.
name
)
str
(
allreduce_op
)
allreduce_op
.
_rename_output
(
allreduce_op
.
output_arg_names
[
0
],
)
group
.
coalesce_var
.
name
)
allreduce_op
.
_rename_input
(
allreduce_op
.
input_arg_names
[
0
],
group
.
coalesce_var
.
name
)
allreduce_op
.
_rename_output
(
allreduce_op
.
output_arg_names
[
0
],
group
.
coalesce_var
.
name
)
# remvoe un-used op
# remvoe un-used op
remove_op_indices
=
group
.
remove_wait_op_indices
+
group
.
remove_allreduce_op_indices
+
group
.
remove_scale_op_indices
remove_op_indices
=
(
group
.
remove_wait_op_indices
+
group
.
remove_allreduce_op_indices
+
group
.
remove_scale_op_indices
)
for
idx
in
sorted
(
remove_op_indices
,
reverse
=
True
):
for
idx
in
sorted
(
remove_op_indices
,
reverse
=
True
):
assert
block
.
ops
[
assert
(
idx
].
type
in
remove_op_types
,
"Unexception: try to remove op {}"
.
format
(
block
.
ops
[
idx
].
type
in
remove_op_types
str
(
op
))
),
"Unexception: try to remove op {}"
.
format
(
str
(
block
.
ops
[
idx
].
type
())
)
block
.
_remove_op
(
idx
)
block
.
_remove_op
(
idx
)
# insert coalecse op
# insert coalecse op
...
@@ -464,22 +525,23 @@ class DataParallelOptimizationPass(PassBase):
...
@@ -464,22 +525,23 @@ class DataParallelOptimizationPass(PassBase):
concated_ranks
.
append
(
len
(
shape
))
concated_ranks
.
append
(
len
(
shape
))
grad_names
=
[
grad
.
name
for
grad
in
group
.
gradients
]
grad_names
=
[
grad
.
name
for
grad
in
group
.
gradients
]
block
.
_insert_op_without_sync
(
group
.
coalesce_op_idx
,
block
.
_insert_op_without_sync
(
group
.
coalesce_op_idx
,
type
=
"coalesce_tensor"
,
type
=
"coalesce_tensor"
,
inputs
=
{
"Input"
:
grad_names
},
inputs
=
{
"Input"
:
grad_names
},
outputs
=
{
outputs
=
{
"Output"
:
grad_names
,
"Output"
:
grad_names
,
"FusedOutput"
:
group
.
coalesce_var
"FusedOutput"
:
group
.
coalesce_var
,
},
},
attrs
=
{
attrs
=
{
"copy_data"
:
False
,
"copy_data"
:
False
,
"use_align"
:
True
,
"use_align"
:
True
,
"dtype"
:
group
.
dtype
,
"dtype"
:
group
.
dtype
,
"concated_shapes"
:
"concated_shapes"
:
concated_shapes
,
concated_shapes
,
"concated_ranks"
:
concated_ranks
,
"concated_ranks"
:
concated_ranks
,
OP_ROLE_KEY
:
OpRole
.
Backward
OP_ROLE_KEY
:
OpRole
.
Backward
,
})
},
)
block
.
_sync_with_cpp
()
block
.
_sync_with_cpp
()
# TODO update dist attr
# TODO update dist attr
...
@@ -487,6 +549,7 @@ class DataParallelOptimizationPass(PassBase):
...
@@ -487,6 +549,7 @@ class DataParallelOptimizationPass(PassBase):
def
summary
(
self
,
grad_groups
=
[]):
def
summary
(
self
,
grad_groups
=
[]):
# TODO: add logger module
# TODO: add logger module
import
logging
import
logging
self
.
_logger
=
logging
.
getLogger
()
self
.
_logger
=
logging
.
getLogger
()
self
.
_logger
.
propagate
=
False
self
.
_logger
.
propagate
=
False
if
not
self
.
_logger
.
handlers
:
if
not
self
.
_logger
.
handlers
:
...
@@ -500,26 +563,31 @@ class DataParallelOptimizationPass(PassBase):
...
@@ -500,26 +563,31 @@ class DataParallelOptimizationPass(PassBase):
if
len
(
grad_groups
)
>
0
:
if
len
(
grad_groups
)
>
0
:
self
.
_logger
.
info
(
self
.
_logger
.
info
(
"origin {} allreduce ops are fused into {} coalecse allreduce ops."
"origin {} allreduce ops are fused into {} coalecse allreduce ops."
.
format
(
.
format
(
len
(
self
.
_grad_name_to_group_map
.
keys
()),
len
(
self
.
_grad_name_to_group_map
.
keys
()),
len
(
grad_groups
)
len
(
grad_groups
)))
)
)
self
.
_logger
.
info
(
"gradient fusing group are following: "
)
self
.
_logger
.
info
(
"gradient fusing group are following: "
)
fused_grads
=
set
()
fused_grads
=
set
()
for
i
,
group
in
enumerate
(
grad_groups
):
for
i
,
group
in
enumerate
(
grad_groups
):
self
.
_logger
.
info
(
self
.
_logger
.
info
(
"coalecse gradient [{}] is composed by: {}"
.
format
(
"coalecse gradient [{}] is composed by: {}"
.
format
(
i
,
[
grad
.
name
for
grad
in
group
.
gradients
]))
i
,
[
grad
.
name
for
grad
in
group
.
gradients
]
)
)
fused_grads
.
update
([
grad
.
name
for
grad
in
group
.
gradients
])
fused_grads
.
update
([
grad
.
name
for
grad
in
group
.
gradients
])
individual_grads
=
set
(
individual_grads
=
set
(
self
.
_grad_name_to_group_map
.
keys
())
-
set
(
self
.
_grad_name_to_group_map
.
keys
())
-
set
(
fused_grads
)
fused_grads
)
self
.
_logger
.
info
(
self
.
_logger
.
info
(
"the following [{}] gradients are not fused: "
.
format
(
"the following [{}] gradients are not fused: "
.
format
(
len
(
individual_grads
)))
len
(
individual_grads
)
)
)
self
.
_logger
.
info
(
"individual gradient {}"
.
format
(
individual_grads
))
self
.
_logger
.
info
(
"individual gradient {}"
.
format
(
individual_grads
))
class
GradientsGroup
(
object
):
class
GradientsGroup
(
object
):
def
__init__
(
self
,
ops
,
max_group_size
):
def
__init__
(
self
,
ops
,
max_group_size
):
self
.
max_group_size
=
max_group_size
self
.
max_group_size
=
max_group_size
self
.
ops
=
ops
self
.
ops
=
ops
...
@@ -575,8 +643,11 @@ class GradientsGroup(object):
...
@@ -575,8 +643,11 @@ class GradientsGroup(object):
grad_op_idx
-=
1
grad_op_idx
-=
1
grad_op
=
self
.
ops
[
grad_op_idx
]
grad_op
=
self
.
ops
[
grad_op_idx
]
assert
grad_var
.
name
in
grad_op
.
output_arg_names
,
"grad [{}] should be output of {}"
.
format
(
assert
(
grad_var
.
name
,
str
(
grad_op
))
grad_var
.
name
in
grad_op
.
output_arg_names
),
"grad [{}] should be output of {}"
.
format
(
grad_var
.
name
,
str
(
grad_op
)
)
self
.
coalesce_op_idx
=
grad_op_idx
self
.
coalesce_op_idx
=
grad_op_idx
def
finalize
(
self
):
def
finalize
(
self
):
...
...
python/paddle/distributed/passes/auto_parallel_gradient_merge.py
浏览文件 @
c47853f6
此差异已折叠。
点击以展开。
python/paddle/distributed/passes/auto_parallel_pipeline.py
浏览文件 @
c47853f6
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/auto_parallel/1F1B_pass_unittest.py
0 → 100644
浏览文件 @
c47853f6
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
random
import
numpy
as
np
import
paddle
from
paddle.distributed.fleet
import
auto
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
get_gpt_model
import
generate_model
,
FakeDataset
paddle
.
enable_static
()
def
apply_pass
(
use_1f1b
=
False
):
strategy
=
auto
.
Strategy
()
strategy
.
auto_mode
=
"semi"
strategy
.
reinit
=
True
if
use_1f1b
:
pipeline
=
strategy
.
pipeline
pipeline
.
enable
=
True
pipeline
.
schedule_mode
=
"1F1B"
pipeline
.
accumulate_steps
=
2
else
:
gradient_merge
=
strategy
.
gradient_merge
gradient_merge
.
enable
=
True
gradient_merge
.
k_steps
=
2
gradient_merge
.
avg
=
True
amp
=
strategy
.
amp
amp
.
enable
=
True
amp
.
custom_white_list
=
[
'softmax'
,
'layer_norm'
,
'gelu'
]
amp
.
custom_black_list
=
[
'c_softmax_with_cross_entropy'
,
'elementwise_div'
,
'reduce_sum'
,
]
amp
.
init_loss_scaling
=
32768
amp
.
use_fp16_guard
=
False
amp
.
use_pure_fp16
=
True
return
strategy
def
reset_prog
():
paddle
.
fluid
.
framework
.
switch_main_program
(
paddle
.
static
.
Program
())
paddle
.
fluid
.
framework
.
switch_startup_program
(
paddle
.
static
.
Program
())
class
Test1F1BPass
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
rtol
=
1e-5
self
.
atol
=
1e-8
self
.
batch_size
=
2
self
.
batch_num
=
10
self
.
clip_norm
=
0.2
self
.
dataset
=
FakeDataset
(
self
.
batch_size
*
self
.
batch_num
)
def
init
(
self
,
engine
):
paddle
.
seed
(
2021
)
np
.
random
.
seed
(
2021
)
random
.
seed
(
2021
)
paddle
.
distributed
.
fleet
.
init
(
is_collective
=
True
)
place
=
paddle
.
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
engine
.
_executor
=
paddle
.
static
.
Executor
(
place
)
def
get_engine
(
self
,
use_1f1b
=
False
):
reset_prog
()
strategy
=
apply_pass
(
use_1f1b
)
clip
=
paddle
.
nn
.
ClipGradByGlobalNorm
(
self
.
clip_norm
)
opt
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
0.00001
,
grad_clip
=
clip
)
model
,
loss
=
generate_model
(
"pp"
)
engine
=
auto
.
Engine
(
model
,
loss
,
opt
,
strategy
=
strategy
)
self
.
init
(
engine
)
return
engine
def
check_results
(
self
,
ref_losses
,
check_losses
):
np
.
testing
.
assert_allclose
(
ref_losses
,
check_losses
,
rtol
=
self
.
rtol
,
atol
=
self
.
atol
,
err_msg
=
'pass {} has wrong results!,
\n
u={}
\n
v={}
\n
diff={}'
.
format
(
__class__
,
ref_losses
,
check_losses
,
ref_losses
-
check_losses
),
)
def
test_1f1b_pass
(
self
):
# navie_pp+gradient_merge training
engine_pp
=
self
.
get_engine
()
history
=
engine_pp
.
fit
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
,
log_freq
=
1
)
assert
engine_pp
.
_strategy
.
pipeline
.
enable
==
False
# pp2 1f1b merge training
engine_1f1b
=
self
.
get_engine
(
True
)
history
=
engine_1f1b
.
fit
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
,
log_freq
=
1
)
assert
engine_1f1b
.
_strategy
.
pipeline
.
enable
==
True
# NOTE: every sample data from dataset is all the same
if
paddle
.
distributed
.
get_rank
()
==
1
:
losses_pp
=
np
.
array
(
history
.
history
[
"loss"
])
losses_1f1b
=
np
.
array
(
history
.
history
[
"loss"
])
self
.
check_results
(
losses_pp
,
losses_1f1b
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
浏览文件 @
c47853f6
...
@@ -69,6 +69,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
...
@@ -69,6 +69,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules
(
test_engine_callbacks MODULES test_engine_callbacks
)
py_test_modules
(
test_engine_callbacks MODULES test_engine_callbacks
)
set_tests_properties
(
test_engine_callbacks
set_tests_properties
(
test_engine_callbacks
PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
TIMEOUT 50
)
PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
TIMEOUT 50
)
py_test_modules
(
test_pass_1F1B MODULES test_pass_1F1B
)
set_tests_properties
(
test_pass_1F1B PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
TIMEOUT 50
)
py_test_modules
(
test_parallel_tuner MODULES test_parallel_tuner ENVS
py_test_modules
(
test_parallel_tuner MODULES test_parallel_tuner ENVS
${
dist_ENVS
}
)
${
dist_ENVS
}
)
...
...
python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py
浏览文件 @
c47853f6
...
@@ -89,6 +89,12 @@ def generate_model(strategy, dropout_prob=0.0):
...
@@ -89,6 +89,12 @@ def generate_model(strategy, dropout_prob=0.0):
modeling
.
_global_parallel_strategy
=
"mp"
modeling
.
_global_parallel_strategy
=
"mp"
elif
strategy
==
"dp"
:
elif
strategy
==
"dp"
:
modeling
.
_global_parallel_strategy
=
"dp"
modeling
.
_global_parallel_strategy
=
"dp"
elif
strategy
==
"pp"
:
modeling
.
_global_parallel_strategy
=
"pp"
modeling
.
PP_MESH_LIST
=
[
auto
.
ProcessMesh
(
mesh
=
[
0
]),
auto
.
ProcessMesh
(
mesh
=
[
1
]),
]
else
:
else
:
raise
ValueError
(
"Only support serial, mp2 and dp2."
)
raise
ValueError
(
"Only support serial, mp2 and dp2."
)
...
@@ -108,6 +114,7 @@ def generate_model(strategy, dropout_prob=0.0):
...
@@ -108,6 +114,7 @@ def generate_model(strategy, dropout_prob=0.0):
eos_token_id
=
7
,
eos_token_id
=
7
,
bos_token_id
=
0
,
bos_token_id
=
0
,
eol_token_id
=
3
,
eol_token_id
=
3
,
pp_degree
=
2
if
strategy
==
"pp"
else
None
,
)
)
model
=
GPTForPretraining
(
model
=
GPTForPretraining
(
gpt
,
vocab_size
=
1000
,
hidden_size
=
64
,
initializer_range
=
0.02
gpt
,
vocab_size
=
1000
,
hidden_size
=
64
,
initializer_range
=
0.02
...
...
python/paddle/fluid/tests/unittests/auto_parallel/gradient_merge_pass_unittest.py
浏览文件 @
c47853f6
...
@@ -19,7 +19,7 @@ import paddle
...
@@ -19,7 +19,7 @@ import paddle
from
paddle.distributed.fleet
import
auto
from
paddle.distributed.fleet
import
auto
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
get_gpt_model
import
generate_model
,
create_data_holder
,
FakeDataset
from
get_gpt_model
import
generate_model
,
FakeDataset
paddle
.
enable_static
()
paddle
.
enable_static
()
...
@@ -28,12 +28,25 @@ def apply_pass(use_gradient_merge=False):
...
@@ -28,12 +28,25 @@ def apply_pass(use_gradient_merge=False):
strategy
=
auto
.
Strategy
()
strategy
=
auto
.
Strategy
()
strategy
.
auto_mode
=
"semi"
strategy
.
auto_mode
=
"semi"
strategy
.
reinit
=
True
strategy
.
reinit
=
True
if
use_gradient_merge
:
if
use_gradient_merge
:
gradient_merge
=
strategy
.
gradient_merge
gradient_merge
=
strategy
.
gradient_merge
gradient_merge
.
enable
=
True
gradient_merge
.
enable
=
True
gradient_merge
.
k_steps
=
4
gradient_merge
.
k_steps
=
4
gradient_merge
.
avg
=
True
gradient_merge
.
avg
=
True
amp
=
strategy
.
amp
amp
.
enable
=
True
amp
.
custom_white_list
=
[
'softmax'
,
'layer_norm'
,
'gelu'
]
amp
.
custom_black_list
=
[
'c_softmax_with_cross_entropy'
,
'elementwise_div'
,
'reduce_sum'
,
]
amp
.
init_loss_scaling
=
32768
amp
.
use_fp16_guard
=
False
amp
.
use_pure_fp16
=
True
return
strategy
return
strategy
...
@@ -88,6 +101,7 @@ class TestGradientMergePass(unittest.TestCase):
...
@@ -88,6 +101,7 @@ class TestGradientMergePass(unittest.TestCase):
history
=
dp_engine
.
fit
(
history
=
dp_engine
.
fit
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
,
log_freq
=
1
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
,
log_freq
=
1
)
)
assert
dp_engine
.
_strategy
.
gradient_merge
.
enable
==
False
dp_losses
=
np
.
array
(
history
.
history
[
"loss"
])
dp_losses
=
np
.
array
(
history
.
history
[
"loss"
])
# dp2 gradient merge training
# dp2 gradient merge training
...
@@ -95,6 +109,7 @@ class TestGradientMergePass(unittest.TestCase):
...
@@ -95,6 +109,7 @@ class TestGradientMergePass(unittest.TestCase):
history
=
gm_engine
.
fit
(
history
=
gm_engine
.
fit
(
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
,
log_freq
=
1
self
.
dataset
,
3
,
batch_size
=
self
.
batch_size
,
log_freq
=
1
)
)
assert
gm_engine
.
_strategy
.
gradient_merge
.
enable
==
True
gm_losses
=
np
.
array
(
history
.
history
[
"loss"
])
gm_losses
=
np
.
array
(
history
.
history
[
"loss"
])
# avg_loss = 0
# avg_loss = 0
...
...
python/paddle/fluid/tests/unittests/auto_parallel/test_pass_1F1B.py
0 → 100644
浏览文件 @
c47853f6
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
tempfile
import
unittest
import
os
import
sys
import
shutil
import
subprocess
from
paddle.distributed.fleet.launch_utils
import
run_with_coverage
class
Test1F1BPass
(
unittest
.
TestCase
):
def
test_pp2
(
self
):
file_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
launch_model_path
=
os
.
path
.
join
(
file_dir
,
"1F1B_pass_unittest.py"
)
if
os
.
environ
.
get
(
"WITH_COVERAGE"
,
"OFF"
)
==
"ON"
:
coverage_args
=
[
"-m"
,
"coverage"
,
"run"
,
"--branch"
,
"-p"
]
else
:
coverage_args
=
[]
tmp_dir
=
tempfile
.
TemporaryDirectory
()
cmd
=
(
[
sys
.
executable
,
"-u"
]
+
coverage_args
+
[
"-m"
,
"paddle.distributed.launch"
,
"--devices"
,
"0,1"
,
"--log_dir"
,
tmp_dir
.
name
,
launch_model_path
,
]
)
process
=
subprocess
.
Popen
(
cmd
)
process
.
wait
()
self
.
assertEqual
(
process
.
returncode
,
0
)
tmp_dir
.
cleanup
()
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录