未验证 提交 11c26f26 编写于 作者: S sneaxiy 提交者: GitHub

refine dygraph_sharding_optimizer.py by sorting parameters (#55059)

* refine dygraph_sharding_optimizer.py by sorting parameters

* Update dygraph_sharding_optimizer.py

Make FLAGS_sharding_sort_parameters=1 by default.
上级 93c58390
......@@ -14,6 +14,8 @@
######
import os
from distutils.util import strtobool
from functools import reduce
import paddle
......@@ -105,7 +107,17 @@ class DygraphShardingOptimizer:
for rank_ in range(self._sharding_world_size):
mapping[rank_] = []
sizes = [0] * self._sharding_world_size
for param in self._parameter_list:
parameters = list(self._parameter_list)
need_sort_parameters = strtobool(
os.getenv('FLAGS_sharding_sort_parameters', '1')
)
if need_sort_parameters:
parameters.sort(
key=lambda p: reduce(lambda x, y: x * y, p.shape), reverse=True
)
for param in parameters:
rank = sizes.index(min(sizes))
mapping[rank].append(param)
numel = reduce(lambda x, y: x * y, param.shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册