From 383f1c4f969e13a9d65835d9e28eccd2783f938c Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Sat, 5 Nov 2022 20:04:37 +0800 Subject: [PATCH] update the split logic for uniform (#47670) --- .../fleet/meta_parallel/parallel_layers/pp_layers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index a40e2ec5d08..663cd7d2814 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -171,8 +171,10 @@ class SegmentLayers(object): def uniform(self, num_items, num_parts): result = [0 for _ in range(num_parts + 1)] part_size = math.floor(num_items / num_parts) - for i in range(num_parts): - result[i] = int(min(part_size * i, num_items)) + extra_layers = num_items % num_parts + for i in range(1, num_parts): + offset = 1 if i > (num_parts - extra_layers) else 0 + result[i] = int(min(result[i - 1] + part_size + offset, num_items)) result[num_parts] = num_items return result -- GitLab