From dbbcc43c18c65f9174957fa913fe9ecb1a48015f Mon Sep 17 00:00:00 2001 From: MrChengmo Date: Wed, 23 Sep 2020 16:32:18 +0800 Subject: [PATCH] fix --- .../framework/distributed_strategy.proto | 1 + .../distributed/fleet/base/role_maker.py | 2 +- .../paddle/distributed/fleet/launch_utils.py | 36 +++++++++++-------- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index df482f4334..2a7d97f353 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -97,6 +97,7 @@ message AsyncConfig { optional int32 thread_pool_size = 6 [ default = 1 ]; optional int32 send_wait_times = 7 [ default = 1 ]; optional bool runtime_split_send_recv = 8 [ default = false ]; + optional string worker_device = 9 [ default = 'cpu' ]; } message PipelineConfig { optional int32 micro_batch = 1 [ default = 1 ]; } diff --git a/python/paddle/distributed/fleet/base/role_maker.py b/python/paddle/distributed/fleet/base/role_maker.py index a7aad92fff..dd76d14284 100644 --- a/python/paddle/distributed/fleet/base/role_maker.py +++ b/python/paddle/distributed/fleet/base/role_maker.py @@ -736,7 +736,7 @@ class PaddleCloudRoleMaker(RoleMakerBase): elif training_role == "HETER_TRAINER": role = Role.HETER_WORKER cur_port = os.getenv("PADDLE_PORT", None) - assert port != None + assert cur_port != None cur_ip = os.getenv("POD_IP", None) assert cur_ip != None curr_endpoint = ":".join([cur_ip, cur_port]) diff --git a/python/paddle/distributed/fleet/launch_utils.py b/python/paddle/distributed/fleet/launch_utils.py index 9e40f2ac60..708ef39693 100644 --- a/python/paddle/distributed/fleet/launch_utils.py +++ b/python/paddle/distributed/fleet/launch_utils.py @@ -621,15 +621,15 @@ class ParameterServerLauncher(object): self.worker_num = 0 self.heter_worker_num = 0 - self.server_endpoints = [] + self.server_endpoints = "" self.server_endpoints_ips = [] self.server_endpoints_port = [] - self.worker_endpoints = [] + self.worker_endpoints = "" self.worker_endpoints_ips = [] self.worker_endpoints_port = [] - self.heter_worker_endpoints = [] + self.heter_worker_endpoints = "" self.heter_worker_endpoints_ips = [] self.heter_worker_endpoints_port = [] @@ -730,7 +730,11 @@ class ParameterServerLauncher(object): self.current_node_ip = self.node_ips[0] else: self.is_local = False - _, self.current_node_ip = get_host_name_ip() + pod_ip = os.getenv("POD_IP", None) + if pod_ip == None: + _, self.current_node_ip = get_host_name_ip() + else: + self.current_node_ip = pod_ip assert self.current_node_ip in self.node_ips, "Can't find your local ip {%s} in args.servers and args.workers ips: {%s}" \ % (self.current_node_ip, self.node_ips) self.node_rank = self.node_ips.index(self.current_node_ip) @@ -803,16 +807,18 @@ class ParameterServerLauncher(object): logger.info( "all workers exit, going to finish parameter server and heter_worker" ) - if len(self.procs["server"]) > 0: - for i, proc in enumerate(self.procs["server"]): - self.log_fns["server"][i].close() - self.procs["server"][i].proc.terminate() - logger.info("all parameter server are killed") if len(self.procs["heter_worker"]) > 0: for i, proc in enumerate(self.procs["heter_worker"]): self.log_fns["heter_worker"][i].close() self.procs["heter_worker"][i].proc.terminate() logger.info("all heter_worker are killed") + + if len(self.procs["server"]) > 0: + for i, proc in enumerate(self.procs["server"]): + self.log_fns["server"][i].close() + self.procs["server"][i].proc.terminate() + logger.info("all parameter server are killed") + else: # if node has not worker procs # blocking training process @@ -909,8 +915,8 @@ class ParameterServerLauncher(object): "PADDLE_WITH_GLOO": "1", "PADDLE_GLOO_RENDEZVOUS": "2", "PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir, - "FLAGS_selected_gpus": 0, - "FLAGS_selected_xpus": 0, + "FLAGS_selected_gpus": "0", + "FLAGS_selected_xpus": "0", "CUDA_VISIBLE_DEVICES": device_id, "XPU_VISIBLE_DEVICES": device_id, } @@ -978,8 +984,8 @@ class ParameterServerLauncher(object): "PADDLE_WITH_GLOO": "1", "PADDLE_GLOO_RENDEZVOUS": "2", "PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir, - "FLAGS_selected_gpus": device_id, - "FLAGS_selected_xpus": device_id, + "FLAGS_selected_gpus": "0", + "FLAGS_selected_xpus": "0", "CUDA_VISIBLE_DEVICES": device_id, "XPU_VISIBLE_DEVICES": device_id, } @@ -991,9 +997,9 @@ class ParameterServerLauncher(object): if idx == 0: logger.info( - "Local server start {} processes. First process distributed " + "Local heter_worker start {} processes. First process distributed " "environment info (Only For Debug): {}".format( - len(pod.servers), + len(pod.heter_workers), pretty_print_envs(proc_env, ("Distributed Envs", "Value" )))) -- GitLab