未验证 提交 0bb7c003 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel] Add patterns of rule based tuner (#51859)

* add patterns

* add unittest
上级 cdefcd00
...@@ -22,24 +22,43 @@ _PATTERNS = {} ...@@ -22,24 +22,43 @@ _PATTERNS = {}
def register_pattern(cls): def register_pattern(cls):
"""Register pattern for rule-based tuner.""" """Register pattern for rule-based tuner."""
name = cls.name
def register(name): def register():
global _PATTERNS global _PATTERNS
_PATTERNS[name] = cls() pattern = cls()
_PATTERNS[pattern.name] = pattern
# sort patterns according to the number of sharded tensors
# set its dist attr by the fisrt one when a tensor can be matched by multiple patterns.
_PATTERNS = dict(
sorted(
_PATTERNS.items(), key=lambda x: -x[1].attrs["sharded_tensors"]
)
)
register(name) register()
return cls return cls
class BasePattern(Graph): class BasePattern(Graph):
name = "base" """
Base class of pattern.
The BasePattern inherits the Graph, two important differences are shard_spec and sharded_tensors.
For shard_spec, it indicates the shard specification of tensor node in this pattern under different parallelism.
For sharded_tensors, it represents the number of tensors which sharded.
"""
_name = "base"
def __init__(self): def __init__(self):
"""Every pattern has its own name and build method."""
super().__init__() super().__init__()
self.build() self.build()
@property
def name(self):
return self.__class__._name
@abstractmethod @abstractmethod
def build(self): def build(self):
pass pass
...@@ -47,6 +66,8 @@ class BasePattern(Graph): ...@@ -47,6 +66,8 @@ class BasePattern(Graph):
@register_pattern @register_pattern
class QKVPattern(BasePattern): class QKVPattern(BasePattern):
"""The QKV pattern defined by GPT model in PaddleFleetX."""
name = "qkv" name = "qkv"
def __init__(self): def __init__(self):
...@@ -55,81 +76,388 @@ class QKVPattern(BasePattern): ...@@ -55,81 +76,388 @@ class QKVPattern(BasePattern):
def build(self): def build(self):
query = self.add_node(0, **{"type": "var"}) query = self.add_node(0, **{"type": "var"})
# define q, k, v weight
q_weight = self.add_node(1, **{"dim": 2, "type": "param"}) q_weight = self.add_node(1, **{"dim": 2, "type": "param"})
k_weight = self.add_node(2, **{"dim": 2, "type": "param"}) k_weight = self.add_node(2, **{"dim": 2, "type": "param"})
v_weight = self.add_node(3, **{"dim": 2, "type": "param"}) v_weight = self.add_node(3, **{"dim": 2, "type": "param"})
# define q, k, v matmul_v2
q_matmul = self.add_node(4, **{"type": "matmul_v2"}) q_matmul_v2 = self.add_node(4, **{"type": "matmul_v2"})
k_matmul = self.add_node(5, **{"type": "matmul_v2"}) k_matmul_v2 = self.add_node(5, **{"type": "matmul_v2"})
v_matmul = self.add_node(6, **{"type": "matmul_v2"}) v_matmul_v2 = self.add_node(6, **{"type": "matmul_v2"})
# define input edge
q_x = self.add_edge(0, 4, **{"input_name": "X"}) q_x_edge = self.add_edge(
k_x = self.add_edge(0, 5, **{"input_name": "X"}) query.id, q_matmul_v2.id, **{"input_name": "X"}
v_x = self.add_edge(0, 6, **{"input_name": "X"}) )
q_y = self.add_edge(1, 4, **{"input_name": "Y"}) k_x_edge = self.add_edge(
k_y = self.add_edge(2, 5, **{"input_name": "Y"}) query.id, k_matmul_v2.id, **{"input_name": "X"}
v_y = self.add_edge(3, 6, **{"input_name": "Y"}) )
v_x_edge = self.add_edge(
query.id, v_matmul_v2.id, **{"input_name": "X"}
)
q_y_edge = self.add_edge(
q_weight.id, q_matmul_v2.id, **{"input_name": "Y"}
)
k_y_edge = self.add_edge(
k_weight.id, k_matmul_v2.id, **{"input_name": "Y"}
)
v_y_edge = self.add_edge(
v_weight.id, v_matmul_v2.id, **{"input_name": "Y"}
)
# define q, k, v matmul_v2 output
q = self.add_node(7, **{"type": "var"}) q = self.add_node(7, **{"type": "var"})
k = self.add_node(8, **{"type": "var"}) k = self.add_node(8, **{"type": "var"})
v = self.add_node(9, **{"type": "var"}) v = self.add_node(9, **{"type": "var"})
q_out = self.add_edge(4, 7, **{"output_name": "Out"}) # define output edge
k_out = self.add_edge(5, 8, **{"output_name": "Out"}) q_out_edge = self.add_edge(
v_out = self.add_edge(6, 9, **{"output_name": "Out"}) q_matmul_v2.id, q.id, **{"output_name": "Out"}
)
# Pattern k_out_edge = self.add_edge(
self.attrs["shard_spec"] = [ k_matmul_v2.id, k.id, **{"output_name": "Out"}
[(1, 2, 3), [[-1, 0], [-1, 1]]], )
] # 2-tuple list such as [(tensor_id, shard_spec)] v_out_edge = self.add_edge(
v_matmul_v2.id, v.id, **{"output_name": "Out"}
)
def convert_to_graph(ops, block):
"""Convert ops to graph.""" # define shard_spec
graph = Graph() shard_spec = {
graph.attrs["var_to_id"] = {} # {var_name: node_id} "dp_mp": {
graph.attrs["id_to_var"] = {} # {node_id: var_name} 0: [0, -1, -1],
graph.attrs["op_to_id"] = {} # {op_id: node_id} 1: [-1, 1],
graph.attrs["id_to_op"] = {} # {node_id: op_id} 2: [-1, 1],
3: [-1, 1],
node_id = -1 },
for op in ops: "mp_dp": {
attrs = op.all_attrs() 0: [1, -1, -1],
attrs["type"] = op.type 1: [-1, 0],
node_id += 1 2: [-1, 0],
3: [-1, 0],
# create op node },
op_node = graph.add_node(node_id, **attrs) "mp": {0: [-1, -1, -1], 1: [-1, 0], 2: [-1, 0], 3: [-1, 0]},
graph.attrs["op_to_id"][op.desc.id()] = op_node.id "dp": {
graph.attrs["id_to_op"][op_node.id] = op.desc.id() 0: [0, -1, -1],
graph._attr_to_nodes[op_node.id] = {} 1: [-1, -1],
for input_name in op.input_names: 2: [-1, -1],
graph._attr_to_nodes[op_node.id][input_name] = [] 3: [-1, -1],
for var_name in op.input(input_name): },
if var_name not in graph.attrs["var_to_id"]: }
# create var node self.attrs["shard_spec"] = shard_spec
node_id += 1 # define sharded_tensors
var_node = graph.add_node(node_id) self.attrs["sharded_tensors"] = 4
var = block._var_recursive(var_name)
if var.is_parameter:
var_node.attrs["type"] = "param" @register_pattern
var_node.attrs["dim"] = len(var.shape) class RowMatmulPattern(BasePattern):
else: """Row matmul pattern defined by GPT model in PaddleFleetX."""
var_node.attrs["type"] = "var"
graph.attrs["var_to_id"][var_name] = var_node.id name = "row_matmul"
graph.attrs["id_to_var"][var_node.id] = var_name
else: def __init__(self):
var_node_id = graph.attrs["var_to_id"][var_name] super().__init__()
var_node = graph._nodes[var_node_id]
def build(self):
# define reshape input
input = self.add_node(0, **{"type": "var"})
# define reshape
reshape = self.add_node(1, **{"type": "reshape2"})
# define reshape input egde
x_edge = self.add_edge(input.id, reshape.id, **{"input_name": "X"})
# define reshape out
output = self.add_node(2, **{"type": "var"})
# define reshape output edge
out_edge = self.add_edge(
reshape.id, output.id, **{"output_name": "Out"}
)
# define matmul_v2 weight
weight = self.add_node(3, **{"dim": 2, "type": "param"})
# define matmul_v2
matmul_v2 = self.add_node(4, **{"type": "matmul_v2"})
# define input edge
x_edge = self.add_edge(output.id, matmul_v2.id, **{"input_name": "X"})
y_edge = self.add_edge(weight.id, matmul_v2.id, **{"input_name": "Y"})
# define q, k, v matmul_v2 output
output = self.add_node(5, **{"type": "var"})
# define output edge
out_edge = self.add_edge(
matmul_v2.id, output.id, **{"output_name": "Out"}
)
# define shard_spec
shard_spec = {
"dp_mp": {
3: [1, -1],
},
"mp_dp": {
3: [0, -1],
},
"mp": {3: [0, -1]},
"dp": {
3: [-1, -1],
},
}
self.attrs["shard_spec"] = shard_spec
# define sharded_tensors
self.attrs["sharded_tensors"] = 1
@register_pattern
class FFNPattrern(BasePattern):
"""FFN pattern defined by GPT model in PaddleFleetX."""
name = "ffn"
def __init__(self):
super().__init__()
def build(self):
x = self.add_node(0, **{"type": "var"})
w1_weight = self.add_node(1, **{"dim": 2, "type": "param"})
w1_matmul = self.add_node(2, **{"type": "matmul_v2"})
w1_x = self.add_edge(0, 2, **{"input_name": "X"})
w1_y = self.add_edge(1, 2, **{"input_name": "Y"})
out1 = self.add_node(3, **{"type": "var"})
w1_out = self.add_edge(2, 3, **{"output_name": "Out"})
w1_b = self.add_node(4, **{"dim": 1, "type": "param"})
add1 = self.add_node(5, **{"type": "elementwise_add"})
add1_x = self.add_edge(3, 5, **{"input_name": "X"})
add1_y = self.add_edge(4, 5, **{"input_name": "Y"})
out2 = self.add_node(6, **{"type": "var"})
add1_out = self.add_edge(5, 6, **{"output_name": "Out"})
gelu = self.add_node(7, **{"type": "gelu"})
gelu_x = self.add_edge(6, 7, **{"input_name": "X"})
out3 = self.add_node(8, **{"type": "var"})
gelu_out = self.add_edge(7, 8, **{"output_name": "Out"})
w2_weight = self.add_node(9, **{"dim": 2, "type": "param"})
w2_matmul = self.add_node(10, **{"type": "matmul_v2"})
w1_x = self.add_edge(8, 10, **{"input_name": "X"})
w1_y = self.add_edge(9, 10, **{"input_name": "Y"})
out4 = self.add_node(11, **{"type": "var"})
w2_out = self.add_edge(10, 11, **{"output_name": "Out"})
w2_b = self.add_node(12, **{"dim": 1, "type": "param"})
add2 = self.add_node(13, **{"type": "elementwise_add"})
add2_x = self.add_edge(11, 13, **{"input_name": "X"})
add2_y = self.add_edge(12, 13, **{"input_name": "Y"})
out5 = self.add_node(14, **{"type": "var"})
add2_out = self.add_edge(13, 14, **{"output_name": "Out"})
# define shard_spec
shard_spec = {
"dp_mp": {0: [0, -1, -1], 1: [-1, 1], 9: [1, -1]},
"mp_dp": {0: [1, -1, -1], 1: [-1, 0], 9: [0, -1]},
"mp": {1: [-1, 0], 9: [0, -1]},
"dp": {0: [0, -1, -1], 1: [-1, -1], 9: [-1, -1]},
}
self.attrs["shard_spec"] = shard_spec
# define sharded_tensors
self.attrs["sharded_tensors"] = 2
@register_pattern
class SharedWordEmbeddingPattern(BasePattern):
"""Sharded word embedding pattern defined by GPT model in PaddleFleetX."""
name = "shared_word_embedding"
def __init__(self):
super().__init__()
def build(self):
# define embedding input
tokens = self.add_node(0, **{"type": "data"})
word_embeddings = self.add_node(1, **{"dim": 2, "type": "param"})
# define embedding
embedding = self.add_node(2, **{"type": "lookup_table_v2"})
# define embedding input edge
ids = self.add_edge(0, 2, **{"input_name": "Ids"})
w = self.add_edge(1, 2, **{"input_name": "W"})
# define embedding output
out = self.add_node(3, **{"type": "var"})
# define embedding output edge
out_edge = self.add_edge(2, 3, **{"output_name": "Out"})
# define matmul_v2 input
x = self.add_node(4, **{"type": "var"})
# define matmul_v2
matmul = self.add_node(5, **{"type": "matmul_v2"})
# define matmul_v2 input edge
x_edge = self.add_edge(4, 5, **{"input_name": "X"})
y_edge = self.add_edge(1, 5, **{"input_name": "Y"})
# define matmul_v2 output
out = self.add_node(6, **{"type": "var"})
# define matmul_v2 output edge
out_edge = self.add_edge(5, 6, **{"output_name": "Out"})
# define shard_spec
shard_spec = {
"dp_mp": {0: [0, -1], 1: [1, -1], 4: [0, -1, -1]},
"mp_dp": {0: [1, -1], 1: [0, -1], 4: [1, -1, -1]},
"mp": {0: [-1, -1], 1: [0, -1], 4: [-1, -1, -1]},
"dp": {0: [0, -1], 1: [-1, -1], 4: [0, -1, -1]},
}
self.attrs["shard_spec"] = shard_spec
self.attrs["sharded_tensors"] = 3
@register_pattern
class PositionEmbeddingPattern(BasePattern):
"""Position embedding pattern defined by GPT model in PaddleFleetX."""
name = "position_embedding"
def __init__(self):
super().__init__()
def build(self):
# define embedding input
tokens = self.add_node(0, **{"type": "data"})
word_embeddings = self.add_node(1, **{"dim": 2, "type": "param"})
# define embedding
embedding = self.add_node(2, **{"type": "lookup_table_v2"})
# define embedding input edge
ids = self.add_edge(0, 2, **{"input_name": "Ids"})
w = self.add_edge(1, 2, **{"input_name": "W"})
# create edge that input -> op # define embedding output
input_edge = graph.add_edge(var_node.id, op_node.id) out = self.add_node(3, **{"type": "var"})
input_edge.attrs["input_name"] = input_name
graph._attr_to_nodes[op_node.id][input_name].append(var_node)
for output_name in op.output_names: # define embedding output edge
graph._attr_to_nodes[op_node.id][output_name] = [] out_edge = self.add_edge(2, 3, **{"output_name": "Out"})
for var_name in op.output(output_name):
# define shard_spec
shard_spec = {
"dp_mp": {0: [0, -1], 1: [-1, -1], 3: [-1, -1, -1]},
"mp_dp": {0: [1, -1], 1: [-1, -1], 3: [1, -1, -1]},
"mp": {0: [-1, -1], 1: [-1, -1], 3: [-1, -1, -1]},
"dp": {0: [0, -1], 1: [-1, -1], 3: [0, -1, -1]},
}
self.attrs["shard_spec"] = shard_spec
# define sharded_tensors
self.attrs["sharded_tensors"] = 1
@register_pattern
class UnsqueezeDataPattern(BasePattern):
"""Unsqueeze data pattern defined by GPT model in the PaddleFleetX."""
name = "unsqueeze_data"
def __init__(self):
super().__init__()
def build(self):
# define unsequeeze input
tokens = self.add_node(0, **{"type": "data"})
# define unsequeeze
unsqueeze = self.add_node(1, **{"type": "unsqueeze2"})
# define unsequeeze input edge
x_edge = self.add_edge(0, 1, **{"input_name": "X"})
# pattern: pure mp or hybrid dp+mp
shard_spec = {
"dp_mp": {0: [0, -1]},
"mp_dp": {0: [1, -1]},
"mp": {0: [-1, -1]},
"dp": {0: [0, -1]},
}
self.attrs["shard_spec"] = shard_spec
self.attrs["sharded_tensors"] = 1
@register_pattern
class ReshapeDataPattern(BasePattern):
"""Reshape data pattern defined by GPT model in PaddleFleetX."""
name = "reshape_data"
def __init__(self):
super().__init__()
def build(self):
# define unsequeeze input
data = self.add_node(0, **{"type": "data"})
# define unsequeeze
reshape = self.add_node(1, **{"type": "reshape2"})
# define unsequeeze input edge
x_edge = self.add_edge(0, 1, **{"input_name": "X"})
# define shard_spec
shard_spec = {
"dp_mp": {0: [0, -1]},
"mp_dp": {0: [1, -1]},
"mp": {0: [-1, -1]},
"dp": {0: [0, -1]},
}
self.attrs["shard_spec"] = shard_spec
# define sharded_tensors
self.attrs["sharded_tensors"] = 1
class GraphUtil:
"""Graph util is used to convert ops to graph or match pattern for graph."""
@staticmethod
def convert_to_graph(block):
"""Convert ops to graph."""
graph = Graph()
graph.attrs["var_to_id"] = {} # {var_name: node_id}
graph.attrs["id_to_var_desc_id"] = {} # {node_id: var_desc_id}
graph.attrs["id_to_var_name"] = {}
graph.attrs["op_to_id"] = {} # {op_id: node_id}
graph.attrs["id_to_op"] = {} # {node_id: op}
ops = block.ops
node_id = -1
for op in ops:
attrs = op.all_attrs()
attrs["type"] = op.type
node_id += 1
# create op node
op_node = graph.add_node(node_id, **attrs)
graph.attrs["op_to_id"][op.desc.id()] = op_node.id
graph.attrs["id_to_op"][op_node.id] = op
graph._attr_to_nodes[op_node.id] = {}
for input_name in op.input_names:
graph._attr_to_nodes[op_node.id][input_name] = []
for var_name in op.input(input_name):
if var_name not in graph.attrs["var_to_id"]: if var_name not in graph.attrs["var_to_id"]:
# create var node # create var node
node_id += 1 node_id += 1
...@@ -137,201 +465,268 @@ def convert_to_graph(ops, block): ...@@ -137,201 +465,268 @@ def convert_to_graph(ops, block):
var = block._var_recursive(var_name) var = block._var_recursive(var_name)
if var.is_parameter: if var.is_parameter:
var_node.attrs["type"] = "param" var_node.attrs["type"] = "param"
var_node.attrs["dim"] = len(var.shape)
elif var.is_data:
var_node.attrs["type"] = "data"
var_node.attrs["dim"] = len(var.shape)
else: else:
var_node.attrs["type"] = "var" var_node.attrs["type"] = "var"
graph.attrs["var_to_id"][var_name] = var_node.id graph.attrs["var_to_id"][var_name] = var_node.id
graph.attrs["id_to_var"][var_node.id] = var_name graph.attrs["id_to_var_desc_id"][
var_node.id
] = var.desc.original_id()
graph.attrs["id_to_var_name"][var_node.id] = var_name
else: else:
var_node_id = graph.attrs["var_to_id"][var_name] var_node_id = graph.attrs["var_to_id"][var_name]
var_node = graph._nodes[var_node_id] var_node = graph._nodes[var_node_id]
# create edge that op -> output # create edge that input -> op
output_edge = graph.add_edge(op_node.id, var_node.id) input_edge = graph.add_edge(var_node.id, op_node.id)
output_edge.attrs["output_name"] = output_name input_edge.attrs["input_name"] = input_name
graph._attr_to_nodes[op_node.id][input_name].append(
graph._attr_to_nodes[op_node.id][output_name].append(
var_node var_node
) )
return graph for output_name in op.output_names:
graph._attr_to_nodes[op_node.id][output_name] = []
for var_name in op.output(output_name):
if var_name not in graph.attrs["var_to_id"]:
# create var node
node_id += 1
var_node = graph.add_node(node_id)
var = block._var_recursive(var_name)
if var.is_parameter:
var_node.attrs["type"] = "param"
else:
var_node.attrs["type"] = "var"
graph.attrs["var_to_id"][var_name] = var_node.id
graph.attrs["id_to_var_desc_id"][
var_node.id
] = var.desc.original_id()
graph.attrs["id_to_var_name"][
var_node.id
] = var_name
else:
var_node_id = graph.attrs["var_to_id"][var_name]
var_node = graph._nodes[var_node_id]
# create edge that op -> output
output_edge = graph.add_edge(op_node.id, var_node.id)
output_edge.attrs["output_name"] = output_name
def match(pattern, graph): graph._attr_to_nodes[op_node.id][output_name].append(
def _is_op_node(node): var_node
"""Judge whether node is op node""" )
if node.attrs["type"] not in ["var", "param", "data"]:
return True
return False return graph
def _compare_op_node(src, tgt): @staticmethod
"""Compare whether two op nodes are equal""" def match_pattern(pattern, graph):
if src.attrs["type"] != tgt.attrs["type"]: def _is_op_node(node):
return False """Judge whether node is op node."""
if node.attrs["type"] not in ["var", "param", "data"]:
return True
return True return False
def _compare_var_node(src, tgt): def _compare_op_node(src, tgt):
"""Compare whether two var nodes are equal""" """Compare whether two op nodes are equivalent."""
for key in src.attrs: if src.attrs["type"] != tgt.attrs["type"]:
if key not in tgt.attrs:
return False
if src.attrs[key] != tgt.attrs[key]:
return False return False
return True return True
def _match_core(src_node, tgt_node):
nonlocal not_matched
# do not support one input name or output name corresponding to multiple vars
if not_matched:
return
if _is_op_node(src_node):
# compare op node whether equal
if not _compare_op_node(src_node, tgt_node):
return
result[src_node.id] = tgt_node.id
# input var nodes def _compare_var_node(src, tgt):
src_input_nodes = src_reverse_adjs[src_node.id] """Compare whether two var nodes are equivalent."""
for node in src_input_nodes: for key in src.attrs:
# has visited if key not in tgt.attrs:
if node.id in result: return False
continue if src.attrs[key] != tgt.attrs[key]:
edge = src_edges[node.id][src_node.id] return False
input_name = edge.attrs["input_name"]
# NOTE: do not support one input name or output name corresponding to multiple vars return True
compare_nodes = tgt_attr_to_nodes[tgt_node.id].get(
input_name, None
)
if not compare_nodes:
not_matched = True
return
_match_core(node, compare_nodes[0])
# output var nodes def _match_core(src_node, tgt_node):
src_output_node_ids = src_edges[src_node.id].keys() nonlocal not_matched
for node_id in src_output_node_ids: # not support one input name or output name corresponding to multiple vars
# has visited if not_matched:
if node_id in result: return
continue
node = src_nodes[node_id]
edge = src_edges[src_node.id][node_id]
output_name = edge.attrs["output_name"]
# NOTE: do not support one input name or output name corresponding to multiple vars if _is_op_node(src_node):
compare_nodes = tgt_attr_to_nodes[tgt_node.id].get( # compare op node whether equal
output_name, None if not _compare_op_node(src_node, tgt_node):
)
if not compare_nodes:
not_matched = True not_matched = True
return return
_match_core(node, compare_nodes[0])
else:
# compare var node whether equal
if not _compare_var_node(src_node, tgt_node):
not_matched = True
return
result[src_node.id] = tgt_node.id
# as input for op nodes result[src_node.id] = tgt_node.id
src_as_input_node_ids = src_edges[src_node.id].keys()
for node_id in src_as_input_node_ids:
if node_id in result:
continue
src_edge = src_edges[src_node.id][node_id] # input var nodes
input_name = src_edge.attrs["input_name"] src_input_nodes = src_reverse_adjs[src_node.id]
compare_node_ids = tgt_edges[tgt_node.id].keys() for node in src_input_nodes:
# has visited
if node.id in result:
continue
edge = src_edges[node.id][src_node.id]
input_name = edge.attrs["input_name"]
compare_node = None # NOTE: do not support one input name or output name corresponding to multiple vars
for compare_node_id in compare_node_ids: compare_nodes = tgt_attr_to_nodes[tgt_node.id].get(
edge = tgt_edges[tgt_node.id][compare_node_id] input_name, None
if ( )
edge.attrs["input_name"] == input_name if not compare_nodes:
and compare_node_id not in result.values() not_matched = True
): return
compare_node = tgt_nodes[compare_node_id] _match_core(node, compare_nodes[0])
break
# output var nodes
src_output_node_ids = src_edges[src_node.id].keys()
for node_id in src_output_node_ids:
# has visited
if node_id in result:
continue
node = src_nodes[node_id]
edge = src_edges[src_node.id][node_id]
output_name = edge.attrs["output_name"]
# NOTE: do not support one input name or output name corresponding to multiple vars
compare_nodes = tgt_attr_to_nodes[tgt_node.id].get(
output_name, None
)
if not compare_nodes:
not_matched = True
return
_match_core(node, compare_nodes[0])
if not compare_node: else:
# compare var nodes whether equal
if not _compare_var_node(src_node, tgt_node):
not_matched = True not_matched = True
return return
_match_core(src_nodes[node_id], compare_node)
# as output for nodes result[src_node.id] = tgt_node.id
src_as_output_nodes = src_reverse_adjs[src_node.id]
for node in src_as_output_nodes: # as input for op node
if node.id in result: src_as_input_node_ids = src_edges[src_node.id].keys()
continue for node_id in src_as_input_node_ids:
if node_id in result:
continue
src_edge = src_edges[src_node.id][node_id]
input_name = src_edge.attrs["input_name"]
compare_node_ids = tgt_edges[tgt_node.id].keys()
compare_node = None
for compare_node_id in compare_node_ids:
edge = tgt_edges[tgt_node.id][compare_node_id]
if (
edge.attrs["input_name"] == input_name
and compare_node_id not in result.values()
):
compare_node = tgt_nodes[compare_node_id]
break
src_edge = src_edges[node.id][src_node.id] if not compare_node:
output_name = src_edge.attrs["output_name"] not_matched = True
return
_match_core(src_nodes[node_id], compare_node)
compare_node_ids = tgt_reverse_adjs[tgt_node.id] # as output for op node
src_as_output_nodes = src_reverse_adjs[src_node.id]
for node in src_as_output_nodes:
if node.id in result:
continue
compare_node = None src_edge = src_edges[node.id][src_node.id]
for node_id in compare_node_ids: output_name = src_edge.attrs["output_name"]
edge = tgt_edges[node_id][tgt_node.id]
if edge.attrs["output_name"] == output_name: compare_nodes = tgt_reverse_adjs[tgt_node.id]
compare_node = tgt_nodes[node_id]
break compare_node = None
if not compare_node: for item in compare_nodes:
not_matched = True node_id = item.id
return edge = tgt_edges[node_id][tgt_node.id]
_match_core(src_nodes[node_id], compare_node) if edge.attrs["output_name"] == output_name:
compare_node = tgt_nodes[node_id]
results = [] break
result = {} if not compare_node:
has_matched = set() not_matched = True
src_nodes = pattern.nodes return
src_edges = pattern._adjs _match_core(src_nodes[node.id], compare_node)
src_reverse_adjs = pattern._reverse_adjs
results = []
tgt_nodes = graph.nodes matched_ids = set()
tgt_edges = graph._adjs matched_op_node_ids = set()
tgt_reverse_adjs = graph._reverse_adjs result = {}
tgt_attr_to_nodes = graph._attr_to_nodes src_nodes = pattern.nodes
not_matched = False src_edges = pattern._adjs
src_reverse_adjs = pattern._reverse_adjs
# starts with a op node
src_start_node = None tgt_nodes = graph.nodes
for node_id in src_nodes: tgt_edges = graph._adjs
node = src_nodes[node_id] tgt_reverse_adjs = graph._reverse_adjs
if node.attrs["type"] not in ["var", "param", "data"]: tgt_attr_to_nodes = graph._attr_to_nodes
src_start_node = node
break # starts with a op node
assert src_start_node is not None src_start_node = None
for node_id in src_nodes:
for node_id in tgt_nodes: node = src_nodes[node_id]
node = tgt_nodes[node_id] if node.attrs["type"] not in ["var", "param", "data"]:
if node.attrs["type"] == src_start_node.attrs["type"]: src_start_node = node
_match_core(src_start_node, node) break
if not not_matched: assert src_start_node is not None
need_to_append = True
for value in result.values(): for node_id in tgt_nodes:
if value in has_matched: node = tgt_nodes[node_id]
result = {} if node.attrs["type"] == src_start_node.attrs["type"]:
need_to_append = False not_matched = False
break _match_core(src_start_node, node)
if need_to_append: if not not_matched:
results.append(result) need_to_append = True
for value in result.values(): for value in result.values():
has_matched.add(value) if value in matched_op_node_ids:
result = {}
need_to_append = False
break
if need_to_append:
results.append(result)
for value in result.values():
matched_ids.add(value)
if value in graph.attrs["id_to_op"].keys():
matched_op_node_ids.add(value)
result = {}
else:
not_matched = False
result = {} result = {}
else: return results, matched_ids
not_matched = False
result = {}
return results @staticmethod
def match_all_patterns(graph):
# matched_results maps pattern_name to list which contains pattern node id to graph node id mapping,
# such as {"pattern_name": [{pattern_node_id: graph_node}, ]}
matched_results = {}
matched_ids = set()
for pattern_name in _PATTERNS:
pattern = _PATTERNS[pattern_name]
results, matched = GraphUtil.match_pattern(pattern, graph)
for result in results:
has_matched = False
for id in result:
if result[id] in matched_ids:
has_matched = True
break
if not has_matched:
for item in result:
matched_ids.add(result[id])
if pattern.name not in matched_results:
matched_results[pattern.name] = []
matched_results[pattern.name].append(result)
return matched_results
class OperatorClusteringUtil: class OperatorClusteringUtil:
"""Operator clustering util is used to cluster operators to layers."""
common_starts = ["layer_norm", "matmul_v2", "matmul"] common_starts = ["layer_norm", "matmul_v2", "matmul"]
@staticmethod @staticmethod
...@@ -506,6 +901,8 @@ class OperatorClusteringUtil: ...@@ -506,6 +901,8 @@ class OperatorClusteringUtil:
class ClusterPartitionUtil: class ClusterPartitionUtil:
"""Cluster partition util is used to get device meshes and process meshes."""
@staticmethod @staticmethod
def factorization(num): def factorization(num):
factors = [] factors = []
...@@ -535,13 +932,11 @@ class ClusterPartitionUtil: ...@@ -535,13 +932,11 @@ class ClusterPartitionUtil:
], ],
) -> list: ) -> list:
""" """
Partition cluster into possible device meshes. Partiton cluster into possible device meshes.
Args: Args:
n (int): The number of nodes. n (int): The number of nodes.
m (int): The number of single devices on each node. m (int): The number of single devices on each node.
filter (list): Functions for filtering useful meshes filter (list): Functions for filtering useful meshes
Returns: Returns:
device_meshed (list) : The possible device meshes. device_meshed (list) : The possible device meshes.
""" """
...@@ -573,10 +968,8 @@ class ClusterPartitionUtil: ...@@ -573,10 +968,8 @@ class ClusterPartitionUtil:
def convert_to_process_meshes(device_mesh: list) -> list: def convert_to_process_meshes(device_mesh: list) -> list:
""" """
Transfer device_meshes into possible process meshes. Transfer device_meshes into possible process meshes.
Args: Args:
device meshes (list): [n,m], one device mesh. device meshes (list): [n,m], one device mesh.
Returns: Returns:
process_meshes (list): Possible process_meshes process_meshes (list): Possible process_meshes
""" """
......
...@@ -95,7 +95,7 @@ def get_gpt_model( ...@@ -95,7 +95,7 @@ def get_gpt_model(
return train_program, start_program, loss, gen_data return train_program, start_program, loss, gen_data
class TestGroupOperators(unittest.TestCase): class TestGroupOperatorsAndPatterns(unittest.TestCase):
def test_gpt(self): def test_gpt(self):
modeling.init_global() modeling.init_global()
train_program = static.Program() train_program = static.Program()
...@@ -117,17 +117,30 @@ class TestGroupOperators(unittest.TestCase): ...@@ -117,17 +117,30 @@ class TestGroupOperators(unittest.TestCase):
) )
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import ( from paddle.distributed.auto_parallel.tuner.rule_based_tuner import (
_PATTERNS, _PATTERNS,
GraphUtil,
RuleBasedTuner, RuleBasedTuner,
convert_to_graph,
) )
dist_context = DistributedContext() dist_context = DistributedContext()
tuner = RuleBasedTuner(dist_context) tuner = RuleBasedTuner(dist_context)
layers = tuner.cluster_operators(train_program.global_block().ops) layers = tuner.cluster_operators(train_program.global_block().ops)
layer = layers[0] graph = GraphUtil.convert_to_graph(train_program.global_block())
graph = convert_to_graph(layer, train_program.global_block())
print("graph: ", graph) print("graph: ", graph)
print("qkv: ", _PATTERNS["qkv"].attrs["shard_spec"]) print("qkv: ", _PATTERNS["qkv"].attrs["shard_spec"])
print("row_matmul: ", _PATTERNS["row_matmul"].attrs["shard_spec"])
print("ffn: ", _PATTERNS["ffn"].attrs["shard_spec"])
print(
"shared_word_embedding: ",
_PATTERNS["shared_word_embedding"].attrs["shard_spec"],
)
print(
"position_embedding: ",
_PATTERNS["position_embedding"].attrs["shard_spec"],
)
print(
"unsqueeze_data: ", _PATTERNS["unsqueeze_data"].attrs["shard_spec"]
)
print("reshape_data: ", _PATTERNS["reshape_data"].attrs["shard_spec"])
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -95,7 +95,7 @@ def get_gpt_model( ...@@ -95,7 +95,7 @@ def get_gpt_model(
return train_program, start_program, loss, gen_data return train_program, start_program, loss, gen_data
class TestGroupOperators(unittest.TestCase): class TestPatternMatch(unittest.TestCase):
def test_gpt(self): def test_gpt(self):
modeling.init_global() modeling.init_global()
train_program = static.Program() train_program = static.Program()
...@@ -116,26 +116,15 @@ class TestGroupOperators(unittest.TestCase): ...@@ -116,26 +116,15 @@ class TestGroupOperators(unittest.TestCase):
DistributedContext, DistributedContext,
) )
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import ( from paddle.distributed.auto_parallel.tuner.rule_based_tuner import (
_PATTERNS, GraphUtil,
RuleBasedTuner, RuleBasedTuner,
convert_to_graph,
match,
) )
dist_context = DistributedContext() dist_context = DistributedContext()
tuner = RuleBasedTuner(dist_context) tuner = RuleBasedTuner(dist_context)
layers = tuner.cluster_operators(train_program.global_block().ops) graph = GraphUtil.convert_to_graph(train_program.global_block())
layer = layers[0] results = GraphUtil.match_all_patterns(graph)
graph = convert_to_graph(layer, train_program.global_block()) print(results)
results = match(_PATTERNS["qkv"], graph)
shard_tensor_infos = _PATTERNS["qkv"].attrs["shard_spec"]
tensor_ids = shard_tensor_infos[0][0]
if results:
for result in results:
for node_id in result:
if node_id in tensor_ids:
print(graph.attrs["id_to_var"][result[node_id]])
print("shard_spec: ", shard_tensor_infos[0][1])
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册