提交 dbbcc43c 编写于 作者: M MrChengmo

fix

上级 1c57d554
......@@ -97,6 +97,7 @@ message AsyncConfig {
optional int32 thread_pool_size = 6 [ default = 1 ];
optional int32 send_wait_times = 7 [ default = 1 ];
optional bool runtime_split_send_recv = 8 [ default = false ];
optional string worker_device = 9 [ default = 'cpu' ];
}
message PipelineConfig { optional int32 micro_batch = 1 [ default = 1 ]; }
......
......@@ -736,7 +736,7 @@ class PaddleCloudRoleMaker(RoleMakerBase):
elif training_role == "HETER_TRAINER":
role = Role.HETER_WORKER
cur_port = os.getenv("PADDLE_PORT", None)
assert port != None
assert cur_port != None
cur_ip = os.getenv("POD_IP", None)
assert cur_ip != None
curr_endpoint = ":".join([cur_ip, cur_port])
......
......@@ -621,15 +621,15 @@ class ParameterServerLauncher(object):
self.worker_num = 0
self.heter_worker_num = 0
self.server_endpoints = []
self.server_endpoints = ""
self.server_endpoints_ips = []
self.server_endpoints_port = []
self.worker_endpoints = []
self.worker_endpoints = ""
self.worker_endpoints_ips = []
self.worker_endpoints_port = []
self.heter_worker_endpoints = []
self.heter_worker_endpoints = ""
self.heter_worker_endpoints_ips = []
self.heter_worker_endpoints_port = []
......@@ -730,7 +730,11 @@ class ParameterServerLauncher(object):
self.current_node_ip = self.node_ips[0]
else:
self.is_local = False
pod_ip = os.getenv("POD_IP", None)
if pod_ip == None:
_, self.current_node_ip = get_host_name_ip()
else:
self.current_node_ip = pod_ip
assert self.current_node_ip in self.node_ips, "Can't find your local ip {%s} in args.servers and args.workers ips: {%s}" \
% (self.current_node_ip, self.node_ips)
self.node_rank = self.node_ips.index(self.current_node_ip)
......@@ -803,16 +807,18 @@ class ParameterServerLauncher(object):
logger.info(
"all workers exit, going to finish parameter server and heter_worker"
)
if len(self.procs["server"]) > 0:
for i, proc in enumerate(self.procs["server"]):
self.log_fns["server"][i].close()
self.procs["server"][i].proc.terminate()
logger.info("all parameter server are killed")
if len(self.procs["heter_worker"]) > 0:
for i, proc in enumerate(self.procs["heter_worker"]):
self.log_fns["heter_worker"][i].close()
self.procs["heter_worker"][i].proc.terminate()
logger.info("all heter_worker are killed")
if len(self.procs["server"]) > 0:
for i, proc in enumerate(self.procs["server"]):
self.log_fns["server"][i].close()
self.procs["server"][i].proc.terminate()
logger.info("all parameter server are killed")
else:
# if node has not worker procs
# blocking training process
......@@ -909,8 +915,8 @@ class ParameterServerLauncher(object):
"PADDLE_WITH_GLOO": "1",
"PADDLE_GLOO_RENDEZVOUS": "2",
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir,
"FLAGS_selected_gpus": 0,
"FLAGS_selected_xpus": 0,
"FLAGS_selected_gpus": "0",
"FLAGS_selected_xpus": "0",
"CUDA_VISIBLE_DEVICES": device_id,
"XPU_VISIBLE_DEVICES": device_id,
}
......@@ -978,8 +984,8 @@ class ParameterServerLauncher(object):
"PADDLE_WITH_GLOO": "1",
"PADDLE_GLOO_RENDEZVOUS": "2",
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir,
"FLAGS_selected_gpus": device_id,
"FLAGS_selected_xpus": device_id,
"FLAGS_selected_gpus": "0",
"FLAGS_selected_xpus": "0",
"CUDA_VISIBLE_DEVICES": device_id,
"XPU_VISIBLE_DEVICES": device_id,
}
......@@ -991,9 +997,9 @@ class ParameterServerLauncher(object):
if idx == 0:
logger.info(
"Local server start {} processes. First process distributed "
"Local heter_worker start {} processes. First process distributed "
"environment info (Only For Debug): {}".format(
len(pod.servers),
len(pod.heter_workers),
pretty_print_envs(proc_env, ("Distributed Envs", "Value"
))))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册