未验证 提交 0bb7c003 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel] Add patterns of rule based tuner (#51859)

* add patterns

* add unittest
上级 cdefcd00
......@@ -95,7 +95,7 @@ def get_gpt_model(
return train_program, start_program, loss, gen_data
class TestGroupOperators(unittest.TestCase):
class TestGroupOperatorsAndPatterns(unittest.TestCase):
def test_gpt(self):
modeling.init_global()
train_program = static.Program()
......@@ -117,17 +117,30 @@ class TestGroupOperators(unittest.TestCase):
)
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import (
_PATTERNS,
GraphUtil,
RuleBasedTuner,
convert_to_graph,
)
dist_context = DistributedContext()
tuner = RuleBasedTuner(dist_context)
layers = tuner.cluster_operators(train_program.global_block().ops)
layer = layers[0]
graph = convert_to_graph(layer, train_program.global_block())
graph = GraphUtil.convert_to_graph(train_program.global_block())
print("graph: ", graph)
print("qkv: ", _PATTERNS["qkv"].attrs["shard_spec"])
print("row_matmul: ", _PATTERNS["row_matmul"].attrs["shard_spec"])
print("ffn: ", _PATTERNS["ffn"].attrs["shard_spec"])
print(
"shared_word_embedding: ",
_PATTERNS["shared_word_embedding"].attrs["shard_spec"],
)
print(
"position_embedding: ",
_PATTERNS["position_embedding"].attrs["shard_spec"],
)
print(
"unsqueeze_data: ", _PATTERNS["unsqueeze_data"].attrs["shard_spec"]
)
print("reshape_data: ", _PATTERNS["reshape_data"].attrs["shard_spec"])
if __name__ == "__main__":
......
......@@ -95,7 +95,7 @@ def get_gpt_model(
return train_program, start_program, loss, gen_data
class TestGroupOperators(unittest.TestCase):
class TestPatternMatch(unittest.TestCase):
def test_gpt(self):
modeling.init_global()
train_program = static.Program()
......@@ -116,26 +116,15 @@ class TestGroupOperators(unittest.TestCase):
DistributedContext,
)
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import (
_PATTERNS,
GraphUtil,
RuleBasedTuner,
convert_to_graph,
match,
)
dist_context = DistributedContext()
tuner = RuleBasedTuner(dist_context)
layers = tuner.cluster_operators(train_program.global_block().ops)
layer = layers[0]
graph = convert_to_graph(layer, train_program.global_block())
results = match(_PATTERNS["qkv"], graph)
shard_tensor_infos = _PATTERNS["qkv"].attrs["shard_spec"]
tensor_ids = shard_tensor_infos[0][0]
if results:
for result in results:
for node_id in result:
if node_id in tensor_ids:
print(graph.attrs["id_to_var"][result[node_id]])
print("shard_spec: ", shard_tensor_infos[0][1])
graph = GraphUtil.convert_to_graph(train_program.global_block())
results = GraphUtil.match_all_patterns(graph)
print(results)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册