Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
0bb7c003
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0bb7c003
编写于
3月 21, 2023
作者:
C
caozhou
提交者:
GitHub
3月 21, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto Parallel] Add patterns of rule based tuner (#51859)
* add patterns * add unittest
上级
cdefcd00
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
646 addition
and
251 deletion
+646
-251
python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py
...addle/distributed/auto_parallel/tuner/rule_based_tuner.py
+624
-231
python/paddle/fluid/tests/unittests/auto_parallel/test_pattern.py
...addle/fluid/tests/unittests/auto_parallel/test_pattern.py
+17
-4
python/paddle/fluid/tests/unittests/auto_parallel/test_pattern_match.py
...fluid/tests/unittests/auto_parallel/test_pattern_match.py
+5
-16
未找到文件。
python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py
浏览文件 @
0bb7c003
...
@@ -22,24 +22,43 @@ _PATTERNS = {}
...
@@ -22,24 +22,43 @@ _PATTERNS = {}
def
register_pattern
(
cls
):
def
register_pattern
(
cls
):
"""Register pattern for rule-based tuner."""
"""Register pattern for rule-based tuner."""
name
=
cls
.
name
def
register
(
name
):
def
register
():
global
_PATTERNS
global
_PATTERNS
_PATTERNS
[
name
]
=
cls
()
pattern
=
cls
()
_PATTERNS
[
pattern
.
name
]
=
pattern
# sort patterns according to the number of sharded tensors
# set its dist attr by the fisrt one when a tensor can be matched by multiple patterns.
_PATTERNS
=
dict
(
sorted
(
_PATTERNS
.
items
(),
key
=
lambda
x
:
-
x
[
1
].
attrs
[
"sharded_tensors"
]
)
)
register
(
name
)
register
()
return
cls
return
cls
class
BasePattern
(
Graph
):
class
BasePattern
(
Graph
):
name
=
"base"
"""
Base class of pattern.
The BasePattern inherits the Graph, two important differences are shard_spec and sharded_tensors.
For shard_spec, it indicates the shard specification of tensor node in this pattern under different parallelism.
For sharded_tensors, it represents the number of tensors which sharded.
"""
_name
=
"base"
def
__init__
(
self
):
def
__init__
(
self
):
"""Every pattern has its own name and build method."""
super
().
__init__
()
super
().
__init__
()
self
.
build
()
self
.
build
()
@
property
def
name
(
self
):
return
self
.
__class__
.
_name
@
abstractmethod
@
abstractmethod
def
build
(
self
):
def
build
(
self
):
pass
pass
...
@@ -47,6 +66,8 @@ class BasePattern(Graph):
...
@@ -47,6 +66,8 @@ class BasePattern(Graph):
@
register_pattern
@
register_pattern
class
QKVPattern
(
BasePattern
):
class
QKVPattern
(
BasePattern
):
"""The QKV pattern defined by GPT model in PaddleFleetX."""
name
=
"qkv"
name
=
"qkv"
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -55,81 +76,388 @@ class QKVPattern(BasePattern):
...
@@ -55,81 +76,388 @@ class QKVPattern(BasePattern):
def
build
(
self
):
def
build
(
self
):
query
=
self
.
add_node
(
0
,
**
{
"type"
:
"var"
})
query
=
self
.
add_node
(
0
,
**
{
"type"
:
"var"
})
# define q, k, v weight
q_weight
=
self
.
add_node
(
1
,
**
{
"dim"
:
2
,
"type"
:
"param"
})
q_weight
=
self
.
add_node
(
1
,
**
{
"dim"
:
2
,
"type"
:
"param"
})
k_weight
=
self
.
add_node
(
2
,
**
{
"dim"
:
2
,
"type"
:
"param"
})
k_weight
=
self
.
add_node
(
2
,
**
{
"dim"
:
2
,
"type"
:
"param"
})
v_weight
=
self
.
add_node
(
3
,
**
{
"dim"
:
2
,
"type"
:
"param"
})
v_weight
=
self
.
add_node
(
3
,
**
{
"dim"
:
2
,
"type"
:
"param"
})
# define q, k, v matmul_v2
q_matmul
=
self
.
add_node
(
4
,
**
{
"type"
:
"matmul_v2"
})
q_matmul_v2
=
self
.
add_node
(
4
,
**
{
"type"
:
"matmul_v2"
})
k_matmul
=
self
.
add_node
(
5
,
**
{
"type"
:
"matmul_v2"
})
k_matmul_v2
=
self
.
add_node
(
5
,
**
{
"type"
:
"matmul_v2"
})
v_matmul
=
self
.
add_node
(
6
,
**
{
"type"
:
"matmul_v2"
})
v_matmul_v2
=
self
.
add_node
(
6
,
**
{
"type"
:
"matmul_v2"
})
# define input edge
q_x
=
self
.
add_edge
(
0
,
4
,
**
{
"input_name"
:
"X"
})
q_x_edge
=
self
.
add_edge
(
k_x
=
self
.
add_edge
(
0
,
5
,
**
{
"input_name"
:
"X"
})
query
.
id
,
q_matmul_v2
.
id
,
**
{
"input_name"
:
"X"
}
v_x
=
self
.
add_edge
(
0
,
6
,
**
{
"input_name"
:
"X"
})
)
q_y
=
self
.
add_edge
(
1
,
4
,
**
{
"input_name"
:
"Y"
})
k_x_edge
=
self
.
add_edge
(
k_y
=
self
.
add_edge
(
2
,
5
,
**
{
"input_name"
:
"Y"
})
query
.
id
,
k_matmul_v2
.
id
,
**
{
"input_name"
:
"X"
}
v_y
=
self
.
add_edge
(
3
,
6
,
**
{
"input_name"
:
"Y"
})
)
v_x_edge
=
self
.
add_edge
(
query
.
id
,
v_matmul_v2
.
id
,
**
{
"input_name"
:
"X"
}
)
q_y_edge
=
self
.
add_edge
(
q_weight
.
id
,
q_matmul_v2
.
id
,
**
{
"input_name"
:
"Y"
}
)
k_y_edge
=
self
.
add_edge
(
k_weight
.
id
,
k_matmul_v2
.
id
,
**
{
"input_name"
:
"Y"
}
)
v_y_edge
=
self
.
add_edge
(
v_weight
.
id
,
v_matmul_v2
.
id
,
**
{
"input_name"
:
"Y"
}
)
# define q, k, v matmul_v2 output
q
=
self
.
add_node
(
7
,
**
{
"type"
:
"var"
})
q
=
self
.
add_node
(
7
,
**
{
"type"
:
"var"
})
k
=
self
.
add_node
(
8
,
**
{
"type"
:
"var"
})
k
=
self
.
add_node
(
8
,
**
{
"type"
:
"var"
})
v
=
self
.
add_node
(
9
,
**
{
"type"
:
"var"
})
v
=
self
.
add_node
(
9
,
**
{
"type"
:
"var"
})
q_out
=
self
.
add_edge
(
4
,
7
,
**
{
"output_name"
:
"Out"
})
# define output edge
k_out
=
self
.
add_edge
(
5
,
8
,
**
{
"output_name"
:
"Out"
})
q_out_edge
=
self
.
add_edge
(
v_out
=
self
.
add_edge
(
6
,
9
,
**
{
"output_name"
:
"Out"
})
q_matmul_v2
.
id
,
q
.
id
,
**
{
"output_name"
:
"Out"
}
)
# Pattern
k_out_edge
=
self
.
add_edge
(
self
.
attrs
[
"shard_spec"
]
=
[
k_matmul_v2
.
id
,
k
.
id
,
**
{
"output_name"
:
"Out"
}
[(
1
,
2
,
3
),
[[
-
1
,
0
],
[
-
1
,
1
]]],
)
]
# 2-tuple list such as [(tensor_id, shard_spec)]
v_out_edge
=
self
.
add_edge
(
v_matmul_v2
.
id
,
v
.
id
,
**
{
"output_name"
:
"Out"
}
)
def
convert_to_graph
(
ops
,
block
):
"""Convert ops to graph."""
# define shard_spec
graph
=
Graph
()
shard_spec
=
{
graph
.
attrs
[
"var_to_id"
]
=
{}
# {var_name: node_id}
"dp_mp"
:
{
graph
.
attrs
[
"id_to_var"
]
=
{}
# {node_id: var_name}
0
:
[
0
,
-
1
,
-
1
],
graph
.
attrs
[
"op_to_id"
]
=
{}
# {op_id: node_id}
1
:
[
-
1
,
1
],
graph
.
attrs
[
"id_to_op"
]
=
{}
# {node_id: op_id}
2
:
[
-
1
,
1
],
3
:
[
-
1
,
1
],
node_id
=
-
1
},
for
op
in
ops
:
"mp_dp"
:
{
attrs
=
op
.
all_attrs
()
0
:
[
1
,
-
1
,
-
1
],
attrs
[
"type"
]
=
op
.
type
1
:
[
-
1
,
0
],
node_id
+=
1
2
:
[
-
1
,
0
],
3
:
[
-
1
,
0
],
# create op node
},
op_node
=
graph
.
add_node
(
node_id
,
**
attrs
)
"mp"
:
{
0
:
[
-
1
,
-
1
,
-
1
],
1
:
[
-
1
,
0
],
2
:
[
-
1
,
0
],
3
:
[
-
1
,
0
]},
graph
.
attrs
[
"op_to_id"
][
op
.
desc
.
id
()]
=
op_node
.
id
"dp"
:
{
graph
.
attrs
[
"id_to_op"
][
op_node
.
id
]
=
op
.
desc
.
id
()
0
:
[
0
,
-
1
,
-
1
],
graph
.
_attr_to_nodes
[
op_node
.
id
]
=
{}
1
:
[
-
1
,
-
1
],
for
input_name
in
op
.
input_names
:
2
:
[
-
1
,
-
1
],
graph
.
_attr_to_nodes
[
op_node
.
id
][
input_name
]
=
[]
3
:
[
-
1
,
-
1
],
for
var_name
in
op
.
input
(
input_name
):
},
if
var_name
not
in
graph
.
attrs
[
"var_to_id"
]:
}
# create var node
self
.
attrs
[
"shard_spec"
]
=
shard_spec
node_id
+=
1
# define sharded_tensors
var_node
=
graph
.
add_node
(
node_id
)
self
.
attrs
[
"sharded_tensors"
]
=
4
var
=
block
.
_var_recursive
(
var_name
)
if
var
.
is_parameter
:
var_node
.
attrs
[
"type"
]
=
"param"
@
register_pattern
var_node
.
attrs
[
"dim"
]
=
len
(
var
.
shape
)
class
RowMatmulPattern
(
BasePattern
):
else
:
"""Row matmul pattern defined by GPT model in PaddleFleetX."""
var_node
.
attrs
[
"type"
]
=
"var"
graph
.
attrs
[
"var_to_id"
][
var_name
]
=
var_node
.
id
name
=
"row_matmul"
graph
.
attrs
[
"id_to_var"
][
var_node
.
id
]
=
var_name
else
:
def
__init__
(
self
):
var_node_id
=
graph
.
attrs
[
"var_to_id"
][
var_name
]
super
().
__init__
()
var_node
=
graph
.
_nodes
[
var_node_id
]
def
build
(
self
):
# define reshape input
input
=
self
.
add_node
(
0
,
**
{
"type"
:
"var"
})
# define reshape
reshape
=
self
.
add_node
(
1
,
**
{
"type"
:
"reshape2"
})
# define reshape input egde
x_edge
=
self
.
add_edge
(
input
.
id
,
reshape
.
id
,
**
{
"input_name"
:
"X"
})
# define reshape out
output
=
self
.
add_node
(
2
,
**
{
"type"
:
"var"
})
# define reshape output edge
out_edge
=
self
.
add_edge
(
reshape
.
id
,
output
.
id
,
**
{
"output_name"
:
"Out"
}
)
# define matmul_v2 weight
weight
=
self
.
add_node
(
3
,
**
{
"dim"
:
2
,
"type"
:
"param"
})
# define matmul_v2
matmul_v2
=
self
.
add_node
(
4
,
**
{
"type"
:
"matmul_v2"
})
# define input edge
x_edge
=
self
.
add_edge
(
output
.
id
,
matmul_v2
.
id
,
**
{
"input_name"
:
"X"
})
y_edge
=
self
.
add_edge
(
weight
.
id
,
matmul_v2
.
id
,
**
{
"input_name"
:
"Y"
})
# define q, k, v matmul_v2 output
output
=
self
.
add_node
(
5
,
**
{
"type"
:
"var"
})
# define output edge
out_edge
=
self
.
add_edge
(
matmul_v2
.
id
,
output
.
id
,
**
{
"output_name"
:
"Out"
}
)
# define shard_spec
shard_spec
=
{
"dp_mp"
:
{
3
:
[
1
,
-
1
],
},
"mp_dp"
:
{
3
:
[
0
,
-
1
],
},
"mp"
:
{
3
:
[
0
,
-
1
]},
"dp"
:
{
3
:
[
-
1
,
-
1
],
},
}
self
.
attrs
[
"shard_spec"
]
=
shard_spec
# define sharded_tensors
self
.
attrs
[
"sharded_tensors"
]
=
1
@
register_pattern
class
FFNPattrern
(
BasePattern
):
"""FFN pattern defined by GPT model in PaddleFleetX."""
name
=
"ffn"
def
__init__
(
self
):
super
().
__init__
()
def
build
(
self
):
x
=
self
.
add_node
(
0
,
**
{
"type"
:
"var"
})
w1_weight
=
self
.
add_node
(
1
,
**
{
"dim"
:
2
,
"type"
:
"param"
})
w1_matmul
=
self
.
add_node
(
2
,
**
{
"type"
:
"matmul_v2"
})
w1_x
=
self
.
add_edge
(
0
,
2
,
**
{
"input_name"
:
"X"
})
w1_y
=
self
.
add_edge
(
1
,
2
,
**
{
"input_name"
:
"Y"
})
out1
=
self
.
add_node
(
3
,
**
{
"type"
:
"var"
})
w1_out
=
self
.
add_edge
(
2
,
3
,
**
{
"output_name"
:
"Out"
})
w1_b
=
self
.
add_node
(
4
,
**
{
"dim"
:
1
,
"type"
:
"param"
})
add1
=
self
.
add_node
(
5
,
**
{
"type"
:
"elementwise_add"
})
add1_x
=
self
.
add_edge
(
3
,
5
,
**
{
"input_name"
:
"X"
})
add1_y
=
self
.
add_edge
(
4
,
5
,
**
{
"input_name"
:
"Y"
})
out2
=
self
.
add_node
(
6
,
**
{
"type"
:
"var"
})
add1_out
=
self
.
add_edge
(
5
,
6
,
**
{
"output_name"
:
"Out"
})
gelu
=
self
.
add_node
(
7
,
**
{
"type"
:
"gelu"
})
gelu_x
=
self
.
add_edge
(
6
,
7
,
**
{
"input_name"
:
"X"
})
out3
=
self
.
add_node
(
8
,
**
{
"type"
:
"var"
})
gelu_out
=
self
.
add_edge
(
7
,
8
,
**
{
"output_name"
:
"Out"
})
w2_weight
=
self
.
add_node
(
9
,
**
{
"dim"
:
2
,
"type"
:
"param"
})
w2_matmul
=
self
.
add_node
(
10
,
**
{
"type"
:
"matmul_v2"
})
w1_x
=
self
.
add_edge
(
8
,
10
,
**
{
"input_name"
:
"X"
})
w1_y
=
self
.
add_edge
(
9
,
10
,
**
{
"input_name"
:
"Y"
})
out4
=
self
.
add_node
(
11
,
**
{
"type"
:
"var"
})
w2_out
=
self
.
add_edge
(
10
,
11
,
**
{
"output_name"
:
"Out"
})
w2_b
=
self
.
add_node
(
12
,
**
{
"dim"
:
1
,
"type"
:
"param"
})
add2
=
self
.
add_node
(
13
,
**
{
"type"
:
"elementwise_add"
})
add2_x
=
self
.
add_edge
(
11
,
13
,
**
{
"input_name"
:
"X"
})
add2_y
=
self
.
add_edge
(
12
,
13
,
**
{
"input_name"
:
"Y"
})
out5
=
self
.
add_node
(
14
,
**
{
"type"
:
"var"
})
add2_out
=
self
.
add_edge
(
13
,
14
,
**
{
"output_name"
:
"Out"
})
# define shard_spec
shard_spec
=
{
"dp_mp"
:
{
0
:
[
0
,
-
1
,
-
1
],
1
:
[
-
1
,
1
],
9
:
[
1
,
-
1
]},
"mp_dp"
:
{
0
:
[
1
,
-
1
,
-
1
],
1
:
[
-
1
,
0
],
9
:
[
0
,
-
1
]},
"mp"
:
{
1
:
[
-
1
,
0
],
9
:
[
0
,
-
1
]},
"dp"
:
{
0
:
[
0
,
-
1
,
-
1
],
1
:
[
-
1
,
-
1
],
9
:
[
-
1
,
-
1
]},
}
self
.
attrs
[
"shard_spec"
]
=
shard_spec
# define sharded_tensors
self
.
attrs
[
"sharded_tensors"
]
=
2
@
register_pattern
class
SharedWordEmbeddingPattern
(
BasePattern
):
"""Sharded word embedding pattern defined by GPT model in PaddleFleetX."""
name
=
"shared_word_embedding"
def
__init__
(
self
):
super
().
__init__
()
def
build
(
self
):
# define embedding input
tokens
=
self
.
add_node
(
0
,
**
{
"type"
:
"data"
})
word_embeddings
=
self
.
add_node
(
1
,
**
{
"dim"
:
2
,
"type"
:
"param"
})
# define embedding
embedding
=
self
.
add_node
(
2
,
**
{
"type"
:
"lookup_table_v2"
})
# define embedding input edge
ids
=
self
.
add_edge
(
0
,
2
,
**
{
"input_name"
:
"Ids"
})
w
=
self
.
add_edge
(
1
,
2
,
**
{
"input_name"
:
"W"
})
# define embedding output
out
=
self
.
add_node
(
3
,
**
{
"type"
:
"var"
})
# define embedding output edge
out_edge
=
self
.
add_edge
(
2
,
3
,
**
{
"output_name"
:
"Out"
})
# define matmul_v2 input
x
=
self
.
add_node
(
4
,
**
{
"type"
:
"var"
})
# define matmul_v2
matmul
=
self
.
add_node
(
5
,
**
{
"type"
:
"matmul_v2"
})
# define matmul_v2 input edge
x_edge
=
self
.
add_edge
(
4
,
5
,
**
{
"input_name"
:
"X"
})
y_edge
=
self
.
add_edge
(
1
,
5
,
**
{
"input_name"
:
"Y"
})
# define matmul_v2 output
out
=
self
.
add_node
(
6
,
**
{
"type"
:
"var"
})
# define matmul_v2 output edge
out_edge
=
self
.
add_edge
(
5
,
6
,
**
{
"output_name"
:
"Out"
})
# define shard_spec
shard_spec
=
{
"dp_mp"
:
{
0
:
[
0
,
-
1
],
1
:
[
1
,
-
1
],
4
:
[
0
,
-
1
,
-
1
]},
"mp_dp"
:
{
0
:
[
1
,
-
1
],
1
:
[
0
,
-
1
],
4
:
[
1
,
-
1
,
-
1
]},
"mp"
:
{
0
:
[
-
1
,
-
1
],
1
:
[
0
,
-
1
],
4
:
[
-
1
,
-
1
,
-
1
]},
"dp"
:
{
0
:
[
0
,
-
1
],
1
:
[
-
1
,
-
1
],
4
:
[
0
,
-
1
,
-
1
]},
}
self
.
attrs
[
"shard_spec"
]
=
shard_spec
self
.
attrs
[
"sharded_tensors"
]
=
3
@
register_pattern
class
PositionEmbeddingPattern
(
BasePattern
):
"""Position embedding pattern defined by GPT model in PaddleFleetX."""
name
=
"position_embedding"
def
__init__
(
self
):
super
().
__init__
()
def
build
(
self
):
# define embedding input
tokens
=
self
.
add_node
(
0
,
**
{
"type"
:
"data"
})
word_embeddings
=
self
.
add_node
(
1
,
**
{
"dim"
:
2
,
"type"
:
"param"
})
# define embedding
embedding
=
self
.
add_node
(
2
,
**
{
"type"
:
"lookup_table_v2"
})
# define embedding input edge
ids
=
self
.
add_edge
(
0
,
2
,
**
{
"input_name"
:
"Ids"
})
w
=
self
.
add_edge
(
1
,
2
,
**
{
"input_name"
:
"W"
})
# create edge that input -> op
# define embedding output
input_edge
=
graph
.
add_edge
(
var_node
.
id
,
op_node
.
id
)
out
=
self
.
add_node
(
3
,
**
{
"type"
:
"var"
})
input_edge
.
attrs
[
"input_name"
]
=
input_name
graph
.
_attr_to_nodes
[
op_node
.
id
][
input_name
].
append
(
var_node
)
for
output_name
in
op
.
output_names
:
# define embedding output edge
graph
.
_attr_to_nodes
[
op_node
.
id
][
output_name
]
=
[]
out_edge
=
self
.
add_edge
(
2
,
3
,
**
{
"output_name"
:
"Out"
})
for
var_name
in
op
.
output
(
output_name
):
# define shard_spec
shard_spec
=
{
"dp_mp"
:
{
0
:
[
0
,
-
1
],
1
:
[
-
1
,
-
1
],
3
:
[
-
1
,
-
1
,
-
1
]},
"mp_dp"
:
{
0
:
[
1
,
-
1
],
1
:
[
-
1
,
-
1
],
3
:
[
1
,
-
1
,
-
1
]},
"mp"
:
{
0
:
[
-
1
,
-
1
],
1
:
[
-
1
,
-
1
],
3
:
[
-
1
,
-
1
,
-
1
]},
"dp"
:
{
0
:
[
0
,
-
1
],
1
:
[
-
1
,
-
1
],
3
:
[
0
,
-
1
,
-
1
]},
}
self
.
attrs
[
"shard_spec"
]
=
shard_spec
# define sharded_tensors
self
.
attrs
[
"sharded_tensors"
]
=
1
@
register_pattern
class
UnsqueezeDataPattern
(
BasePattern
):
"""Unsqueeze data pattern defined by GPT model in the PaddleFleetX."""
name
=
"unsqueeze_data"
def
__init__
(
self
):
super
().
__init__
()
def
build
(
self
):
# define unsequeeze input
tokens
=
self
.
add_node
(
0
,
**
{
"type"
:
"data"
})
# define unsequeeze
unsqueeze
=
self
.
add_node
(
1
,
**
{
"type"
:
"unsqueeze2"
})
# define unsequeeze input edge
x_edge
=
self
.
add_edge
(
0
,
1
,
**
{
"input_name"
:
"X"
})
# pattern: pure mp or hybrid dp+mp
shard_spec
=
{
"dp_mp"
:
{
0
:
[
0
,
-
1
]},
"mp_dp"
:
{
0
:
[
1
,
-
1
]},
"mp"
:
{
0
:
[
-
1
,
-
1
]},
"dp"
:
{
0
:
[
0
,
-
1
]},
}
self
.
attrs
[
"shard_spec"
]
=
shard_spec
self
.
attrs
[
"sharded_tensors"
]
=
1
@
register_pattern
class
ReshapeDataPattern
(
BasePattern
):
"""Reshape data pattern defined by GPT model in PaddleFleetX."""
name
=
"reshape_data"
def
__init__
(
self
):
super
().
__init__
()
def
build
(
self
):
# define unsequeeze input
data
=
self
.
add_node
(
0
,
**
{
"type"
:
"data"
})
# define unsequeeze
reshape
=
self
.
add_node
(
1
,
**
{
"type"
:
"reshape2"
})
# define unsequeeze input edge
x_edge
=
self
.
add_edge
(
0
,
1
,
**
{
"input_name"
:
"X"
})
# define shard_spec
shard_spec
=
{
"dp_mp"
:
{
0
:
[
0
,
-
1
]},
"mp_dp"
:
{
0
:
[
1
,
-
1
]},
"mp"
:
{
0
:
[
-
1
,
-
1
]},
"dp"
:
{
0
:
[
0
,
-
1
]},
}
self
.
attrs
[
"shard_spec"
]
=
shard_spec
# define sharded_tensors
self
.
attrs
[
"sharded_tensors"
]
=
1
class
GraphUtil
:
"""Graph util is used to convert ops to graph or match pattern for graph."""
@
staticmethod
def
convert_to_graph
(
block
):
"""Convert ops to graph."""
graph
=
Graph
()
graph
.
attrs
[
"var_to_id"
]
=
{}
# {var_name: node_id}
graph
.
attrs
[
"id_to_var_desc_id"
]
=
{}
# {node_id: var_desc_id}
graph
.
attrs
[
"id_to_var_name"
]
=
{}
graph
.
attrs
[
"op_to_id"
]
=
{}
# {op_id: node_id}
graph
.
attrs
[
"id_to_op"
]
=
{}
# {node_id: op}
ops
=
block
.
ops
node_id
=
-
1
for
op
in
ops
:
attrs
=
op
.
all_attrs
()
attrs
[
"type"
]
=
op
.
type
node_id
+=
1
# create op node
op_node
=
graph
.
add_node
(
node_id
,
**
attrs
)
graph
.
attrs
[
"op_to_id"
][
op
.
desc
.
id
()]
=
op_node
.
id
graph
.
attrs
[
"id_to_op"
][
op_node
.
id
]
=
op
graph
.
_attr_to_nodes
[
op_node
.
id
]
=
{}
for
input_name
in
op
.
input_names
:
graph
.
_attr_to_nodes
[
op_node
.
id
][
input_name
]
=
[]
for
var_name
in
op
.
input
(
input_name
):
if
var_name
not
in
graph
.
attrs
[
"var_to_id"
]:
if
var_name
not
in
graph
.
attrs
[
"var_to_id"
]:
# create var node
# create var node
node_id
+=
1
node_id
+=
1
...
@@ -137,201 +465,268 @@ def convert_to_graph(ops, block):
...
@@ -137,201 +465,268 @@ def convert_to_graph(ops, block):
var
=
block
.
_var_recursive
(
var_name
)
var
=
block
.
_var_recursive
(
var_name
)
if
var
.
is_parameter
:
if
var
.
is_parameter
:
var_node
.
attrs
[
"type"
]
=
"param"
var_node
.
attrs
[
"type"
]
=
"param"
var_node
.
attrs
[
"dim"
]
=
len
(
var
.
shape
)
elif
var
.
is_data
:
var_node
.
attrs
[
"type"
]
=
"data"
var_node
.
attrs
[
"dim"
]
=
len
(
var
.
shape
)
else
:
else
:
var_node
.
attrs
[
"type"
]
=
"var"
var_node
.
attrs
[
"type"
]
=
"var"
graph
.
attrs
[
"var_to_id"
][
var_name
]
=
var_node
.
id
graph
.
attrs
[
"var_to_id"
][
var_name
]
=
var_node
.
id
graph
.
attrs
[
"id_to_var"
][
var_node
.
id
]
=
var_name
graph
.
attrs
[
"id_to_var_desc_id"
][
var_node
.
id
]
=
var
.
desc
.
original_id
()
graph
.
attrs
[
"id_to_var_name"
][
var_node
.
id
]
=
var_name
else
:
else
:
var_node_id
=
graph
.
attrs
[
"var_to_id"
][
var_name
]
var_node_id
=
graph
.
attrs
[
"var_to_id"
][
var_name
]
var_node
=
graph
.
_nodes
[
var_node_id
]
var_node
=
graph
.
_nodes
[
var_node_id
]
# create edge that op -> output
# create edge that input -> op
output_edge
=
graph
.
add_edge
(
op_node
.
id
,
var_node
.
id
)
input_edge
=
graph
.
add_edge
(
var_node
.
id
,
op_node
.
id
)
output_edge
.
attrs
[
"output_name"
]
=
output_name
input_edge
.
attrs
[
"input_name"
]
=
input_name
graph
.
_attr_to_nodes
[
op_node
.
id
][
input_name
].
append
(
graph
.
_attr_to_nodes
[
op_node
.
id
][
output_name
].
append
(
var_node
var_node
)
)
return
graph
for
output_name
in
op
.
output_names
:
graph
.
_attr_to_nodes
[
op_node
.
id
][
output_name
]
=
[]
for
var_name
in
op
.
output
(
output_name
):
if
var_name
not
in
graph
.
attrs
[
"var_to_id"
]:
# create var node
node_id
+=
1
var_node
=
graph
.
add_node
(
node_id
)
var
=
block
.
_var_recursive
(
var_name
)
if
var
.
is_parameter
:
var_node
.
attrs
[
"type"
]
=
"param"
else
:
var_node
.
attrs
[
"type"
]
=
"var"
graph
.
attrs
[
"var_to_id"
][
var_name
]
=
var_node
.
id
graph
.
attrs
[
"id_to_var_desc_id"
][
var_node
.
id
]
=
var
.
desc
.
original_id
()
graph
.
attrs
[
"id_to_var_name"
][
var_node
.
id
]
=
var_name
else
:
var_node_id
=
graph
.
attrs
[
"var_to_id"
][
var_name
]
var_node
=
graph
.
_nodes
[
var_node_id
]
# create edge that op -> output
output_edge
=
graph
.
add_edge
(
op_node
.
id
,
var_node
.
id
)
output_edge
.
attrs
[
"output_name"
]
=
output_name
def
match
(
pattern
,
graph
):
graph
.
_attr_to_nodes
[
op_node
.
id
][
output_name
].
append
(
def
_is_op_node
(
node
):
var_node
"""Judge whether node is op node"""
)
if
node
.
attrs
[
"type"
]
not
in
[
"var"
,
"param"
,
"data"
]:
return
True
return
False
return
graph
def
_compare_op_node
(
src
,
tgt
):
@
staticmethod
"""Compare whether two op nodes are equal"""
def
match_pattern
(
pattern
,
graph
):
if
src
.
attrs
[
"type"
]
!=
tgt
.
attrs
[
"type"
]:
def
_is_op_node
(
node
):
return
False
"""Judge whether node is op node."""
if
node
.
attrs
[
"type"
]
not
in
[
"var"
,
"param"
,
"data"
]:
return
True
return
Tru
e
return
Fals
e
def
_compare_var_node
(
src
,
tgt
):
def
_compare_op_node
(
src
,
tgt
):
"""Compare whether two var nodes are equal"""
"""Compare whether two op nodes are equivalent."""
for
key
in
src
.
attrs
:
if
src
.
attrs
[
"type"
]
!=
tgt
.
attrs
[
"type"
]:
if
key
not
in
tgt
.
attrs
:
return
False
if
src
.
attrs
[
key
]
!=
tgt
.
attrs
[
key
]:
return
False
return
False
return
True
return
True
def
_match_core
(
src_node
,
tgt_node
):
nonlocal
not_matched
# do not support one input name or output name corresponding to multiple vars
if
not_matched
:
return
if
_is_op_node
(
src_node
):
# compare op node whether equal
if
not
_compare_op_node
(
src_node
,
tgt_node
):
return
result
[
src_node
.
id
]
=
tgt_node
.
id
# input var nodes
def
_compare_var_node
(
src
,
tgt
):
src_input_nodes
=
src_reverse_adjs
[
src_node
.
id
]
"""Compare whether two var nodes are equivalent."""
for
node
in
src_input_nodes
:
for
key
in
src
.
attrs
:
# has visited
if
key
not
in
tgt
.
attrs
:
if
node
.
id
in
result
:
return
False
continue
if
src
.
attrs
[
key
]
!=
tgt
.
attrs
[
key
]:
edge
=
src_edges
[
node
.
id
][
src_node
.
id
]
return
False
input_name
=
edge
.
attrs
[
"input_name"
]
# NOTE: do not support one input name or output name corresponding to multiple vars
return
True
compare_nodes
=
tgt_attr_to_nodes
[
tgt_node
.
id
].
get
(
input_name
,
None
)
if
not
compare_nodes
:
not_matched
=
True
return
_match_core
(
node
,
compare_nodes
[
0
])
# output var nodes
def
_match_core
(
src_node
,
tgt_node
):
src_output_node_ids
=
src_edges
[
src_node
.
id
].
keys
()
nonlocal
not_matched
for
node_id
in
src_output_node_ids
:
# not support one input name or output name corresponding to multiple vars
# has visited
if
not_matched
:
if
node_id
in
result
:
return
continue
node
=
src_nodes
[
node_id
]
edge
=
src_edges
[
src_node
.
id
][
node_id
]
output_name
=
edge
.
attrs
[
"output_name"
]
# NOTE: do not support one input name or output name corresponding to multiple vars
if
_is_op_node
(
src_node
):
compare_nodes
=
tgt_attr_to_nodes
[
tgt_node
.
id
].
get
(
# compare op node whether equal
output_name
,
None
if
not
_compare_op_node
(
src_node
,
tgt_node
):
)
if
not
compare_nodes
:
not_matched
=
True
not_matched
=
True
return
return
_match_core
(
node
,
compare_nodes
[
0
])
else
:
# compare var node whether equal
if
not
_compare_var_node
(
src_node
,
tgt_node
):
not_matched
=
True
return
result
[
src_node
.
id
]
=
tgt_node
.
id
# as input for op nodes
result
[
src_node
.
id
]
=
tgt_node
.
id
src_as_input_node_ids
=
src_edges
[
src_node
.
id
].
keys
()
for
node_id
in
src_as_input_node_ids
:
if
node_id
in
result
:
continue
src_edge
=
src_edges
[
src_node
.
id
][
node_id
]
# input var nodes
input_name
=
src_edge
.
attrs
[
"input_name"
]
src_input_nodes
=
src_reverse_adjs
[
src_node
.
id
]
compare_node_ids
=
tgt_edges
[
tgt_node
.
id
].
keys
()
for
node
in
src_input_nodes
:
# has visited
if
node
.
id
in
result
:
continue
edge
=
src_edges
[
node
.
id
][
src_node
.
id
]
input_name
=
edge
.
attrs
[
"input_name"
]
compare_node
=
None
# NOTE: do not support one input name or output name corresponding to multiple vars
for
compare_node_id
in
compare_node_ids
:
compare_nodes
=
tgt_attr_to_nodes
[
tgt_node
.
id
].
get
(
edge
=
tgt_edges
[
tgt_node
.
id
][
compare_node_id
]
input_name
,
None
if
(
)
edge
.
attrs
[
"input_name"
]
==
input_name
if
not
compare_nodes
:
and
compare_node_id
not
in
result
.
values
()
not_matched
=
True
):
return
compare_node
=
tgt_nodes
[
compare_node_id
]
_match_core
(
node
,
compare_nodes
[
0
])
break
# output var nodes
src_output_node_ids
=
src_edges
[
src_node
.
id
].
keys
()
for
node_id
in
src_output_node_ids
:
# has visited
if
node_id
in
result
:
continue
node
=
src_nodes
[
node_id
]
edge
=
src_edges
[
src_node
.
id
][
node_id
]
output_name
=
edge
.
attrs
[
"output_name"
]
# NOTE: do not support one input name or output name corresponding to multiple vars
compare_nodes
=
tgt_attr_to_nodes
[
tgt_node
.
id
].
get
(
output_name
,
None
)
if
not
compare_nodes
:
not_matched
=
True
return
_match_core
(
node
,
compare_nodes
[
0
])
if
not
compare_node
:
else
:
# compare var nodes whether equal
if
not
_compare_var_node
(
src_node
,
tgt_node
):
not_matched
=
True
not_matched
=
True
return
return
_match_core
(
src_nodes
[
node_id
],
compare_node
)
# as output for nodes
result
[
src_node
.
id
]
=
tgt_node
.
id
src_as_output_nodes
=
src_reverse_adjs
[
src_node
.
id
]
for
node
in
src_as_output_nodes
:
# as input for op node
if
node
.
id
in
result
:
src_as_input_node_ids
=
src_edges
[
src_node
.
id
].
keys
()
continue
for
node_id
in
src_as_input_node_ids
:
if
node_id
in
result
:
continue
src_edge
=
src_edges
[
src_node
.
id
][
node_id
]
input_name
=
src_edge
.
attrs
[
"input_name"
]
compare_node_ids
=
tgt_edges
[
tgt_node
.
id
].
keys
()
compare_node
=
None
for
compare_node_id
in
compare_node_ids
:
edge
=
tgt_edges
[
tgt_node
.
id
][
compare_node_id
]
if
(
edge
.
attrs
[
"input_name"
]
==
input_name
and
compare_node_id
not
in
result
.
values
()
):
compare_node
=
tgt_nodes
[
compare_node_id
]
break
src_edge
=
src_edges
[
node
.
id
][
src_node
.
id
]
if
not
compare_node
:
output_name
=
src_edge
.
attrs
[
"output_name"
]
not_matched
=
True
return
_match_core
(
src_nodes
[
node_id
],
compare_node
)
compare_node_ids
=
tgt_reverse_adjs
[
tgt_node
.
id
]
# as output for op node
src_as_output_nodes
=
src_reverse_adjs
[
src_node
.
id
]
for
node
in
src_as_output_nodes
:
if
node
.
id
in
result
:
continue
compare_node
=
None
src_edge
=
src_edges
[
node
.
id
][
src_node
.
id
]
for
node_id
in
compare_node_ids
:
output_name
=
src_edge
.
attrs
[
"output_name"
]
edge
=
tgt_edges
[
node_id
][
tgt_node
.
id
]
if
edge
.
attrs
[
"output_name"
]
==
output_name
:
compare_nodes
=
tgt_reverse_adjs
[
tgt_node
.
id
]
compare_node
=
tgt_nodes
[
node_id
]
break
compare_node
=
None
if
not
compare_node
:
for
item
in
compare_nodes
:
not_matched
=
True
node_id
=
item
.
id
return
edge
=
tgt_edges
[
node_id
][
tgt_node
.
id
]
_match_core
(
src_nodes
[
node_id
],
compare_node
)
if
edge
.
attrs
[
"output_name"
]
==
output_name
:
compare_node
=
tgt_nodes
[
node_id
]
results
=
[]
break
result
=
{}
if
not
compare_node
:
has_matched
=
set
()
not_matched
=
True
src_nodes
=
pattern
.
nodes
return
src_edges
=
pattern
.
_adjs
_match_core
(
src_nodes
[
node
.
id
],
compare_node
)
src_reverse_adjs
=
pattern
.
_reverse_adjs
results
=
[]
tgt_nodes
=
graph
.
nodes
matched_ids
=
set
()
tgt_edges
=
graph
.
_adjs
matched_op_node_ids
=
set
()
tgt_reverse_adjs
=
graph
.
_reverse_adjs
result
=
{}
tgt_attr_to_nodes
=
graph
.
_attr_to_
nodes
src_nodes
=
pattern
.
nodes
not_matched
=
False
src_edges
=
pattern
.
_adjs
src_reverse_adjs
=
pattern
.
_reverse_adjs
# starts with a op node
src_start_node
=
None
tgt_nodes
=
graph
.
nodes
for
node_id
in
src_nodes
:
tgt_edges
=
graph
.
_adjs
node
=
src_nodes
[
node_id
]
tgt_reverse_adjs
=
graph
.
_reverse_adjs
if
node
.
attrs
[
"type"
]
not
in
[
"var"
,
"param"
,
"data"
]:
tgt_attr_to_nodes
=
graph
.
_attr_to_nodes
src_start_node
=
node
break
# starts with a op node
assert
src_start_node
is
not
None
src_start_node
=
None
for
node_id
in
src_nodes
:
for
node_id
in
tgt_nodes
:
node
=
src_nodes
[
node_id
]
node
=
tgt_nodes
[
node_id
]
if
node
.
attrs
[
"type"
]
not
in
[
"var"
,
"param"
,
"data"
]:
if
node
.
attrs
[
"type"
]
==
src_start_node
.
attrs
[
"type"
]:
src_start_node
=
node
_match_core
(
src_start_node
,
node
)
break
if
not
not_matched
:
assert
src_start_node
is
not
None
need_to_append
=
True
for
value
in
result
.
values
()
:
for
node_id
in
tgt_nodes
:
if
value
in
has_matched
:
node
=
tgt_nodes
[
node_id
]
result
=
{}
if
node
.
attrs
[
"type"
]
==
src_start_node
.
attrs
[
"type"
]:
need_to_appen
d
=
False
not_matche
d
=
False
break
_match_core
(
src_start_node
,
node
)
if
n
eed_to_appen
d
:
if
n
ot
not_matche
d
:
results
.
append
(
result
)
need_to_append
=
True
for
value
in
result
.
values
():
for
value
in
result
.
values
():
has_matched
.
add
(
value
)
if
value
in
matched_op_node_ids
:
result
=
{}
need_to_append
=
False
break
if
need_to_append
:
results
.
append
(
result
)
for
value
in
result
.
values
():
matched_ids
.
add
(
value
)
if
value
in
graph
.
attrs
[
"id_to_op"
].
keys
():
matched_op_node_ids
.
add
(
value
)
result
=
{}
else
:
not_matched
=
False
result
=
{}
result
=
{}
else
:
return
results
,
matched_ids
not_matched
=
False
result
=
{}
return
results
@
staticmethod
def
match_all_patterns
(
graph
):
# matched_results maps pattern_name to list which contains pattern node id to graph node id mapping,
# such as {"pattern_name": [{pattern_node_id: graph_node}, ]}
matched_results
=
{}
matched_ids
=
set
()
for
pattern_name
in
_PATTERNS
:
pattern
=
_PATTERNS
[
pattern_name
]
results
,
matched
=
GraphUtil
.
match_pattern
(
pattern
,
graph
)
for
result
in
results
:
has_matched
=
False
for
id
in
result
:
if
result
[
id
]
in
matched_ids
:
has_matched
=
True
break
if
not
has_matched
:
for
item
in
result
:
matched_ids
.
add
(
result
[
id
])
if
pattern
.
name
not
in
matched_results
:
matched_results
[
pattern
.
name
]
=
[]
matched_results
[
pattern
.
name
].
append
(
result
)
return
matched_results
class
OperatorClusteringUtil
:
class
OperatorClusteringUtil
:
"""Operator clustering util is used to cluster operators to layers."""
common_starts
=
[
"layer_norm"
,
"matmul_v2"
,
"matmul"
]
common_starts
=
[
"layer_norm"
,
"matmul_v2"
,
"matmul"
]
@
staticmethod
@
staticmethod
...
@@ -506,6 +901,8 @@ class OperatorClusteringUtil:
...
@@ -506,6 +901,8 @@ class OperatorClusteringUtil:
class
ClusterPartitionUtil
:
class
ClusterPartitionUtil
:
"""Cluster partition util is used to get device meshes and process meshes."""
@
staticmethod
@
staticmethod
def
factorization
(
num
):
def
factorization
(
num
):
factors
=
[]
factors
=
[]
...
@@ -535,13 +932,11 @@ class ClusterPartitionUtil:
...
@@ -535,13 +932,11 @@ class ClusterPartitionUtil:
],
],
)
->
list
:
)
->
list
:
"""
"""
Partition cluster into possible device meshes.
Partiton cluster into possible device meshes.
Args:
Args:
n (int): The number of nodes.
n (int): The number of nodes.
m (int): The number of single devices on each node.
m (int): The number of single devices on each node.
filter (list): Functions for filtering useful meshes
filter (list): Functions for filtering useful meshes
Returns:
Returns:
device_meshed (list) : The possible device meshes.
device_meshed (list) : The possible device meshes.
"""
"""
...
@@ -573,10 +968,8 @@ class ClusterPartitionUtil:
...
@@ -573,10 +968,8 @@ class ClusterPartitionUtil:
def
convert_to_process_meshes
(
device_mesh
:
list
)
->
list
:
def
convert_to_process_meshes
(
device_mesh
:
list
)
->
list
:
"""
"""
Transfer device_meshes into possible process meshes.
Transfer device_meshes into possible process meshes.
Args:
Args:
device meshes (list): [n,m], one device mesh.
device meshes (list): [n,m], one device mesh.
Returns:
Returns:
process_meshes (list): Possible process_meshes
process_meshes (list): Possible process_meshes
"""
"""
...
...
python/paddle/fluid/tests/unittests/auto_parallel/test_pattern.py
浏览文件 @
0bb7c003
...
@@ -95,7 +95,7 @@ def get_gpt_model(
...
@@ -95,7 +95,7 @@ def get_gpt_model(
return
train_program
,
start_program
,
loss
,
gen_data
return
train_program
,
start_program
,
loss
,
gen_data
class
TestGroupOperators
(
unittest
.
TestCase
):
class
TestGroupOperators
AndPatterns
(
unittest
.
TestCase
):
def
test_gpt
(
self
):
def
test_gpt
(
self
):
modeling
.
init_global
()
modeling
.
init_global
()
train_program
=
static
.
Program
()
train_program
=
static
.
Program
()
...
@@ -117,17 +117,30 @@ class TestGroupOperators(unittest.TestCase):
...
@@ -117,17 +117,30 @@ class TestGroupOperators(unittest.TestCase):
)
)
from
paddle.distributed.auto_parallel.tuner.rule_based_tuner
import
(
from
paddle.distributed.auto_parallel.tuner.rule_based_tuner
import
(
_PATTERNS
,
_PATTERNS
,
GraphUtil
,
RuleBasedTuner
,
RuleBasedTuner
,
convert_to_graph
,
)
)
dist_context
=
DistributedContext
()
dist_context
=
DistributedContext
()
tuner
=
RuleBasedTuner
(
dist_context
)
tuner
=
RuleBasedTuner
(
dist_context
)
layers
=
tuner
.
cluster_operators
(
train_program
.
global_block
().
ops
)
layers
=
tuner
.
cluster_operators
(
train_program
.
global_block
().
ops
)
layer
=
layers
[
0
]
graph
=
GraphUtil
.
convert_to_graph
(
train_program
.
global_block
())
graph
=
convert_to_graph
(
layer
,
train_program
.
global_block
())
print
(
"graph: "
,
graph
)
print
(
"graph: "
,
graph
)
print
(
"qkv: "
,
_PATTERNS
[
"qkv"
].
attrs
[
"shard_spec"
])
print
(
"qkv: "
,
_PATTERNS
[
"qkv"
].
attrs
[
"shard_spec"
])
print
(
"row_matmul: "
,
_PATTERNS
[
"row_matmul"
].
attrs
[
"shard_spec"
])
print
(
"ffn: "
,
_PATTERNS
[
"ffn"
].
attrs
[
"shard_spec"
])
print
(
"shared_word_embedding: "
,
_PATTERNS
[
"shared_word_embedding"
].
attrs
[
"shard_spec"
],
)
print
(
"position_embedding: "
,
_PATTERNS
[
"position_embedding"
].
attrs
[
"shard_spec"
],
)
print
(
"unsqueeze_data: "
,
_PATTERNS
[
"unsqueeze_data"
].
attrs
[
"shard_spec"
]
)
print
(
"reshape_data: "
,
_PATTERNS
[
"reshape_data"
].
attrs
[
"shard_spec"
])
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
python/paddle/fluid/tests/unittests/auto_parallel/test_pattern_match.py
浏览文件 @
0bb7c003
...
@@ -95,7 +95,7 @@ def get_gpt_model(
...
@@ -95,7 +95,7 @@ def get_gpt_model(
return
train_program
,
start_program
,
loss
,
gen_data
return
train_program
,
start_program
,
loss
,
gen_data
class
Test
GroupOperators
(
unittest
.
TestCase
):
class
Test
PatternMatch
(
unittest
.
TestCase
):
def
test_gpt
(
self
):
def
test_gpt
(
self
):
modeling
.
init_global
()
modeling
.
init_global
()
train_program
=
static
.
Program
()
train_program
=
static
.
Program
()
...
@@ -116,26 +116,15 @@ class TestGroupOperators(unittest.TestCase):
...
@@ -116,26 +116,15 @@ class TestGroupOperators(unittest.TestCase):
DistributedContext
,
DistributedContext
,
)
)
from
paddle.distributed.auto_parallel.tuner.rule_based_tuner
import
(
from
paddle.distributed.auto_parallel.tuner.rule_based_tuner
import
(
_PATTERNS
,
GraphUtil
,
RuleBasedTuner
,
RuleBasedTuner
,
convert_to_graph
,
match
,
)
)
dist_context
=
DistributedContext
()
dist_context
=
DistributedContext
()
tuner
=
RuleBasedTuner
(
dist_context
)
tuner
=
RuleBasedTuner
(
dist_context
)
layers
=
tuner
.
cluster_operators
(
train_program
.
global_block
().
ops
)
graph
=
GraphUtil
.
convert_to_graph
(
train_program
.
global_block
())
layer
=
layers
[
0
]
results
=
GraphUtil
.
match_all_patterns
(
graph
)
graph
=
convert_to_graph
(
layer
,
train_program
.
global_block
())
print
(
results
)
results
=
match
(
_PATTERNS
[
"qkv"
],
graph
)
shard_tensor_infos
=
_PATTERNS
[
"qkv"
].
attrs
[
"shard_spec"
]
tensor_ids
=
shard_tensor_infos
[
0
][
0
]
if
results
:
for
result
in
results
:
for
node_id
in
result
:
if
node_id
in
tensor_ids
:
print
(
graph
.
attrs
[
"id_to_var"
][
result
[
node_id
]])
print
(
"shard_spec: "
,
shard_tensor_infos
[
0
][
1
])
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录