未验证 提交 0bb7c003 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel] Add patterns of rule based tuner (#51859)

* add patterns

* add unittest
上级 cdefcd00
...@@ -22,24 +22,43 @@ _PATTERNS = {} ...@@ -22,24 +22,43 @@ _PATTERNS = {}
def register_pattern(cls): def register_pattern(cls):
"""Register pattern for rule-based tuner.""" """Register pattern for rule-based tuner."""
name = cls.name
def register(name): def register():
global _PATTERNS global _PATTERNS
_PATTERNS[name] = cls() pattern = cls()
_PATTERNS[pattern.name] = pattern
# sort patterns according to the number of sharded tensors
# set its dist attr by the fisrt one when a tensor can be matched by multiple patterns.
_PATTERNS = dict(
sorted(
_PATTERNS.items(), key=lambda x: -x[1].attrs["sharded_tensors"]
)
)
register(name) register()
return cls return cls
class BasePattern(Graph): class BasePattern(Graph):
name = "base" """
Base class of pattern.
The BasePattern inherits the Graph, two important differences are shard_spec and sharded_tensors.
For shard_spec, it indicates the shard specification of tensor node in this pattern under different parallelism.
For sharded_tensors, it represents the number of tensors which sharded.
"""
_name = "base"
def __init__(self): def __init__(self):
"""Every pattern has its own name and build method."""
super().__init__() super().__init__()
self.build() self.build()
@property
def name(self):
return self.__class__._name
@abstractmethod @abstractmethod
def build(self): def build(self):
pass pass
...@@ -47,6 +66,8 @@ class BasePattern(Graph): ...@@ -47,6 +66,8 @@ class BasePattern(Graph):
@register_pattern @register_pattern
class QKVPattern(BasePattern): class QKVPattern(BasePattern):
"""The QKV pattern defined by GPT model in PaddleFleetX."""
name = "qkv" name = "qkv"
def __init__(self): def __init__(self):
...@@ -55,43 +76,374 @@ class QKVPattern(BasePattern): ...@@ -55,43 +76,374 @@ class QKVPattern(BasePattern):
def build(self): def build(self):
query = self.add_node(0, **{"type": "var"}) query = self.add_node(0, **{"type": "var"})
# define q, k, v weight
q_weight = self.add_node(1, **{"dim": 2, "type": "param"}) q_weight = self.add_node(1, **{"dim": 2, "type": "param"})
k_weight = self.add_node(2, **{"dim": 2, "type": "param"}) k_weight = self.add_node(2, **{"dim": 2, "type": "param"})
v_weight = self.add_node(3, **{"dim": 2, "type": "param"}) v_weight = self.add_node(3, **{"dim": 2, "type": "param"})
# define q, k, v matmul_v2
q_matmul = self.add_node(4, **{"type": "matmul_v2"}) q_matmul_v2 = self.add_node(4, **{"type": "matmul_v2"})
k_matmul = self.add_node(5, **{"type": "matmul_v2"}) k_matmul_v2 = self.add_node(5, **{"type": "matmul_v2"})
v_matmul = self.add_node(6, **{"type": "matmul_v2"}) v_matmul_v2 = self.add_node(6, **{"type": "matmul_v2"})
# define input edge
q_x = self.add_edge(0, 4, **{"input_name": "X"}) q_x_edge = self.add_edge(
k_x = self.add_edge(0, 5, **{"input_name": "X"}) query.id, q_matmul_v2.id, **{"input_name": "X"}
v_x = self.add_edge(0, 6, **{"input_name": "X"}) )
q_y = self.add_edge(1, 4, **{"input_name": "Y"}) k_x_edge = self.add_edge(
k_y = self.add_edge(2, 5, **{"input_name": "Y"}) query.id, k_matmul_v2.id, **{"input_name": "X"}
v_y = self.add_edge(3, 6, **{"input_name": "Y"}) )
v_x_edge = self.add_edge(
query.id, v_matmul_v2.id, **{"input_name": "X"}
)
q_y_edge = self.add_edge(
q_weight.id, q_matmul_v2.id, **{"input_name": "Y"}
)
k_y_edge = self.add_edge(
k_weight.id, k_matmul_v2.id, **{"input_name": "Y"}
)
v_y_edge = self.add_edge(
v_weight.id, v_matmul_v2.id, **{"input_name": "Y"}
)
# define q, k, v matmul_v2 output
q = self.add_node(7, **{"type": "var"}) q = self.add_node(7, **{"type": "var"})
k = self.add_node(8, **{"type": "var"}) k = self.add_node(8, **{"type": "var"})
v = self.add_node(9, **{"type": "var"}) v = self.add_node(9, **{"type": "var"})
q_out = self.add_edge(4, 7, **{"output_name": "Out"}) # define output edge
k_out = self.add_edge(5, 8, **{"output_name": "Out"}) q_out_edge = self.add_edge(
v_out = self.add_edge(6, 9, **{"output_name": "Out"}) q_matmul_v2.id, q.id, **{"output_name": "Out"}
)
k_out_edge = self.add_edge(
k_matmul_v2.id, k.id, **{"output_name": "Out"}
)
v_out_edge = self.add_edge(
v_matmul_v2.id, v.id, **{"output_name": "Out"}
)
# define shard_spec
shard_spec = {
"dp_mp": {
0: [0, -1, -1],
1: [-1, 1],
2: [-1, 1],
3: [-1, 1],
},
"mp_dp": {
0: [1, -1, -1],
1: [-1, 0],
2: [-1, 0],
3: [-1, 0],
},
"mp": {0: [-1, -1, -1], 1: [-1, 0], 2: [-1, 0], 3: [-1, 0]},
"dp": {
0: [0, -1, -1],
1: [-1, -1],
2: [-1, -1],
3: [-1, -1],
},
}
self.attrs["shard_spec"] = shard_spec
# define sharded_tensors
self.attrs["sharded_tensors"] = 4
@register_pattern
class RowMatmulPattern(BasePattern):
"""Row matmul pattern defined by GPT model in PaddleFleetX."""
name = "row_matmul"
def __init__(self):
super().__init__()
def build(self):
# define reshape input
input = self.add_node(0, **{"type": "var"})
# define reshape
reshape = self.add_node(1, **{"type": "reshape2"})
# define reshape input egde
x_edge = self.add_edge(input.id, reshape.id, **{"input_name": "X"})
# define reshape out
output = self.add_node(2, **{"type": "var"})
# define reshape output edge
out_edge = self.add_edge(
reshape.id, output.id, **{"output_name": "Out"}
)
# define matmul_v2 weight
weight = self.add_node(3, **{"dim": 2, "type": "param"})
# define matmul_v2
matmul_v2 = self.add_node(4, **{"type": "matmul_v2"})
# Pattern # define input edge
self.attrs["shard_spec"] = [ x_edge = self.add_edge(output.id, matmul_v2.id, **{"input_name": "X"})
[(1, 2, 3), [[-1, 0], [-1, 1]]], y_edge = self.add_edge(weight.id, matmul_v2.id, **{"input_name": "Y"})
] # 2-tuple list such as [(tensor_id, shard_spec)]
# define q, k, v matmul_v2 output
output = self.add_node(5, **{"type": "var"})
# define output edge
out_edge = self.add_edge(
matmul_v2.id, output.id, **{"output_name": "Out"}
)
# define shard_spec
shard_spec = {
"dp_mp": {
3: [1, -1],
},
"mp_dp": {
3: [0, -1],
},
"mp": {3: [0, -1]},
"dp": {
3: [-1, -1],
},
}
self.attrs["shard_spec"] = shard_spec
# define sharded_tensors
self.attrs["sharded_tensors"] = 1
@register_pattern
class FFNPattrern(BasePattern):
"""FFN pattern defined by GPT model in PaddleFleetX."""
name = "ffn"
def __init__(self):
super().__init__()
def build(self):
x = self.add_node(0, **{"type": "var"})
w1_weight = self.add_node(1, **{"dim": 2, "type": "param"})
w1_matmul = self.add_node(2, **{"type": "matmul_v2"})
w1_x = self.add_edge(0, 2, **{"input_name": "X"})
w1_y = self.add_edge(1, 2, **{"input_name": "Y"})
out1 = self.add_node(3, **{"type": "var"})
w1_out = self.add_edge(2, 3, **{"output_name": "Out"})
w1_b = self.add_node(4, **{"dim": 1, "type": "param"})
add1 = self.add_node(5, **{"type": "elementwise_add"})
add1_x = self.add_edge(3, 5, **{"input_name": "X"})
add1_y = self.add_edge(4, 5, **{"input_name": "Y"})
out2 = self.add_node(6, **{"type": "var"})
add1_out = self.add_edge(5, 6, **{"output_name": "Out"})
gelu = self.add_node(7, **{"type": "gelu"})
gelu_x = self.add_edge(6, 7, **{"input_name": "X"})
out3 = self.add_node(8, **{"type": "var"})
gelu_out = self.add_edge(7, 8, **{"output_name": "Out"})
w2_weight = self.add_node(9, **{"dim": 2, "type": "param"})
w2_matmul = self.add_node(10, **{"type": "matmul_v2"})
w1_x = self.add_edge(8, 10, **{"input_name": "X"})
w1_y = self.add_edge(9, 10, **{"input_name": "Y"})
out4 = self.add_node(11, **{"type": "var"})
w2_out = self.add_edge(10, 11, **{"output_name": "Out"})
w2_b = self.add_node(12, **{"dim": 1, "type": "param"})
add2 = self.add_node(13, **{"type": "elementwise_add"})
add2_x = self.add_edge(11, 13, **{"input_name": "X"})
add2_y = self.add_edge(12, 13, **{"input_name": "Y"})
out5 = self.add_node(14, **{"type": "var"})
add2_out = self.add_edge(13, 14, **{"output_name": "Out"})
# define shard_spec
shard_spec = {
"dp_mp": {0: [0, -1, -1], 1: [-1, 1], 9: [1, -1]},
"mp_dp": {0: [1, -1, -1], 1: [-1, 0], 9: [0, -1]},
"mp": {1: [-1, 0], 9: [0, -1]},
"dp": {0: [0, -1, -1], 1: [-1, -1], 9: [-1, -1]},
}
self.attrs["shard_spec"] = shard_spec
# define sharded_tensors
self.attrs["sharded_tensors"] = 2
@register_pattern
class SharedWordEmbeddingPattern(BasePattern):
"""Sharded word embedding pattern defined by GPT model in PaddleFleetX."""
name = "shared_word_embedding"
def __init__(self):
super().__init__()
def convert_to_graph(ops, block): def build(self):
# define embedding input
tokens = self.add_node(0, **{"type": "data"})
word_embeddings = self.add_node(1, **{"dim": 2, "type": "param"})
# define embedding
embedding = self.add_node(2, **{"type": "lookup_table_v2"})
# define embedding input edge
ids = self.add_edge(0, 2, **{"input_name": "Ids"})
w = self.add_edge(1, 2, **{"input_name": "W"})
# define embedding output
out = self.add_node(3, **{"type": "var"})
# define embedding output edge
out_edge = self.add_edge(2, 3, **{"output_name": "Out"})
# define matmul_v2 input
x = self.add_node(4, **{"type": "var"})
# define matmul_v2
matmul = self.add_node(5, **{"type": "matmul_v2"})
# define matmul_v2 input edge
x_edge = self.add_edge(4, 5, **{"input_name": "X"})
y_edge = self.add_edge(1, 5, **{"input_name": "Y"})
# define matmul_v2 output
out = self.add_node(6, **{"type": "var"})
# define matmul_v2 output edge
out_edge = self.add_edge(5, 6, **{"output_name": "Out"})
# define shard_spec
shard_spec = {
"dp_mp": {0: [0, -1], 1: [1, -1], 4: [0, -1, -1]},
"mp_dp": {0: [1, -1], 1: [0, -1], 4: [1, -1, -1]},
"mp": {0: [-1, -1], 1: [0, -1], 4: [-1, -1, -1]},
"dp": {0: [0, -1], 1: [-1, -1], 4: [0, -1, -1]},
}
self.attrs["shard_spec"] = shard_spec
self.attrs["sharded_tensors"] = 3
@register_pattern
class PositionEmbeddingPattern(BasePattern):
"""Position embedding pattern defined by GPT model in PaddleFleetX."""
name = "position_embedding"
def __init__(self):
super().__init__()
def build(self):
# define embedding input
tokens = self.add_node(0, **{"type": "data"})
word_embeddings = self.add_node(1, **{"dim": 2, "type": "param"})
# define embedding
embedding = self.add_node(2, **{"type": "lookup_table_v2"})
# define embedding input edge
ids = self.add_edge(0, 2, **{"input_name": "Ids"})
w = self.add_edge(1, 2, **{"input_name": "W"})
# define embedding output
out = self.add_node(3, **{"type": "var"})
# define embedding output edge
out_edge = self.add_edge(2, 3, **{"output_name": "Out"})
# define shard_spec
shard_spec = {
"dp_mp": {0: [0, -1], 1: [-1, -1], 3: [-1, -1, -1]},
"mp_dp": {0: [1, -1], 1: [-1, -1], 3: [1, -1, -1]},
"mp": {0: [-1, -1], 1: [-1, -1], 3: [-1, -1, -1]},
"dp": {0: [0, -1], 1: [-1, -1], 3: [0, -1, -1]},
}
self.attrs["shard_spec"] = shard_spec
# define sharded_tensors
self.attrs["sharded_tensors"] = 1
@register_pattern
class UnsqueezeDataPattern(BasePattern):
"""Unsqueeze data pattern defined by GPT model in the PaddleFleetX."""
name = "unsqueeze_data"
def __init__(self):
super().__init__()
def build(self):
# define unsequeeze input
tokens = self.add_node(0, **{"type": "data"})
# define unsequeeze
unsqueeze = self.add_node(1, **{"type": "unsqueeze2"})
# define unsequeeze input edge
x_edge = self.add_edge(0, 1, **{"input_name": "X"})
# pattern: pure mp or hybrid dp+mp
shard_spec = {
"dp_mp": {0: [0, -1]},
"mp_dp": {0: [1, -1]},
"mp": {0: [-1, -1]},
"dp": {0: [0, -1]},
}
self.attrs["shard_spec"] = shard_spec
self.attrs["sharded_tensors"] = 1
@register_pattern
class ReshapeDataPattern(BasePattern):
"""Reshape data pattern defined by GPT model in PaddleFleetX."""
name = "reshape_data"
def __init__(self):
super().__init__()
def build(self):
# define unsequeeze input
data = self.add_node(0, **{"type": "data"})
# define unsequeeze
reshape = self.add_node(1, **{"type": "reshape2"})
# define unsequeeze input edge
x_edge = self.add_edge(0, 1, **{"input_name": "X"})
# define shard_spec
shard_spec = {
"dp_mp": {0: [0, -1]},
"mp_dp": {0: [1, -1]},
"mp": {0: [-1, -1]},
"dp": {0: [0, -1]},
}
self.attrs["shard_spec"] = shard_spec
# define sharded_tensors
self.attrs["sharded_tensors"] = 1
class GraphUtil:
"""Graph util is used to convert ops to graph or match pattern for graph."""
@staticmethod
def convert_to_graph(block):
"""Convert ops to graph.""" """Convert ops to graph."""
graph = Graph() graph = Graph()
graph.attrs["var_to_id"] = {} # {var_name: node_id} graph.attrs["var_to_id"] = {} # {var_name: node_id}
graph.attrs["id_to_var"] = {} # {node_id: var_name} graph.attrs["id_to_var_desc_id"] = {} # {node_id: var_desc_id}
graph.attrs["id_to_var_name"] = {}
graph.attrs["op_to_id"] = {} # {op_id: node_id} graph.attrs["op_to_id"] = {} # {op_id: node_id}
graph.attrs["id_to_op"] = {} # {node_id: op_id} graph.attrs["id_to_op"] = {} # {node_id: op}
ops = block.ops
node_id = -1 node_id = -1
for op in ops: for op in ops:
attrs = op.all_attrs() attrs = op.all_attrs()
...@@ -101,7 +453,7 @@ def convert_to_graph(ops, block): ...@@ -101,7 +453,7 @@ def convert_to_graph(ops, block):
# create op node # create op node
op_node = graph.add_node(node_id, **attrs) op_node = graph.add_node(node_id, **attrs)
graph.attrs["op_to_id"][op.desc.id()] = op_node.id graph.attrs["op_to_id"][op.desc.id()] = op_node.id
graph.attrs["id_to_op"][op_node.id] = op.desc.id() graph.attrs["id_to_op"][op_node.id] = op
graph._attr_to_nodes[op_node.id] = {} graph._attr_to_nodes[op_node.id] = {}
for input_name in op.input_names: for input_name in op.input_names:
graph._attr_to_nodes[op_node.id][input_name] = [] graph._attr_to_nodes[op_node.id][input_name] = []
...@@ -114,10 +466,16 @@ def convert_to_graph(ops, block): ...@@ -114,10 +466,16 @@ def convert_to_graph(ops, block):
if var.is_parameter: if var.is_parameter:
var_node.attrs["type"] = "param" var_node.attrs["type"] = "param"
var_node.attrs["dim"] = len(var.shape) var_node.attrs["dim"] = len(var.shape)
elif var.is_data:
var_node.attrs["type"] = "data"
var_node.attrs["dim"] = len(var.shape)
else: else:
var_node.attrs["type"] = "var" var_node.attrs["type"] = "var"
graph.attrs["var_to_id"][var_name] = var_node.id graph.attrs["var_to_id"][var_name] = var_node.id
graph.attrs["id_to_var"][var_node.id] = var_name graph.attrs["id_to_var_desc_id"][
var_node.id
] = var.desc.original_id()
graph.attrs["id_to_var_name"][var_node.id] = var_name
else: else:
var_node_id = graph.attrs["var_to_id"][var_name] var_node_id = graph.attrs["var_to_id"][var_name]
var_node = graph._nodes[var_node_id] var_node = graph._nodes[var_node_id]
...@@ -125,7 +483,9 @@ def convert_to_graph(ops, block): ...@@ -125,7 +483,9 @@ def convert_to_graph(ops, block):
# create edge that input -> op # create edge that input -> op
input_edge = graph.add_edge(var_node.id, op_node.id) input_edge = graph.add_edge(var_node.id, op_node.id)
input_edge.attrs["input_name"] = input_name input_edge.attrs["input_name"] = input_name
graph._attr_to_nodes[op_node.id][input_name].append(var_node) graph._attr_to_nodes[op_node.id][input_name].append(
var_node
)
for output_name in op.output_names: for output_name in op.output_names:
graph._attr_to_nodes[op_node.id][output_name] = [] graph._attr_to_nodes[op_node.id][output_name] = []
...@@ -140,7 +500,12 @@ def convert_to_graph(ops, block): ...@@ -140,7 +500,12 @@ def convert_to_graph(ops, block):
else: else:
var_node.attrs["type"] = "var" var_node.attrs["type"] = "var"
graph.attrs["var_to_id"][var_name] = var_node.id graph.attrs["var_to_id"][var_name] = var_node.id
graph.attrs["id_to_var"][var_node.id] = var_name graph.attrs["id_to_var_desc_id"][
var_node.id
] = var.desc.original_id()
graph.attrs["id_to_var_name"][
var_node.id
] = var_name
else: else:
var_node_id = graph.attrs["var_to_id"][var_name] var_node_id = graph.attrs["var_to_id"][var_name]
var_node = graph._nodes[var_node_id] var_node = graph._nodes[var_node_id]
...@@ -155,24 +520,24 @@ def convert_to_graph(ops, block): ...@@ -155,24 +520,24 @@ def convert_to_graph(ops, block):
return graph return graph
@staticmethod
def match(pattern, graph): def match_pattern(pattern, graph):
def _is_op_node(node): def _is_op_node(node):
"""Judge whether node is op node""" """Judge whether node is op node."""
if node.attrs["type"] not in ["var", "param", "data"]: if node.attrs["type"] not in ["var", "param", "data"]:
return True return True
return False return False
def _compare_op_node(src, tgt): def _compare_op_node(src, tgt):
"""Compare whether two op nodes are equal""" """Compare whether two op nodes are equivalent."""
if src.attrs["type"] != tgt.attrs["type"]: if src.attrs["type"] != tgt.attrs["type"]:
return False return False
return True return True
def _compare_var_node(src, tgt): def _compare_var_node(src, tgt):
"""Compare whether two var nodes are equal""" """Compare whether two var nodes are equivalent."""
for key in src.attrs: for key in src.attrs:
if key not in tgt.attrs: if key not in tgt.attrs:
return False return False
...@@ -183,13 +548,14 @@ def match(pattern, graph): ...@@ -183,13 +548,14 @@ def match(pattern, graph):
def _match_core(src_node, tgt_node): def _match_core(src_node, tgt_node):
nonlocal not_matched nonlocal not_matched
# do not support one input name or output name corresponding to multiple vars # not support one input name or output name corresponding to multiple vars
if not_matched: if not_matched:
return return
if _is_op_node(src_node): if _is_op_node(src_node):
# compare op node whether equal # compare op node whether equal
if not _compare_op_node(src_node, tgt_node): if not _compare_op_node(src_node, tgt_node):
not_matched = True
return return
result[src_node.id] = tgt_node.id result[src_node.id] = tgt_node.id
...@@ -232,14 +598,14 @@ def match(pattern, graph): ...@@ -232,14 +598,14 @@ def match(pattern, graph):
_match_core(node, compare_nodes[0]) _match_core(node, compare_nodes[0])
else: else:
# compare var node whether equal # compare var nodes whether equal
if not _compare_var_node(src_node, tgt_node): if not _compare_var_node(src_node, tgt_node):
not_matched = True not_matched = True
return return
result[src_node.id] = tgt_node.id result[src_node.id] = tgt_node.id
# as input for op nodes # as input for op node
src_as_input_node_ids = src_edges[src_node.id].keys() src_as_input_node_ids = src_edges[src_node.id].keys()
for node_id in src_as_input_node_ids: for node_id in src_as_input_node_ids:
if node_id in result: if node_id in result:
...@@ -264,7 +630,7 @@ def match(pattern, graph): ...@@ -264,7 +630,7 @@ def match(pattern, graph):
return return
_match_core(src_nodes[node_id], compare_node) _match_core(src_nodes[node_id], compare_node)
# as output for nodes # as output for op node
src_as_output_nodes = src_reverse_adjs[src_node.id] src_as_output_nodes = src_reverse_adjs[src_node.id]
for node in src_as_output_nodes: for node in src_as_output_nodes:
if node.id in result: if node.id in result:
...@@ -273,10 +639,11 @@ def match(pattern, graph): ...@@ -273,10 +639,11 @@ def match(pattern, graph):
src_edge = src_edges[node.id][src_node.id] src_edge = src_edges[node.id][src_node.id]
output_name = src_edge.attrs["output_name"] output_name = src_edge.attrs["output_name"]
compare_node_ids = tgt_reverse_adjs[tgt_node.id] compare_nodes = tgt_reverse_adjs[tgt_node.id]
compare_node = None compare_node = None
for node_id in compare_node_ids: for item in compare_nodes:
node_id = item.id
edge = tgt_edges[node_id][tgt_node.id] edge = tgt_edges[node_id][tgt_node.id]
if edge.attrs["output_name"] == output_name: if edge.attrs["output_name"] == output_name:
compare_node = tgt_nodes[node_id] compare_node = tgt_nodes[node_id]
...@@ -284,11 +651,12 @@ def match(pattern, graph): ...@@ -284,11 +651,12 @@ def match(pattern, graph):
if not compare_node: if not compare_node:
not_matched = True not_matched = True
return return
_match_core(src_nodes[node_id], compare_node) _match_core(src_nodes[node.id], compare_node)
results = [] results = []
matched_ids = set()
matched_op_node_ids = set()
result = {} result = {}
has_matched = set()
src_nodes = pattern.nodes src_nodes = pattern.nodes
src_edges = pattern._adjs src_edges = pattern._adjs
src_reverse_adjs = pattern._reverse_adjs src_reverse_adjs = pattern._reverse_adjs
...@@ -297,7 +665,6 @@ def match(pattern, graph): ...@@ -297,7 +665,6 @@ def match(pattern, graph):
tgt_edges = graph._adjs tgt_edges = graph._adjs
tgt_reverse_adjs = graph._reverse_adjs tgt_reverse_adjs = graph._reverse_adjs
tgt_attr_to_nodes = graph._attr_to_nodes tgt_attr_to_nodes = graph._attr_to_nodes
not_matched = False
# starts with a op node # starts with a op node
src_start_node = None src_start_node = None
...@@ -311,27 +678,55 @@ def match(pattern, graph): ...@@ -311,27 +678,55 @@ def match(pattern, graph):
for node_id in tgt_nodes: for node_id in tgt_nodes:
node = tgt_nodes[node_id] node = tgt_nodes[node_id]
if node.attrs["type"] == src_start_node.attrs["type"]: if node.attrs["type"] == src_start_node.attrs["type"]:
not_matched = False
_match_core(src_start_node, node) _match_core(src_start_node, node)
if not not_matched: if not not_matched:
need_to_append = True need_to_append = True
for value in result.values(): for value in result.values():
if value in has_matched: if value in matched_op_node_ids:
result = {} result = {}
need_to_append = False need_to_append = False
break break
if need_to_append: if need_to_append:
results.append(result) results.append(result)
for value in result.values(): for value in result.values():
has_matched.add(value) matched_ids.add(value)
if value in graph.attrs["id_to_op"].keys():
matched_op_node_ids.add(value)
result = {} result = {}
else: else:
not_matched = False not_matched = False
result = {} result = {}
return results, matched_ids
return results @staticmethod
def match_all_patterns(graph):
# matched_results maps pattern_name to list which contains pattern node id to graph node id mapping,
# such as {"pattern_name": [{pattern_node_id: graph_node}, ]}
matched_results = {}
matched_ids = set()
for pattern_name in _PATTERNS:
pattern = _PATTERNS[pattern_name]
results, matched = GraphUtil.match_pattern(pattern, graph)
for result in results:
has_matched = False
for id in result:
if result[id] in matched_ids:
has_matched = True
break
if not has_matched:
for item in result:
matched_ids.add(result[id])
if pattern.name not in matched_results:
matched_results[pattern.name] = []
matched_results[pattern.name].append(result)
return matched_results
class OperatorClusteringUtil: class OperatorClusteringUtil:
"""Operator clustering util is used to cluster operators to layers."""
common_starts = ["layer_norm", "matmul_v2", "matmul"] common_starts = ["layer_norm", "matmul_v2", "matmul"]
@staticmethod @staticmethod
...@@ -506,6 +901,8 @@ class OperatorClusteringUtil: ...@@ -506,6 +901,8 @@ class OperatorClusteringUtil:
class ClusterPartitionUtil: class ClusterPartitionUtil:
"""Cluster partition util is used to get device meshes and process meshes."""
@staticmethod @staticmethod
def factorization(num): def factorization(num):
factors = [] factors = []
...@@ -535,13 +932,11 @@ class ClusterPartitionUtil: ...@@ -535,13 +932,11 @@ class ClusterPartitionUtil:
], ],
) -> list: ) -> list:
""" """
Partition cluster into possible device meshes. Partiton cluster into possible device meshes.
Args: Args:
n (int): The number of nodes. n (int): The number of nodes.
m (int): The number of single devices on each node. m (int): The number of single devices on each node.
filter (list): Functions for filtering useful meshes filter (list): Functions for filtering useful meshes
Returns: Returns:
device_meshed (list) : The possible device meshes. device_meshed (list) : The possible device meshes.
""" """
...@@ -573,10 +968,8 @@ class ClusterPartitionUtil: ...@@ -573,10 +968,8 @@ class ClusterPartitionUtil:
def convert_to_process_meshes(device_mesh: list) -> list: def convert_to_process_meshes(device_mesh: list) -> list:
""" """
Transfer device_meshes into possible process meshes. Transfer device_meshes into possible process meshes.
Args: Args:
device meshes (list): [n,m], one device mesh. device meshes (list): [n,m], one device mesh.
Returns: Returns:
process_meshes (list): Possible process_meshes process_meshes (list): Possible process_meshes
""" """
......
...@@ -95,7 +95,7 @@ def get_gpt_model( ...@@ -95,7 +95,7 @@ def get_gpt_model(
return train_program, start_program, loss, gen_data return train_program, start_program, loss, gen_data
class TestGroupOperators(unittest.TestCase): class TestGroupOperatorsAndPatterns(unittest.TestCase):
def test_gpt(self): def test_gpt(self):
modeling.init_global() modeling.init_global()
train_program = static.Program() train_program = static.Program()
...@@ -117,17 +117,30 @@ class TestGroupOperators(unittest.TestCase): ...@@ -117,17 +117,30 @@ class TestGroupOperators(unittest.TestCase):
) )
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import ( from paddle.distributed.auto_parallel.tuner.rule_based_tuner import (
_PATTERNS, _PATTERNS,
GraphUtil,
RuleBasedTuner, RuleBasedTuner,
convert_to_graph,
) )
dist_context = DistributedContext() dist_context = DistributedContext()
tuner = RuleBasedTuner(dist_context) tuner = RuleBasedTuner(dist_context)
layers = tuner.cluster_operators(train_program.global_block().ops) layers = tuner.cluster_operators(train_program.global_block().ops)
layer = layers[0] graph = GraphUtil.convert_to_graph(train_program.global_block())
graph = convert_to_graph(layer, train_program.global_block())
print("graph: ", graph) print("graph: ", graph)
print("qkv: ", _PATTERNS["qkv"].attrs["shard_spec"]) print("qkv: ", _PATTERNS["qkv"].attrs["shard_spec"])
print("row_matmul: ", _PATTERNS["row_matmul"].attrs["shard_spec"])
print("ffn: ", _PATTERNS["ffn"].attrs["shard_spec"])
print(
"shared_word_embedding: ",
_PATTERNS["shared_word_embedding"].attrs["shard_spec"],
)
print(
"position_embedding: ",
_PATTERNS["position_embedding"].attrs["shard_spec"],
)
print(
"unsqueeze_data: ", _PATTERNS["unsqueeze_data"].attrs["shard_spec"]
)
print("reshape_data: ", _PATTERNS["reshape_data"].attrs["shard_spec"])
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -95,7 +95,7 @@ def get_gpt_model( ...@@ -95,7 +95,7 @@ def get_gpt_model(
return train_program, start_program, loss, gen_data return train_program, start_program, loss, gen_data
class TestGroupOperators(unittest.TestCase): class TestPatternMatch(unittest.TestCase):
def test_gpt(self): def test_gpt(self):
modeling.init_global() modeling.init_global()
train_program = static.Program() train_program = static.Program()
...@@ -116,26 +116,15 @@ class TestGroupOperators(unittest.TestCase): ...@@ -116,26 +116,15 @@ class TestGroupOperators(unittest.TestCase):
DistributedContext, DistributedContext,
) )
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import ( from paddle.distributed.auto_parallel.tuner.rule_based_tuner import (
_PATTERNS, GraphUtil,
RuleBasedTuner, RuleBasedTuner,
convert_to_graph,
match,
) )
dist_context = DistributedContext() dist_context = DistributedContext()
tuner = RuleBasedTuner(dist_context) tuner = RuleBasedTuner(dist_context)
layers = tuner.cluster_operators(train_program.global_block().ops) graph = GraphUtil.convert_to_graph(train_program.global_block())
layer = layers[0] results = GraphUtil.match_all_patterns(graph)
graph = convert_to_graph(layer, train_program.global_block()) print(results)
results = match(_PATTERNS["qkv"], graph)
shard_tensor_infos = _PATTERNS["qkv"].attrs["shard_spec"]
tensor_ids = shard_tensor_infos[0][0]
if results:
for result in results:
for node_id in result:
if node_id in tensor_ids:
print(graph.attrs["id_to_var"][result[node_id]])
print("shard_spec: ", shard_tensor_infos[0][1])
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册