diff --git a/python/paddle/distributed/auto_parallel/graph.py b/python/paddle/distributed/auto_parallel/graph.py index be27bd50867d735533704e1cc749267c6b049afb..0ccb93412abcac7cced3692b7f7e7735caa78d85 100644 --- a/python/paddle/distributed/auto_parallel/graph.py +++ b/python/paddle/distributed/auto_parallel/graph.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License +from collections import OrderedDict + class Node: def __init__(self, id, **attrs): @@ -100,6 +102,8 @@ class Graph: # Attributes for Graph self._attrs = {} self._attrs.update(attrs) + self._reverse_adjs = {} + self._attr_to_nodes = {} @property def nodes(self): @@ -120,6 +124,7 @@ class Graph: node = Node(node_id, **attrs) self._nodes[node_id] = node self._adjs[node_id] = {} + self._reverse_adjs[node_id] = [] else: self._nodes[node_id].attrs.update(attrs) @@ -134,14 +139,21 @@ class Graph: if src_id not in self._nodes: src_node = Node(src_id) self._nodes[src_id] = src_node - self._adjs[src_id] = {} + # for one tensor to multiple ops + self._adjs[src_id] = OrderedDict() + self._reverse_adjs[src_id] = [] if tgt_id not in self._nodes: tgt_node = Node(tgt_id) self._nodes[tgt_id] = tgt_node - self._adjs[tgt_id] = {} + # for one tensor to multiple ops + self._adjs[tgt_id] = OrderedDict() + self._reverse_adjs[tgt_id] = [] # add the edge edge = Edge(src_id, tgt_id, **attrs) self._adjs[src_id][tgt_id] = edge + + # add the reverse adj + self._reverse_adjs[tgt_id].append(self.nodes[src_id]) return edge def __len__(self): diff --git a/python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py b/python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py index f6e855f71ffb04aa0767762175b993345efd11e2..e00efcb15323a06b69829f1dc8067191cdab4980 100644 --- a/python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py +++ b/python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC, abstractmethod +from abc import abstractmethod from ..graph import Graph @@ -32,6 +32,57 @@ def register_pattern(cls): return cls +class BasePattern(Graph): + name = "base" + + def __init__(self): + super().__init__() + self.build() + + @abstractmethod + def build(self): + pass + + +@register_pattern +class QKVPattern(BasePattern): + name = "qkv" + + def __init__(self): + super().__init__() + + def build(self): + query = self.add_node(0, **{"type": "var"}) + + q_weight = self.add_node(1, **{"dim": 2, "type": "param"}) + k_weight = self.add_node(2, **{"dim": 2, "type": "param"}) + v_weight = self.add_node(3, **{"dim": 2, "type": "param"}) + + q_matmul = self.add_node(4, **{"type": "matmul_v2"}) + k_matmul = self.add_node(5, **{"type": "matmul_v2"}) + v_matmul = self.add_node(6, **{"type": "matmul_v2"}) + + q_x = self.add_edge(0, 4, **{"input_name": "X"}) + k_x = self.add_edge(0, 5, **{"input_name": "X"}) + v_x = self.add_edge(0, 6, **{"input_name": "X"}) + q_y = self.add_edge(1, 4, **{"input_name": "Y"}) + k_y = self.add_edge(2, 5, **{"input_name": "Y"}) + v_y = self.add_edge(3, 6, **{"input_name": "Y"}) + + q = self.add_node(7, **{"type": "var"}) + k = self.add_node(8, **{"type": "var"}) + v = self.add_node(9, **{"type": "var"}) + + q_out = self.add_edge(4, 7, **{"output_name": "Out"}) + k_out = self.add_edge(5, 8, **{"output_name": "Out"}) + v_out = self.add_edge(6, 9, **{"output_name": "Out"}) + + # Pattern + self.attrs["shard_spec"] = [ + [(1, 2, 3), [[-1, 0], [-1, 1]]], + ] # 2-tuple list such as [(tensor_id, shard_sepc)] + + def convert_to_graph(ops, block): """Convert ops to graph.""" graph = Graph() @@ -50,7 +101,9 @@ def convert_to_graph(ops, block): op_node = graph.add_node(node_id, **attrs) graph.attrs["op_to_id"][op.desc.id()] = op_node.id graph.attrs["id_to_op"][op_node.id] = op.desc.id() + graph._attr_to_nodes[op_node.id] = {} for input_name in op.input_names: + graph._attr_to_nodes[op_node.id][input_name] = [] for var_name in op.input(input_name): if var_name not in graph.attrs["var_to_id"]: # create var node @@ -59,6 +112,7 @@ def convert_to_graph(ops, block): var = block._var_recursive(var_name) if var.is_parameter: var_node.attrs["type"] = "param" + var_node.attrs["dim"] = len(var.shape) else: var_node.attrs["type"] = "var" graph.attrs["var_to_id"][var_name] = var_node.id @@ -70,8 +124,10 @@ def convert_to_graph(ops, block): # create edge that input -> op input_edge = graph.add_edge(var_node.id, op_node.id) input_edge.attrs["input_name"] = input_name + graph._attr_to_nodes[op_node.id][input_name].append(var_node) for output_name in op.output_names: + graph._attr_to_nodes[op_node.id][output_name] = [] for var_name in op.output(output_name): if var_name not in graph.attrs["var_to_id"]: # create var node @@ -92,64 +148,189 @@ def convert_to_graph(ops, block): output_edge = graph.add_edge(op_node.id, var_node.id) output_edge.attrs["output_name"] = output_name + graph._attr_to_nodes[op_node.id][output_name].append( + var_node + ) + return graph -class BasePattern(ABC): - name = "base" +def match(pattern, graph): + def _is_op_node(node): + """Judge whether node is op node""" + if node.attrs["type"] not in ["var", "param", "data"]: + return True - def __init__(self): - self.graph = None - self.build() + return False - @abstractmethod - def build(self): - pass + def _compare_op_node(src, tgt): + """Compare whether two op nodes are equal""" + if src.attrs["type"] != tgt.attrs["type"]: + return False + return True -@register_pattern -class QKVPattern(BasePattern): - name = "qkv" + def _compare_var_node(src, tgt): + """Compare whether two var nodes are equal""" + for key in src.attrs: + if key not in tgt.attrs: + return False + if src.attrs[key] != tgt.attrs[key]: + return False - def __init__(self): - super().__init__() + return True - def build(self): - self.graph = Graph() + def _match_core(src_node, tgt_node): + nonlocal not_matched + # do not support one input name or output name corresponding to multiple vars + if not_matched: + return - query = self.graph.add_node(0, **{"type": "var"}) + if _is_op_node(src_node): + # compare op node whether equal + if not _compare_op_node(src_node, tgt_node): + return - q_weight = self.graph.add_node(1, **{"dim": 2, "type": "param"}) - k_weight = self.graph.add_node(2, **{"dim": 2, "type": "param"}) - v_weight = self.graph.add_node(3, **{"dim": 2, "type": "param"}) + result[src_node.id] = tgt_node.id - q_matmul = self.graph.add_node(4, **{"type": "matmul_v2"}) - k_matmul = self.graph.add_node(5, **{"type": "matmul_v2"}) - v_matmul = self.graph.add_node(6, **{"type": "matmul_v2"}) + # input var nodes + src_input_nodes = src_reverse_adjs[src_node.id] + for node in src_input_nodes: + # has visited + if node.id in result: + continue + edge = src_edges[node.id][src_node.id] + input_name = edge.attrs["input_name"] + + # NOTE: do not support one input name or output name corresponding to multiple vars + compare_nodes = tgt_attr_to_nodes[tgt_node.id].get( + input_name, None + ) + if not compare_nodes: + not_matched = True + return + _match_core(node, compare_nodes[0]) + + # output var nodes + src_output_node_ids = src_edges[src_node.id].keys() + for node_id in src_output_node_ids: + # has visited + if node_id in result: + continue + node = src_nodes[node_id] + edge = src_edges[src_node.id][node_id] + output_name = edge.attrs["output_name"] + + # NOTE: do not support one input name or output name corresponding to multiple vars + compare_nodes = tgt_attr_to_nodes[tgt_node.id].get( + output_name, None + ) + if not compare_nodes: + not_matched = True + return + _match_core(node, compare_nodes[0]) - q_x = self.graph.add_edge(0, 4, **{"input_name": "X"}) - k_x = self.graph.add_edge(0, 5, **{"input_name": "X"}) - v_x = self.graph.add_edge(0, 6, **{"input_name": "X"}) - q_y = self.graph.add_edge(1, 4, **{"input_name": "Y"}) - k_y = self.graph.add_edge(2, 5, **{"input_name": "Y"}) - v_y = self.graph.add_edge(3, 6, **{"input_name": "Y"}) + else: + # compare var node whether equal + if not _compare_var_node(src_node, tgt_node): + not_matched = True + return - q = self.graph.add_node(7, **{"type": "var"}) - k = self.graph.add_node(8, **{"type": "var"}) - v = self.graph.add_node(9, **{"type": "var"}) + result[src_node.id] = tgt_node.id - q_out = self.graph.add_edge(7, 4, **{"output_name": "Out"}) - k_out = self.graph.add_edge(8, 5, **{"output_name": "Out"}) - v_out = self.graph.add_edge(9, 6, **{"output_name": "Out"}) + # as input for op nodes + src_as_input_node_ids = src_edges[src_node.id].keys() + for node_id in src_as_input_node_ids: + if node_id in result: + continue - # Pattern - self.graph.attrs["shard_tensor"] = [ - (1, 2, 3), - [[-1, 0], [-1, 1]], - ] # 2-tuple such as (tensor_id, patterns) + src_edge = src_edges[src_node.id][node_id] + input_name = src_edge.attrs["input_name"] + compare_node_ids = tgt_edges[tgt_node.id].keys() + + compare_node = None + for compare_node_id in compare_node_ids: + edge = tgt_edges[tgt_node.id][compare_node_id] + if ( + edge.attrs["input_name"] == input_name + and compare_node_id not in result.values() + ): + compare_node = tgt_nodes[compare_node_id] + break + + if not compare_node: + not_matched = True + return + _match_core(src_nodes[node_id], compare_node) + + # as output for nodes + src_as_output_nodes = src_reverse_adjs[src_node.id] + for node in src_as_output_nodes: + if node.id in result: + continue + + src_edge = src_edges[node.id][src_node.id] + output_name = src_edge.attrs["output_name"] + compare_node_ids = tgt_reverse_adjs[tgt_node.id] -class OperatorGroupUtil: + compare_node = None + for node_id in compare_node_ids: + edge = tgt_edges[node_id][tgt_node.id] + if edge.attrs["output_name"] == output_name: + compare_node = tgt_nodes[node_id] + break + if not compare_node: + not_matched = True + return + _match_core(src_nodes[node_id], compare_node) + + results = [] + result = {} + has_matched = set() + src_nodes = pattern.nodes + src_edges = pattern._adjs + src_reverse_adjs = pattern._reverse_adjs + + tgt_nodes = graph.nodes + tgt_edges = graph._adjs + tgt_reverse_adjs = graph._reverse_adjs + tgt_attr_to_nodes = graph._attr_to_nodes + not_matched = False + + # starts with a op node + src_start_node = None + for node_id in src_nodes: + node = src_nodes[node_id] + if node.attrs["type"] not in ["var", "param", "data"]: + src_start_node = node + break + assert src_start_node is not None + + for node_id in tgt_nodes: + node = tgt_nodes[node_id] + if node.attrs["type"] == src_start_node.attrs["type"]: + _match_core(src_start_node, node) + if not not_matched: + need_to_append = True + for value in result.values(): + if value in has_matched: + result = {} + need_to_append = False + break + if need_to_append: + results.append(result) + for value in result.values(): + has_matched.add(value) + result = {} + else: + not_matched = False + result = {} + + return results + + +class OperatorClusteringUtil: common_starts = ["layer_norm", "matmul_v2", "matmul"] @staticmethod @@ -257,7 +438,10 @@ class OperatorGroupUtil: min_index = min(index_group) if max_index - min_index >= k: longest_sub_seq = seq[min_index : min_index + k] - if longest_sub_seq[0] in OperatorGroupUtil.common_starts: + if ( + longest_sub_seq[0] + in OperatorClusteringUtil.common_starts + ): return longest_sub_seq if longest_sub_seq is not None: return longest_sub_seq @@ -325,9 +509,9 @@ class RuleBasedTuner: self._dist_context = dist_context self._mode = mode - def group_operators(self, ops): + def cluster_operators(self, ops): """ - Group operators to layers. + Cluster operators to layers. Args: ops (list): A operator list. @@ -337,7 +521,7 @@ class RuleBasedTuner: """ seq = [op.type for op in ops] - while not OperatorGroupUtil.stop_replace(seq): + while not OperatorClusteringUtil.stop_replace(seq): to_replace_seq = [] to_replace_idxes = [] has_append = False @@ -351,11 +535,15 @@ class RuleBasedTuner: elif isinstance(seq, list) and has_append: break - ranks = OperatorGroupUtil.get_ranks(to_replace_seq) - suffixes = OperatorGroupUtil.get_suffixes(ranks) - heights = OperatorGroupUtil.get_heights(suffixes, to_replace_seq) - longest_sub_seq = OperatorGroupUtil.get_longest_repeated_sub_seq( - suffixes, heights, to_replace_seq + ranks = OperatorClusteringUtil.get_ranks(to_replace_seq) + suffixes = OperatorClusteringUtil.get_suffixes(ranks) + heights = OperatorClusteringUtil.get_heights( + suffixes, to_replace_seq + ) + longest_sub_seq = ( + OperatorClusteringUtil.get_longest_repeated_sub_seq( + suffixes, heights, to_replace_seq + ) ) has_merged = False if longest_sub_seq is None: @@ -374,10 +562,10 @@ class RuleBasedTuner: seq = [to_replace_seq] break - decomposed_sub_seq = OperatorGroupUtil.get_decomposed_sub_seq( + decomposed_sub_seq = OperatorClusteringUtil.get_decomposed_sub_seq( longest_sub_seq ) - to_replace_seq = OperatorGroupUtil.replace_by_decomposed_seq( + to_replace_seq = OperatorClusteringUtil.replace_by_decomposed_seq( decomposed_sub_seq, to_replace_seq ) result = seq[: to_replace_idxes[0]] diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index 8486056984cf0d2bdf61edee0e433d8a097b5a26..18fad917b683979571de47232ee9d2aece2d4aaf 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -120,4 +120,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_fp16_assign MODULES test_fp16_assign) py_test_modules(test_group_operators MODULES test_group_operators) py_test_modules(test_pattern MODULES test_pattern) + py_test_modules(test_pattern_match MODULES test_pattern_match) endif() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_group_operators.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_group_operators.py index e6353dadb947eb5fb600e217107a4806674499d0..2823d4d9a318c637058e5569381cfa9321aeb646 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_group_operators.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_group_operators.py @@ -121,7 +121,7 @@ class TestGroupOperators(unittest.TestCase): dist_context = DistributedContext() tuner = RuleBasedTuner(dist_context) - layers = tuner.group_operators(train_program.global_block().ops) + layers = tuner.cluster_operators(train_program.global_block().ops) op_types = [] for layer in layers: tmp = [] diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_pattern.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_pattern.py index 159def7617a2fdc7f89b1d79f55f5dc3e81fee1e..dca87bdf9ce248e4f43488c94238a08aa60952ee 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_pattern.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_pattern.py @@ -14,6 +14,7 @@ import sys import unittest + import numpy as np import paddle @@ -22,8 +23,8 @@ import paddle.static as static sys.path.append("..") import auto_parallel_gpt_model as modeling from auto_parallel_gpt_model import ( - GPTModel, GPTForPretraining, + GPTModel, GPTPretrainingCriterion, ) @@ -111,22 +112,22 @@ class TestGroupOperators(unittest.TestCase): sequence_len, vocab_size, ) + from paddle.distributed.auto_parallel.dist_context import ( + DistributedContext, + ) from paddle.distributed.auto_parallel.tuner.rule_based_tuner import ( + _PATTERNS, RuleBasedTuner, convert_to_graph, - _PATTERNS, - ) - from paddle.distributed.auto_parallel.dist_context import ( - DistributedContext, ) dist_context = DistributedContext() tuner = RuleBasedTuner(dist_context) - layers = tuner.group_operators(train_program.global_block().ops) + layers = tuner.cluster_operators(train_program.global_block().ops) layer = layers[0] graph = convert_to_graph(layer, train_program.global_block()) - print(graph) - print("qkv: ", _PATTERNS["qkv"].graph) + print("graph: ", graph) + print("qkv: ", _PATTERNS["qkv"].attrs["shard_spec"]) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_pattern_match.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_pattern_match.py new file mode 100644 index 0000000000000000000000000000000000000000..e18298b890a5855be337b39c763ad2da7ada9726 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_pattern_match.py @@ -0,0 +1,142 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +import numpy as np + +import paddle +import paddle.static as static + +sys.path.append("..") +import auto_parallel_gpt_model as modeling +from auto_parallel_gpt_model import ( + GPTForPretraining, + GPTModel, + GPTPretrainingCriterion, +) + + +def get_gpt_model( + train_program, start_program, place, batch_size, sequence_len, vocab_size +): + with static.program_guard(train_program, start_program): + tokens = paddle.static.data( + name="tokens", shape=[batch_size, sequence_len], dtype='int64' + ) + position_ids = paddle.static.data( + name="position_ids", shape=[batch_size, sequence_len], dtype='int64' + ) + attention_mask = paddle.static.data( + name="attention_mask", + shape=[batch_size, 1, sequence_len, sequence_len], + dtype='float32', + ) + labels = paddle.static.data( + name="labels", shape=[batch_size, sequence_len], dtype='int64' + ) + loss_mask = paddle.static.data( + name="loss_mask", shape=[batch_size, sequence_len], dtype='float32' + ) + + gpt = GPTModel( + vocab_size=1000, + hidden_size=64, + num_hidden_layers=2, + num_attention_heads=8, + intermediate_size=256, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + max_position_embeddings=1024, + type_vocab_size=1, + initializer_range=0.02, + pad_token_id=0, + eos_token_id=7, + bos_token_id=0, + eol_token_id=3, + ) + + model = GPTForPretraining( + gpt, vocab_size=1000, hidden_size=64, initializer_range=0.02 + ) + preds = model(tokens, position_ids, attention_mask) + criterion = GPTPretrainingCriterion() + loss = criterion(preds, labels, loss_mask) + + def gen_data(): + np.random.seed(2021) + tokens = [] + position_ids = [] + attention_mask = [] + labels = [] + loss_mask = [] + for _ in range(batch_size): + tokens.append(np.random.randint(vocab_size, size=sequence_len)) + position_ids.append(np.arange(sequence_len)) + attention_mask.append([np.tril(np.ones(sequence_len))]) + labels.append(np.random.randint(vocab_size, size=sequence_len)) + loss_mask.append(np.ones(sequence_len)) + + return tokens, position_ids, attention_mask, labels, loss_mask + + return train_program, start_program, loss, gen_data + + +class TestGroupOperators(unittest.TestCase): + def test_gpt(self): + modeling.init_global() + train_program = static.Program() + start_program = static.Program() + place = paddle.set_device("gpu") + batch_size = 8 + sequence_len = 512 + vocab_size = 1000 + train_program, start_program, loss, gen_data = get_gpt_model( + train_program, + start_program, + place, + batch_size, + sequence_len, + vocab_size, + ) + from paddle.distributed.auto_parallel.dist_context import ( + DistributedContext, + ) + from paddle.distributed.auto_parallel.tuner.rule_based_tuner import ( + _PATTERNS, + RuleBasedTuner, + convert_to_graph, + match, + ) + + dist_context = DistributedContext() + tuner = RuleBasedTuner(dist_context) + layers = tuner.cluster_operators(train_program.global_block().ops) + layer = layers[0] + graph = convert_to_graph(layer, train_program.global_block()) + results = match(_PATTERNS["qkv"], graph) + shard_tensor_infos = _PATTERNS["qkv"].attrs["shard_spec"] + tensor_ids = shard_tensor_infos[0][0] + if results: + for result in results: + for node_id in result: + if node_id in tensor_ids: + print(graph.attrs["id_to_var"][result[node_id]]) + print("shard_spec: ", shard_tensor_infos[0][1]) + + +if __name__ == "__main__": + unittest.main()