提交 be70c94e 编写于 作者: M MrChengmo

fix

上级 50ada3a4
......@@ -121,11 +121,13 @@ void RecvLodTensor(const CommContext &rpc_ctx, const framework::Scope &scope) {
if (rpc_ctx.origin_varnames.size() == 1 &&
rpc_ctx.splited_varnames.size() == 1) {
auto varname = rpc_ctx.origin_varnames[0];
const auto place =
scope.FindVar(varname)->Get<framework::LoDTensor>().place();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &ctx = *pool.Get(place);
VLOG(4) << "recv " << varname << " from " << rpc_ctx.epmap[0] << " in gpu? "
<< platform::is_gpu_place(place);
rets.push_back(rpc_client->AsyncGetVarNoBarrier(rpc_ctx.epmap[0], cpu_ctx,
rets.push_back(rpc_client->AsyncGetVarNoBarrier(rpc_ctx.epmap[0], ctx,
scope, varname, varname));
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE_NE(
......
......@@ -657,10 +657,9 @@ class PaddleCloudRoleMaker(RoleMakerBase):
return self._role == Role.HETER_WORKER
def _ps_env(self):
try:
# Environment variable PADDLE_PSERVERS_IP_PORT_LIST must be set
# format: string(ip:port,ip:port), eg. 127.0.0.1:6001,127.0.0.1:6002
self._server_endpoints = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST")
self._server_endpoints = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST", None)
if self._server_endpoints is None:
# back to non_distributed execution.
......@@ -676,8 +675,8 @@ class PaddleCloudRoleMaker(RoleMakerBase):
self._server_endpoints = self._server_endpoints.split(",")
self._worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS")
if self._worker_endpoints:
self._worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", None)
if self._worker_endpoints != None:
self._worker_endpoints = self._worker_endpoints.split(",")
else:
self._worker_endpoints = []
......@@ -691,11 +690,10 @@ class PaddleCloudRoleMaker(RoleMakerBase):
format(training_role))
# For heter parameter server env setting
heter_trainer_eplist = os.getenv(
"PADDLE_HETER_TRAINER_IP_PORT_LIST", None)
heter_trainer_device = os.getenv("PADDLE_HETER_TRAINER_DEVICE",
None)
if heter_trainer_eplist and heter_trainer_device:
heter_trainer_eplist = os.getenv("PADDLE_HETER_TRAINER_IP_PORT_LIST",
"")
heter_trainer_device = os.getenv("PADDLE_HETER_TRAINER_DEVICE", "")
if heter_trainer_eplist != "" and heter_trainer_device != "":
try:
heter_trainer_eplist = os.environ[
"PADDLE_HETER_TRAINER_IP_PORT_LIST"].split(",")
......@@ -736,9 +734,6 @@ class PaddleCloudRoleMaker(RoleMakerBase):
else:
raise ValueError(
"TRAINING_ROLE must be PSERVER or TRAINER or HETER_TRAINER")
except ValueError as e:
raise ValueError(
"Something wrong with PaddleCloud, please check environment")
self._trainers_num = trainers_num
self._role = role
......
......@@ -114,6 +114,8 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra
"followed by all the arguments for the "
"training script")
base_group.add_argument('training_script_args', nargs=REMAINDER)
# Optional arguments for the launch helper
# for collective
collective_group = parser.add_argument_group("Collective Parameters")
......
......@@ -739,8 +739,7 @@ class ParameterServerLauncher(object):
if ip == self.heter_worker_endpoints_ips[k]:
heter_worker = Trainer()
heter_worker.endpoint = "%s:%s" % (
ip,
self.endpoints_dict["heter_worker_endpoints_port"][k])
ip, self.heter_worker_endpoints_port[k])
heter_worker.rank = heter_worker_rank
heter_worker_rank += 1
pod.heter_workers.append(heter_worker)
......@@ -770,9 +769,9 @@ class ParameterServerLauncher(object):
self.procs[i].proc.wait()
if len(self.log_fns) > 0:
self.log_fns[i].close()
print(
"all workers exit, going to finish parameter server and heter_worker",
file=sys.stderr)
logger.info(
"all workers exit, going to finish parameter server and heter_worker"
)
for i in range(
len(pod.servers + pod.workers),
......@@ -780,13 +779,13 @@ class ParameterServerLauncher(object):
if len(self.log_fns) > 0:
self.log_fns[i].close()
self.procs[i].proc.terminate()
print("all heter worker are killed", file=sys.stderr)
logger.info("all heter worker are killed")
for i in range(len(pod.servers)):
if len(self.log_fns) > 0:
self.log_fns[i].close()
self.procs[i].proc.terminate()
print("all parameter server are killed", file=sys.stderr)
logger.info("all parameter server are killed", file=sys.stderr)
if os.path.exists(self.gloo_rendezvous_dir):
shutil.rmtree(self.gloo_rendezvous_dir)
......@@ -857,6 +856,7 @@ class ParameterServerLauncher(object):
heter_device_num = fluid.core.get_xpu_device_count()
for idx, cur_worker in enumerate(pod.workers):
device_id = str(idx % heter_device_num)
proc_env = {
"PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints,
"PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints,
......@@ -869,10 +869,10 @@ class ParameterServerLauncher(object):
"PADDLE_WITH_GLOO": "1",
"PADDLE_GLOO_RENDEZVOUS": "2",
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir,
"FLAGS_selected_gpus": idx % heter_device_num,
"FLAGS_selected_xpus": idx % heter_device_num,
"CUDA_VISIBLE_DEVICES": idx % heter_device_num,
"XPU_VISIBLE_DEVICES": idx % heter_device_num,
"FLAGS_selected_gpus": 0,
"FLAGS_selected_xpus": 0,
"CUDA_VISIBLE_DEVICES": device_id,
"XPU_VISIBLE_DEVICES": device_id,
}
current_env.update(proc_env)
......@@ -921,6 +921,7 @@ class ParameterServerLauncher(object):
assert heter_device_num != 0
for idx, cur_heter_worker in enumerate(pod.heter_workers):
device_id = str(idx % heter_device_num)
proc_env = {
"PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints,
"PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints,
......@@ -934,10 +935,10 @@ class ParameterServerLauncher(object):
"PADDLE_WITH_GLOO": "1",
"PADDLE_GLOO_RENDEZVOUS": "2",
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir,
"FLAGS_selected_gpus": idx % heter_device_num,
"FLAGS_selected_xpus": idx % heter_device_num,
"CUDA_VISIBLE_DEVICES": idx % heter_device_num,
"XPU_VISIBLE_DEVICES": idx % heter_device_num,
"FLAGS_selected_gpus": device_id,
"FLAGS_selected_xpus": device_id,
"CUDA_VISIBLE_DEVICES": device_id,
"XPU_VISIBLE_DEVICES": device_id,
}
current_env.update(proc_env)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册