未验证 提交 0a5b2e79 编写于 作者: J Jianghai 提交者: GitHub

[Auto Parallel]fix bugs in cluster to device meshes (#49892)

* fix bugs in cluster to device meshes

* add tests

* 1
上级 636780b5
......@@ -516,6 +516,11 @@ class ClusterPartitionUtil:
@staticmethod
def complete_meshes(partitions: list, num: int):
if num == 2:
return [[1, 2], [2, 1]]
if num == 3:
return [[1, 2], [2, 1], [1]]
# special cases
if len(partitions) == 1:
partitions = ClusterPartitionUtil.factorization(num - 1)
partitions.append([1])
......
......@@ -17,7 +17,7 @@ import unittest
class TestClusterPartition(unittest.TestCase):
def test_cluster_partition(self):
clusters = [(5, 8), (1, 8), (4, 8), (16, 8)]
clusters = [(5, 8), (1, 8), (4, 8), (16, 8), (2, 8), (3, 8)]
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import (
ClusterPartitionUtil,
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册