未验证 提交 0bb7c003 编写于 作者: C caozhou 提交者: GitHub

[Auto Parallel] Add patterns of rule based tuner (#51859)

* add patterns

* add unittest
上级 cdefcd00
...@@ -95,7 +95,7 @@ def get_gpt_model( ...@@ -95,7 +95,7 @@ def get_gpt_model(
return train_program, start_program, loss, gen_data return train_program, start_program, loss, gen_data
class TestGroupOperators(unittest.TestCase): class TestGroupOperatorsAndPatterns(unittest.TestCase):
def test_gpt(self): def test_gpt(self):
modeling.init_global() modeling.init_global()
train_program = static.Program() train_program = static.Program()
...@@ -117,17 +117,30 @@ class TestGroupOperators(unittest.TestCase): ...@@ -117,17 +117,30 @@ class TestGroupOperators(unittest.TestCase):
) )
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import ( from paddle.distributed.auto_parallel.tuner.rule_based_tuner import (
_PATTERNS, _PATTERNS,
GraphUtil,
RuleBasedTuner, RuleBasedTuner,
convert_to_graph,
) )
dist_context = DistributedContext() dist_context = DistributedContext()
tuner = RuleBasedTuner(dist_context) tuner = RuleBasedTuner(dist_context)
layers = tuner.cluster_operators(train_program.global_block().ops) layers = tuner.cluster_operators(train_program.global_block().ops)
layer = layers[0] graph = GraphUtil.convert_to_graph(train_program.global_block())
graph = convert_to_graph(layer, train_program.global_block())
print("graph: ", graph) print("graph: ", graph)
print("qkv: ", _PATTERNS["qkv"].attrs["shard_spec"]) print("qkv: ", _PATTERNS["qkv"].attrs["shard_spec"])
print("row_matmul: ", _PATTERNS["row_matmul"].attrs["shard_spec"])
print("ffn: ", _PATTERNS["ffn"].attrs["shard_spec"])
print(
"shared_word_embedding: ",
_PATTERNS["shared_word_embedding"].attrs["shard_spec"],
)
print(
"position_embedding: ",
_PATTERNS["position_embedding"].attrs["shard_spec"],
)
print(
"unsqueeze_data: ", _PATTERNS["unsqueeze_data"].attrs["shard_spec"]
)
print("reshape_data: ", _PATTERNS["reshape_data"].attrs["shard_spec"])
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -95,7 +95,7 @@ def get_gpt_model( ...@@ -95,7 +95,7 @@ def get_gpt_model(
return train_program, start_program, loss, gen_data return train_program, start_program, loss, gen_data
class TestGroupOperators(unittest.TestCase): class TestPatternMatch(unittest.TestCase):
def test_gpt(self): def test_gpt(self):
modeling.init_global() modeling.init_global()
train_program = static.Program() train_program = static.Program()
...@@ -116,26 +116,15 @@ class TestGroupOperators(unittest.TestCase): ...@@ -116,26 +116,15 @@ class TestGroupOperators(unittest.TestCase):
DistributedContext, DistributedContext,
) )
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import ( from paddle.distributed.auto_parallel.tuner.rule_based_tuner import (
_PATTERNS, GraphUtil,
RuleBasedTuner, RuleBasedTuner,
convert_to_graph,
match,
) )
dist_context = DistributedContext() dist_context = DistributedContext()
tuner = RuleBasedTuner(dist_context) tuner = RuleBasedTuner(dist_context)
layers = tuner.cluster_operators(train_program.global_block().ops) graph = GraphUtil.convert_to_graph(train_program.global_block())
layer = layers[0] results = GraphUtil.match_all_patterns(graph)
graph = convert_to_graph(layer, train_program.global_block()) print(results)
results = match(_PATTERNS["qkv"], graph)
shard_tensor_infos = _PATTERNS["qkv"].attrs["shard_spec"]
tensor_ids = shard_tensor_infos[0][0]
if results:
for result in results:
for node_id in result:
if node_id in tensor_ids:
print(graph.attrs["id_to_var"][result[node_id]])
print("shard_spec: ", shard_tensor_infos[0][1])
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册