From 0a5b2e79abdb5214223dd33dc5e2cb082f140676 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Thu, 2 Feb 2023 16:28:00 +0800 Subject: [PATCH] [Auto Parallel]fix bugs in cluster to device meshes (#49892) * fix bugs in cluster to device meshes * add tests * 1 --- .../distributed/auto_parallel/tuner/rule_based_tuner.py | 5 +++++ .../tests/unittests/auto_parallel/test_cluster_partition.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py b/python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py index f1508f793a..cdfc87868c 100644 --- a/python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py +++ b/python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py @@ -516,6 +516,11 @@ class ClusterPartitionUtil: @staticmethod def complete_meshes(partitions: list, num: int): + if num == 2: + return [[1, 2], [2, 1]] + if num == 3: + return [[1, 2], [2, 1], [1]] + # special cases if len(partitions) == 1: partitions = ClusterPartitionUtil.factorization(num - 1) partitions.append([1]) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_cluster_partition.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_cluster_partition.py index 2223724c29..9071b481eb 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_cluster_partition.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_cluster_partition.py @@ -17,7 +17,7 @@ import unittest class TestClusterPartition(unittest.TestCase): def test_cluster_partition(self): - clusters = [(5, 8), (1, 8), (4, 8), (16, 8)] + clusters = [(5, 8), (1, 8), (4, 8), (16, 8), (2, 8), (3, 8)] from paddle.distributed.auto_parallel.tuner.rule_based_tuner import ( ClusterPartitionUtil, ) -- GitLab