未验证 提交 0bb7c003 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel] Add patterns of rule based tuner (#51859)

* add patterns

* add unittest
上级 cdefcd00
......@@ -22,24 +22,43 @@ _PATTERNS = {}
def register_pattern(cls):
"""Register pattern for rule-based tuner."""
name = cls.name
def register(name):
def register():
global _PATTERNS
_PATTERNS[name] = cls()
pattern = cls()
_PATTERNS[pattern.name] = pattern
# sort patterns according to the number of sharded tensors
# set its dist attr by the fisrt one when a tensor can be matched by multiple patterns.
_PATTERNS = dict(
sorted(
_PATTERNS.items(), key=lambda x: -x[1].attrs["sharded_tensors"]
)
)
register(name)
register()
return cls
class BasePattern(Graph):
name = "base"
"""
Base class of pattern.
The BasePattern inherits the Graph, two important differences are shard_spec and sharded_tensors.
For shard_spec, it indicates the shard specification of tensor node in this pattern under different parallelism.
For sharded_tensors, it represents the number of tensors which sharded.
"""
_name = "base"
def __init__(self):
"""Every pattern has its own name and build method."""
super().__init__()
self.build()
@property
def name(self):
return self.__class__._name
@abstractmethod
def build(self):
pass
......@@ -47,6 +66,8 @@ class BasePattern(Graph):
@register_pattern
class QKVPattern(BasePattern):
"""The QKV pattern defined by GPT model in PaddleFleetX."""
name = "qkv"
def __init__(self):
......@@ -55,43 +76,374 @@ class QKVPattern(BasePattern):
def build(self):
query = self.add_node(0, **{"type": "var"})
# define q, k, v weight
q_weight = self.add_node(1, **{"dim": 2, "type": "param"})
k_weight = self.add_node(2, **{"dim": 2, "type": "param"})
v_weight = self.add_node(3, **{"dim": 2, "type": "param"})
q_matmul = self.add_node(4, **{"type": "matmul_v2"})
k_matmul = self.add_node(5, **{"type": "matmul_v2"})
v_matmul = self.add_node(6, **{"type": "matmul_v2"})
q_x = self.add_edge(0, 4, **{"input_name": "X"})
k_x = self.add_edge(0, 5, **{"input_name": "X"})
v_x = self.add_edge(0, 6, **{"input_name": "X"})
q_y = self.add_edge(1, 4, **{"input_name": "Y"})
k_y = self.add_edge(2, 5, **{"input_name": "Y"})
v_y = self.add_edge(3, 6, **{"input_name": "Y"})
# define q, k, v matmul_v2
q_matmul_v2 = self.add_node(4, **{"type": "matmul_v2"})
k_matmul_v2 = self.add_node(5, **{"type": "matmul_v2"})
v_matmul_v2 = self.add_node(6, **{"type": "matmul_v2"})
# define input edge
q_x_edge = self.add_edge(
query.id, q_matmul_v2.id, **{"input_name": "X"}
)
k_x_edge = self.add_edge(
query.id, k_matmul_v2.id, **{"input_name": "X"}
)
v_x_edge = self.add_edge(
query.id, v_matmul_v2.id, **{"input_name": "X"}
)
q_y_edge = self.add_edge(
q_weight.id, q_matmul_v2.id, **{"input_name": "Y"}
)
k_y_edge = self.add_edge(
k_weight.id, k_matmul_v2.id, **{"input_name": "Y"}
)
v_y_edge = self.add_edge(
v_weight.id, v_matmul_v2.id, **{"input_name": "Y"}
)
# define q, k, v matmul_v2 output
q = self.add_node(7, **{"type": "var"})
k = self.add_node(8, **{"type": "var"})
v = self.add_node(9, **{"type": "var"})
q_out = self.add_edge(4, 7, **{"output_name": "Out"})
k_out = self.add_edge(5, 8, **{"output_name": "Out"})
v_out = self.add_edge(6, 9, **{"output_name": "Out"})
# define output edge
q_out_edge = self.add_edge(
q_matmul_v2.id, q.id, **{"output_name": "Out"}
)
k_out_edge = self.add_edge(
k_matmul_v2.id, k.id, **{"output_name": "Out"}
)
v_out_edge = self.add_edge(
v_matmul_v2.id, v.id, **{"output_name": "Out"}
)
# define shard_spec
shard_spec = {
"dp_mp": {
0: [0, -1, -1],
1: [-1, 1],
2: [-1, 1],
3: [-1, 1],
},
"mp_dp": {
0: [1, -1, -1],
1: [-1, 0],
2: [-1, 0],
3: [-1, 0],
},
"mp": {0: [-1, -1, -1], 1: [-1, 0], 2: [-1, 0], 3: [-1, 0]},
"dp": {
0: [0, -1, -1],
1: [-1, -1],
2: [-1, -1],
3: [-1, -1],
},
}
self.attrs["shard_spec"] = shard_spec
# define sharded_tensors
self.attrs["sharded_tensors"] = 4
@register_pattern
class RowMatmulPattern(BasePattern):
"""Row matmul pattern defined by GPT model in PaddleFleetX."""
name = "row_matmul"
def __init__(self):
super().__init__()
def build(self):
# define reshape input
input = self.add_node(0, **{"type": "var"})
# define reshape
reshape = self.add_node(1, **{"type": "reshape2"})
# define reshape input egde
x_edge = self.add_edge(input.id, reshape.id, **{"input_name": "X"})
# define reshape out
output = self.add_node(2, **{"type": "var"})
# define reshape output edge
out_edge = self.add_edge(
reshape.id, output.id, **{"output_name": "Out"}
)
# define matmul_v2 weight
weight = self.add_node(3, **{"dim": 2, "type": "param"})
# define matmul_v2
matmul_v2 = self.add_node(4, **{"type": "matmul_v2"})
# Pattern
self.attrs["shard_spec"] = [
[(1, 2, 3), [[-1, 0], [-1, 1]]],
] # 2-tuple list such as [(tensor_id, shard_spec)]
# define input edge
x_edge = self.add_edge(output.id, matmul_v2.id, **{"input_name": "X"})
y_edge = self.add_edge(weight.id, matmul_v2.id, **{"input_name": "Y"})
# define q, k, v matmul_v2 output
output = self.add_node(5, **{"type": "var"})
# define output edge
out_edge = self.add_edge(
matmul_v2.id, output.id, **{"output_name": "Out"}
)
# define shard_spec
shard_spec = {
"dp_mp": {
3: [1, -1],
},
"mp_dp": {
3: [0, -1],
},
"mp": {3: [0, -1]},
"dp": {
3: [-1, -1],
},
}
self.attrs["shard_spec"] = shard_spec
# define sharded_tensors
self.attrs["sharded_tensors"] = 1
@register_pattern
class FFNPattrern(BasePattern):
"""FFN pattern defined by GPT model in PaddleFleetX."""
name = "ffn"
def __init__(self):
super().__init__()
def build(self):
x = self.add_node(0, **{"type": "var"})
w1_weight = self.add_node(1, **{"dim": 2, "type": "param"})
w1_matmul = self.add_node(2, **{"type": "matmul_v2"})
w1_x = self.add_edge(0, 2, **{"input_name": "X"})
w1_y = self.add_edge(1, 2, **{"input_name": "Y"})
out1 = self.add_node(3, **{"type": "var"})
w1_out = self.add_edge(2, 3, **{"output_name": "Out"})
w1_b = self.add_node(4, **{"dim": 1, "type": "param"})
add1 = self.add_node(5, **{"type": "elementwise_add"})
add1_x = self.add_edge(3, 5, **{"input_name": "X"})
add1_y = self.add_edge(4, 5, **{"input_name": "Y"})
out2 = self.add_node(6, **{"type": "var"})
add1_out = self.add_edge(5, 6, **{"output_name": "Out"})
gelu = self.add_node(7, **{"type": "gelu"})
gelu_x = self.add_edge(6, 7, **{"input_name": "X"})
out3 = self.add_node(8, **{"type": "var"})
gelu_out = self.add_edge(7, 8, **{"output_name": "Out"})
w2_weight = self.add_node(9, **{"dim": 2, "type": "param"})
w2_matmul = self.add_node(10, **{"type": "matmul_v2"})
w1_x = self.add_edge(8, 10, **{"input_name": "X"})
w1_y = self.add_edge(9, 10, **{"input_name": "Y"})
out4 = self.add_node(11, **{"type": "var"})
w2_out = self.add_edge(10, 11, **{"output_name": "Out"})
w2_b = self.add_node(12, **{"dim": 1, "type": "param"})
add2 = self.add_node(13, **{"type": "elementwise_add"})
add2_x = self.add_edge(11, 13, **{"input_name": "X"})
add2_y = self.add_edge(12, 13, **{"input_name": "Y"})
out5 = self.add_node(14, **{"type": "var"})
add2_out = self.add_edge(13, 14, **{"output_name": "Out"})
# define shard_spec
shard_spec = {
"dp_mp": {0: [0, -1, -1], 1: [-1, 1], 9: [1, -1]},
"mp_dp": {0: [1, -1, -1], 1: [-1, 0], 9: [0, -1]},
"mp": {1: [-1, 0], 9: [0, -1]},
"dp": {0: [0, -1, -1], 1: [-1, -1], 9: [-1, -1]},
}
self.attrs["shard_spec"] = shard_spec
# define sharded_tensors
self.attrs["sharded_tensors"] = 2
@register_pattern
class SharedWordEmbeddingPattern(BasePattern):
"""Sharded word embedding pattern defined by GPT model in PaddleFleetX."""
name = "shared_word_embedding"
def __init__(self):
super().__init__()
def convert_to_graph(ops, block):
def build(self):
# define embedding input
tokens = self.add_node(0, **{"type": "data"})
word_embeddings = self.add_node(1, **{"dim": 2, "type": "param"})
# define embedding
embedding = self.add_node(2, **{"type": "lookup_table_v2"})
# define embedding input edge
ids = self.add_edge(0, 2, **{"input_name": "Ids"})
w = self.add_edge(1, 2, **{"input_name": "W"})
# define embedding output
out = self.add_node(3, **{"type": "var"})
# define embedding output edge
out_edge = self.add_edge(2, 3, **{"output_name": "Out"})
# define matmul_v2 input
x = self.add_node(4, **{"type": "var"})
# define matmul_v2
matmul = self.add_node(5, **{"type": "matmul_v2"})
# define matmul_v2 input edge
x_edge = self.add_edge(4, 5, **{"input_name": "X"})
y_edge = self.add_edge(1, 5, **{"input_name": "Y"})
# define matmul_v2 output
out = self.add_node(6, **{"type": "var"})
# define matmul_v2 output edge
out_edge = self.add_edge(5, 6, **{"output_name": "Out"})
# define shard_spec
shard_spec = {
"dp_mp": {0: [0, -1], 1: [1, -1], 4: [0, -1, -1]},
"mp_dp": {0: [1, -1], 1: [0, -1], 4: [1, -1, -1]},
"mp": {0: [-1, -1], 1: [0, -1], 4: [-1, -1, -1]},
"dp": {0: [0, -1], 1: [-1, -1], 4: [0, -1, -1]},
}
self.attrs["shard_spec"] = shard_spec
self.attrs["sharded_tensors"] = 3
@register_pattern
class PositionEmbeddingPattern(BasePattern):
"""Position embedding pattern defined by GPT model in PaddleFleetX."""
name = "position_embedding"
def __init__(self):
super().__init__()
def build(self):
# define embedding input
tokens = self.add_node(0, **{"type": "data"})
word_embeddings = self.add_node(1, **{"dim": 2, "type": "param"})
# define embedding
embedding = self.add_node(2, **{"type": "lookup_table_v2"})
# define embedding input edge
ids = self.add_edge(0, 2, **{"input_name": "Ids"})
w = self.add_edge(1, 2, **{"input_name": "W"})
# define embedding output
out = self.add_node(3, **{"type": "var"})
# define embedding output edge
out_edge = self.add_edge(2, 3, **{"output_name": "Out"})
# define shard_spec
shard_spec = {
"dp_mp": {0: [0, -1], 1: [-1, -1], 3: [-1, -1, -1]},
"mp_dp": {0: [1, -1], 1: [-1, -1], 3: [1, -1, -1]},
"mp": {0: [-1, -1], 1: [-1, -1], 3: [-1, -1, -1]},
"dp": {0: [0, -1], 1: [-1, -1], 3: [0, -1, -1]},
}
self.attrs["shard_spec"] = shard_spec
# define sharded_tensors
self.attrs["sharded_tensors"] = 1
@register_pattern
class UnsqueezeDataPattern(BasePattern):
"""Unsqueeze data pattern defined by GPT model in the PaddleFleetX."""
name = "unsqueeze_data"
def __init__(self):
super().__init__()
def build(self):
# define unsequeeze input
tokens = self.add_node(0, **{"type": "data"})
# define unsequeeze
unsqueeze = self.add_node(1, **{"type": "unsqueeze2"})
# define unsequeeze input edge
x_edge = self.add_edge(0, 1, **{"input_name": "X"})
# pattern: pure mp or hybrid dp+mp
shard_spec = {
"dp_mp": {0: [0, -1]},
"mp_dp": {0: [1, -1]},
"mp": {0: [-1, -1]},
"dp": {0: [0, -1]},
}
self.attrs["shard_spec"] = shard_spec
self.attrs["sharded_tensors"] = 1
@register_pattern
class ReshapeDataPattern(BasePattern):
"""Reshape data pattern defined by GPT model in PaddleFleetX."""
name = "reshape_data"
def __init__(self):
super().__init__()
def build(self):
# define unsequeeze input
data = self.add_node(0, **{"type": "data"})
# define unsequeeze
reshape = self.add_node(1, **{"type": "reshape2"})
# define unsequeeze input edge
x_edge = self.add_edge(0, 1, **{"input_name": "X"})
# define shard_spec
shard_spec = {
"dp_mp": {0: [0, -1]},
"mp_dp": {0: [1, -1]},
"mp": {0: [-1, -1]},
"dp": {0: [0, -1]},
}
self.attrs["shard_spec"] = shard_spec
# define sharded_tensors
self.attrs["sharded_tensors"] = 1
class GraphUtil:
"""Graph util is used to convert ops to graph or match pattern for graph."""
@staticmethod
def convert_to_graph(block):
"""Convert ops to graph."""
graph = Graph()
graph.attrs["var_to_id"] = {} # {var_name: node_id}
graph.attrs["id_to_var"] = {} # {node_id: var_name}
graph.attrs["id_to_var_desc_id"] = {} # {node_id: var_desc_id}
graph.attrs["id_to_var_name"] = {}
graph.attrs["op_to_id"] = {} # {op_id: node_id}
graph.attrs["id_to_op"] = {} # {node_id: op_id}
graph.attrs["id_to_op"] = {} # {node_id: op}
ops = block.ops
node_id = -1
for op in ops:
attrs = op.all_attrs()
......@@ -101,7 +453,7 @@ def convert_to_graph(ops, block):
# create op node
op_node = graph.add_node(node_id, **attrs)
graph.attrs["op_to_id"][op.desc.id()] = op_node.id
graph.attrs["id_to_op"][op_node.id] = op.desc.id()
graph.attrs["id_to_op"][op_node.id] = op
graph._attr_to_nodes[op_node.id] = {}
for input_name in op.input_names:
graph._attr_to_nodes[op_node.id][input_name] = []
......@@ -114,10 +466,16 @@ def convert_to_graph(ops, block):
if var.is_parameter:
var_node.attrs["type"] = "param"
var_node.attrs["dim"] = len(var.shape)
elif var.is_data:
var_node.attrs["type"] = "data"
var_node.attrs["dim"] = len(var.shape)
else:
var_node.attrs["type"] = "var"
graph.attrs["var_to_id"][var_name] = var_node.id
graph.attrs["id_to_var"][var_node.id] = var_name
graph.attrs["id_to_var_desc_id"][
var_node.id
] = var.desc.original_id()
graph.attrs["id_to_var_name"][var_node.id] = var_name
else:
var_node_id = graph.attrs["var_to_id"][var_name]
var_node = graph._nodes[var_node_id]
......@@ -125,7 +483,9 @@ def convert_to_graph(ops, block):
# create edge that input -> op
input_edge = graph.add_edge(var_node.id, op_node.id)
input_edge.attrs["input_name"] = input_name
graph._attr_to_nodes[op_node.id][input_name].append(var_node)
graph._attr_to_nodes[op_node.id][input_name].append(
var_node
)
for output_name in op.output_names:
graph._attr_to_nodes[op_node.id][output_name] = []
......@@ -140,7 +500,12 @@ def convert_to_graph(ops, block):
else:
var_node.attrs["type"] = "var"
graph.attrs["var_to_id"][var_name] = var_node.id
graph.attrs["id_to_var"][var_node.id] = var_name
graph.attrs["id_to_var_desc_id"][
var_node.id
] = var.desc.original_id()
graph.attrs["id_to_var_name"][
var_node.id
] = var_name
else:
var_node_id = graph.attrs["var_to_id"][var_name]
var_node = graph._nodes[var_node_id]
......@@ -155,24 +520,24 @@ def convert_to_graph(ops, block):
return graph
def match(pattern, graph):
@staticmethod
def match_pattern(pattern, graph):
def _is_op_node(node):
"""Judge whether node is op node"""
"""Judge whether node is op node."""
if node.attrs["type"] not in ["var", "param", "data"]:
return True
return False
def _compare_op_node(src, tgt):
"""Compare whether two op nodes are equal"""
"""Compare whether two op nodes are equivalent."""
if src.attrs["type"] != tgt.attrs["type"]:
return False
return True
def _compare_var_node(src, tgt):
"""Compare whether two var nodes are equal"""
"""Compare whether two var nodes are equivalent."""
for key in src.attrs:
if key not in tgt.attrs:
return False
......@@ -183,13 +548,14 @@ def match(pattern, graph):
def _match_core(src_node, tgt_node):
nonlocal not_matched
# do not support one input name or output name corresponding to multiple vars
# not support one input name or output name corresponding to multiple vars
if not_matched:
return
if _is_op_node(src_node):
# compare op node whether equal
if not _compare_op_node(src_node, tgt_node):
not_matched = True
return
result[src_node.id] = tgt_node.id
......@@ -232,14 +598,14 @@ def match(pattern, graph):
_match_core(node, compare_nodes[0])
else:
# compare var node whether equal
# compare var nodes whether equal
if not _compare_var_node(src_node, tgt_node):
not_matched = True
return
result[src_node.id] = tgt_node.id
# as input for op nodes
# as input for op node
src_as_input_node_ids = src_edges[src_node.id].keys()
for node_id in src_as_input_node_ids:
if node_id in result:
......@@ -264,7 +630,7 @@ def match(pattern, graph):
return
_match_core(src_nodes[node_id], compare_node)
# as output for nodes
# as output for op node
src_as_output_nodes = src_reverse_adjs[src_node.id]
for node in src_as_output_nodes:
if node.id in result:
......@@ -273,10 +639,11 @@ def match(pattern, graph):
src_edge = src_edges[node.id][src_node.id]
output_name = src_edge.attrs["output_name"]
compare_node_ids = tgt_reverse_adjs[tgt_node.id]
compare_nodes = tgt_reverse_adjs[tgt_node.id]
compare_node = None
for node_id in compare_node_ids:
for item in compare_nodes:
node_id = item.id
edge = tgt_edges[node_id][tgt_node.id]
if edge.attrs["output_name"] == output_name:
compare_node = tgt_nodes[node_id]
......@@ -284,11 +651,12 @@ def match(pattern, graph):
if not compare_node:
not_matched = True
return
_match_core(src_nodes[node_id], compare_node)
_match_core(src_nodes[node.id], compare_node)
results = []
matched_ids = set()
matched_op_node_ids = set()
result = {}
has_matched = set()
src_nodes = pattern.nodes
src_edges = pattern._adjs
src_reverse_adjs = pattern._reverse_adjs
......@@ -297,7 +665,6 @@ def match(pattern, graph):
tgt_edges = graph._adjs
tgt_reverse_adjs = graph._reverse_adjs
tgt_attr_to_nodes = graph._attr_to_nodes
not_matched = False
# starts with a op node
src_start_node = None
......@@ -311,27 +678,55 @@ def match(pattern, graph):
for node_id in tgt_nodes:
node = tgt_nodes[node_id]
if node.attrs["type"] == src_start_node.attrs["type"]:
not_matched = False
_match_core(src_start_node, node)
if not not_matched:
need_to_append = True
for value in result.values():
if value in has_matched:
if value in matched_op_node_ids:
result = {}
need_to_append = False
break
if need_to_append:
results.append(result)
for value in result.values():
has_matched.add(value)
matched_ids.add(value)
if value in graph.attrs["id_to_op"].keys():
matched_op_node_ids.add(value)
result = {}
else:
not_matched = False
result = {}
return results, matched_ids
return results
@staticmethod
def match_all_patterns(graph):
# matched_results maps pattern_name to list which contains pattern node id to graph node id mapping,
# such as {"pattern_name": [{pattern_node_id: graph_node}, ]}
matched_results = {}
matched_ids = set()
for pattern_name in _PATTERNS:
pattern = _PATTERNS[pattern_name]
results, matched = GraphUtil.match_pattern(pattern, graph)
for result in results:
has_matched = False
for id in result:
if result[id] in matched_ids:
has_matched = True
break
if not has_matched:
for item in result:
matched_ids.add(result[id])
if pattern.name not in matched_results:
matched_results[pattern.name] = []
matched_results[pattern.name].append(result)
return matched_results
class OperatorClusteringUtil:
"""Operator clustering util is used to cluster operators to layers."""
common_starts = ["layer_norm", "matmul_v2", "matmul"]
@staticmethod
......@@ -506,6 +901,8 @@ class OperatorClusteringUtil:
class ClusterPartitionUtil:
"""Cluster partition util is used to get device meshes and process meshes."""
@staticmethod
def factorization(num):
factors = []
......@@ -535,13 +932,11 @@ class ClusterPartitionUtil:
],
) -> list:
"""
Partition cluster into possible device meshes.
Partiton cluster into possible device meshes.
Args:
n (int): The number of nodes.
m (int): The number of single devices on each node.
filter (list): Functions for filtering useful meshes
Returns:
device_meshed (list) : The possible device meshes.
"""
......@@ -573,10 +968,8 @@ class ClusterPartitionUtil:
def convert_to_process_meshes(device_mesh: list) -> list:
"""
Transfer device_meshes into possible process meshes.
Args:
device meshes (list): [n,m], one device mesh.
Returns:
process_meshes (list): Possible process_meshes
"""
......
......@@ -95,7 +95,7 @@ def get_gpt_model(
return train_program, start_program, loss, gen_data
class TestGroupOperators(unittest.TestCase):
class TestGroupOperatorsAndPatterns(unittest.TestCase):
def test_gpt(self):
modeling.init_global()
train_program = static.Program()
......@@ -117,17 +117,30 @@ class TestGroupOperators(unittest.TestCase):
)
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import (
_PATTERNS,
GraphUtil,
RuleBasedTuner,
convert_to_graph,
)
dist_context = DistributedContext()
tuner = RuleBasedTuner(dist_context)
layers = tuner.cluster_operators(train_program.global_block().ops)
layer = layers[0]
graph = convert_to_graph(layer, train_program.global_block())
graph = GraphUtil.convert_to_graph(train_program.global_block())
print("graph: ", graph)
print("qkv: ", _PATTERNS["qkv"].attrs["shard_spec"])
print("row_matmul: ", _PATTERNS["row_matmul"].attrs["shard_spec"])
print("ffn: ", _PATTERNS["ffn"].attrs["shard_spec"])
print(
"shared_word_embedding: ",
_PATTERNS["shared_word_embedding"].attrs["shard_spec"],
)
print(
"position_embedding: ",
_PATTERNS["position_embedding"].attrs["shard_spec"],
)
print(
"unsqueeze_data: ", _PATTERNS["unsqueeze_data"].attrs["shard_spec"]
)
print("reshape_data: ", _PATTERNS["reshape_data"].attrs["shard_spec"])
if __name__ == "__main__":
......
......@@ -95,7 +95,7 @@ def get_gpt_model(
return train_program, start_program, loss, gen_data
class TestGroupOperators(unittest.TestCase):
class TestPatternMatch(unittest.TestCase):
def test_gpt(self):
modeling.init_global()
train_program = static.Program()
......@@ -116,26 +116,15 @@ class TestGroupOperators(unittest.TestCase):
DistributedContext,
)
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import (
_PATTERNS,
GraphUtil,
RuleBasedTuner,
convert_to_graph,
match,
)
dist_context = DistributedContext()
tuner = RuleBasedTuner(dist_context)
layers = tuner.cluster_operators(train_program.global_block().ops)
layer = layers[0]
graph = convert_to_graph(layer, train_program.global_block())
results = match(_PATTERNS["qkv"], graph)
shard_tensor_infos = _PATTERNS["qkv"].attrs["shard_spec"]
tensor_ids = shard_tensor_infos[0][0]
if results:
for result in results:
for node_id in result:
if node_id in tensor_ids:
print(graph.attrs["id_to_var"][result[node_id]])
print("shard_spec: ", shard_tensor_infos[0][1])
graph = GraphUtil.convert_to_graph(train_program.global_block())
results = GraphUtil.match_all_patterns(graph)
print(results)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册