diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index a40e2ec5d08cb056f3714e8f59c143651853da73..663cd7d28140abe77b0a9c61ee515ecc58131948 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -171,8 +171,10 @@ class SegmentLayers(object): def uniform(self, num_items, num_parts): result = [0 for _ in range(num_parts + 1)] part_size = math.floor(num_items / num_parts) - for i in range(num_parts): - result[i] = int(min(part_size * i, num_items)) + extra_layers = num_items % num_parts + for i in range(1, num_parts): + offset = 1 if i > (num_parts - extra_layers) else 0 + result[i] = int(min(result[i - 1] + part_size + offset, num_items)) result[num_parts] = num_items return result