From f3eded6cbd011b26f6defe478badfcc7fa1b35ca Mon Sep 17 00:00:00 2001 From: Chitsing KUI Date: Thu, 8 Jun 2023 17:21:24 +0800 Subject: [PATCH] enable sort ip in launch (#54435) --- .../distributed/launch/context/args_envs.py | 8 +++++++ .../distributed/launch/controllers/master.py | 21 +++++++++++++++++-- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/launch/context/args_envs.py b/python/paddle/distributed/launch/context/args_envs.py index 9e0565b1c45..e7005102f00 100644 --- a/python/paddle/distributed/launch/context/args_envs.py +++ b/python/paddle/distributed/launch/context/args_envs.py @@ -24,6 +24,7 @@ env_args_mapping = { 'PADDLE_RUN_MODE': 'run_mode', 'PADDLE_LOG_LEVEL': 'log_level', 'PADDLE_LOG_OVERWRITE': 'log_overwrite', + 'PADDLE_SORT_IP': 'sort_ip', 'PADDLE_NPROC_PER_NODE': 'nproc_per_node', 'PADDLE_JOB_ID': 'job_id', 'PADDLE_RANK': 'rank', @@ -81,6 +82,13 @@ def parse_args(): help="overwrite exits logfiles. Default False", ) + base_group.add_argument( + "--sort_ip", + type=strtobool, + default=False, + help="rank node by ip. Default False", + ) + base_group.add_argument( "--nnodes", type=str, diff --git a/python/paddle/distributed/launch/controllers/master.py b/python/paddle/distributed/launch/controllers/master.py index fb6016d9e40..fc1f2937247 100644 --- a/python/paddle/distributed/launch/controllers/master.py +++ b/python/paddle/distributed/launch/controllers/master.py @@ -13,6 +13,8 @@ # limitations under the License. import copy +import ipaddress +import json import random import sys import threading @@ -24,6 +26,12 @@ from paddle.distributed.launch.utils.kv_server import KVServer ETCD_PROTOCAL = 'etcd://' +def _cmp_by_ip(x): + x = json.loads(x) + ip_x = x.get('candidate', '127.0.0.1:8080').split(':')[0] + return int(ipaddress.IPv4Address(ip_x)) + + class Master: ''' Master is a distributed store design to exchange info among nodes @@ -156,7 +164,11 @@ class HTTPMaster(Master): rjson = self.client.get_prefix(prefix) self.ctx.logger.debug(f"sync peers {rjson}") if rjson and len(rjson) == size: - if rank < 0: + if self.ctx.args.sort_ip: + ret = sorted(rjson.values(), key=_cmp_by_ip) + idx = ret.index(value) + return ret, idx + elif rank < 0: keys = list(rjson.keys()) keys.sort() ret = [rjson[k] for k in keys] @@ -210,7 +222,12 @@ class ETCDMaster(Master): self.ctx.logger.debug(f"sync peers {result}") if len(result) == size: - if rank < 0: + if self.ctx.args.sort_ip: + values = [i[0].decode() for i in result] + ret = sorted(values, key=_cmp_by_ip) + idx = ret.index(value) + return ret, idx + elif rank < 0: keys = [i[1].key.decode() for i in result] sorted_keys = [i[1].key.decode() for i in result] sorted_keys.sort() -- GitLab