未验证 提交 30315ac9 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel] Add pattern match (#48464)

* add pattern match

* add unittest
上级 41946522
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
from collections import OrderedDict
class Node: class Node:
def __init__(self, id, **attrs): def __init__(self, id, **attrs):
...@@ -100,6 +102,8 @@ class Graph: ...@@ -100,6 +102,8 @@ class Graph:
# Attributes for Graph # Attributes for Graph
self._attrs = {} self._attrs = {}
self._attrs.update(attrs) self._attrs.update(attrs)
self._reverse_adjs = {}
self._attr_to_nodes = {}
@property @property
def nodes(self): def nodes(self):
...@@ -120,6 +124,7 @@ class Graph: ...@@ -120,6 +124,7 @@ class Graph:
node = Node(node_id, **attrs) node = Node(node_id, **attrs)
self._nodes[node_id] = node self._nodes[node_id] = node
self._adjs[node_id] = {} self._adjs[node_id] = {}
self._reverse_adjs[node_id] = []
else: else:
self._nodes[node_id].attrs.update(attrs) self._nodes[node_id].attrs.update(attrs)
...@@ -134,14 +139,21 @@ class Graph: ...@@ -134,14 +139,21 @@ class Graph:
if src_id not in self._nodes: if src_id not in self._nodes:
src_node = Node(src_id) src_node = Node(src_id)
self._nodes[src_id] = src_node self._nodes[src_id] = src_node
self._adjs[src_id] = {} # for one tensor to multiple ops
self._adjs[src_id] = OrderedDict()
self._reverse_adjs[src_id] = []
if tgt_id not in self._nodes: if tgt_id not in self._nodes:
tgt_node = Node(tgt_id) tgt_node = Node(tgt_id)
self._nodes[tgt_id] = tgt_node self._nodes[tgt_id] = tgt_node
self._adjs[tgt_id] = {} # for one tensor to multiple ops
self._adjs[tgt_id] = OrderedDict()
self._reverse_adjs[tgt_id] = []
# add the edge # add the edge
edge = Edge(src_id, tgt_id, **attrs) edge = Edge(src_id, tgt_id, **attrs)
self._adjs[src_id][tgt_id] = edge self._adjs[src_id][tgt_id] = edge
# add the reverse adj
self._reverse_adjs[tgt_id].append(self.nodes[src_id])
return edge return edge
def __len__(self): def __len__(self):
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from abc import ABC, abstractmethod from abc import abstractmethod
from ..graph import Graph from ..graph import Graph
...@@ -32,6 +32,57 @@ def register_pattern(cls): ...@@ -32,6 +32,57 @@ def register_pattern(cls):
return cls return cls
class BasePattern(Graph):
name = "base"
def __init__(self):
super().__init__()
self.build()
@abstractmethod
def build(self):
pass
@register_pattern
class QKVPattern(BasePattern):
name = "qkv"
def __init__(self):
super().__init__()
def build(self):
query = self.add_node(0, **{"type": "var"})
q_weight = self.add_node(1, **{"dim": 2, "type": "param"})
k_weight = self.add_node(2, **{"dim": 2, "type": "param"})
v_weight = self.add_node(3, **{"dim": 2, "type": "param"})
q_matmul = self.add_node(4, **{"type": "matmul_v2"})
k_matmul = self.add_node(5, **{"type": "matmul_v2"})
v_matmul = self.add_node(6, **{"type": "matmul_v2"})
q_x = self.add_edge(0, 4, **{"input_name": "X"})
k_x = self.add_edge(0, 5, **{"input_name": "X"})
v_x = self.add_edge(0, 6, **{"input_name": "X"})
q_y = self.add_edge(1, 4, **{"input_name": "Y"})
k_y = self.add_edge(2, 5, **{"input_name": "Y"})
v_y = self.add_edge(3, 6, **{"input_name": "Y"})
q = self.add_node(7, **{"type": "var"})
k = self.add_node(8, **{"type": "var"})
v = self.add_node(9, **{"type": "var"})
q_out = self.add_edge(4, 7, **{"output_name": "Out"})
k_out = self.add_edge(5, 8, **{"output_name": "Out"})
v_out = self.add_edge(6, 9, **{"output_name": "Out"})
# Pattern
self.attrs["shard_spec"] = [
[(1, 2, 3), [[-1, 0], [-1, 1]]],
] # 2-tuple list such as [(tensor_id, shard_sepc)]
def convert_to_graph(ops, block): def convert_to_graph(ops, block):
"""Convert ops to graph.""" """Convert ops to graph."""
graph = Graph() graph = Graph()
...@@ -50,7 +101,9 @@ def convert_to_graph(ops, block): ...@@ -50,7 +101,9 @@ def convert_to_graph(ops, block):
op_node = graph.add_node(node_id, **attrs) op_node = graph.add_node(node_id, **attrs)
graph.attrs["op_to_id"][op.desc.id()] = op_node.id graph.attrs["op_to_id"][op.desc.id()] = op_node.id
graph.attrs["id_to_op"][op_node.id] = op.desc.id() graph.attrs["id_to_op"][op_node.id] = op.desc.id()
graph._attr_to_nodes[op_node.id] = {}
for input_name in op.input_names: for input_name in op.input_names:
graph._attr_to_nodes[op_node.id][input_name] = []
for var_name in op.input(input_name): for var_name in op.input(input_name):
if var_name not in graph.attrs["var_to_id"]: if var_name not in graph.attrs["var_to_id"]:
# create var node # create var node
...@@ -59,6 +112,7 @@ def convert_to_graph(ops, block): ...@@ -59,6 +112,7 @@ def convert_to_graph(ops, block):
var = block._var_recursive(var_name) var = block._var_recursive(var_name)
if var.is_parameter: if var.is_parameter:
var_node.attrs["type"] = "param" var_node.attrs["type"] = "param"
var_node.attrs["dim"] = len(var.shape)
else: else:
var_node.attrs["type"] = "var" var_node.attrs["type"] = "var"
graph.attrs["var_to_id"][var_name] = var_node.id graph.attrs["var_to_id"][var_name] = var_node.id
...@@ -70,8 +124,10 @@ def convert_to_graph(ops, block): ...@@ -70,8 +124,10 @@ def convert_to_graph(ops, block):
# create edge that input -> op # create edge that input -> op
input_edge = graph.add_edge(var_node.id, op_node.id) input_edge = graph.add_edge(var_node.id, op_node.id)
input_edge.attrs["input_name"] = input_name input_edge.attrs["input_name"] = input_name
graph._attr_to_nodes[op_node.id][input_name].append(var_node)
for output_name in op.output_names: for output_name in op.output_names:
graph._attr_to_nodes[op_node.id][output_name] = []
for var_name in op.output(output_name): for var_name in op.output(output_name):
if var_name not in graph.attrs["var_to_id"]: if var_name not in graph.attrs["var_to_id"]:
# create var node # create var node
...@@ -92,64 +148,189 @@ def convert_to_graph(ops, block): ...@@ -92,64 +148,189 @@ def convert_to_graph(ops, block):
output_edge = graph.add_edge(op_node.id, var_node.id) output_edge = graph.add_edge(op_node.id, var_node.id)
output_edge.attrs["output_name"] = output_name output_edge.attrs["output_name"] = output_name
graph._attr_to_nodes[op_node.id][output_name].append(
var_node
)
return graph return graph
class BasePattern(ABC): def match(pattern, graph):
name = "base" def _is_op_node(node):
"""Judge whether node is op node"""
if node.attrs["type"] not in ["var", "param", "data"]:
return True
def __init__(self): return False
self.graph = None
self.build()
@abstractmethod def _compare_op_node(src, tgt):
def build(self): """Compare whether two op nodes are equal"""
pass if src.attrs["type"] != tgt.attrs["type"]:
return False
return True
@register_pattern def _compare_var_node(src, tgt):
class QKVPattern(BasePattern): """Compare whether two var nodes are equal"""
name = "qkv" for key in src.attrs:
if key not in tgt.attrs:
return False
if src.attrs[key] != tgt.attrs[key]:
return False
def __init__(self): return True
super().__init__()
def build(self): def _match_core(src_node, tgt_node):
self.graph = Graph() nonlocal not_matched
# do not support one input name or output name corresponding to multiple vars
if not_matched:
return
query = self.graph.add_node(0, **{"type": "var"}) if _is_op_node(src_node):
# compare op node whether equal
if not _compare_op_node(src_node, tgt_node):
return
q_weight = self.graph.add_node(1, **{"dim": 2, "type": "param"}) result[src_node.id] = tgt_node.id
k_weight = self.graph.add_node(2, **{"dim": 2, "type": "param"})
v_weight = self.graph.add_node(3, **{"dim": 2, "type": "param"})
q_matmul = self.graph.add_node(4, **{"type": "matmul_v2"}) # input var nodes
k_matmul = self.graph.add_node(5, **{"type": "matmul_v2"}) src_input_nodes = src_reverse_adjs[src_node.id]
v_matmul = self.graph.add_node(6, **{"type": "matmul_v2"}) for node in src_input_nodes:
# has visited
if node.id in result:
continue
edge = src_edges[node.id][src_node.id]
input_name = edge.attrs["input_name"]
# NOTE: do not support one input name or output name corresponding to multiple vars
compare_nodes = tgt_attr_to_nodes[tgt_node.id].get(
input_name, None
)
if not compare_nodes:
not_matched = True
return
_match_core(node, compare_nodes[0])
# output var nodes
src_output_node_ids = src_edges[src_node.id].keys()
for node_id in src_output_node_ids:
# has visited
if node_id in result:
continue
node = src_nodes[node_id]
edge = src_edges[src_node.id][node_id]
output_name = edge.attrs["output_name"]
# NOTE: do not support one input name or output name corresponding to multiple vars
compare_nodes = tgt_attr_to_nodes[tgt_node.id].get(
output_name, None
)
if not compare_nodes:
not_matched = True
return
_match_core(node, compare_nodes[0])
q_x = self.graph.add_edge(0, 4, **{"input_name": "X"}) else:
k_x = self.graph.add_edge(0, 5, **{"input_name": "X"}) # compare var node whether equal
v_x = self.graph.add_edge(0, 6, **{"input_name": "X"}) if not _compare_var_node(src_node, tgt_node):
q_y = self.graph.add_edge(1, 4, **{"input_name": "Y"}) not_matched = True
k_y = self.graph.add_edge(2, 5, **{"input_name": "Y"}) return
v_y = self.graph.add_edge(3, 6, **{"input_name": "Y"})
q = self.graph.add_node(7, **{"type": "var"}) result[src_node.id] = tgt_node.id
k = self.graph.add_node(8, **{"type": "var"})
v = self.graph.add_node(9, **{"type": "var"})
q_out = self.graph.add_edge(7, 4, **{"output_name": "Out"}) # as input for op nodes
k_out = self.graph.add_edge(8, 5, **{"output_name": "Out"}) src_as_input_node_ids = src_edges[src_node.id].keys()
v_out = self.graph.add_edge(9, 6, **{"output_name": "Out"}) for node_id in src_as_input_node_ids:
if node_id in result:
continue
# Pattern src_edge = src_edges[src_node.id][node_id]
self.graph.attrs["shard_tensor"] = [ input_name = src_edge.attrs["input_name"]
(1, 2, 3), compare_node_ids = tgt_edges[tgt_node.id].keys()
[[-1, 0], [-1, 1]],
] # 2-tuple such as (tensor_id, patterns) compare_node = None
for compare_node_id in compare_node_ids:
edge = tgt_edges[tgt_node.id][compare_node_id]
if (
edge.attrs["input_name"] == input_name
and compare_node_id not in result.values()
):
compare_node = tgt_nodes[compare_node_id]
break
if not compare_node:
not_matched = True
return
_match_core(src_nodes[node_id], compare_node)
# as output for nodes
src_as_output_nodes = src_reverse_adjs[src_node.id]
for node in src_as_output_nodes:
if node.id in result:
continue
src_edge = src_edges[node.id][src_node.id]
output_name = src_edge.attrs["output_name"]
compare_node_ids = tgt_reverse_adjs[tgt_node.id]
class OperatorGroupUtil: compare_node = None
for node_id in compare_node_ids:
edge = tgt_edges[node_id][tgt_node.id]
if edge.attrs["output_name"] == output_name:
compare_node = tgt_nodes[node_id]
break
if not compare_node:
not_matched = True
return
_match_core(src_nodes[node_id], compare_node)
results = []
result = {}
has_matched = set()
src_nodes = pattern.nodes
src_edges = pattern._adjs
src_reverse_adjs = pattern._reverse_adjs
tgt_nodes = graph.nodes
tgt_edges = graph._adjs
tgt_reverse_adjs = graph._reverse_adjs
tgt_attr_to_nodes = graph._attr_to_nodes
not_matched = False
# starts with a op node
src_start_node = None
for node_id in src_nodes:
node = src_nodes[node_id]
if node.attrs["type"] not in ["var", "param", "data"]:
src_start_node = node
break
assert src_start_node is not None
for node_id in tgt_nodes:
node = tgt_nodes[node_id]
if node.attrs["type"] == src_start_node.attrs["type"]:
_match_core(src_start_node, node)
if not not_matched:
need_to_append = True
for value in result.values():
if value in has_matched:
result = {}
need_to_append = False
break
if need_to_append:
results.append(result)
for value in result.values():
has_matched.add(value)
result = {}
else:
not_matched = False
result = {}
return results
class OperatorClusteringUtil:
common_starts = ["layer_norm", "matmul_v2", "matmul"] common_starts = ["layer_norm", "matmul_v2", "matmul"]
@staticmethod @staticmethod
...@@ -257,7 +438,10 @@ class OperatorGroupUtil: ...@@ -257,7 +438,10 @@ class OperatorGroupUtil:
min_index = min(index_group) min_index = min(index_group)
if max_index - min_index >= k: if max_index - min_index >= k:
longest_sub_seq = seq[min_index : min_index + k] longest_sub_seq = seq[min_index : min_index + k]
if longest_sub_seq[0] in OperatorGroupUtil.common_starts: if (
longest_sub_seq[0]
in OperatorClusteringUtil.common_starts
):
return longest_sub_seq return longest_sub_seq
if longest_sub_seq is not None: if longest_sub_seq is not None:
return longest_sub_seq return longest_sub_seq
...@@ -325,9 +509,9 @@ class RuleBasedTuner: ...@@ -325,9 +509,9 @@ class RuleBasedTuner:
self._dist_context = dist_context self._dist_context = dist_context
self._mode = mode self._mode = mode
def group_operators(self, ops): def cluster_operators(self, ops):
""" """
Group operators to layers. Cluster operators to layers.
Args: Args:
ops (list): A operator list. ops (list): A operator list.
...@@ -337,7 +521,7 @@ class RuleBasedTuner: ...@@ -337,7 +521,7 @@ class RuleBasedTuner:
""" """
seq = [op.type for op in ops] seq = [op.type for op in ops]
while not OperatorGroupUtil.stop_replace(seq): while not OperatorClusteringUtil.stop_replace(seq):
to_replace_seq = [] to_replace_seq = []
to_replace_idxes = [] to_replace_idxes = []
has_append = False has_append = False
...@@ -351,11 +535,15 @@ class RuleBasedTuner: ...@@ -351,11 +535,15 @@ class RuleBasedTuner:
elif isinstance(seq, list) and has_append: elif isinstance(seq, list) and has_append:
break break
ranks = OperatorGroupUtil.get_ranks(to_replace_seq) ranks = OperatorClusteringUtil.get_ranks(to_replace_seq)
suffixes = OperatorGroupUtil.get_suffixes(ranks) suffixes = OperatorClusteringUtil.get_suffixes(ranks)
heights = OperatorGroupUtil.get_heights(suffixes, to_replace_seq) heights = OperatorClusteringUtil.get_heights(
longest_sub_seq = OperatorGroupUtil.get_longest_repeated_sub_seq( suffixes, to_replace_seq
suffixes, heights, to_replace_seq )
longest_sub_seq = (
OperatorClusteringUtil.get_longest_repeated_sub_seq(
suffixes, heights, to_replace_seq
)
) )
has_merged = False has_merged = False
if longest_sub_seq is None: if longest_sub_seq is None:
...@@ -374,10 +562,10 @@ class RuleBasedTuner: ...@@ -374,10 +562,10 @@ class RuleBasedTuner:
seq = [to_replace_seq] seq = [to_replace_seq]
break break
decomposed_sub_seq = OperatorGroupUtil.get_decomposed_sub_seq( decomposed_sub_seq = OperatorClusteringUtil.get_decomposed_sub_seq(
longest_sub_seq longest_sub_seq
) )
to_replace_seq = OperatorGroupUtil.replace_by_decomposed_seq( to_replace_seq = OperatorClusteringUtil.replace_by_decomposed_seq(
decomposed_sub_seq, to_replace_seq decomposed_sub_seq, to_replace_seq
) )
result = seq[: to_replace_idxes[0]] result = seq[: to_replace_idxes[0]]
......
...@@ -120,4 +120,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -120,4 +120,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_fp16_assign MODULES test_fp16_assign) py_test_modules(test_fp16_assign MODULES test_fp16_assign)
py_test_modules(test_group_operators MODULES test_group_operators) py_test_modules(test_group_operators MODULES test_group_operators)
py_test_modules(test_pattern MODULES test_pattern) py_test_modules(test_pattern MODULES test_pattern)
py_test_modules(test_pattern_match MODULES test_pattern_match)
endif() endif()
...@@ -121,7 +121,7 @@ class TestGroupOperators(unittest.TestCase): ...@@ -121,7 +121,7 @@ class TestGroupOperators(unittest.TestCase):
dist_context = DistributedContext() dist_context = DistributedContext()
tuner = RuleBasedTuner(dist_context) tuner = RuleBasedTuner(dist_context)
layers = tuner.group_operators(train_program.global_block().ops) layers = tuner.cluster_operators(train_program.global_block().ops)
op_types = [] op_types = []
for layer in layers: for layer in layers:
tmp = [] tmp = []
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import sys import sys
import unittest import unittest
import numpy as np import numpy as np
import paddle import paddle
...@@ -22,8 +23,8 @@ import paddle.static as static ...@@ -22,8 +23,8 @@ import paddle.static as static
sys.path.append("..") sys.path.append("..")
import auto_parallel_gpt_model as modeling import auto_parallel_gpt_model as modeling
from auto_parallel_gpt_model import ( from auto_parallel_gpt_model import (
GPTModel,
GPTForPretraining, GPTForPretraining,
GPTModel,
GPTPretrainingCriterion, GPTPretrainingCriterion,
) )
...@@ -111,22 +112,22 @@ class TestGroupOperators(unittest.TestCase): ...@@ -111,22 +112,22 @@ class TestGroupOperators(unittest.TestCase):
sequence_len, sequence_len,
vocab_size, vocab_size,
) )
from paddle.distributed.auto_parallel.dist_context import (
DistributedContext,
)
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import ( from paddle.distributed.auto_parallel.tuner.rule_based_tuner import (
_PATTERNS,
RuleBasedTuner, RuleBasedTuner,
convert_to_graph, convert_to_graph,
_PATTERNS,
)
from paddle.distributed.auto_parallel.dist_context import (
DistributedContext,
) )
dist_context = DistributedContext() dist_context = DistributedContext()
tuner = RuleBasedTuner(dist_context) tuner = RuleBasedTuner(dist_context)
layers = tuner.group_operators(train_program.global_block().ops) layers = tuner.cluster_operators(train_program.global_block().ops)
layer = layers[0] layer = layers[0]
graph = convert_to_graph(layer, train_program.global_block()) graph = convert_to_graph(layer, train_program.global_block())
print(graph) print("graph: ", graph)
print("qkv: ", _PATTERNS["qkv"].graph) print("qkv: ", _PATTERNS["qkv"].attrs["shard_spec"])
if __name__ == "__main__": if __name__ == "__main__":
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
import paddle
import paddle.static as static
sys.path.append("..")
import auto_parallel_gpt_model as modeling
from auto_parallel_gpt_model import (
GPTForPretraining,
GPTModel,
GPTPretrainingCriterion,
)
def get_gpt_model(
train_program, start_program, place, batch_size, sequence_len, vocab_size
):
with static.program_guard(train_program, start_program):
tokens = paddle.static.data(
name="tokens", shape=[batch_size, sequence_len], dtype='int64'
)
position_ids = paddle.static.data(
name="position_ids", shape=[batch_size, sequence_len], dtype='int64'
)
attention_mask = paddle.static.data(
name="attention_mask",
shape=[batch_size, 1, sequence_len, sequence_len],
dtype='float32',
)
labels = paddle.static.data(
name="labels", shape=[batch_size, sequence_len], dtype='int64'
)
loss_mask = paddle.static.data(
name="loss_mask", shape=[batch_size, sequence_len], dtype='float32'
)
gpt = GPTModel(
vocab_size=1000,
hidden_size=64,
num_hidden_layers=2,
num_attention_heads=8,
intermediate_size=256,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=1024,
type_vocab_size=1,
initializer_range=0.02,
pad_token_id=0,
eos_token_id=7,
bos_token_id=0,
eol_token_id=3,
)
model = GPTForPretraining(
gpt, vocab_size=1000, hidden_size=64, initializer_range=0.02
)
preds = model(tokens, position_ids, attention_mask)
criterion = GPTPretrainingCriterion()
loss = criterion(preds, labels, loss_mask)
def gen_data():
np.random.seed(2021)
tokens = []
position_ids = []
attention_mask = []
labels = []
loss_mask = []
for _ in range(batch_size):
tokens.append(np.random.randint(vocab_size, size=sequence_len))
position_ids.append(np.arange(sequence_len))
attention_mask.append([np.tril(np.ones(sequence_len))])
labels.append(np.random.randint(vocab_size, size=sequence_len))
loss_mask.append(np.ones(sequence_len))
return tokens, position_ids, attention_mask, labels, loss_mask
return train_program, start_program, loss, gen_data
class TestGroupOperators(unittest.TestCase):
def test_gpt(self):
modeling.init_global()
train_program = static.Program()
start_program = static.Program()
place = paddle.set_device("gpu")
batch_size = 8
sequence_len = 512
vocab_size = 1000
train_program, start_program, loss, gen_data = get_gpt_model(
train_program,
start_program,
place,
batch_size,
sequence_len,
vocab_size,
)
from paddle.distributed.auto_parallel.dist_context import (
DistributedContext,
)
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import (
_PATTERNS,
RuleBasedTuner,
convert_to_graph,
match,
)
dist_context = DistributedContext()
tuner = RuleBasedTuner(dist_context)
layers = tuner.cluster_operators(train_program.global_block().ops)
layer = layers[0]
graph = convert_to_graph(layer, train_program.global_block())
results = match(_PATTERNS["qkv"], graph)
shard_tensor_infos = _PATTERNS["qkv"].attrs["shard_spec"]
tensor_ids = shard_tensor_infos[0][0]
if results:
for result in results:
for node_id in result:
if node_id in tensor_ids:
print(graph.attrs["id_to_var"][result[node_id]])
print("shard_spec: ", shard_tensor_infos[0][1])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册