提交 be70c94e 编写于 作者: M MrChengmo

fix

上级 50ada3a4
......@@ -121,11 +121,13 @@ void RecvLodTensor(const CommContext &rpc_ctx, const framework::Scope &scope) {
if (rpc_ctx.origin_varnames.size() == 1 &&
rpc_ctx.splited_varnames.size() == 1) {
auto varname = rpc_ctx.origin_varnames[0];
const auto place =
scope.FindVar(varname)->Get<framework::LoDTensor>().place();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &ctx = *pool.Get(place);
VLOG(4) << "recv " << varname << " from " << rpc_ctx.epmap[0] << " in gpu? "
<< platform::is_gpu_place(place);
rets.push_back(rpc_client->AsyncGetVarNoBarrier(rpc_ctx.epmap[0], cpu_ctx,
rets.push_back(rpc_client->AsyncGetVarNoBarrier(rpc_ctx.epmap[0], ctx,
scope, varname, varname));
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(
......
......@@ -657,88 +657,83 @@ class PaddleCloudRoleMaker(RoleMakerBase):
return self._role == Role.HETER_WORKER
def _ps_env(self):
try:
# Environment variable PADDLE_PSERVERS_IP_PORT_LIST must be set
# format: string(ip:port,ip:port), eg. 127.0.0.1:6001,127.0.0.1:6002
self._server_endpoints = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST")
if self._server_endpoints is None:
# back to non_distributed execution.
self._server_endpoints = ""
self._trainers_num = 1
self._role = Role.WORKER
self._current_id = 0
self._nodes_num = 1
self._heter_trainers_num = 0
self._heter_trainer_endpoints = None
self._non_distributed = True
return
self._server_endpoints = self._server_endpoints.split(",")
self._worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS")
if self._worker_endpoints:
self._worker_endpoints = self._worker_endpoints.split(",")
else:
self._worker_endpoints = []
# Environment variable PADDLE_PSERVERS_IP_PORT_LIST must be set
# format: string(ip:port,ip:port), eg. 127.0.0.1:6001,127.0.0.1:6002
self._server_endpoints = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST", None)
if self._server_endpoints is None:
# back to non_distributed execution.
self._server_endpoints = ""
self._trainers_num = 1
self._role = Role.WORKER
self._current_id = 0
self._nodes_num = 1
self._heter_trainers_num = 0
self._heter_trainer_endpoints = None
self._non_distributed = True
return
self._server_endpoints = self._server_endpoints.split(",")
self._worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", None)
if self._worker_endpoints != None:
self._worker_endpoints = self._worker_endpoints.split(",")
else:
self._worker_endpoints = []
trainers_num = int(os.environ["PADDLE_TRAINERS_NUM"])
training_role = os.environ["TRAINING_ROLE"]
trainers_num = int(os.environ["PADDLE_TRAINERS_NUM"])
training_role = os.environ["TRAINING_ROLE"]
if training_role not in ["TRAINER", "PSERVER", "HETER_TRAINER"]:
if training_role not in ["TRAINER", "PSERVER", "HETER_TRAINER"]:
raise ValueError(
"TRAINING_ROLE must be PSERVER or TRAINER or HETER_TRAINER, but get {}, please check your environment.".
format(training_role))
# For heter parameter server env setting
heter_trainer_eplist = os.getenv("PADDLE_HETER_TRAINER_IP_PORT_LIST",
"")
heter_trainer_device = os.getenv("PADDLE_HETER_TRAINER_DEVICE", "")
if heter_trainer_eplist != "" and heter_trainer_device != "":
try:
heter_trainer_eplist = os.environ[
"PADDLE_HETER_TRAINER_IP_PORT_LIST"].split(",")
except:
raise ValueError(
"TRAINING_ROLE must be PSERVER or TRAINER or HETER_TRAINER, but get {}, please check your environment.".
format(training_role))
# For heter parameter server env setting
heter_trainer_eplist = os.getenv(
"PADDLE_HETER_TRAINER_IP_PORT_LIST", None)
heter_trainer_device = os.getenv("PADDLE_HETER_TRAINER_DEVICE",
None)
if heter_trainer_eplist and heter_trainer_device:
try:
heter_trainer_eplist = os.environ[
"PADDLE_HETER_TRAINER_IP_PORT_LIST"].split(",")
except:
raise ValueError(
"Can not Find PADDLE_HETER_TRAINER_IP_PORT_LIST in env or its format doesn't match the requirement: 'IP:PORT,IP:PORT' ."
)
self._is_heter_parameter_server_mode = True
heter_trainers_num = len(heter_trainer_eplist)
current_node_device = heter_trainer_device.upper()
if current_node_device not in ["CPU", "GPU", "XPU"]:
raise ValueError(
"Heter Trainer doesn't support {} device now, please use CPU / GPU / XPU(KunLun)".
format(heter_trainer_device))
self._heter_trainer_device = current_node_device
else:
self._is_heter_parameter_server_mode = False
heter_trainers_num = 0
if training_role == "TRAINER":
role = Role.WORKER
current_id = int(os.environ["PADDLE_TRAINER_ID"])
if len(self._worker_endpoints) > 0:
self._cur_endpoint = self._worker_endpoints[current_id]
elif training_role == "PSERVER":
role = Role.SERVER
port = os.environ["PADDLE_PORT"]
ip = os.environ["POD_IP"]
self._cur_endpoint = ip + ":" + port
current_id = self._server_endpoints.index(self._cur_endpoint)
elif training_role == "HETER_TRAINER":
role = Role.HETER_WORKER
cur_ip = os.environ["POD_IP"]
cur_port = os.environ["PADDLE_PORT"]
curr_endpoint = ":".join([cur_ip, cur_port])
current_id = heter_trainer_eplist.index(curr_endpoint)
else:
"Can not Find PADDLE_HETER_TRAINER_IP_PORT_LIST in env or its format doesn't match the requirement: 'IP:PORT,IP:PORT' ."
)
self._is_heter_parameter_server_mode = True
heter_trainers_num = len(heter_trainer_eplist)
current_node_device = heter_trainer_device.upper()
if current_node_device not in ["CPU", "GPU", "XPU"]:
raise ValueError(
"TRAINING_ROLE must be PSERVER or TRAINER or HETER_TRAINER")
except ValueError as e:
"Heter Trainer doesn't support {} device now, please use CPU / GPU / XPU(KunLun)".
format(heter_trainer_device))
self._heter_trainer_device = current_node_device
else:
self._is_heter_parameter_server_mode = False
heter_trainers_num = 0
if training_role == "TRAINER":
role = Role.WORKER
current_id = int(os.environ["PADDLE_TRAINER_ID"])
if len(self._worker_endpoints) > 0:
self._cur_endpoint = self._worker_endpoints[current_id]
elif training_role == "PSERVER":
role = Role.SERVER
port = os.environ["PADDLE_PORT"]
ip = os.environ["POD_IP"]
self._cur_endpoint = ip + ":" + port
current_id = self._server_endpoints.index(self._cur_endpoint)
elif training_role == "HETER_TRAINER":
role = Role.HETER_WORKER
cur_ip = os.environ["POD_IP"]
cur_port = os.environ["PADDLE_PORT"]
curr_endpoint = ":".join([cur_ip, cur_port])
current_id = heter_trainer_eplist.index(curr_endpoint)
else:
raise ValueError(
"Something wrong with PaddleCloud, please check environment")
"TRAINING_ROLE must be PSERVER or TRAINER or HETER_TRAINER")
self._trainers_num = trainers_num
self._role = role
......
......@@ -114,6 +114,8 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra
"followed by all the arguments for the "
"training script")
base_group.add_argument('training_script_args', nargs=REMAINDER)
# Optional arguments for the launch helper
# for collective
collective_group = parser.add_argument_group("Collective Parameters")
......
......@@ -739,8 +739,7 @@ class ParameterServerLauncher(object):
if ip == self.heter_worker_endpoints_ips[k]:
heter_worker = Trainer()
heter_worker.endpoint = "%s:%s" % (
ip,
self.endpoints_dict["heter_worker_endpoints_port"][k])
ip, self.heter_worker_endpoints_port[k])
heter_worker.rank = heter_worker_rank
heter_worker_rank += 1
pod.heter_workers.append(heter_worker)
......@@ -770,9 +769,9 @@ class ParameterServerLauncher(object):
self.procs[i].proc.wait()
if len(self.log_fns) > 0:
self.log_fns[i].close()
print(
"all workers exit, going to finish parameter server and heter_worker",
file=sys.stderr)
logger.info(
"all workers exit, going to finish parameter server and heter_worker"
)
for i in range(
len(pod.servers + pod.workers),
......@@ -780,13 +779,13 @@ class ParameterServerLauncher(object):
if len(self.log_fns) > 0:
self.log_fns[i].close()
self.procs[i].proc.terminate()
print("all heter worker are killed", file=sys.stderr)
logger.info("all heter worker are killed")
for i in range(len(pod.servers)):
if len(self.log_fns) > 0:
self.log_fns[i].close()
self.procs[i].proc.terminate()
print("all parameter server are killed", file=sys.stderr)
logger.info("all parameter server are killed", file=sys.stderr)
if os.path.exists(self.gloo_rendezvous_dir):
shutil.rmtree(self.gloo_rendezvous_dir)
......@@ -857,6 +856,7 @@ class ParameterServerLauncher(object):
heter_device_num = fluid.core.get_xpu_device_count()
for idx, cur_worker in enumerate(pod.workers):
device_id = str(idx % heter_device_num)
proc_env = {
"PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints,
"PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints,
......@@ -869,10 +869,10 @@ class ParameterServerLauncher(object):
"PADDLE_WITH_GLOO": "1",
"PADDLE_GLOO_RENDEZVOUS": "2",
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir,
"FLAGS_selected_gpus": idx % heter_device_num,
"FLAGS_selected_xpus": idx % heter_device_num,
"CUDA_VISIBLE_DEVICES": idx % heter_device_num,
"XPU_VISIBLE_DEVICES": idx % heter_device_num,
"FLAGS_selected_gpus": 0,
"FLAGS_selected_xpus": 0,
"CUDA_VISIBLE_DEVICES": device_id,
"XPU_VISIBLE_DEVICES": device_id,
}
current_env.update(proc_env)
......@@ -921,6 +921,7 @@ class ParameterServerLauncher(object):
assert heter_device_num != 0
for idx, cur_heter_worker in enumerate(pod.heter_workers):
device_id = str(idx % heter_device_num)
proc_env = {
"PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints,
"PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints,
......@@ -934,10 +935,10 @@ class ParameterServerLauncher(object):
"PADDLE_WITH_GLOO": "1",
"PADDLE_GLOO_RENDEZVOUS": "2",
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir,
"FLAGS_selected_gpus": idx % heter_device_num,
"FLAGS_selected_xpus": idx % heter_device_num,
"CUDA_VISIBLE_DEVICES": idx % heter_device_num,
"XPU_VISIBLE_DEVICES": idx % heter_device_num,
"FLAGS_selected_gpus": device_id,
"FLAGS_selected_xpus": device_id,
"CUDA_VISIBLE_DEVICES": device_id,
"XPU_VISIBLE_DEVICES": device_id,
}
current_env.update(proc_env)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册