Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
0bb7c003
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0bb7c003
编写于
3月 21, 2023
作者:
C
caozhou
提交者:
GitHub
3月 21, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto Parallel] Add patterns of rule based tuner (#51859)
* add patterns * add unittest
上级
cdefcd00
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
646 addition
and
251 deletion
+646
-251
python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py
...addle/distributed/auto_parallel/tuner/rule_based_tuner.py
+624
-231
python/paddle/fluid/tests/unittests/auto_parallel/test_pattern.py
...addle/fluid/tests/unittests/auto_parallel/test_pattern.py
+17
-4
python/paddle/fluid/tests/unittests/auto_parallel/test_pattern_match.py
...fluid/tests/unittests/auto_parallel/test_pattern_match.py
+5
-16
未找到文件。
python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py
浏览文件 @
0bb7c003
...
@@ -22,24 +22,43 @@ _PATTERNS = {}
...
@@ -22,24 +22,43 @@ _PATTERNS = {}
def
register_pattern
(
cls
):
def
register_pattern
(
cls
):
"""Register pattern for rule-based tuner."""
"""Register pattern for rule-based tuner."""
name
=
cls
.
name
def
register
(
name
):
def
register
():
global
_PATTERNS
global
_PATTERNS
_PATTERNS
[
name
]
=
cls
()
pattern
=
cls
()
_PATTERNS
[
pattern
.
name
]
=
pattern
# sort patterns according to the number of sharded tensors
# set its dist attr by the fisrt one when a tensor can be matched by multiple patterns.
_PATTERNS
=
dict
(
sorted
(
_PATTERNS
.
items
(),
key
=
lambda
x
:
-
x
[
1
].
attrs
[
"sharded_tensors"
]
)
)
register
(
name
)
register
()
return
cls
return
cls
class
BasePattern
(
Graph
):
class
BasePattern
(
Graph
):
name
=
"base"
"""
Base class of pattern.
The BasePattern inherits the Graph, two important differences are shard_spec and sharded_tensors.
For shard_spec, it indicates the shard specification of tensor node in this pattern under different parallelism.
For sharded_tensors, it represents the number of tensors which sharded.
"""
_name
=
"base"
def
__init__
(
self
):
def
__init__
(
self
):
"""Every pattern has its own name and build method."""
super
().
__init__
()
super
().
__init__
()
self
.
build
()
self
.
build
()
@
property
def
name
(
self
):
return
self
.
__class__
.
_name
@
abstractmethod
@
abstractmethod
def
build
(
self
):
def
build
(
self
):
pass
pass
...
@@ -47,6 +66,8 @@ class BasePattern(Graph):
...
@@ -47,6 +66,8 @@ class BasePattern(Graph):
@
register_pattern
@
register_pattern
class
QKVPattern
(
BasePattern
):
class
QKVPattern
(
BasePattern
):
"""The QKV pattern defined by GPT model in PaddleFleetX."""
name
=
"qkv"
name
=
"qkv"
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -55,43 +76,374 @@ class QKVPattern(BasePattern):
...
@@ -55,43 +76,374 @@ class QKVPattern(BasePattern):
def
build
(
self
):
def
build
(
self
):
query
=
self
.
add_node
(
0
,
**
{
"type"
:
"var"
})
query
=
self
.
add_node
(
0
,
**
{
"type"
:
"var"
})
# define q, k, v weight
q_weight
=
self
.
add_node
(
1
,
**
{
"dim"
:
2
,
"type"
:
"param"
})
q_weight
=
self
.
add_node
(
1
,
**
{
"dim"
:
2
,
"type"
:
"param"
})
k_weight
=
self
.
add_node
(
2
,
**
{
"dim"
:
2
,
"type"
:
"param"
})
k_weight
=
self
.
add_node
(
2
,
**
{
"dim"
:
2
,
"type"
:
"param"
})
v_weight
=
self
.
add_node
(
3
,
**
{
"dim"
:
2
,
"type"
:
"param"
})
v_weight
=
self
.
add_node
(
3
,
**
{
"dim"
:
2
,
"type"
:
"param"
})
# define q, k, v matmul_v2
q_matmul
=
self
.
add_node
(
4
,
**
{
"type"
:
"matmul_v2"
})
q_matmul_v2
=
self
.
add_node
(
4
,
**
{
"type"
:
"matmul_v2"
})
k_matmul
=
self
.
add_node
(
5
,
**
{
"type"
:
"matmul_v2"
})
k_matmul_v2
=
self
.
add_node
(
5
,
**
{
"type"
:
"matmul_v2"
})
v_matmul
=
self
.
add_node
(
6
,
**
{
"type"
:
"matmul_v2"
})
v_matmul_v2
=
self
.
add_node
(
6
,
**
{
"type"
:
"matmul_v2"
})
# define input edge
q_x
=
self
.
add_edge
(
0
,
4
,
**
{
"input_name"
:
"X"
})
q_x_edge
=
self
.
add_edge
(
k_x
=
self
.
add_edge
(
0
,
5
,
**
{
"input_name"
:
"X"
})
query
.
id
,
q_matmul_v2
.
id
,
**
{
"input_name"
:
"X"
}
v_x
=
self
.
add_edge
(
0
,
6
,
**
{
"input_name"
:
"X"
})
)
q_y
=
self
.
add_edge
(
1
,
4
,
**
{
"input_name"
:
"Y"
})
k_x_edge
=
self
.
add_edge
(
k_y
=
self
.
add_edge
(
2
,
5
,
**
{
"input_name"
:
"Y"
})
query
.
id
,
k_matmul_v2
.
id
,
**
{
"input_name"
:
"X"
}
v_y
=
self
.
add_edge
(
3
,
6
,
**
{
"input_name"
:
"Y"
})
)
v_x_edge
=
self
.
add_edge
(
query
.
id
,
v_matmul_v2
.
id
,
**
{
"input_name"
:
"X"
}
)
q_y_edge
=
self
.
add_edge
(
q_weight
.
id
,
q_matmul_v2
.
id
,
**
{
"input_name"
:
"Y"
}
)
k_y_edge
=
self
.
add_edge
(
k_weight
.
id
,
k_matmul_v2
.
id
,
**
{
"input_name"
:
"Y"
}
)
v_y_edge
=
self
.
add_edge
(
v_weight
.
id
,
v_matmul_v2
.
id
,
**
{
"input_name"
:
"Y"
}
)
# define q, k, v matmul_v2 output
q
=
self
.
add_node
(
7
,
**
{
"type"
:
"var"
})
q
=
self
.
add_node
(
7
,
**
{
"type"
:
"var"
})
k
=
self
.
add_node
(
8
,
**
{
"type"
:
"var"
})
k
=
self
.
add_node
(
8
,
**
{
"type"
:
"var"
})
v
=
self
.
add_node
(
9
,
**
{
"type"
:
"var"
})
v
=
self
.
add_node
(
9
,
**
{
"type"
:
"var"
})
q_out
=
self
.
add_edge
(
4
,
7
,
**
{
"output_name"
:
"Out"
})
# define output edge
k_out
=
self
.
add_edge
(
5
,
8
,
**
{
"output_name"
:
"Out"
})
q_out_edge
=
self
.
add_edge
(
v_out
=
self
.
add_edge
(
6
,
9
,
**
{
"output_name"
:
"Out"
})
q_matmul_v2
.
id
,
q
.
id
,
**
{
"output_name"
:
"Out"
}
)
k_out_edge
=
self
.
add_edge
(
k_matmul_v2
.
id
,
k
.
id
,
**
{
"output_name"
:
"Out"
}
)
v_out_edge
=
self
.
add_edge
(
v_matmul_v2
.
id
,
v
.
id
,
**
{
"output_name"
:
"Out"
}
)
# define shard_spec
shard_spec
=
{
"dp_mp"
:
{
0
:
[
0
,
-
1
,
-
1
],
1
:
[
-
1
,
1
],
2
:
[
-
1
,
1
],
3
:
[
-
1
,
1
],
},
"mp_dp"
:
{
0
:
[
1
,
-
1
,
-
1
],
1
:
[
-
1
,
0
],
2
:
[
-
1
,
0
],
3
:
[
-
1
,
0
],
},
"mp"
:
{
0
:
[
-
1
,
-
1
,
-
1
],
1
:
[
-
1
,
0
],
2
:
[
-
1
,
0
],
3
:
[
-
1
,
0
]},
"dp"
:
{
0
:
[
0
,
-
1
,
-
1
],
1
:
[
-
1
,
-
1
],
2
:
[
-
1
,
-
1
],
3
:
[
-
1
,
-
1
],
},
}
self
.
attrs
[
"shard_spec"
]
=
shard_spec
# define sharded_tensors
self
.
attrs
[
"sharded_tensors"
]
=
4
@
register_pattern
class
RowMatmulPattern
(
BasePattern
):
"""Row matmul pattern defined by GPT model in PaddleFleetX."""
name
=
"row_matmul"
def
__init__
(
self
):
super
().
__init__
()
def
build
(
self
):
# define reshape input
input
=
self
.
add_node
(
0
,
**
{
"type"
:
"var"
})
# define reshape
reshape
=
self
.
add_node
(
1
,
**
{
"type"
:
"reshape2"
})
# define reshape input egde
x_edge
=
self
.
add_edge
(
input
.
id
,
reshape
.
id
,
**
{
"input_name"
:
"X"
})
# define reshape out
output
=
self
.
add_node
(
2
,
**
{
"type"
:
"var"
})
# define reshape output edge
out_edge
=
self
.
add_edge
(
reshape
.
id
,
output
.
id
,
**
{
"output_name"
:
"Out"
}
)
# define matmul_v2 weight
weight
=
self
.
add_node
(
3
,
**
{
"dim"
:
2
,
"type"
:
"param"
})
# define matmul_v2
matmul_v2
=
self
.
add_node
(
4
,
**
{
"type"
:
"matmul_v2"
})
# Pattern
# define input edge
self
.
attrs
[
"shard_spec"
]
=
[
x_edge
=
self
.
add_edge
(
output
.
id
,
matmul_v2
.
id
,
**
{
"input_name"
:
"X"
})
[(
1
,
2
,
3
),
[[
-
1
,
0
],
[
-
1
,
1
]]],
y_edge
=
self
.
add_edge
(
weight
.
id
,
matmul_v2
.
id
,
**
{
"input_name"
:
"Y"
})
]
# 2-tuple list such as [(tensor_id, shard_spec)]
# define q, k, v matmul_v2 output
output
=
self
.
add_node
(
5
,
**
{
"type"
:
"var"
})
# define output edge
out_edge
=
self
.
add_edge
(
matmul_v2
.
id
,
output
.
id
,
**
{
"output_name"
:
"Out"
}
)
# define shard_spec
shard_spec
=
{
"dp_mp"
:
{
3
:
[
1
,
-
1
],
},
"mp_dp"
:
{
3
:
[
0
,
-
1
],
},
"mp"
:
{
3
:
[
0
,
-
1
]},
"dp"
:
{
3
:
[
-
1
,
-
1
],
},
}
self
.
attrs
[
"shard_spec"
]
=
shard_spec
# define sharded_tensors
self
.
attrs
[
"sharded_tensors"
]
=
1
@
register_pattern
class
FFNPattrern
(
BasePattern
):
"""FFN pattern defined by GPT model in PaddleFleetX."""
name
=
"ffn"
def
__init__
(
self
):
super
().
__init__
()
def
build
(
self
):
x
=
self
.
add_node
(
0
,
**
{
"type"
:
"var"
})
w1_weight
=
self
.
add_node
(
1
,
**
{
"dim"
:
2
,
"type"
:
"param"
})
w1_matmul
=
self
.
add_node
(
2
,
**
{
"type"
:
"matmul_v2"
})
w1_x
=
self
.
add_edge
(
0
,
2
,
**
{
"input_name"
:
"X"
})
w1_y
=
self
.
add_edge
(
1
,
2
,
**
{
"input_name"
:
"Y"
})
out1
=
self
.
add_node
(
3
,
**
{
"type"
:
"var"
})
w1_out
=
self
.
add_edge
(
2
,
3
,
**
{
"output_name"
:
"Out"
})
w1_b
=
self
.
add_node
(
4
,
**
{
"dim"
:
1
,
"type"
:
"param"
})
add1
=
self
.
add_node
(
5
,
**
{
"type"
:
"elementwise_add"
})
add1_x
=
self
.
add_edge
(
3
,
5
,
**
{
"input_name"
:
"X"
})
add1_y
=
self
.
add_edge
(
4
,
5
,
**
{
"input_name"
:
"Y"
})
out2
=
self
.
add_node
(
6
,
**
{
"type"
:
"var"
})
add1_out
=
self
.
add_edge
(
5
,
6
,
**
{
"output_name"
:
"Out"
})
gelu
=
self
.
add_node
(
7
,
**
{
"type"
:
"gelu"
})
gelu_x
=
self
.
add_edge
(
6
,
7
,
**
{
"input_name"
:
"X"
})
out3
=
self
.
add_node
(
8
,
**
{
"type"
:
"var"
})
gelu_out
=
self
.
add_edge
(
7
,
8
,
**
{
"output_name"
:
"Out"
})
w2_weight
=
self
.
add_node
(
9
,
**
{
"dim"
:
2
,
"type"
:
"param"
})
w2_matmul
=
self
.
add_node
(
10
,
**
{
"type"
:
"matmul_v2"
})
w1_x
=
self
.
add_edge
(
8
,
10
,
**
{
"input_name"
:
"X"
})
w1_y
=
self
.
add_edge
(
9
,
10
,
**
{
"input_name"
:
"Y"
})
out4
=
self
.
add_node
(
11
,
**
{
"type"
:
"var"
})
w2_out
=
self
.
add_edge
(
10
,
11
,
**
{
"output_name"
:
"Out"
})
w2_b
=
self
.
add_node
(
12
,
**
{
"dim"
:
1
,
"type"
:
"param"
})
add2
=
self
.
add_node
(
13
,
**
{
"type"
:
"elementwise_add"
})
add2_x
=
self
.
add_edge
(
11
,
13
,
**
{
"input_name"
:
"X"
})
add2_y
=
self
.
add_edge
(
12
,
13
,
**
{
"input_name"
:
"Y"
})
out5
=
self
.
add_node
(
14
,
**
{
"type"
:
"var"
})
add2_out
=
self
.
add_edge
(
13
,
14
,
**
{
"output_name"
:
"Out"
})
# define shard_spec
shard_spec
=
{
"dp_mp"
:
{
0
:
[
0
,
-
1
,
-
1
],
1
:
[
-
1
,
1
],
9
:
[
1
,
-
1
]},
"mp_dp"
:
{
0
:
[
1
,
-
1
,
-
1
],
1
:
[
-
1
,
0
],
9
:
[
0
,
-
1
]},
"mp"
:
{
1
:
[
-
1
,
0
],
9
:
[
0
,
-
1
]},
"dp"
:
{
0
:
[
0
,
-
1
,
-
1
],
1
:
[
-
1
,
-
1
],
9
:
[
-
1
,
-
1
]},
}
self
.
attrs
[
"shard_spec"
]
=
shard_spec
# define sharded_tensors
self
.
attrs
[
"sharded_tensors"
]
=
2
@
register_pattern
class
SharedWordEmbeddingPattern
(
BasePattern
):
"""Sharded word embedding pattern defined by GPT model in PaddleFleetX."""
name
=
"shared_word_embedding"
def
__init__
(
self
):
super
().
__init__
()
def
convert_to_graph
(
ops
,
block
):
def
build
(
self
):
# define embedding input
tokens
=
self
.
add_node
(
0
,
**
{
"type"
:
"data"
})
word_embeddings
=
self
.
add_node
(
1
,
**
{
"dim"
:
2
,
"type"
:
"param"
})
# define embedding
embedding
=
self
.
add_node
(
2
,
**
{
"type"
:
"lookup_table_v2"
})
# define embedding input edge
ids
=
self
.
add_edge
(
0
,
2
,
**
{
"input_name"
:
"Ids"
})
w
=
self
.
add_edge
(
1
,
2
,
**
{
"input_name"
:
"W"
})
# define embedding output
out
=
self
.
add_node
(
3
,
**
{
"type"
:
"var"
})
# define embedding output edge
out_edge
=
self
.
add_edge
(
2
,
3
,
**
{
"output_name"
:
"Out"
})
# define matmul_v2 input
x
=
self
.
add_node
(
4
,
**
{
"type"
:
"var"
})
# define matmul_v2
matmul
=
self
.
add_node
(
5
,
**
{
"type"
:
"matmul_v2"
})
# define matmul_v2 input edge
x_edge
=
self
.
add_edge
(
4
,
5
,
**
{
"input_name"
:
"X"
})
y_edge
=
self
.
add_edge
(
1
,
5
,
**
{
"input_name"
:
"Y"
})
# define matmul_v2 output
out
=
self
.
add_node
(
6
,
**
{
"type"
:
"var"
})
# define matmul_v2 output edge
out_edge
=
self
.
add_edge
(
5
,
6
,
**
{
"output_name"
:
"Out"
})
# define shard_spec
shard_spec
=
{
"dp_mp"
:
{
0
:
[
0
,
-
1
],
1
:
[
1
,
-
1
],
4
:
[
0
,
-
1
,
-
1
]},
"mp_dp"
:
{
0
:
[
1
,
-
1
],
1
:
[
0
,
-
1
],
4
:
[
1
,
-
1
,
-
1
]},
"mp"
:
{
0
:
[
-
1
,
-
1
],
1
:
[
0
,
-
1
],
4
:
[
-
1
,
-
1
,
-
1
]},
"dp"
:
{
0
:
[
0
,
-
1
],
1
:
[
-
1
,
-
1
],
4
:
[
0
,
-
1
,
-
1
]},
}
self
.
attrs
[
"shard_spec"
]
=
shard_spec
self
.
attrs
[
"sharded_tensors"
]
=
3
@
register_pattern
class
PositionEmbeddingPattern
(
BasePattern
):
"""Position embedding pattern defined by GPT model in PaddleFleetX."""
name
=
"position_embedding"
def
__init__
(
self
):
super
().
__init__
()
def
build
(
self
):
# define embedding input
tokens
=
self
.
add_node
(
0
,
**
{
"type"
:
"data"
})
word_embeddings
=
self
.
add_node
(
1
,
**
{
"dim"
:
2
,
"type"
:
"param"
})
# define embedding
embedding
=
self
.
add_node
(
2
,
**
{
"type"
:
"lookup_table_v2"
})
# define embedding input edge
ids
=
self
.
add_edge
(
0
,
2
,
**
{
"input_name"
:
"Ids"
})
w
=
self
.
add_edge
(
1
,
2
,
**
{
"input_name"
:
"W"
})
# define embedding output
out
=
self
.
add_node
(
3
,
**
{
"type"
:
"var"
})
# define embedding output edge
out_edge
=
self
.
add_edge
(
2
,
3
,
**
{
"output_name"
:
"Out"
})
# define shard_spec
shard_spec
=
{
"dp_mp"
:
{
0
:
[
0
,
-
1
],
1
:
[
-
1
,
-
1
],
3
:
[
-
1
,
-
1
,
-
1
]},
"mp_dp"
:
{
0
:
[
1
,
-
1
],
1
:
[
-
1
,
-
1
],
3
:
[
1
,
-
1
,
-
1
]},
"mp"
:
{
0
:
[
-
1
,
-
1
],
1
:
[
-
1
,
-
1
],
3
:
[
-
1
,
-
1
,
-
1
]},
"dp"
:
{
0
:
[
0
,
-
1
],
1
:
[
-
1
,
-
1
],
3
:
[
0
,
-
1
,
-
1
]},
}
self
.
attrs
[
"shard_spec"
]
=
shard_spec
# define sharded_tensors
self
.
attrs
[
"sharded_tensors"
]
=
1
@
register_pattern
class
UnsqueezeDataPattern
(
BasePattern
):
"""Unsqueeze data pattern defined by GPT model in the PaddleFleetX."""
name
=
"unsqueeze_data"
def
__init__
(
self
):
super
().
__init__
()
def
build
(
self
):
# define unsequeeze input
tokens
=
self
.
add_node
(
0
,
**
{
"type"
:
"data"
})
# define unsequeeze
unsqueeze
=
self
.
add_node
(
1
,
**
{
"type"
:
"unsqueeze2"
})
# define unsequeeze input edge
x_edge
=
self
.
add_edge
(
0
,
1
,
**
{
"input_name"
:
"X"
})
# pattern: pure mp or hybrid dp+mp
shard_spec
=
{
"dp_mp"
:
{
0
:
[
0
,
-
1
]},
"mp_dp"
:
{
0
:
[
1
,
-
1
]},
"mp"
:
{
0
:
[
-
1
,
-
1
]},
"dp"
:
{
0
:
[
0
,
-
1
]},
}
self
.
attrs
[
"shard_spec"
]
=
shard_spec
self
.
attrs
[
"sharded_tensors"
]
=
1
@
register_pattern
class
ReshapeDataPattern
(
BasePattern
):
"""Reshape data pattern defined by GPT model in PaddleFleetX."""
name
=
"reshape_data"
def
__init__
(
self
):
super
().
__init__
()
def
build
(
self
):
# define unsequeeze input
data
=
self
.
add_node
(
0
,
**
{
"type"
:
"data"
})
# define unsequeeze
reshape
=
self
.
add_node
(
1
,
**
{
"type"
:
"reshape2"
})
# define unsequeeze input edge
x_edge
=
self
.
add_edge
(
0
,
1
,
**
{
"input_name"
:
"X"
})
# define shard_spec
shard_spec
=
{
"dp_mp"
:
{
0
:
[
0
,
-
1
]},
"mp_dp"
:
{
0
:
[
1
,
-
1
]},
"mp"
:
{
0
:
[
-
1
,
-
1
]},
"dp"
:
{
0
:
[
0
,
-
1
]},
}
self
.
attrs
[
"shard_spec"
]
=
shard_spec
# define sharded_tensors
self
.
attrs
[
"sharded_tensors"
]
=
1
class
GraphUtil
:
"""Graph util is used to convert ops to graph or match pattern for graph."""
@
staticmethod
def
convert_to_graph
(
block
):
"""Convert ops to graph."""
"""Convert ops to graph."""
graph
=
Graph
()
graph
=
Graph
()
graph
.
attrs
[
"var_to_id"
]
=
{}
# {var_name: node_id}
graph
.
attrs
[
"var_to_id"
]
=
{}
# {var_name: node_id}
graph
.
attrs
[
"id_to_var"
]
=
{}
# {node_id: var_name}
graph
.
attrs
[
"id_to_var_desc_id"
]
=
{}
# {node_id: var_desc_id}
graph
.
attrs
[
"id_to_var_name"
]
=
{}
graph
.
attrs
[
"op_to_id"
]
=
{}
# {op_id: node_id}
graph
.
attrs
[
"op_to_id"
]
=
{}
# {op_id: node_id}
graph
.
attrs
[
"id_to_op"
]
=
{}
# {node_id: op_id
}
graph
.
attrs
[
"id_to_op"
]
=
{}
# {node_id: op
}
ops
=
block
.
ops
node_id
=
-
1
node_id
=
-
1
for
op
in
ops
:
for
op
in
ops
:
attrs
=
op
.
all_attrs
()
attrs
=
op
.
all_attrs
()
...
@@ -101,7 +453,7 @@ def convert_to_graph(ops, block):
...
@@ -101,7 +453,7 @@ def convert_to_graph(ops, block):
# create op node
# create op node
op_node
=
graph
.
add_node
(
node_id
,
**
attrs
)
op_node
=
graph
.
add_node
(
node_id
,
**
attrs
)
graph
.
attrs
[
"op_to_id"
][
op
.
desc
.
id
()]
=
op_node
.
id
graph
.
attrs
[
"op_to_id"
][
op
.
desc
.
id
()]
=
op_node
.
id
graph
.
attrs
[
"id_to_op"
][
op_node
.
id
]
=
op
.
desc
.
id
()
graph
.
attrs
[
"id_to_op"
][
op_node
.
id
]
=
op
graph
.
_attr_to_nodes
[
op_node
.
id
]
=
{}
graph
.
_attr_to_nodes
[
op_node
.
id
]
=
{}
for
input_name
in
op
.
input_names
:
for
input_name
in
op
.
input_names
:
graph
.
_attr_to_nodes
[
op_node
.
id
][
input_name
]
=
[]
graph
.
_attr_to_nodes
[
op_node
.
id
][
input_name
]
=
[]
...
@@ -114,10 +466,16 @@ def convert_to_graph(ops, block):
...
@@ -114,10 +466,16 @@ def convert_to_graph(ops, block):
if
var
.
is_parameter
:
if
var
.
is_parameter
:
var_node
.
attrs
[
"type"
]
=
"param"
var_node
.
attrs
[
"type"
]
=
"param"
var_node
.
attrs
[
"dim"
]
=
len
(
var
.
shape
)
var_node
.
attrs
[
"dim"
]
=
len
(
var
.
shape
)
elif
var
.
is_data
:
var_node
.
attrs
[
"type"
]
=
"data"
var_node
.
attrs
[
"dim"
]
=
len
(
var
.
shape
)
else
:
else
:
var_node
.
attrs
[
"type"
]
=
"var"
var_node
.
attrs
[
"type"
]
=
"var"
graph
.
attrs
[
"var_to_id"
][
var_name
]
=
var_node
.
id
graph
.
attrs
[
"var_to_id"
][
var_name
]
=
var_node
.
id
graph
.
attrs
[
"id_to_var"
][
var_node
.
id
]
=
var_name
graph
.
attrs
[
"id_to_var_desc_id"
][
var_node
.
id
]
=
var
.
desc
.
original_id
()
graph
.
attrs
[
"id_to_var_name"
][
var_node
.
id
]
=
var_name
else
:
else
:
var_node_id
=
graph
.
attrs
[
"var_to_id"
][
var_name
]
var_node_id
=
graph
.
attrs
[
"var_to_id"
][
var_name
]
var_node
=
graph
.
_nodes
[
var_node_id
]
var_node
=
graph
.
_nodes
[
var_node_id
]
...
@@ -125,7 +483,9 @@ def convert_to_graph(ops, block):
...
@@ -125,7 +483,9 @@ def convert_to_graph(ops, block):
# create edge that input -> op
# create edge that input -> op
input_edge
=
graph
.
add_edge
(
var_node
.
id
,
op_node
.
id
)
input_edge
=
graph
.
add_edge
(
var_node
.
id
,
op_node
.
id
)
input_edge
.
attrs
[
"input_name"
]
=
input_name
input_edge
.
attrs
[
"input_name"
]
=
input_name
graph
.
_attr_to_nodes
[
op_node
.
id
][
input_name
].
append
(
var_node
)
graph
.
_attr_to_nodes
[
op_node
.
id
][
input_name
].
append
(
var_node
)
for
output_name
in
op
.
output_names
:
for
output_name
in
op
.
output_names
:
graph
.
_attr_to_nodes
[
op_node
.
id
][
output_name
]
=
[]
graph
.
_attr_to_nodes
[
op_node
.
id
][
output_name
]
=
[]
...
@@ -140,7 +500,12 @@ def convert_to_graph(ops, block):
...
@@ -140,7 +500,12 @@ def convert_to_graph(ops, block):
else
:
else
:
var_node
.
attrs
[
"type"
]
=
"var"
var_node
.
attrs
[
"type"
]
=
"var"
graph
.
attrs
[
"var_to_id"
][
var_name
]
=
var_node
.
id
graph
.
attrs
[
"var_to_id"
][
var_name
]
=
var_node
.
id
graph
.
attrs
[
"id_to_var"
][
var_node
.
id
]
=
var_name
graph
.
attrs
[
"id_to_var_desc_id"
][
var_node
.
id
]
=
var
.
desc
.
original_id
()
graph
.
attrs
[
"id_to_var_name"
][
var_node
.
id
]
=
var_name
else
:
else
:
var_node_id
=
graph
.
attrs
[
"var_to_id"
][
var_name
]
var_node_id
=
graph
.
attrs
[
"var_to_id"
][
var_name
]
var_node
=
graph
.
_nodes
[
var_node_id
]
var_node
=
graph
.
_nodes
[
var_node_id
]
...
@@ -155,24 +520,24 @@ def convert_to_graph(ops, block):
...
@@ -155,24 +520,24 @@ def convert_to_graph(ops, block):
return
graph
return
graph
@
staticmethod
def
match
(
pattern
,
graph
):
def
match_pattern
(
pattern
,
graph
):
def
_is_op_node
(
node
):
def
_is_op_node
(
node
):
"""Judge whether node is op node
"""
"""Judge whether node is op node.
"""
if
node
.
attrs
[
"type"
]
not
in
[
"var"
,
"param"
,
"data"
]:
if
node
.
attrs
[
"type"
]
not
in
[
"var"
,
"param"
,
"data"
]:
return
True
return
True
return
False
return
False
def
_compare_op_node
(
src
,
tgt
):
def
_compare_op_node
(
src
,
tgt
):
"""Compare whether two op nodes are equal
"""
"""Compare whether two op nodes are equivalent.
"""
if
src
.
attrs
[
"type"
]
!=
tgt
.
attrs
[
"type"
]:
if
src
.
attrs
[
"type"
]
!=
tgt
.
attrs
[
"type"
]:
return
False
return
False
return
True
return
True
def
_compare_var_node
(
src
,
tgt
):
def
_compare_var_node
(
src
,
tgt
):
"""Compare whether two var nodes are equal
"""
"""Compare whether two var nodes are equivalent.
"""
for
key
in
src
.
attrs
:
for
key
in
src
.
attrs
:
if
key
not
in
tgt
.
attrs
:
if
key
not
in
tgt
.
attrs
:
return
False
return
False
...
@@ -183,13 +548,14 @@ def match(pattern, graph):
...
@@ -183,13 +548,14 @@ def match(pattern, graph):
def
_match_core
(
src_node
,
tgt_node
):
def
_match_core
(
src_node
,
tgt_node
):
nonlocal
not_matched
nonlocal
not_matched
# do
not support one input name or output name corresponding to multiple vars
#
not support one input name or output name corresponding to multiple vars
if
not_matched
:
if
not_matched
:
return
return
if
_is_op_node
(
src_node
):
if
_is_op_node
(
src_node
):
# compare op node whether equal
# compare op node whether equal
if
not
_compare_op_node
(
src_node
,
tgt_node
):
if
not
_compare_op_node
(
src_node
,
tgt_node
):
not_matched
=
True
return
return
result
[
src_node
.
id
]
=
tgt_node
.
id
result
[
src_node
.
id
]
=
tgt_node
.
id
...
@@ -232,14 +598,14 @@ def match(pattern, graph):
...
@@ -232,14 +598,14 @@ def match(pattern, graph):
_match_core
(
node
,
compare_nodes
[
0
])
_match_core
(
node
,
compare_nodes
[
0
])
else
:
else
:
# compare var node
whether equal
# compare var nodes
whether equal
if
not
_compare_var_node
(
src_node
,
tgt_node
):
if
not
_compare_var_node
(
src_node
,
tgt_node
):
not_matched
=
True
not_matched
=
True
return
return
result
[
src_node
.
id
]
=
tgt_node
.
id
result
[
src_node
.
id
]
=
tgt_node
.
id
# as input for op nodes
# as input for op node
src_as_input_node_ids
=
src_edges
[
src_node
.
id
].
keys
()
src_as_input_node_ids
=
src_edges
[
src_node
.
id
].
keys
()
for
node_id
in
src_as_input_node_ids
:
for
node_id
in
src_as_input_node_ids
:
if
node_id
in
result
:
if
node_id
in
result
:
...
@@ -264,7 +630,7 @@ def match(pattern, graph):
...
@@ -264,7 +630,7 @@ def match(pattern, graph):
return
return
_match_core
(
src_nodes
[
node_id
],
compare_node
)
_match_core
(
src_nodes
[
node_id
],
compare_node
)
# as output for nodes
# as output for op node
src_as_output_nodes
=
src_reverse_adjs
[
src_node
.
id
]
src_as_output_nodes
=
src_reverse_adjs
[
src_node
.
id
]
for
node
in
src_as_output_nodes
:
for
node
in
src_as_output_nodes
:
if
node
.
id
in
result
:
if
node
.
id
in
result
:
...
@@ -273,10 +639,11 @@ def match(pattern, graph):
...
@@ -273,10 +639,11 @@ def match(pattern, graph):
src_edge
=
src_edges
[
node
.
id
][
src_node
.
id
]
src_edge
=
src_edges
[
node
.
id
][
src_node
.
id
]
output_name
=
src_edge
.
attrs
[
"output_name"
]
output_name
=
src_edge
.
attrs
[
"output_name"
]
compare_node_id
s
=
tgt_reverse_adjs
[
tgt_node
.
id
]
compare_node
s
=
tgt_reverse_adjs
[
tgt_node
.
id
]
compare_node
=
None
compare_node
=
None
for
node_id
in
compare_node_ids
:
for
item
in
compare_nodes
:
node_id
=
item
.
id
edge
=
tgt_edges
[
node_id
][
tgt_node
.
id
]
edge
=
tgt_edges
[
node_id
][
tgt_node
.
id
]
if
edge
.
attrs
[
"output_name"
]
==
output_name
:
if
edge
.
attrs
[
"output_name"
]
==
output_name
:
compare_node
=
tgt_nodes
[
node_id
]
compare_node
=
tgt_nodes
[
node_id
]
...
@@ -284,11 +651,12 @@ def match(pattern, graph):
...
@@ -284,11 +651,12 @@ def match(pattern, graph):
if
not
compare_node
:
if
not
compare_node
:
not_matched
=
True
not_matched
=
True
return
return
_match_core
(
src_nodes
[
node_
id
],
compare_node
)
_match_core
(
src_nodes
[
node
.
id
],
compare_node
)
results
=
[]
results
=
[]
matched_ids
=
set
()
matched_op_node_ids
=
set
()
result
=
{}
result
=
{}
has_matched
=
set
()
src_nodes
=
pattern
.
nodes
src_nodes
=
pattern
.
nodes
src_edges
=
pattern
.
_adjs
src_edges
=
pattern
.
_adjs
src_reverse_adjs
=
pattern
.
_reverse_adjs
src_reverse_adjs
=
pattern
.
_reverse_adjs
...
@@ -297,7 +665,6 @@ def match(pattern, graph):
...
@@ -297,7 +665,6 @@ def match(pattern, graph):
tgt_edges
=
graph
.
_adjs
tgt_edges
=
graph
.
_adjs
tgt_reverse_adjs
=
graph
.
_reverse_adjs
tgt_reverse_adjs
=
graph
.
_reverse_adjs
tgt_attr_to_nodes
=
graph
.
_attr_to_nodes
tgt_attr_to_nodes
=
graph
.
_attr_to_nodes
not_matched
=
False
# starts with a op node
# starts with a op node
src_start_node
=
None
src_start_node
=
None
...
@@ -311,27 +678,55 @@ def match(pattern, graph):
...
@@ -311,27 +678,55 @@ def match(pattern, graph):
for
node_id
in
tgt_nodes
:
for
node_id
in
tgt_nodes
:
node
=
tgt_nodes
[
node_id
]
node
=
tgt_nodes
[
node_id
]
if
node
.
attrs
[
"type"
]
==
src_start_node
.
attrs
[
"type"
]:
if
node
.
attrs
[
"type"
]
==
src_start_node
.
attrs
[
"type"
]:
not_matched
=
False
_match_core
(
src_start_node
,
node
)
_match_core
(
src_start_node
,
node
)
if
not
not_matched
:
if
not
not_matched
:
need_to_append
=
True
need_to_append
=
True
for
value
in
result
.
values
():
for
value
in
result
.
values
():
if
value
in
has_matched
:
if
value
in
matched_op_node_ids
:
result
=
{}
result
=
{}
need_to_append
=
False
need_to_append
=
False
break
break
if
need_to_append
:
if
need_to_append
:
results
.
append
(
result
)
results
.
append
(
result
)
for
value
in
result
.
values
():
for
value
in
result
.
values
():
has_matched
.
add
(
value
)
matched_ids
.
add
(
value
)
if
value
in
graph
.
attrs
[
"id_to_op"
].
keys
():
matched_op_node_ids
.
add
(
value
)
result
=
{}
result
=
{}
else
:
else
:
not_matched
=
False
not_matched
=
False
result
=
{}
result
=
{}
return
results
,
matched_ids
return
results
@
staticmethod
def
match_all_patterns
(
graph
):
# matched_results maps pattern_name to list which contains pattern node id to graph node id mapping,
# such as {"pattern_name": [{pattern_node_id: graph_node}, ]}
matched_results
=
{}
matched_ids
=
set
()
for
pattern_name
in
_PATTERNS
:
pattern
=
_PATTERNS
[
pattern_name
]
results
,
matched
=
GraphUtil
.
match_pattern
(
pattern
,
graph
)
for
result
in
results
:
has_matched
=
False
for
id
in
result
:
if
result
[
id
]
in
matched_ids
:
has_matched
=
True
break
if
not
has_matched
:
for
item
in
result
:
matched_ids
.
add
(
result
[
id
])
if
pattern
.
name
not
in
matched_results
:
matched_results
[
pattern
.
name
]
=
[]
matched_results
[
pattern
.
name
].
append
(
result
)
return
matched_results
class
OperatorClusteringUtil
:
class
OperatorClusteringUtil
:
"""Operator clustering util is used to cluster operators to layers."""
common_starts
=
[
"layer_norm"
,
"matmul_v2"
,
"matmul"
]
common_starts
=
[
"layer_norm"
,
"matmul_v2"
,
"matmul"
]
@
staticmethod
@
staticmethod
...
@@ -506,6 +901,8 @@ class OperatorClusteringUtil:
...
@@ -506,6 +901,8 @@ class OperatorClusteringUtil:
class
ClusterPartitionUtil
:
class
ClusterPartitionUtil
:
"""Cluster partition util is used to get device meshes and process meshes."""
@
staticmethod
@
staticmethod
def
factorization
(
num
):
def
factorization
(
num
):
factors
=
[]
factors
=
[]
...
@@ -535,13 +932,11 @@ class ClusterPartitionUtil:
...
@@ -535,13 +932,11 @@ class ClusterPartitionUtil:
],
],
)
->
list
:
)
->
list
:
"""
"""
Partition cluster into possible device meshes.
Partiton cluster into possible device meshes.
Args:
Args:
n (int): The number of nodes.
n (int): The number of nodes.
m (int): The number of single devices on each node.
m (int): The number of single devices on each node.
filter (list): Functions for filtering useful meshes
filter (list): Functions for filtering useful meshes
Returns:
Returns:
device_meshed (list) : The possible device meshes.
device_meshed (list) : The possible device meshes.
"""
"""
...
@@ -573,10 +968,8 @@ class ClusterPartitionUtil:
...
@@ -573,10 +968,8 @@ class ClusterPartitionUtil:
def
convert_to_process_meshes
(
device_mesh
:
list
)
->
list
:
def
convert_to_process_meshes
(
device_mesh
:
list
)
->
list
:
"""
"""
Transfer device_meshes into possible process meshes.
Transfer device_meshes into possible process meshes.
Args:
Args:
device meshes (list): [n,m], one device mesh.
device meshes (list): [n,m], one device mesh.
Returns:
Returns:
process_meshes (list): Possible process_meshes
process_meshes (list): Possible process_meshes
"""
"""
...
...
python/paddle/fluid/tests/unittests/auto_parallel/test_pattern.py
浏览文件 @
0bb7c003
...
@@ -95,7 +95,7 @@ def get_gpt_model(
...
@@ -95,7 +95,7 @@ def get_gpt_model(
return
train_program
,
start_program
,
loss
,
gen_data
return
train_program
,
start_program
,
loss
,
gen_data
class
TestGroupOperators
(
unittest
.
TestCase
):
class
TestGroupOperators
AndPatterns
(
unittest
.
TestCase
):
def
test_gpt
(
self
):
def
test_gpt
(
self
):
modeling
.
init_global
()
modeling
.
init_global
()
train_program
=
static
.
Program
()
train_program
=
static
.
Program
()
...
@@ -117,17 +117,30 @@ class TestGroupOperators(unittest.TestCase):
...
@@ -117,17 +117,30 @@ class TestGroupOperators(unittest.TestCase):
)
)
from
paddle.distributed.auto_parallel.tuner.rule_based_tuner
import
(
from
paddle.distributed.auto_parallel.tuner.rule_based_tuner
import
(
_PATTERNS
,
_PATTERNS
,
GraphUtil
,
RuleBasedTuner
,
RuleBasedTuner
,
convert_to_graph
,
)
)
dist_context
=
DistributedContext
()
dist_context
=
DistributedContext
()
tuner
=
RuleBasedTuner
(
dist_context
)
tuner
=
RuleBasedTuner
(
dist_context
)
layers
=
tuner
.
cluster_operators
(
train_program
.
global_block
().
ops
)
layers
=
tuner
.
cluster_operators
(
train_program
.
global_block
().
ops
)
layer
=
layers
[
0
]
graph
=
GraphUtil
.
convert_to_graph
(
train_program
.
global_block
())
graph
=
convert_to_graph
(
layer
,
train_program
.
global_block
())
print
(
"graph: "
,
graph
)
print
(
"graph: "
,
graph
)
print
(
"qkv: "
,
_PATTERNS
[
"qkv"
].
attrs
[
"shard_spec"
])
print
(
"qkv: "
,
_PATTERNS
[
"qkv"
].
attrs
[
"shard_spec"
])
print
(
"row_matmul: "
,
_PATTERNS
[
"row_matmul"
].
attrs
[
"shard_spec"
])
print
(
"ffn: "
,
_PATTERNS
[
"ffn"
].
attrs
[
"shard_spec"
])
print
(
"shared_word_embedding: "
,
_PATTERNS
[
"shared_word_embedding"
].
attrs
[
"shard_spec"
],
)
print
(
"position_embedding: "
,
_PATTERNS
[
"position_embedding"
].
attrs
[
"shard_spec"
],
)
print
(
"unsqueeze_data: "
,
_PATTERNS
[
"unsqueeze_data"
].
attrs
[
"shard_spec"
]
)
print
(
"reshape_data: "
,
_PATTERNS
[
"reshape_data"
].
attrs
[
"shard_spec"
])
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
python/paddle/fluid/tests/unittests/auto_parallel/test_pattern_match.py
浏览文件 @
0bb7c003
...
@@ -95,7 +95,7 @@ def get_gpt_model(
...
@@ -95,7 +95,7 @@ def get_gpt_model(
return
train_program
,
start_program
,
loss
,
gen_data
return
train_program
,
start_program
,
loss
,
gen_data
class
Test
GroupOperators
(
unittest
.
TestCase
):
class
Test
PatternMatch
(
unittest
.
TestCase
):
def
test_gpt
(
self
):
def
test_gpt
(
self
):
modeling
.
init_global
()
modeling
.
init_global
()
train_program
=
static
.
Program
()
train_program
=
static
.
Program
()
...
@@ -116,26 +116,15 @@ class TestGroupOperators(unittest.TestCase):
...
@@ -116,26 +116,15 @@ class TestGroupOperators(unittest.TestCase):
DistributedContext
,
DistributedContext
,
)
)
from
paddle.distributed.auto_parallel.tuner.rule_based_tuner
import
(
from
paddle.distributed.auto_parallel.tuner.rule_based_tuner
import
(
_PATTERNS
,
GraphUtil
,
RuleBasedTuner
,
RuleBasedTuner
,
convert_to_graph
,
match
,
)
)
dist_context
=
DistributedContext
()
dist_context
=
DistributedContext
()
tuner
=
RuleBasedTuner
(
dist_context
)
tuner
=
RuleBasedTuner
(
dist_context
)
layers
=
tuner
.
cluster_operators
(
train_program
.
global_block
().
ops
)
graph
=
GraphUtil
.
convert_to_graph
(
train_program
.
global_block
())
layer
=
layers
[
0
]
results
=
GraphUtil
.
match_all_patterns
(
graph
)
graph
=
convert_to_graph
(
layer
,
train_program
.
global_block
())
print
(
results
)
results
=
match
(
_PATTERNS
[
"qkv"
],
graph
)
shard_tensor_infos
=
_PATTERNS
[
"qkv"
].
attrs
[
"shard_spec"
]
tensor_ids
=
shard_tensor_infos
[
0
][
0
]
if
results
:
for
result
in
results
:
for
node_id
in
result
:
if
node_id
in
tensor_ids
:
print
(
graph
.
attrs
[
"id_to_var"
][
result
[
node_id
]])
print
(
"shard_spec: "
,
shard_tensor_infos
[
0
][
1
])
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录