提交 be70c94e 编写于 作者: M MrChengmo

fix

上级 50ada3a4
...@@ -121,11 +121,13 @@ void RecvLodTensor(const CommContext &rpc_ctx, const framework::Scope &scope) { ...@@ -121,11 +121,13 @@ void RecvLodTensor(const CommContext &rpc_ctx, const framework::Scope &scope) {
if (rpc_ctx.origin_varnames.size() == 1 && if (rpc_ctx.origin_varnames.size() == 1 &&
rpc_ctx.splited_varnames.size() == 1) { rpc_ctx.splited_varnames.size() == 1) {
auto varname = rpc_ctx.origin_varnames[0]; auto varname = rpc_ctx.origin_varnames[0];
const auto place =
scope.FindVar(varname)->Get<framework::LoDTensor>().place();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &ctx = *pool.Get(place); auto &ctx = *pool.Get(place);
VLOG(4) << "recv " << varname << " from " << rpc_ctx.epmap[0] << " in gpu? " VLOG(4) << "recv " << varname << " from " << rpc_ctx.epmap[0] << " in gpu? "
<< platform::is_gpu_place(place); << platform::is_gpu_place(place);
rets.push_back(rpc_client->AsyncGetVarNoBarrier(rpc_ctx.epmap[0], cpu_ctx, rets.push_back(rpc_client->AsyncGetVarNoBarrier(rpc_ctx.epmap[0], ctx,
scope, varname, varname)); scope, varname, varname));
for (size_t i = 0; i < rets.size(); i++) { for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
......
...@@ -657,10 +657,9 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -657,10 +657,9 @@ class PaddleCloudRoleMaker(RoleMakerBase):
return self._role == Role.HETER_WORKER return self._role == Role.HETER_WORKER
def _ps_env(self): def _ps_env(self):
try:
# Environment variable PADDLE_PSERVERS_IP_PORT_LIST must be set # Environment variable PADDLE_PSERVERS_IP_PORT_LIST must be set
# format: string(ip:port,ip:port), eg. 127.0.0.1:6001,127.0.0.1:6002 # format: string(ip:port,ip:port), eg. 127.0.0.1:6001,127.0.0.1:6002
self._server_endpoints = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST") self._server_endpoints = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST", None)
if self._server_endpoints is None: if self._server_endpoints is None:
# back to non_distributed execution. # back to non_distributed execution.
...@@ -676,8 +675,8 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -676,8 +675,8 @@ class PaddleCloudRoleMaker(RoleMakerBase):
self._server_endpoints = self._server_endpoints.split(",") self._server_endpoints = self._server_endpoints.split(",")
self._worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS") self._worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", None)
if self._worker_endpoints: if self._worker_endpoints != None:
self._worker_endpoints = self._worker_endpoints.split(",") self._worker_endpoints = self._worker_endpoints.split(",")
else: else:
self._worker_endpoints = [] self._worker_endpoints = []
...@@ -691,11 +690,10 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -691,11 +690,10 @@ class PaddleCloudRoleMaker(RoleMakerBase):
format(training_role)) format(training_role))
# For heter parameter server env setting # For heter parameter server env setting
heter_trainer_eplist = os.getenv( heter_trainer_eplist = os.getenv("PADDLE_HETER_TRAINER_IP_PORT_LIST",
"PADDLE_HETER_TRAINER_IP_PORT_LIST", None) "")
heter_trainer_device = os.getenv("PADDLE_HETER_TRAINER_DEVICE", heter_trainer_device = os.getenv("PADDLE_HETER_TRAINER_DEVICE", "")
None) if heter_trainer_eplist != "" and heter_trainer_device != "":
if heter_trainer_eplist and heter_trainer_device:
try: try:
heter_trainer_eplist = os.environ[ heter_trainer_eplist = os.environ[
"PADDLE_HETER_TRAINER_IP_PORT_LIST"].split(",") "PADDLE_HETER_TRAINER_IP_PORT_LIST"].split(",")
...@@ -736,9 +734,6 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -736,9 +734,6 @@ class PaddleCloudRoleMaker(RoleMakerBase):
else: else:
raise ValueError( raise ValueError(
"TRAINING_ROLE must be PSERVER or TRAINER or HETER_TRAINER") "TRAINING_ROLE must be PSERVER or TRAINER or HETER_TRAINER")
except ValueError as e:
raise ValueError(
"Something wrong with PaddleCloud, please check environment")
self._trainers_num = trainers_num self._trainers_num = trainers_num
self._role = role self._role = role
......
...@@ -114,6 +114,8 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra ...@@ -114,6 +114,8 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra
"followed by all the arguments for the " "followed by all the arguments for the "
"training script") "training script")
base_group.add_argument('training_script_args', nargs=REMAINDER)
# Optional arguments for the launch helper # Optional arguments for the launch helper
# for collective # for collective
collective_group = parser.add_argument_group("Collective Parameters") collective_group = parser.add_argument_group("Collective Parameters")
......
...@@ -739,8 +739,7 @@ class ParameterServerLauncher(object): ...@@ -739,8 +739,7 @@ class ParameterServerLauncher(object):
if ip == self.heter_worker_endpoints_ips[k]: if ip == self.heter_worker_endpoints_ips[k]:
heter_worker = Trainer() heter_worker = Trainer()
heter_worker.endpoint = "%s:%s" % ( heter_worker.endpoint = "%s:%s" % (
ip, ip, self.heter_worker_endpoints_port[k])
self.endpoints_dict["heter_worker_endpoints_port"][k])
heter_worker.rank = heter_worker_rank heter_worker.rank = heter_worker_rank
heter_worker_rank += 1 heter_worker_rank += 1
pod.heter_workers.append(heter_worker) pod.heter_workers.append(heter_worker)
...@@ -770,9 +769,9 @@ class ParameterServerLauncher(object): ...@@ -770,9 +769,9 @@ class ParameterServerLauncher(object):
self.procs[i].proc.wait() self.procs[i].proc.wait()
if len(self.log_fns) > 0: if len(self.log_fns) > 0:
self.log_fns[i].close() self.log_fns[i].close()
print( logger.info(
"all workers exit, going to finish parameter server and heter_worker", "all workers exit, going to finish parameter server and heter_worker"
file=sys.stderr) )
for i in range( for i in range(
len(pod.servers + pod.workers), len(pod.servers + pod.workers),
...@@ -780,13 +779,13 @@ class ParameterServerLauncher(object): ...@@ -780,13 +779,13 @@ class ParameterServerLauncher(object):
if len(self.log_fns) > 0: if len(self.log_fns) > 0:
self.log_fns[i].close() self.log_fns[i].close()
self.procs[i].proc.terminate() self.procs[i].proc.terminate()
print("all heter worker are killed", file=sys.stderr) logger.info("all heter worker are killed")
for i in range(len(pod.servers)): for i in range(len(pod.servers)):
if len(self.log_fns) > 0: if len(self.log_fns) > 0:
self.log_fns[i].close() self.log_fns[i].close()
self.procs[i].proc.terminate() self.procs[i].proc.terminate()
print("all parameter server are killed", file=sys.stderr) logger.info("all parameter server are killed", file=sys.stderr)
if os.path.exists(self.gloo_rendezvous_dir): if os.path.exists(self.gloo_rendezvous_dir):
shutil.rmtree(self.gloo_rendezvous_dir) shutil.rmtree(self.gloo_rendezvous_dir)
...@@ -857,6 +856,7 @@ class ParameterServerLauncher(object): ...@@ -857,6 +856,7 @@ class ParameterServerLauncher(object):
heter_device_num = fluid.core.get_xpu_device_count() heter_device_num = fluid.core.get_xpu_device_count()
for idx, cur_worker in enumerate(pod.workers): for idx, cur_worker in enumerate(pod.workers):
device_id = str(idx % heter_device_num)
proc_env = { proc_env = {
"PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints, "PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints,
"PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints, "PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints,
...@@ -869,10 +869,10 @@ class ParameterServerLauncher(object): ...@@ -869,10 +869,10 @@ class ParameterServerLauncher(object):
"PADDLE_WITH_GLOO": "1", "PADDLE_WITH_GLOO": "1",
"PADDLE_GLOO_RENDEZVOUS": "2", "PADDLE_GLOO_RENDEZVOUS": "2",
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir, "PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir,
"FLAGS_selected_gpus": idx % heter_device_num, "FLAGS_selected_gpus": 0,
"FLAGS_selected_xpus": idx % heter_device_num, "FLAGS_selected_xpus": 0,
"CUDA_VISIBLE_DEVICES": idx % heter_device_num, "CUDA_VISIBLE_DEVICES": device_id,
"XPU_VISIBLE_DEVICES": idx % heter_device_num, "XPU_VISIBLE_DEVICES": device_id,
} }
current_env.update(proc_env) current_env.update(proc_env)
...@@ -921,6 +921,7 @@ class ParameterServerLauncher(object): ...@@ -921,6 +921,7 @@ class ParameterServerLauncher(object):
assert heter_device_num != 0 assert heter_device_num != 0
for idx, cur_heter_worker in enumerate(pod.heter_workers): for idx, cur_heter_worker in enumerate(pod.heter_workers):
device_id = str(idx % heter_device_num)
proc_env = { proc_env = {
"PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints, "PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints,
"PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints, "PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints,
...@@ -934,10 +935,10 @@ class ParameterServerLauncher(object): ...@@ -934,10 +935,10 @@ class ParameterServerLauncher(object):
"PADDLE_WITH_GLOO": "1", "PADDLE_WITH_GLOO": "1",
"PADDLE_GLOO_RENDEZVOUS": "2", "PADDLE_GLOO_RENDEZVOUS": "2",
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir, "PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir,
"FLAGS_selected_gpus": idx % heter_device_num, "FLAGS_selected_gpus": device_id,
"FLAGS_selected_xpus": idx % heter_device_num, "FLAGS_selected_xpus": device_id,
"CUDA_VISIBLE_DEVICES": idx % heter_device_num, "CUDA_VISIBLE_DEVICES": device_id,
"XPU_VISIBLE_DEVICES": idx % heter_device_num, "XPU_VISIBLE_DEVICES": device_id,
} }
current_env.update(proc_env) current_env.update(proc_env)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册