未验证 提交 f3eded6c 编写于 作者: C Chitsing KUI 提交者: GitHub

enable sort ip in launch (#54435)

上级 56fd25b8
......@@ -24,6 +24,7 @@ env_args_mapping = {
'PADDLE_RUN_MODE': 'run_mode',
'PADDLE_LOG_LEVEL': 'log_level',
'PADDLE_LOG_OVERWRITE': 'log_overwrite',
'PADDLE_SORT_IP': 'sort_ip',
'PADDLE_NPROC_PER_NODE': 'nproc_per_node',
'PADDLE_JOB_ID': 'job_id',
'PADDLE_RANK': 'rank',
......@@ -81,6 +82,13 @@ def parse_args():
help="overwrite exits logfiles. Default False",
)
base_group.add_argument(
"--sort_ip",
type=strtobool,
default=False,
help="rank node by ip. Default False",
)
base_group.add_argument(
"--nnodes",
type=str,
......
......@@ -13,6 +13,8 @@
# limitations under the License.
import copy
import ipaddress
import json
import random
import sys
import threading
......@@ -24,6 +26,12 @@ from paddle.distributed.launch.utils.kv_server import KVServer
ETCD_PROTOCAL = 'etcd://'
def _cmp_by_ip(x):
x = json.loads(x)
ip_x = x.get('candidate', '127.0.0.1:8080').split(':')[0]
return int(ipaddress.IPv4Address(ip_x))
class Master:
'''
Master is a distributed store design to exchange info among nodes
......@@ -156,7 +164,11 @@ class HTTPMaster(Master):
rjson = self.client.get_prefix(prefix)
self.ctx.logger.debug(f"sync peers {rjson}")
if rjson and len(rjson) == size:
if rank < 0:
if self.ctx.args.sort_ip:
ret = sorted(rjson.values(), key=_cmp_by_ip)
idx = ret.index(value)
return ret, idx
elif rank < 0:
keys = list(rjson.keys())
keys.sort()
ret = [rjson[k] for k in keys]
......@@ -210,7 +222,12 @@ class ETCDMaster(Master):
self.ctx.logger.debug(f"sync peers {result}")
if len(result) == size:
if rank < 0:
if self.ctx.args.sort_ip:
values = [i[0].decode() for i in result]
ret = sorted(values, key=_cmp_by_ip)
idx = ret.index(value)
return ret, idx
elif rank < 0:
keys = [i[1].key.decode() for i in result]
sorted_keys = [i[1].key.decode() for i in result]
sorted_keys.sort()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册