From 11c26f26d3437438264452fd3121b0c2cc9a9f8b Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Wed, 5 Jul 2023 10:47:34 +0800 Subject: [PATCH] refine dygraph_sharding_optimizer.py by sorting parameters (#55059) * refine dygraph_sharding_optimizer.py by sorting parameters * Update dygraph_sharding_optimizer.py Make FLAGS_sharding_sort_parameters=1 by default. --- .../dygraph_sharding_optimizer.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index b16eab734a5..29656e89828 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -14,6 +14,8 @@ ###### +import os +from distutils.util import strtobool from functools import reduce import paddle @@ -105,7 +107,17 @@ class DygraphShardingOptimizer: for rank_ in range(self._sharding_world_size): mapping[rank_] = [] sizes = [0] * self._sharding_world_size - for param in self._parameter_list: + + parameters = list(self._parameter_list) + need_sort_parameters = strtobool( + os.getenv('FLAGS_sharding_sort_parameters', '1') + ) + if need_sort_parameters: + parameters.sort( + key=lambda p: reduce(lambda x, y: x * y, p.shape), reverse=True + ) + + for param in parameters: rank = sizes.index(min(sizes)) mapping[rank].append(param) numel = reduce(lambda x, y: x * y, param.shape) -- GitLab