未验证 提交 0e07f20e 编写于 作者: K kuizhiqing 提交者: GitHub

py2 to py3 bug and iface fix for pslib (#36102)

上级 53f9768d
......@@ -383,7 +383,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
return the current number of worker
"""
if self._check_role_generation():
return self._get_size() / self._proc_per_node
return int(self._get_size() / self._proc_per_node)
return 0
def _server_num(self):
......@@ -391,30 +391,30 @@ class MPISymetricRoleMaker(MPIRoleMaker):
return the current number of server
"""
if self._check_role_generation():
return self._get_size() / self._proc_per_node
return int(self._get_size() / self._proc_per_node)
else:
self.generate_role()
return self._get_size() / self._proc_per_node
return int(self._get_size() / self._proc_per_node)
def worker_index(self):
"""
return the index of worker
"""
if self._check_role_generation():
return self._rank / self._proc_per_node
return int(self._rank / self._proc_per_node)
else:
self.generate_role()
return self._get_size() / 2
return int(self._get_size() / 2)
def server_index(self):
"""
return the index of server
"""
if self._check_role_generation():
return self._rank / self._proc_per_node
return int(self._rank / self._proc_per_node)
else:
self.generate_role()
return self._get_size() / self._proc_per_node
return int(self._get_size() / self._proc_per_node)
def _all_reduce(self, input, output, mode="sum"):
"""
......@@ -612,6 +612,7 @@ class GeneralRoleMaker(RoleMakerBase):
# set running status of http server
self._http_server_d["running"] = False
self._iface = self.__get_default_iface()
self._iface = "" if self._iface == "lo" else self._iface
# this environment variable can be empty
self._prefix = os.getenv("SYS_JOB_ID", "")
......
......@@ -270,6 +270,7 @@ class PSLib(Fleet):
self._role_maker._barrier_worker()
if self._role_maker.is_first_worker():
self._fleet_ptr.stop_server()
if self._heter_ptr:
self._heter_ptr.stop_xpu_service()
self._role_maker._barrier_worker()
self._role_maker._barrier_all()
......
......@@ -846,7 +846,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
"user_define_dump_filename", "")
opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "")
opt_info["dump_param"] = strategy.get("dump_param", [])
gpus_env = os.getenv("FLAGS_selected_gpus")
gpus_env = os.getenv("FLAGS_selected_gpus", "0")
opt_info["worker_places"] = [int(s) for s in gpus_env.split(",")]
opt_info["use_ps_gpu"] = strategy.get("use_ps_gpu", False)
if server._server.downpour_server_param.downpour_table_param[
......
......@@ -25,8 +25,8 @@ import errno
import time
import logging
import six
from . import fs
from .fs import FS, LocalFS, FSFileExistsError, FSFileNotExistsError, ExecuteError, FSTimeOut, FSShellCmdAborted
#from . import fs
from paddle.distributed.fleet.utils.fs import FS, LocalFS, FSFileExistsError, FSFileNotExistsError, ExecuteError, FSTimeOut, FSShellCmdAborted
from paddle.fluid import core
import functools
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册