From 0bb7c00382a96a0fcdd89ff28ed657a019cefa47 Mon Sep 17 00:00:00 2001 From: caozhou <48191911+Caozhou1995@users.noreply.github.com> Date: Tue, 21 Mar 2023 11:12:40 +0800 Subject: [PATCH] [Auto Parallel] Add patterns of rule based tuner (#51859) * add patterns * add unittest --- .../auto_parallel/tuner/rule_based_tuner.py | 855 +++++++++++++----- .../unittests/auto_parallel/test_pattern.py | 21 +- .../auto_parallel/test_pattern_match.py | 21 +- 3 files changed, 646 insertions(+), 251 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py b/python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py index 6c08dd4d206..86038d97d22 100644 --- a/python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py +++ b/python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py @@ -22,24 +22,43 @@ _PATTERNS = {} def register_pattern(cls): """Register pattern for rule-based tuner.""" - name = cls.name - def register(name): + def register(): global _PATTERNS - _PATTERNS[name] = cls() + pattern = cls() + _PATTERNS[pattern.name] = pattern + # sort patterns according to the number of sharded tensors + # set its dist attr by the fisrt one when a tensor can be matched by multiple patterns. + _PATTERNS = dict( + sorted( + _PATTERNS.items(), key=lambda x: -x[1].attrs["sharded_tensors"] + ) + ) - register(name) + register() return cls class BasePattern(Graph): - name = "base" + """ + Base class of pattern. + The BasePattern inherits the Graph, two important differences are shard_spec and sharded_tensors. + For shard_spec, it indicates the shard specification of tensor node in this pattern under different parallelism. + For sharded_tensors, it represents the number of tensors which sharded. + """ + + _name = "base" def __init__(self): + """Every pattern has its own name and build method.""" super().__init__() self.build() + @property + def name(self): + return self.__class__._name + @abstractmethod def build(self): pass @@ -47,6 +66,8 @@ class BasePattern(Graph): @register_pattern class QKVPattern(BasePattern): + """The QKV pattern defined by GPT model in PaddleFleetX.""" + name = "qkv" def __init__(self): @@ -55,81 +76,388 @@ class QKVPattern(BasePattern): def build(self): query = self.add_node(0, **{"type": "var"}) + # define q, k, v weight q_weight = self.add_node(1, **{"dim": 2, "type": "param"}) k_weight = self.add_node(2, **{"dim": 2, "type": "param"}) v_weight = self.add_node(3, **{"dim": 2, "type": "param"}) - - q_matmul = self.add_node(4, **{"type": "matmul_v2"}) - k_matmul = self.add_node(5, **{"type": "matmul_v2"}) - v_matmul = self.add_node(6, **{"type": "matmul_v2"}) - - q_x = self.add_edge(0, 4, **{"input_name": "X"}) - k_x = self.add_edge(0, 5, **{"input_name": "X"}) - v_x = self.add_edge(0, 6, **{"input_name": "X"}) - q_y = self.add_edge(1, 4, **{"input_name": "Y"}) - k_y = self.add_edge(2, 5, **{"input_name": "Y"}) - v_y = self.add_edge(3, 6, **{"input_name": "Y"}) - + # define q, k, v matmul_v2 + q_matmul_v2 = self.add_node(4, **{"type": "matmul_v2"}) + k_matmul_v2 = self.add_node(5, **{"type": "matmul_v2"}) + v_matmul_v2 = self.add_node(6, **{"type": "matmul_v2"}) + # define input edge + q_x_edge = self.add_edge( + query.id, q_matmul_v2.id, **{"input_name": "X"} + ) + k_x_edge = self.add_edge( + query.id, k_matmul_v2.id, **{"input_name": "X"} + ) + v_x_edge = self.add_edge( + query.id, v_matmul_v2.id, **{"input_name": "X"} + ) + q_y_edge = self.add_edge( + q_weight.id, q_matmul_v2.id, **{"input_name": "Y"} + ) + k_y_edge = self.add_edge( + k_weight.id, k_matmul_v2.id, **{"input_name": "Y"} + ) + v_y_edge = self.add_edge( + v_weight.id, v_matmul_v2.id, **{"input_name": "Y"} + ) + # define q, k, v matmul_v2 output q = self.add_node(7, **{"type": "var"}) k = self.add_node(8, **{"type": "var"}) v = self.add_node(9, **{"type": "var"}) - q_out = self.add_edge(4, 7, **{"output_name": "Out"}) - k_out = self.add_edge(5, 8, **{"output_name": "Out"}) - v_out = self.add_edge(6, 9, **{"output_name": "Out"}) - - # Pattern - self.attrs["shard_spec"] = [ - [(1, 2, 3), [[-1, 0], [-1, 1]]], - ] # 2-tuple list such as [(tensor_id, shard_spec)] - - -def convert_to_graph(ops, block): - """Convert ops to graph.""" - graph = Graph() - graph.attrs["var_to_id"] = {} # {var_name: node_id} - graph.attrs["id_to_var"] = {} # {node_id: var_name} - graph.attrs["op_to_id"] = {} # {op_id: node_id} - graph.attrs["id_to_op"] = {} # {node_id: op_id} - - node_id = -1 - for op in ops: - attrs = op.all_attrs() - attrs["type"] = op.type - node_id += 1 - - # create op node - op_node = graph.add_node(node_id, **attrs) - graph.attrs["op_to_id"][op.desc.id()] = op_node.id - graph.attrs["id_to_op"][op_node.id] = op.desc.id() - graph._attr_to_nodes[op_node.id] = {} - for input_name in op.input_names: - graph._attr_to_nodes[op_node.id][input_name] = [] - for var_name in op.input(input_name): - if var_name not in graph.attrs["var_to_id"]: - # create var node - node_id += 1 - var_node = graph.add_node(node_id) - var = block._var_recursive(var_name) - if var.is_parameter: - var_node.attrs["type"] = "param" - var_node.attrs["dim"] = len(var.shape) - else: - var_node.attrs["type"] = "var" - graph.attrs["var_to_id"][var_name] = var_node.id - graph.attrs["id_to_var"][var_node.id] = var_name - else: - var_node_id = graph.attrs["var_to_id"][var_name] - var_node = graph._nodes[var_node_id] + # define output edge + q_out_edge = self.add_edge( + q_matmul_v2.id, q.id, **{"output_name": "Out"} + ) + k_out_edge = self.add_edge( + k_matmul_v2.id, k.id, **{"output_name": "Out"} + ) + v_out_edge = self.add_edge( + v_matmul_v2.id, v.id, **{"output_name": "Out"} + ) + + # define shard_spec + shard_spec = { + "dp_mp": { + 0: [0, -1, -1], + 1: [-1, 1], + 2: [-1, 1], + 3: [-1, 1], + }, + "mp_dp": { + 0: [1, -1, -1], + 1: [-1, 0], + 2: [-1, 0], + 3: [-1, 0], + }, + "mp": {0: [-1, -1, -1], 1: [-1, 0], 2: [-1, 0], 3: [-1, 0]}, + "dp": { + 0: [0, -1, -1], + 1: [-1, -1], + 2: [-1, -1], + 3: [-1, -1], + }, + } + self.attrs["shard_spec"] = shard_spec + # define sharded_tensors + self.attrs["sharded_tensors"] = 4 + + +@register_pattern +class RowMatmulPattern(BasePattern): + """Row matmul pattern defined by GPT model in PaddleFleetX.""" + + name = "row_matmul" + + def __init__(self): + super().__init__() + + def build(self): + # define reshape input + input = self.add_node(0, **{"type": "var"}) + + # define reshape + reshape = self.add_node(1, **{"type": "reshape2"}) + + # define reshape input egde + x_edge = self.add_edge(input.id, reshape.id, **{"input_name": "X"}) + + # define reshape out + output = self.add_node(2, **{"type": "var"}) + + # define reshape output edge + out_edge = self.add_edge( + reshape.id, output.id, **{"output_name": "Out"} + ) + + # define matmul_v2 weight + weight = self.add_node(3, **{"dim": 2, "type": "param"}) + + # define matmul_v2 + matmul_v2 = self.add_node(4, **{"type": "matmul_v2"}) + + # define input edge + x_edge = self.add_edge(output.id, matmul_v2.id, **{"input_name": "X"}) + y_edge = self.add_edge(weight.id, matmul_v2.id, **{"input_name": "Y"}) + + # define q, k, v matmul_v2 output + output = self.add_node(5, **{"type": "var"}) + + # define output edge + out_edge = self.add_edge( + matmul_v2.id, output.id, **{"output_name": "Out"} + ) + + # define shard_spec + shard_spec = { + "dp_mp": { + 3: [1, -1], + }, + "mp_dp": { + 3: [0, -1], + }, + "mp": {3: [0, -1]}, + "dp": { + 3: [-1, -1], + }, + } + self.attrs["shard_spec"] = shard_spec + + # define sharded_tensors + self.attrs["sharded_tensors"] = 1 + + +@register_pattern +class FFNPattrern(BasePattern): + """FFN pattern defined by GPT model in PaddleFleetX.""" + + name = "ffn" + + def __init__(self): + super().__init__() + + def build(self): + x = self.add_node(0, **{"type": "var"}) + + w1_weight = self.add_node(1, **{"dim": 2, "type": "param"}) + w1_matmul = self.add_node(2, **{"type": "matmul_v2"}) + + w1_x = self.add_edge(0, 2, **{"input_name": "X"}) + w1_y = self.add_edge(1, 2, **{"input_name": "Y"}) + + out1 = self.add_node(3, **{"type": "var"}) + w1_out = self.add_edge(2, 3, **{"output_name": "Out"}) + + w1_b = self.add_node(4, **{"dim": 1, "type": "param"}) + add1 = self.add_node(5, **{"type": "elementwise_add"}) + + add1_x = self.add_edge(3, 5, **{"input_name": "X"}) + add1_y = self.add_edge(4, 5, **{"input_name": "Y"}) + + out2 = self.add_node(6, **{"type": "var"}) + add1_out = self.add_edge(5, 6, **{"output_name": "Out"}) + + gelu = self.add_node(7, **{"type": "gelu"}) + + gelu_x = self.add_edge(6, 7, **{"input_name": "X"}) + out3 = self.add_node(8, **{"type": "var"}) + gelu_out = self.add_edge(7, 8, **{"output_name": "Out"}) + + w2_weight = self.add_node(9, **{"dim": 2, "type": "param"}) + w2_matmul = self.add_node(10, **{"type": "matmul_v2"}) + + w1_x = self.add_edge(8, 10, **{"input_name": "X"}) + w1_y = self.add_edge(9, 10, **{"input_name": "Y"}) + + out4 = self.add_node(11, **{"type": "var"}) + w2_out = self.add_edge(10, 11, **{"output_name": "Out"}) + + w2_b = self.add_node(12, **{"dim": 1, "type": "param"}) + add2 = self.add_node(13, **{"type": "elementwise_add"}) + + add2_x = self.add_edge(11, 13, **{"input_name": "X"}) + add2_y = self.add_edge(12, 13, **{"input_name": "Y"}) + + out5 = self.add_node(14, **{"type": "var"}) + add2_out = self.add_edge(13, 14, **{"output_name": "Out"}) + + # define shard_spec + shard_spec = { + "dp_mp": {0: [0, -1, -1], 1: [-1, 1], 9: [1, -1]}, + "mp_dp": {0: [1, -1, -1], 1: [-1, 0], 9: [0, -1]}, + "mp": {1: [-1, 0], 9: [0, -1]}, + "dp": {0: [0, -1, -1], 1: [-1, -1], 9: [-1, -1]}, + } + self.attrs["shard_spec"] = shard_spec + + # define sharded_tensors + self.attrs["sharded_tensors"] = 2 + + +@register_pattern +class SharedWordEmbeddingPattern(BasePattern): + """Sharded word embedding pattern defined by GPT model in PaddleFleetX.""" + + name = "shared_word_embedding" + + def __init__(self): + super().__init__() + + def build(self): + # define embedding input + tokens = self.add_node(0, **{"type": "data"}) + word_embeddings = self.add_node(1, **{"dim": 2, "type": "param"}) + + # define embedding + embedding = self.add_node(2, **{"type": "lookup_table_v2"}) + + # define embedding input edge + ids = self.add_edge(0, 2, **{"input_name": "Ids"}) + w = self.add_edge(1, 2, **{"input_name": "W"}) + + # define embedding output + out = self.add_node(3, **{"type": "var"}) + + # define embedding output edge + out_edge = self.add_edge(2, 3, **{"output_name": "Out"}) + + # define matmul_v2 input + x = self.add_node(4, **{"type": "var"}) + + # define matmul_v2 + matmul = self.add_node(5, **{"type": "matmul_v2"}) + + # define matmul_v2 input edge + x_edge = self.add_edge(4, 5, **{"input_name": "X"}) + y_edge = self.add_edge(1, 5, **{"input_name": "Y"}) + + # define matmul_v2 output + out = self.add_node(6, **{"type": "var"}) + + # define matmul_v2 output edge + out_edge = self.add_edge(5, 6, **{"output_name": "Out"}) + + # define shard_spec + shard_spec = { + "dp_mp": {0: [0, -1], 1: [1, -1], 4: [0, -1, -1]}, + "mp_dp": {0: [1, -1], 1: [0, -1], 4: [1, -1, -1]}, + "mp": {0: [-1, -1], 1: [0, -1], 4: [-1, -1, -1]}, + "dp": {0: [0, -1], 1: [-1, -1], 4: [0, -1, -1]}, + } + self.attrs["shard_spec"] = shard_spec + self.attrs["sharded_tensors"] = 3 + + +@register_pattern +class PositionEmbeddingPattern(BasePattern): + """Position embedding pattern defined by GPT model in PaddleFleetX.""" + + name = "position_embedding" + + def __init__(self): + super().__init__() + + def build(self): + # define embedding input + tokens = self.add_node(0, **{"type": "data"}) + word_embeddings = self.add_node(1, **{"dim": 2, "type": "param"}) + + # define embedding + embedding = self.add_node(2, **{"type": "lookup_table_v2"}) + + # define embedding input edge + ids = self.add_edge(0, 2, **{"input_name": "Ids"}) + w = self.add_edge(1, 2, **{"input_name": "W"}) - # create edge that input -> op - input_edge = graph.add_edge(var_node.id, op_node.id) - input_edge.attrs["input_name"] = input_name - graph._attr_to_nodes[op_node.id][input_name].append(var_node) + # define embedding output + out = self.add_node(3, **{"type": "var"}) - for output_name in op.output_names: - graph._attr_to_nodes[op_node.id][output_name] = [] - for var_name in op.output(output_name): + # define embedding output edge + out_edge = self.add_edge(2, 3, **{"output_name": "Out"}) + + # define shard_spec + shard_spec = { + "dp_mp": {0: [0, -1], 1: [-1, -1], 3: [-1, -1, -1]}, + "mp_dp": {0: [1, -1], 1: [-1, -1], 3: [1, -1, -1]}, + "mp": {0: [-1, -1], 1: [-1, -1], 3: [-1, -1, -1]}, + "dp": {0: [0, -1], 1: [-1, -1], 3: [0, -1, -1]}, + } + self.attrs["shard_spec"] = shard_spec + + # define sharded_tensors + self.attrs["sharded_tensors"] = 1 + + +@register_pattern +class UnsqueezeDataPattern(BasePattern): + """Unsqueeze data pattern defined by GPT model in the PaddleFleetX.""" + + name = "unsqueeze_data" + + def __init__(self): + super().__init__() + + def build(self): + # define unsequeeze input + tokens = self.add_node(0, **{"type": "data"}) + # define unsequeeze + unsqueeze = self.add_node(1, **{"type": "unsqueeze2"}) + # define unsequeeze input edge + x_edge = self.add_edge(0, 1, **{"input_name": "X"}) + # pattern: pure mp or hybrid dp+mp + shard_spec = { + "dp_mp": {0: [0, -1]}, + "mp_dp": {0: [1, -1]}, + "mp": {0: [-1, -1]}, + "dp": {0: [0, -1]}, + } + self.attrs["shard_spec"] = shard_spec + self.attrs["sharded_tensors"] = 1 + + +@register_pattern +class ReshapeDataPattern(BasePattern): + """Reshape data pattern defined by GPT model in PaddleFleetX.""" + + name = "reshape_data" + + def __init__(self): + super().__init__() + + def build(self): + # define unsequeeze input + data = self.add_node(0, **{"type": "data"}) + + # define unsequeeze + reshape = self.add_node(1, **{"type": "reshape2"}) + + # define unsequeeze input edge + x_edge = self.add_edge(0, 1, **{"input_name": "X"}) + + # define shard_spec + shard_spec = { + "dp_mp": {0: [0, -1]}, + "mp_dp": {0: [1, -1]}, + "mp": {0: [-1, -1]}, + "dp": {0: [0, -1]}, + } + self.attrs["shard_spec"] = shard_spec + + # define sharded_tensors + self.attrs["sharded_tensors"] = 1 + + +class GraphUtil: + """Graph util is used to convert ops to graph or match pattern for graph.""" + + @staticmethod + def convert_to_graph(block): + """Convert ops to graph.""" + graph = Graph() + graph.attrs["var_to_id"] = {} # {var_name: node_id} + graph.attrs["id_to_var_desc_id"] = {} # {node_id: var_desc_id} + graph.attrs["id_to_var_name"] = {} + graph.attrs["op_to_id"] = {} # {op_id: node_id} + graph.attrs["id_to_op"] = {} # {node_id: op} + + ops = block.ops + node_id = -1 + for op in ops: + attrs = op.all_attrs() + attrs["type"] = op.type + node_id += 1 + + # create op node + op_node = graph.add_node(node_id, **attrs) + graph.attrs["op_to_id"][op.desc.id()] = op_node.id + graph.attrs["id_to_op"][op_node.id] = op + graph._attr_to_nodes[op_node.id] = {} + for input_name in op.input_names: + graph._attr_to_nodes[op_node.id][input_name] = [] + for var_name in op.input(input_name): if var_name not in graph.attrs["var_to_id"]: # create var node node_id += 1 @@ -137,201 +465,268 @@ def convert_to_graph(ops, block): var = block._var_recursive(var_name) if var.is_parameter: var_node.attrs["type"] = "param" + var_node.attrs["dim"] = len(var.shape) + elif var.is_data: + var_node.attrs["type"] = "data" + var_node.attrs["dim"] = len(var.shape) else: var_node.attrs["type"] = "var" graph.attrs["var_to_id"][var_name] = var_node.id - graph.attrs["id_to_var"][var_node.id] = var_name + graph.attrs["id_to_var_desc_id"][ + var_node.id + ] = var.desc.original_id() + graph.attrs["id_to_var_name"][var_node.id] = var_name else: var_node_id = graph.attrs["var_to_id"][var_name] var_node = graph._nodes[var_node_id] - # create edge that op -> output - output_edge = graph.add_edge(op_node.id, var_node.id) - output_edge.attrs["output_name"] = output_name - - graph._attr_to_nodes[op_node.id][output_name].append( + # create edge that input -> op + input_edge = graph.add_edge(var_node.id, op_node.id) + input_edge.attrs["input_name"] = input_name + graph._attr_to_nodes[op_node.id][input_name].append( var_node ) - return graph + for output_name in op.output_names: + graph._attr_to_nodes[op_node.id][output_name] = [] + for var_name in op.output(output_name): + if var_name not in graph.attrs["var_to_id"]: + # create var node + node_id += 1 + var_node = graph.add_node(node_id) + var = block._var_recursive(var_name) + if var.is_parameter: + var_node.attrs["type"] = "param" + else: + var_node.attrs["type"] = "var" + graph.attrs["var_to_id"][var_name] = var_node.id + graph.attrs["id_to_var_desc_id"][ + var_node.id + ] = var.desc.original_id() + graph.attrs["id_to_var_name"][ + var_node.id + ] = var_name + else: + var_node_id = graph.attrs["var_to_id"][var_name] + var_node = graph._nodes[var_node_id] + # create edge that op -> output + output_edge = graph.add_edge(op_node.id, var_node.id) + output_edge.attrs["output_name"] = output_name -def match(pattern, graph): - def _is_op_node(node): - """Judge whether node is op node""" - if node.attrs["type"] not in ["var", "param", "data"]: - return True + graph._attr_to_nodes[op_node.id][output_name].append( + var_node + ) - return False + return graph - def _compare_op_node(src, tgt): - """Compare whether two op nodes are equal""" - if src.attrs["type"] != tgt.attrs["type"]: - return False + @staticmethod + def match_pattern(pattern, graph): + def _is_op_node(node): + """Judge whether node is op node.""" + if node.attrs["type"] not in ["var", "param", "data"]: + return True - return True + return False - def _compare_var_node(src, tgt): - """Compare whether two var nodes are equal""" - for key in src.attrs: - if key not in tgt.attrs: - return False - if src.attrs[key] != tgt.attrs[key]: + def _compare_op_node(src, tgt): + """Compare whether two op nodes are equivalent.""" + if src.attrs["type"] != tgt.attrs["type"]: return False - return True - - def _match_core(src_node, tgt_node): - nonlocal not_matched - # do not support one input name or output name corresponding to multiple vars - if not_matched: - return - - if _is_op_node(src_node): - # compare op node whether equal - if not _compare_op_node(src_node, tgt_node): - return - - result[src_node.id] = tgt_node.id + return True - # input var nodes - src_input_nodes = src_reverse_adjs[src_node.id] - for node in src_input_nodes: - # has visited - if node.id in result: - continue - edge = src_edges[node.id][src_node.id] - input_name = edge.attrs["input_name"] + def _compare_var_node(src, tgt): + """Compare whether two var nodes are equivalent.""" + for key in src.attrs: + if key not in tgt.attrs: + return False + if src.attrs[key] != tgt.attrs[key]: + return False - # NOTE: do not support one input name or output name corresponding to multiple vars - compare_nodes = tgt_attr_to_nodes[tgt_node.id].get( - input_name, None - ) - if not compare_nodes: - not_matched = True - return - _match_core(node, compare_nodes[0]) + return True - # output var nodes - src_output_node_ids = src_edges[src_node.id].keys() - for node_id in src_output_node_ids: - # has visited - if node_id in result: - continue - node = src_nodes[node_id] - edge = src_edges[src_node.id][node_id] - output_name = edge.attrs["output_name"] + def _match_core(src_node, tgt_node): + nonlocal not_matched + # not support one input name or output name corresponding to multiple vars + if not_matched: + return - # NOTE: do not support one input name or output name corresponding to multiple vars - compare_nodes = tgt_attr_to_nodes[tgt_node.id].get( - output_name, None - ) - if not compare_nodes: + if _is_op_node(src_node): + # compare op node whether equal + if not _compare_op_node(src_node, tgt_node): not_matched = True return - _match_core(node, compare_nodes[0]) - - else: - # compare var node whether equal - if not _compare_var_node(src_node, tgt_node): - not_matched = True - return - - result[src_node.id] = tgt_node.id - # as input for op nodes - src_as_input_node_ids = src_edges[src_node.id].keys() - for node_id in src_as_input_node_ids: - if node_id in result: - continue + result[src_node.id] = tgt_node.id - src_edge = src_edges[src_node.id][node_id] - input_name = src_edge.attrs["input_name"] - compare_node_ids = tgt_edges[tgt_node.id].keys() + # input var nodes + src_input_nodes = src_reverse_adjs[src_node.id] + for node in src_input_nodes: + # has visited + if node.id in result: + continue + edge = src_edges[node.id][src_node.id] + input_name = edge.attrs["input_name"] - compare_node = None - for compare_node_id in compare_node_ids: - edge = tgt_edges[tgt_node.id][compare_node_id] - if ( - edge.attrs["input_name"] == input_name - and compare_node_id not in result.values() - ): - compare_node = tgt_nodes[compare_node_id] - break + # NOTE: do not support one input name or output name corresponding to multiple vars + compare_nodes = tgt_attr_to_nodes[tgt_node.id].get( + input_name, None + ) + if not compare_nodes: + not_matched = True + return + _match_core(node, compare_nodes[0]) + + # output var nodes + src_output_node_ids = src_edges[src_node.id].keys() + for node_id in src_output_node_ids: + # has visited + if node_id in result: + continue + node = src_nodes[node_id] + edge = src_edges[src_node.id][node_id] + output_name = edge.attrs["output_name"] + + # NOTE: do not support one input name or output name corresponding to multiple vars + compare_nodes = tgt_attr_to_nodes[tgt_node.id].get( + output_name, None + ) + if not compare_nodes: + not_matched = True + return + _match_core(node, compare_nodes[0]) - if not compare_node: + else: + # compare var nodes whether equal + if not _compare_var_node(src_node, tgt_node): not_matched = True return - _match_core(src_nodes[node_id], compare_node) - # as output for nodes - src_as_output_nodes = src_reverse_adjs[src_node.id] - for node in src_as_output_nodes: - if node.id in result: - continue + result[src_node.id] = tgt_node.id + + # as input for op node + src_as_input_node_ids = src_edges[src_node.id].keys() + for node_id in src_as_input_node_ids: + if node_id in result: + continue + + src_edge = src_edges[src_node.id][node_id] + input_name = src_edge.attrs["input_name"] + compare_node_ids = tgt_edges[tgt_node.id].keys() + + compare_node = None + for compare_node_id in compare_node_ids: + edge = tgt_edges[tgt_node.id][compare_node_id] + if ( + edge.attrs["input_name"] == input_name + and compare_node_id not in result.values() + ): + compare_node = tgt_nodes[compare_node_id] + break - src_edge = src_edges[node.id][src_node.id] - output_name = src_edge.attrs["output_name"] + if not compare_node: + not_matched = True + return + _match_core(src_nodes[node_id], compare_node) - compare_node_ids = tgt_reverse_adjs[tgt_node.id] + # as output for op node + src_as_output_nodes = src_reverse_adjs[src_node.id] + for node in src_as_output_nodes: + if node.id in result: + continue - compare_node = None - for node_id in compare_node_ids: - edge = tgt_edges[node_id][tgt_node.id] - if edge.attrs["output_name"] == output_name: - compare_node = tgt_nodes[node_id] - break - if not compare_node: - not_matched = True - return - _match_core(src_nodes[node_id], compare_node) - - results = [] - result = {} - has_matched = set() - src_nodes = pattern.nodes - src_edges = pattern._adjs - src_reverse_adjs = pattern._reverse_adjs - - tgt_nodes = graph.nodes - tgt_edges = graph._adjs - tgt_reverse_adjs = graph._reverse_adjs - tgt_attr_to_nodes = graph._attr_to_nodes - not_matched = False - - # starts with a op node - src_start_node = None - for node_id in src_nodes: - node = src_nodes[node_id] - if node.attrs["type"] not in ["var", "param", "data"]: - src_start_node = node - break - assert src_start_node is not None - - for node_id in tgt_nodes: - node = tgt_nodes[node_id] - if node.attrs["type"] == src_start_node.attrs["type"]: - _match_core(src_start_node, node) - if not not_matched: - need_to_append = True - for value in result.values(): - if value in has_matched: - result = {} - need_to_append = False - break - if need_to_append: - results.append(result) + src_edge = src_edges[node.id][src_node.id] + output_name = src_edge.attrs["output_name"] + + compare_nodes = tgt_reverse_adjs[tgt_node.id] + + compare_node = None + for item in compare_nodes: + node_id = item.id + edge = tgt_edges[node_id][tgt_node.id] + if edge.attrs["output_name"] == output_name: + compare_node = tgt_nodes[node_id] + break + if not compare_node: + not_matched = True + return + _match_core(src_nodes[node.id], compare_node) + + results = [] + matched_ids = set() + matched_op_node_ids = set() + result = {} + src_nodes = pattern.nodes + src_edges = pattern._adjs + src_reverse_adjs = pattern._reverse_adjs + + tgt_nodes = graph.nodes + tgt_edges = graph._adjs + tgt_reverse_adjs = graph._reverse_adjs + tgt_attr_to_nodes = graph._attr_to_nodes + + # starts with a op node + src_start_node = None + for node_id in src_nodes: + node = src_nodes[node_id] + if node.attrs["type"] not in ["var", "param", "data"]: + src_start_node = node + break + assert src_start_node is not None + + for node_id in tgt_nodes: + node = tgt_nodes[node_id] + if node.attrs["type"] == src_start_node.attrs["type"]: + not_matched = False + _match_core(src_start_node, node) + if not not_matched: + need_to_append = True for value in result.values(): - has_matched.add(value) + if value in matched_op_node_ids: + result = {} + need_to_append = False + break + if need_to_append: + results.append(result) + for value in result.values(): + matched_ids.add(value) + if value in graph.attrs["id_to_op"].keys(): + matched_op_node_ids.add(value) + result = {} + else: + not_matched = False result = {} - else: - not_matched = False - result = {} + return results, matched_ids - return results + @staticmethod + def match_all_patterns(graph): + # matched_results maps pattern_name to list which contains pattern node id to graph node id mapping, + # such as {"pattern_name": [{pattern_node_id: graph_node}, ]} + matched_results = {} + matched_ids = set() + for pattern_name in _PATTERNS: + pattern = _PATTERNS[pattern_name] + results, matched = GraphUtil.match_pattern(pattern, graph) + for result in results: + has_matched = False + for id in result: + if result[id] in matched_ids: + has_matched = True + break + if not has_matched: + for item in result: + matched_ids.add(result[id]) + if pattern.name not in matched_results: + matched_results[pattern.name] = [] + matched_results[pattern.name].append(result) + + return matched_results class OperatorClusteringUtil: + """Operator clustering util is used to cluster operators to layers.""" + common_starts = ["layer_norm", "matmul_v2", "matmul"] @staticmethod @@ -506,6 +901,8 @@ class OperatorClusteringUtil: class ClusterPartitionUtil: + """Cluster partition util is used to get device meshes and process meshes.""" + @staticmethod def factorization(num): factors = [] @@ -535,13 +932,11 @@ class ClusterPartitionUtil: ], ) -> list: """ - Partition cluster into possible device meshes. - + Partiton cluster into possible device meshes. Args: n (int): The number of nodes. m (int): The number of single devices on each node. filter (list): Functions for filtering useful meshes - Returns: device_meshed (list) : The possible device meshes. """ @@ -573,10 +968,8 @@ class ClusterPartitionUtil: def convert_to_process_meshes(device_mesh: list) -> list: """ Transfer device_meshes into possible process meshes. - Args: device meshes (list): [n,m], one device mesh. - Returns: process_meshes (list): Possible process_meshes """ diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_pattern.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_pattern.py index dca87bdf9ce..047b9c7507f 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_pattern.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_pattern.py @@ -95,7 +95,7 @@ def get_gpt_model( return train_program, start_program, loss, gen_data -class TestGroupOperators(unittest.TestCase): +class TestGroupOperatorsAndPatterns(unittest.TestCase): def test_gpt(self): modeling.init_global() train_program = static.Program() @@ -117,17 +117,30 @@ class TestGroupOperators(unittest.TestCase): ) from paddle.distributed.auto_parallel.tuner.rule_based_tuner import ( _PATTERNS, + GraphUtil, RuleBasedTuner, - convert_to_graph, ) dist_context = DistributedContext() tuner = RuleBasedTuner(dist_context) layers = tuner.cluster_operators(train_program.global_block().ops) - layer = layers[0] - graph = convert_to_graph(layer, train_program.global_block()) + graph = GraphUtil.convert_to_graph(train_program.global_block()) print("graph: ", graph) print("qkv: ", _PATTERNS["qkv"].attrs["shard_spec"]) + print("row_matmul: ", _PATTERNS["row_matmul"].attrs["shard_spec"]) + print("ffn: ", _PATTERNS["ffn"].attrs["shard_spec"]) + print( + "shared_word_embedding: ", + _PATTERNS["shared_word_embedding"].attrs["shard_spec"], + ) + print( + "position_embedding: ", + _PATTERNS["position_embedding"].attrs["shard_spec"], + ) + print( + "unsqueeze_data: ", _PATTERNS["unsqueeze_data"].attrs["shard_spec"] + ) + print("reshape_data: ", _PATTERNS["reshape_data"].attrs["shard_spec"]) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_pattern_match.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_pattern_match.py index e18298b890a..72c19a588dc 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_pattern_match.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_pattern_match.py @@ -95,7 +95,7 @@ def get_gpt_model( return train_program, start_program, loss, gen_data -class TestGroupOperators(unittest.TestCase): +class TestPatternMatch(unittest.TestCase): def test_gpt(self): modeling.init_global() train_program = static.Program() @@ -116,26 +116,15 @@ class TestGroupOperators(unittest.TestCase): DistributedContext, ) from paddle.distributed.auto_parallel.tuner.rule_based_tuner import ( - _PATTERNS, + GraphUtil, RuleBasedTuner, - convert_to_graph, - match, ) dist_context = DistributedContext() tuner = RuleBasedTuner(dist_context) - layers = tuner.cluster_operators(train_program.global_block().ops) - layer = layers[0] - graph = convert_to_graph(layer, train_program.global_block()) - results = match(_PATTERNS["qkv"], graph) - shard_tensor_infos = _PATTERNS["qkv"].attrs["shard_spec"] - tensor_ids = shard_tensor_infos[0][0] - if results: - for result in results: - for node_id in result: - if node_id in tensor_ids: - print(graph.attrs["id_to_var"][result[node_id]]) - print("shard_spec: ", shard_tensor_infos[0][1]) + graph = GraphUtil.convert_to_graph(train_program.global_block()) + results = GraphUtil.match_all_patterns(graph) + print(results) if __name__ == "__main__": -- GitLab