Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
30315ac9
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
30315ac9
编写于
11月 29, 2022
作者:
C
caozhou
提交者:
GitHub
11月 29, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto Parallel] Add pattern match (#48464)
* add pattern match * add unittest
上级
41946522
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
407 addition
and
63 deletion
+407
-63
python/paddle/distributed/auto_parallel/graph.py
python/paddle/distributed/auto_parallel/graph.py
+14
-2
python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py
...addle/distributed/auto_parallel/tuner/rule_based_tuner.py
+240
-52
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
...paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_group_operators.py
...uid/tests/unittests/auto_parallel/test_group_operators.py
+1
-1
python/paddle/fluid/tests/unittests/auto_parallel/test_pattern.py
...addle/fluid/tests/unittests/auto_parallel/test_pattern.py
+9
-8
python/paddle/fluid/tests/unittests/auto_parallel/test_pattern_match.py
...fluid/tests/unittests/auto_parallel/test_pattern_match.py
+142
-0
未找到文件。
python/paddle/distributed/auto_parallel/graph.py
浏览文件 @
30315ac9
...
...
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License
from
collections
import
OrderedDict
class
Node
:
def
__init__
(
self
,
id
,
**
attrs
):
...
...
@@ -100,6 +102,8 @@ class Graph:
# Attributes for Graph
self
.
_attrs
=
{}
self
.
_attrs
.
update
(
attrs
)
self
.
_reverse_adjs
=
{}
self
.
_attr_to_nodes
=
{}
@
property
def
nodes
(
self
):
...
...
@@ -120,6 +124,7 @@ class Graph:
node
=
Node
(
node_id
,
**
attrs
)
self
.
_nodes
[
node_id
]
=
node
self
.
_adjs
[
node_id
]
=
{}
self
.
_reverse_adjs
[
node_id
]
=
[]
else
:
self
.
_nodes
[
node_id
].
attrs
.
update
(
attrs
)
...
...
@@ -134,14 +139,21 @@ class Graph:
if
src_id
not
in
self
.
_nodes
:
src_node
=
Node
(
src_id
)
self
.
_nodes
[
src_id
]
=
src_node
self
.
_adjs
[
src_id
]
=
{}
# for one tensor to multiple ops
self
.
_adjs
[
src_id
]
=
OrderedDict
()
self
.
_reverse_adjs
[
src_id
]
=
[]
if
tgt_id
not
in
self
.
_nodes
:
tgt_node
=
Node
(
tgt_id
)
self
.
_nodes
[
tgt_id
]
=
tgt_node
self
.
_adjs
[
tgt_id
]
=
{}
# for one tensor to multiple ops
self
.
_adjs
[
tgt_id
]
=
OrderedDict
()
self
.
_reverse_adjs
[
tgt_id
]
=
[]
# add the edge
edge
=
Edge
(
src_id
,
tgt_id
,
**
attrs
)
self
.
_adjs
[
src_id
][
tgt_id
]
=
edge
# add the reverse adj
self
.
_reverse_adjs
[
tgt_id
].
append
(
self
.
nodes
[
src_id
])
return
edge
def
__len__
(
self
):
...
...
python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py
浏览文件 @
30315ac9
...
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
abc
import
ABC
,
abstractmethod
from
abc
import
abstractmethod
from
..graph
import
Graph
...
...
@@ -32,6 +32,57 @@ def register_pattern(cls):
return
cls
class
BasePattern
(
Graph
):
name
=
"base"
def
__init__
(
self
):
super
().
__init__
()
self
.
build
()
@
abstractmethod
def
build
(
self
):
pass
@
register_pattern
class
QKVPattern
(
BasePattern
):
name
=
"qkv"
def
__init__
(
self
):
super
().
__init__
()
def
build
(
self
):
query
=
self
.
add_node
(
0
,
**
{
"type"
:
"var"
})
q_weight
=
self
.
add_node
(
1
,
**
{
"dim"
:
2
,
"type"
:
"param"
})
k_weight
=
self
.
add_node
(
2
,
**
{
"dim"
:
2
,
"type"
:
"param"
})
v_weight
=
self
.
add_node
(
3
,
**
{
"dim"
:
2
,
"type"
:
"param"
})
q_matmul
=
self
.
add_node
(
4
,
**
{
"type"
:
"matmul_v2"
})
k_matmul
=
self
.
add_node
(
5
,
**
{
"type"
:
"matmul_v2"
})
v_matmul
=
self
.
add_node
(
6
,
**
{
"type"
:
"matmul_v2"
})
q_x
=
self
.
add_edge
(
0
,
4
,
**
{
"input_name"
:
"X"
})
k_x
=
self
.
add_edge
(
0
,
5
,
**
{
"input_name"
:
"X"
})
v_x
=
self
.
add_edge
(
0
,
6
,
**
{
"input_name"
:
"X"
})
q_y
=
self
.
add_edge
(
1
,
4
,
**
{
"input_name"
:
"Y"
})
k_y
=
self
.
add_edge
(
2
,
5
,
**
{
"input_name"
:
"Y"
})
v_y
=
self
.
add_edge
(
3
,
6
,
**
{
"input_name"
:
"Y"
})
q
=
self
.
add_node
(
7
,
**
{
"type"
:
"var"
})
k
=
self
.
add_node
(
8
,
**
{
"type"
:
"var"
})
v
=
self
.
add_node
(
9
,
**
{
"type"
:
"var"
})
q_out
=
self
.
add_edge
(
4
,
7
,
**
{
"output_name"
:
"Out"
})
k_out
=
self
.
add_edge
(
5
,
8
,
**
{
"output_name"
:
"Out"
})
v_out
=
self
.
add_edge
(
6
,
9
,
**
{
"output_name"
:
"Out"
})
# Pattern
self
.
attrs
[
"shard_spec"
]
=
[
[(
1
,
2
,
3
),
[[
-
1
,
0
],
[
-
1
,
1
]]],
]
# 2-tuple list such as [(tensor_id, shard_sepc)]
def
convert_to_graph
(
ops
,
block
):
"""Convert ops to graph."""
graph
=
Graph
()
...
...
@@ -50,7 +101,9 @@ def convert_to_graph(ops, block):
op_node
=
graph
.
add_node
(
node_id
,
**
attrs
)
graph
.
attrs
[
"op_to_id"
][
op
.
desc
.
id
()]
=
op_node
.
id
graph
.
attrs
[
"id_to_op"
][
op_node
.
id
]
=
op
.
desc
.
id
()
graph
.
_attr_to_nodes
[
op_node
.
id
]
=
{}
for
input_name
in
op
.
input_names
:
graph
.
_attr_to_nodes
[
op_node
.
id
][
input_name
]
=
[]
for
var_name
in
op
.
input
(
input_name
):
if
var_name
not
in
graph
.
attrs
[
"var_to_id"
]:
# create var node
...
...
@@ -59,6 +112,7 @@ def convert_to_graph(ops, block):
var
=
block
.
_var_recursive
(
var_name
)
if
var
.
is_parameter
:
var_node
.
attrs
[
"type"
]
=
"param"
var_node
.
attrs
[
"dim"
]
=
len
(
var
.
shape
)
else
:
var_node
.
attrs
[
"type"
]
=
"var"
graph
.
attrs
[
"var_to_id"
][
var_name
]
=
var_node
.
id
...
...
@@ -70,8 +124,10 @@ def convert_to_graph(ops, block):
# create edge that input -> op
input_edge
=
graph
.
add_edge
(
var_node
.
id
,
op_node
.
id
)
input_edge
.
attrs
[
"input_name"
]
=
input_name
graph
.
_attr_to_nodes
[
op_node
.
id
][
input_name
].
append
(
var_node
)
for
output_name
in
op
.
output_names
:
graph
.
_attr_to_nodes
[
op_node
.
id
][
output_name
]
=
[]
for
var_name
in
op
.
output
(
output_name
):
if
var_name
not
in
graph
.
attrs
[
"var_to_id"
]:
# create var node
...
...
@@ -92,64 +148,189 @@ def convert_to_graph(ops, block):
output_edge
=
graph
.
add_edge
(
op_node
.
id
,
var_node
.
id
)
output_edge
.
attrs
[
"output_name"
]
=
output_name
graph
.
_attr_to_nodes
[
op_node
.
id
][
output_name
].
append
(
var_node
)
return
graph
class
BasePattern
(
ABC
):
name
=
"base"
def
match
(
pattern
,
graph
):
def
_is_op_node
(
node
):
"""Judge whether node is op node"""
if
node
.
attrs
[
"type"
]
not
in
[
"var"
,
"param"
,
"data"
]:
return
True
def
__init__
(
self
):
self
.
graph
=
None
self
.
build
()
return
False
@
abstractmethod
def
build
(
self
):
pass
def
_compare_op_node
(
src
,
tgt
):
"""Compare whether two op nodes are equal"""
if
src
.
attrs
[
"type"
]
!=
tgt
.
attrs
[
"type"
]:
return
False
return
True
@
register_pattern
class
QKVPattern
(
BasePattern
):
name
=
"qkv"
def
_compare_var_node
(
src
,
tgt
):
"""Compare whether two var nodes are equal"""
for
key
in
src
.
attrs
:
if
key
not
in
tgt
.
attrs
:
return
False
if
src
.
attrs
[
key
]
!=
tgt
.
attrs
[
key
]:
return
False
def
__init__
(
self
):
super
().
__init__
()
return
True
def
build
(
self
):
self
.
graph
=
Graph
()
def
_match_core
(
src_node
,
tgt_node
):
nonlocal
not_matched
# do not support one input name or output name corresponding to multiple vars
if
not_matched
:
return
if
_is_op_node
(
src_node
):
# compare op node whether equal
if
not
_compare_op_node
(
src_node
,
tgt_node
):
return
result
[
src_node
.
id
]
=
tgt_node
.
id
# input var nodes
src_input_nodes
=
src_reverse_adjs
[
src_node
.
id
]
for
node
in
src_input_nodes
:
# has visited
if
node
.
id
in
result
:
continue
edge
=
src_edges
[
node
.
id
][
src_node
.
id
]
input_name
=
edge
.
attrs
[
"input_name"
]
# NOTE: do not support one input name or output name corresponding to multiple vars
compare_nodes
=
tgt_attr_to_nodes
[
tgt_node
.
id
].
get
(
input_name
,
None
)
if
not
compare_nodes
:
not_matched
=
True
return
_match_core
(
node
,
compare_nodes
[
0
])
# output var nodes
src_output_node_ids
=
src_edges
[
src_node
.
id
].
keys
()
for
node_id
in
src_output_node_ids
:
# has visited
if
node_id
in
result
:
continue
node
=
src_nodes
[
node_id
]
edge
=
src_edges
[
src_node
.
id
][
node_id
]
output_name
=
edge
.
attrs
[
"output_name"
]
query
=
self
.
graph
.
add_node
(
0
,
**
{
"type"
:
"var"
})
# NOTE: do not support one input name or output name corresponding to multiple vars
compare_nodes
=
tgt_attr_to_nodes
[
tgt_node
.
id
].
get
(
output_name
,
None
)
if
not
compare_nodes
:
not_matched
=
True
return
_match_core
(
node
,
compare_nodes
[
0
])
q_weight
=
self
.
graph
.
add_node
(
1
,
**
{
"dim"
:
2
,
"type"
:
"param"
})
k_weight
=
self
.
graph
.
add_node
(
2
,
**
{
"dim"
:
2
,
"type"
:
"param"
})
v_weight
=
self
.
graph
.
add_node
(
3
,
**
{
"dim"
:
2
,
"type"
:
"param"
})
else
:
# compare var node whether equal
if
not
_compare_var_node
(
src_node
,
tgt_node
):
not_matched
=
True
return
q_matmul
=
self
.
graph
.
add_node
(
4
,
**
{
"type"
:
"matmul_v2"
})
k_matmul
=
self
.
graph
.
add_node
(
5
,
**
{
"type"
:
"matmul_v2"
})
v_matmul
=
self
.
graph
.
add_node
(
6
,
**
{
"type"
:
"matmul_v2"
})
result
[
src_node
.
id
]
=
tgt_node
.
id
q_x
=
self
.
graph
.
add_edge
(
0
,
4
,
**
{
"input_name"
:
"X"
})
k_x
=
self
.
graph
.
add_edge
(
0
,
5
,
**
{
"input_name"
:
"X"
})
v_x
=
self
.
graph
.
add_edge
(
0
,
6
,
**
{
"input_name"
:
"X"
})
q_y
=
self
.
graph
.
add_edge
(
1
,
4
,
**
{
"input_name"
:
"Y"
})
k_y
=
self
.
graph
.
add_edge
(
2
,
5
,
**
{
"input_name"
:
"Y"
})
v_y
=
self
.
graph
.
add_edge
(
3
,
6
,
**
{
"input_name"
:
"Y"
})
# as input for op nodes
src_as_input_node_ids
=
src_edges
[
src_node
.
id
].
keys
()
for
node_id
in
src_as_input_node_ids
:
if
node_id
in
result
:
continue
q
=
self
.
graph
.
add_node
(
7
,
**
{
"type"
:
"var"
})
k
=
self
.
graph
.
add_node
(
8
,
**
{
"type"
:
"var"
})
v
=
self
.
graph
.
add_node
(
9
,
**
{
"type"
:
"var"
})
src_edge
=
src_edges
[
src_node
.
id
][
node_id
]
input_name
=
src_edge
.
attrs
[
"input_name"
]
compare_node_ids
=
tgt_edges
[
tgt_node
.
id
].
keys
()
compare_node
=
None
for
compare_node_id
in
compare_node_ids
:
edge
=
tgt_edges
[
tgt_node
.
id
][
compare_node_id
]
if
(
edge
.
attrs
[
"input_name"
]
==
input_name
and
compare_node_id
not
in
result
.
values
()
):
compare_node
=
tgt_nodes
[
compare_node_id
]
break
q_out
=
self
.
graph
.
add_edge
(
7
,
4
,
**
{
"output_name"
:
"Out"
})
k_out
=
self
.
graph
.
add_edge
(
8
,
5
,
**
{
"output_name"
:
"Out"
})
v_out
=
self
.
graph
.
add_edge
(
9
,
6
,
**
{
"output_name"
:
"Out"
})
if
not
compare_node
:
not_matched
=
True
return
_match_core
(
src_nodes
[
node_id
],
compare_node
)
# Pattern
self
.
graph
.
attrs
[
"shard_tensor"
]
=
[
(
1
,
2
,
3
),
[[
-
1
,
0
],
[
-
1
,
1
]],
]
# 2-tuple such as (tensor_id, patterns)
# as output for nodes
src_as_output_nodes
=
src_reverse_adjs
[
src_node
.
id
]
for
node
in
src_as_output_nodes
:
if
node
.
id
in
result
:
continue
src_edge
=
src_edges
[
node
.
id
][
src_node
.
id
]
output_name
=
src_edge
.
attrs
[
"output_name"
]
compare_node_ids
=
tgt_reverse_adjs
[
tgt_node
.
id
]
class
OperatorGroupUtil
:
compare_node
=
None
for
node_id
in
compare_node_ids
:
edge
=
tgt_edges
[
node_id
][
tgt_node
.
id
]
if
edge
.
attrs
[
"output_name"
]
==
output_name
:
compare_node
=
tgt_nodes
[
node_id
]
break
if
not
compare_node
:
not_matched
=
True
return
_match_core
(
src_nodes
[
node_id
],
compare_node
)
results
=
[]
result
=
{}
has_matched
=
set
()
src_nodes
=
pattern
.
nodes
src_edges
=
pattern
.
_adjs
src_reverse_adjs
=
pattern
.
_reverse_adjs
tgt_nodes
=
graph
.
nodes
tgt_edges
=
graph
.
_adjs
tgt_reverse_adjs
=
graph
.
_reverse_adjs
tgt_attr_to_nodes
=
graph
.
_attr_to_nodes
not_matched
=
False
# starts with a op node
src_start_node
=
None
for
node_id
in
src_nodes
:
node
=
src_nodes
[
node_id
]
if
node
.
attrs
[
"type"
]
not
in
[
"var"
,
"param"
,
"data"
]:
src_start_node
=
node
break
assert
src_start_node
is
not
None
for
node_id
in
tgt_nodes
:
node
=
tgt_nodes
[
node_id
]
if
node
.
attrs
[
"type"
]
==
src_start_node
.
attrs
[
"type"
]:
_match_core
(
src_start_node
,
node
)
if
not
not_matched
:
need_to_append
=
True
for
value
in
result
.
values
():
if
value
in
has_matched
:
result
=
{}
need_to_append
=
False
break
if
need_to_append
:
results
.
append
(
result
)
for
value
in
result
.
values
():
has_matched
.
add
(
value
)
result
=
{}
else
:
not_matched
=
False
result
=
{}
return
results
class
OperatorClusteringUtil
:
common_starts
=
[
"layer_norm"
,
"matmul_v2"
,
"matmul"
]
@
staticmethod
...
...
@@ -257,7 +438,10 @@ class OperatorGroupUtil:
min_index
=
min
(
index_group
)
if
max_index
-
min_index
>=
k
:
longest_sub_seq
=
seq
[
min_index
:
min_index
+
k
]
if
longest_sub_seq
[
0
]
in
OperatorGroupUtil
.
common_starts
:
if
(
longest_sub_seq
[
0
]
in
OperatorClusteringUtil
.
common_starts
):
return
longest_sub_seq
if
longest_sub_seq
is
not
None
:
return
longest_sub_seq
...
...
@@ -325,9 +509,9 @@ class RuleBasedTuner:
self
.
_dist_context
=
dist_context
self
.
_mode
=
mode
def
group
_operators
(
self
,
ops
):
def
cluster
_operators
(
self
,
ops
):
"""
Group
operators to layers.
Cluster
operators to layers.
Args:
ops (list): A operator list.
...
...
@@ -337,7 +521,7 @@ class RuleBasedTuner:
"""
seq
=
[
op
.
type
for
op
in
ops
]
while
not
Operator
Group
Util
.
stop_replace
(
seq
):
while
not
Operator
Clustering
Util
.
stop_replace
(
seq
):
to_replace_seq
=
[]
to_replace_idxes
=
[]
has_append
=
False
...
...
@@ -351,12 +535,16 @@ class RuleBasedTuner:
elif
isinstance
(
seq
,
list
)
and
has_append
:
break
ranks
=
OperatorGroupUtil
.
get_ranks
(
to_replace_seq
)
suffixes
=
OperatorGroupUtil
.
get_suffixes
(
ranks
)
heights
=
OperatorGroupUtil
.
get_heights
(
suffixes
,
to_replace_seq
)
longest_sub_seq
=
OperatorGroupUtil
.
get_longest_repeated_sub_seq
(
ranks
=
OperatorClusteringUtil
.
get_ranks
(
to_replace_seq
)
suffixes
=
OperatorClusteringUtil
.
get_suffixes
(
ranks
)
heights
=
OperatorClusteringUtil
.
get_heights
(
suffixes
,
to_replace_seq
)
longest_sub_seq
=
(
OperatorClusteringUtil
.
get_longest_repeated_sub_seq
(
suffixes
,
heights
,
to_replace_seq
)
)
has_merged
=
False
if
longest_sub_seq
is
None
:
for
i
in
range
(
to_replace_idxes
[
-
1
]
+
1
,
len
(
seq
)):
...
...
@@ -374,10 +562,10 @@ class RuleBasedTuner:
seq
=
[
to_replace_seq
]
break
decomposed_sub_seq
=
Operator
Group
Util
.
get_decomposed_sub_seq
(
decomposed_sub_seq
=
Operator
Clustering
Util
.
get_decomposed_sub_seq
(
longest_sub_seq
)
to_replace_seq
=
Operator
Group
Util
.
replace_by_decomposed_seq
(
to_replace_seq
=
Operator
Clustering
Util
.
replace_by_decomposed_seq
(
decomposed_sub_seq
,
to_replace_seq
)
result
=
seq
[:
to_replace_idxes
[
0
]]
...
...
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
浏览文件 @
30315ac9
...
...
@@ -120,4 +120,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules
(
test_fp16_assign MODULES test_fp16_assign
)
py_test_modules
(
test_group_operators MODULES test_group_operators
)
py_test_modules
(
test_pattern MODULES test_pattern
)
py_test_modules
(
test_pattern_match MODULES test_pattern_match
)
endif
()
python/paddle/fluid/tests/unittests/auto_parallel/test_group_operators.py
浏览文件 @
30315ac9
...
...
@@ -121,7 +121,7 @@ class TestGroupOperators(unittest.TestCase):
dist_context
=
DistributedContext
()
tuner
=
RuleBasedTuner
(
dist_context
)
layers
=
tuner
.
group
_operators
(
train_program
.
global_block
().
ops
)
layers
=
tuner
.
cluster
_operators
(
train_program
.
global_block
().
ops
)
op_types
=
[]
for
layer
in
layers
:
tmp
=
[]
...
...
python/paddle/fluid/tests/unittests/auto_parallel/test_pattern.py
浏览文件 @
30315ac9
...
...
@@ -14,6 +14,7 @@
import
sys
import
unittest
import
numpy
as
np
import
paddle
...
...
@@ -22,8 +23,8 @@ import paddle.static as static
sys
.
path
.
append
(
".."
)
import
auto_parallel_gpt_model
as
modeling
from
auto_parallel_gpt_model
import
(
GPTModel
,
GPTForPretraining
,
GPTModel
,
GPTPretrainingCriterion
,
)
...
...
@@ -111,22 +112,22 @@ class TestGroupOperators(unittest.TestCase):
sequence_len
,
vocab_size
,
)
from
paddle.distributed.auto_parallel.dist_context
import
(
DistributedContext
,
)
from
paddle.distributed.auto_parallel.tuner.rule_based_tuner
import
(
_PATTERNS
,
RuleBasedTuner
,
convert_to_graph
,
_PATTERNS
,
)
from
paddle.distributed.auto_parallel.dist_context
import
(
DistributedContext
,
)
dist_context
=
DistributedContext
()
tuner
=
RuleBasedTuner
(
dist_context
)
layers
=
tuner
.
group
_operators
(
train_program
.
global_block
().
ops
)
layers
=
tuner
.
cluster
_operators
(
train_program
.
global_block
().
ops
)
layer
=
layers
[
0
]
graph
=
convert_to_graph
(
layer
,
train_program
.
global_block
())
print
(
graph
)
print
(
"qkv: "
,
_PATTERNS
[
"qkv"
].
graph
)
print
(
"graph: "
,
graph
)
print
(
"qkv: "
,
_PATTERNS
[
"qkv"
].
attrs
[
"shard_spec"
]
)
if
__name__
==
"__main__"
:
...
...
python/paddle/fluid/tests/unittests/auto_parallel/test_pattern_match.py
0 → 100644
浏览文件 @
30315ac9
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
unittest
import
numpy
as
np
import
paddle
import
paddle.static
as
static
sys
.
path
.
append
(
".."
)
import
auto_parallel_gpt_model
as
modeling
from
auto_parallel_gpt_model
import
(
GPTForPretraining
,
GPTModel
,
GPTPretrainingCriterion
,
)
def
get_gpt_model
(
train_program
,
start_program
,
place
,
batch_size
,
sequence_len
,
vocab_size
):
with
static
.
program_guard
(
train_program
,
start_program
):
tokens
=
paddle
.
static
.
data
(
name
=
"tokens"
,
shape
=
[
batch_size
,
sequence_len
],
dtype
=
'int64'
)
position_ids
=
paddle
.
static
.
data
(
name
=
"position_ids"
,
shape
=
[
batch_size
,
sequence_len
],
dtype
=
'int64'
)
attention_mask
=
paddle
.
static
.
data
(
name
=
"attention_mask"
,
shape
=
[
batch_size
,
1
,
sequence_len
,
sequence_len
],
dtype
=
'float32'
,
)
labels
=
paddle
.
static
.
data
(
name
=
"labels"
,
shape
=
[
batch_size
,
sequence_len
],
dtype
=
'int64'
)
loss_mask
=
paddle
.
static
.
data
(
name
=
"loss_mask"
,
shape
=
[
batch_size
,
sequence_len
],
dtype
=
'float32'
)
gpt
=
GPTModel
(
vocab_size
=
1000
,
hidden_size
=
64
,
num_hidden_layers
=
2
,
num_attention_heads
=
8
,
intermediate_size
=
256
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
0.0
,
attention_probs_dropout_prob
=
0.0
,
max_position_embeddings
=
1024
,
type_vocab_size
=
1
,
initializer_range
=
0.02
,
pad_token_id
=
0
,
eos_token_id
=
7
,
bos_token_id
=
0
,
eol_token_id
=
3
,
)
model
=
GPTForPretraining
(
gpt
,
vocab_size
=
1000
,
hidden_size
=
64
,
initializer_range
=
0.02
)
preds
=
model
(
tokens
,
position_ids
,
attention_mask
)
criterion
=
GPTPretrainingCriterion
()
loss
=
criterion
(
preds
,
labels
,
loss_mask
)
def
gen_data
():
np
.
random
.
seed
(
2021
)
tokens
=
[]
position_ids
=
[]
attention_mask
=
[]
labels
=
[]
loss_mask
=
[]
for
_
in
range
(
batch_size
):
tokens
.
append
(
np
.
random
.
randint
(
vocab_size
,
size
=
sequence_len
))
position_ids
.
append
(
np
.
arange
(
sequence_len
))
attention_mask
.
append
([
np
.
tril
(
np
.
ones
(
sequence_len
))])
labels
.
append
(
np
.
random
.
randint
(
vocab_size
,
size
=
sequence_len
))
loss_mask
.
append
(
np
.
ones
(
sequence_len
))
return
tokens
,
position_ids
,
attention_mask
,
labels
,
loss_mask
return
train_program
,
start_program
,
loss
,
gen_data
class
TestGroupOperators
(
unittest
.
TestCase
):
def
test_gpt
(
self
):
modeling
.
init_global
()
train_program
=
static
.
Program
()
start_program
=
static
.
Program
()
place
=
paddle
.
set_device
(
"gpu"
)
batch_size
=
8
sequence_len
=
512
vocab_size
=
1000
train_program
,
start_program
,
loss
,
gen_data
=
get_gpt_model
(
train_program
,
start_program
,
place
,
batch_size
,
sequence_len
,
vocab_size
,
)
from
paddle.distributed.auto_parallel.dist_context
import
(
DistributedContext
,
)
from
paddle.distributed.auto_parallel.tuner.rule_based_tuner
import
(
_PATTERNS
,
RuleBasedTuner
,
convert_to_graph
,
match
,
)
dist_context
=
DistributedContext
()
tuner
=
RuleBasedTuner
(
dist_context
)
layers
=
tuner
.
cluster_operators
(
train_program
.
global_block
().
ops
)
layer
=
layers
[
0
]
graph
=
convert_to_graph
(
layer
,
train_program
.
global_block
())
results
=
match
(
_PATTERNS
[
"qkv"
],
graph
)
shard_tensor_infos
=
_PATTERNS
[
"qkv"
].
attrs
[
"shard_spec"
]
tensor_ids
=
shard_tensor_infos
[
0
][
0
]
if
results
:
for
result
in
results
:
for
node_id
in
result
:
if
node_id
in
tensor_ids
:
print
(
graph
.
attrs
[
"id_to_var"
][
result
[
node_id
]])
print
(
"shard_spec: "
,
shard_tensor_infos
[
0
][
1
])
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录