未验证 提交 0a5b2e79 编写于 作者: J Jianghai 提交者: GitHub

[Auto Parallel]fix bugs in cluster to device meshes (#49892)

* fix bugs in cluster to device meshes

* add tests

* 1
上级 636780b5
...@@ -516,6 +516,11 @@ class ClusterPartitionUtil: ...@@ -516,6 +516,11 @@ class ClusterPartitionUtil:
@staticmethod @staticmethod
def complete_meshes(partitions: list, num: int): def complete_meshes(partitions: list, num: int):
if num == 2:
return [[1, 2], [2, 1]]
if num == 3:
return [[1, 2], [2, 1], [1]]
# special cases
if len(partitions) == 1: if len(partitions) == 1:
partitions = ClusterPartitionUtil.factorization(num - 1) partitions = ClusterPartitionUtil.factorization(num - 1)
partitions.append([1]) partitions.append([1])
......
...@@ -17,7 +17,7 @@ import unittest ...@@ -17,7 +17,7 @@ import unittest
class TestClusterPartition(unittest.TestCase): class TestClusterPartition(unittest.TestCase):
def test_cluster_partition(self): def test_cluster_partition(self):
clusters = [(5, 8), (1, 8), (4, 8), (16, 8)] clusters = [(5, 8), (1, 8), (4, 8), (16, 8), (2, 8), (3, 8)]
from paddle.distributed.auto_parallel.tuner.rule_based_tuner import ( from paddle.distributed.auto_parallel.tuner.rule_based_tuner import (
ClusterPartitionUtil, ClusterPartitionUtil,
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册