提交 dbbcc43c 编写于 作者: M MrChengmo

fix

上级 1c57d554
...@@ -97,6 +97,7 @@ message AsyncConfig { ...@@ -97,6 +97,7 @@ message AsyncConfig {
optional int32 thread_pool_size = 6 [ default = 1 ]; optional int32 thread_pool_size = 6 [ default = 1 ];
optional int32 send_wait_times = 7 [ default = 1 ]; optional int32 send_wait_times = 7 [ default = 1 ];
optional bool runtime_split_send_recv = 8 [ default = false ]; optional bool runtime_split_send_recv = 8 [ default = false ];
optional string worker_device = 9 [ default = 'cpu' ];
} }
message PipelineConfig { optional int32 micro_batch = 1 [ default = 1 ]; } message PipelineConfig { optional int32 micro_batch = 1 [ default = 1 ]; }
......
...@@ -736,7 +736,7 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -736,7 +736,7 @@ class PaddleCloudRoleMaker(RoleMakerBase):
elif training_role == "HETER_TRAINER": elif training_role == "HETER_TRAINER":
role = Role.HETER_WORKER role = Role.HETER_WORKER
cur_port = os.getenv("PADDLE_PORT", None) cur_port = os.getenv("PADDLE_PORT", None)
assert port != None assert cur_port != None
cur_ip = os.getenv("POD_IP", None) cur_ip = os.getenv("POD_IP", None)
assert cur_ip != None assert cur_ip != None
curr_endpoint = ":".join([cur_ip, cur_port]) curr_endpoint = ":".join([cur_ip, cur_port])
......
...@@ -621,15 +621,15 @@ class ParameterServerLauncher(object): ...@@ -621,15 +621,15 @@ class ParameterServerLauncher(object):
self.worker_num = 0 self.worker_num = 0
self.heter_worker_num = 0 self.heter_worker_num = 0
self.server_endpoints = [] self.server_endpoints = ""
self.server_endpoints_ips = [] self.server_endpoints_ips = []
self.server_endpoints_port = [] self.server_endpoints_port = []
self.worker_endpoints = [] self.worker_endpoints = ""
self.worker_endpoints_ips = [] self.worker_endpoints_ips = []
self.worker_endpoints_port = [] self.worker_endpoints_port = []
self.heter_worker_endpoints = [] self.heter_worker_endpoints = ""
self.heter_worker_endpoints_ips = [] self.heter_worker_endpoints_ips = []
self.heter_worker_endpoints_port = [] self.heter_worker_endpoints_port = []
...@@ -730,7 +730,11 @@ class ParameterServerLauncher(object): ...@@ -730,7 +730,11 @@ class ParameterServerLauncher(object):
self.current_node_ip = self.node_ips[0] self.current_node_ip = self.node_ips[0]
else: else:
self.is_local = False self.is_local = False
pod_ip = os.getenv("POD_IP", None)
if pod_ip == None:
_, self.current_node_ip = get_host_name_ip() _, self.current_node_ip = get_host_name_ip()
else:
self.current_node_ip = pod_ip
assert self.current_node_ip in self.node_ips, "Can't find your local ip {%s} in args.servers and args.workers ips: {%s}" \ assert self.current_node_ip in self.node_ips, "Can't find your local ip {%s} in args.servers and args.workers ips: {%s}" \
% (self.current_node_ip, self.node_ips) % (self.current_node_ip, self.node_ips)
self.node_rank = self.node_ips.index(self.current_node_ip) self.node_rank = self.node_ips.index(self.current_node_ip)
...@@ -803,16 +807,18 @@ class ParameterServerLauncher(object): ...@@ -803,16 +807,18 @@ class ParameterServerLauncher(object):
logger.info( logger.info(
"all workers exit, going to finish parameter server and heter_worker" "all workers exit, going to finish parameter server and heter_worker"
) )
if len(self.procs["server"]) > 0:
for i, proc in enumerate(self.procs["server"]):
self.log_fns["server"][i].close()
self.procs["server"][i].proc.terminate()
logger.info("all parameter server are killed")
if len(self.procs["heter_worker"]) > 0: if len(self.procs["heter_worker"]) > 0:
for i, proc in enumerate(self.procs["heter_worker"]): for i, proc in enumerate(self.procs["heter_worker"]):
self.log_fns["heter_worker"][i].close() self.log_fns["heter_worker"][i].close()
self.procs["heter_worker"][i].proc.terminate() self.procs["heter_worker"][i].proc.terminate()
logger.info("all heter_worker are killed") logger.info("all heter_worker are killed")
if len(self.procs["server"]) > 0:
for i, proc in enumerate(self.procs["server"]):
self.log_fns["server"][i].close()
self.procs["server"][i].proc.terminate()
logger.info("all parameter server are killed")
else: else:
# if node has not worker procs # if node has not worker procs
# blocking training process # blocking training process
...@@ -909,8 +915,8 @@ class ParameterServerLauncher(object): ...@@ -909,8 +915,8 @@ class ParameterServerLauncher(object):
"PADDLE_WITH_GLOO": "1", "PADDLE_WITH_GLOO": "1",
"PADDLE_GLOO_RENDEZVOUS": "2", "PADDLE_GLOO_RENDEZVOUS": "2",
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir, "PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir,
"FLAGS_selected_gpus": 0, "FLAGS_selected_gpus": "0",
"FLAGS_selected_xpus": 0, "FLAGS_selected_xpus": "0",
"CUDA_VISIBLE_DEVICES": device_id, "CUDA_VISIBLE_DEVICES": device_id,
"XPU_VISIBLE_DEVICES": device_id, "XPU_VISIBLE_DEVICES": device_id,
} }
...@@ -978,8 +984,8 @@ class ParameterServerLauncher(object): ...@@ -978,8 +984,8 @@ class ParameterServerLauncher(object):
"PADDLE_WITH_GLOO": "1", "PADDLE_WITH_GLOO": "1",
"PADDLE_GLOO_RENDEZVOUS": "2", "PADDLE_GLOO_RENDEZVOUS": "2",
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir, "PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir,
"FLAGS_selected_gpus": device_id, "FLAGS_selected_gpus": "0",
"FLAGS_selected_xpus": device_id, "FLAGS_selected_xpus": "0",
"CUDA_VISIBLE_DEVICES": device_id, "CUDA_VISIBLE_DEVICES": device_id,
"XPU_VISIBLE_DEVICES": device_id, "XPU_VISIBLE_DEVICES": device_id,
} }
...@@ -991,9 +997,9 @@ class ParameterServerLauncher(object): ...@@ -991,9 +997,9 @@ class ParameterServerLauncher(object):
if idx == 0: if idx == 0:
logger.info( logger.info(
"Local server start {} processes. First process distributed " "Local heter_worker start {} processes. First process distributed "
"environment info (Only For Debug): {}".format( "environment info (Only For Debug): {}".format(
len(pod.servers), len(pod.heter_workers),
pretty_print_envs(proc_env, ("Distributed Envs", "Value" pretty_print_envs(proc_env, ("Distributed Envs", "Value"
)))) ))))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册