未验证 提交 0e07f20e 编写于 作者: K kuizhiqing 提交者: GitHub

py2 to py3 bug and iface fix for pslib (#36102)

上级 53f9768d
...@@ -383,7 +383,7 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -383,7 +383,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
return the current number of worker return the current number of worker
""" """
if self._check_role_generation(): if self._check_role_generation():
return self._get_size() / self._proc_per_node return int(self._get_size() / self._proc_per_node)
return 0 return 0
def _server_num(self): def _server_num(self):
...@@ -391,30 +391,30 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -391,30 +391,30 @@ class MPISymetricRoleMaker(MPIRoleMaker):
return the current number of server return the current number of server
""" """
if self._check_role_generation(): if self._check_role_generation():
return self._get_size() / self._proc_per_node return int(self._get_size() / self._proc_per_node)
else: else:
self.generate_role() self.generate_role()
return self._get_size() / self._proc_per_node return int(self._get_size() / self._proc_per_node)
def worker_index(self): def worker_index(self):
""" """
return the index of worker return the index of worker
""" """
if self._check_role_generation(): if self._check_role_generation():
return self._rank / self._proc_per_node return int(self._rank / self._proc_per_node)
else: else:
self.generate_role() self.generate_role()
return self._get_size() / 2 return int(self._get_size() / 2)
def server_index(self): def server_index(self):
""" """
return the index of server return the index of server
""" """
if self._check_role_generation(): if self._check_role_generation():
return self._rank / self._proc_per_node return int(self._rank / self._proc_per_node)
else: else:
self.generate_role() self.generate_role()
return self._get_size() / self._proc_per_node return int(self._get_size() / self._proc_per_node)
def _all_reduce(self, input, output, mode="sum"): def _all_reduce(self, input, output, mode="sum"):
""" """
...@@ -612,6 +612,7 @@ class GeneralRoleMaker(RoleMakerBase): ...@@ -612,6 +612,7 @@ class GeneralRoleMaker(RoleMakerBase):
# set running status of http server # set running status of http server
self._http_server_d["running"] = False self._http_server_d["running"] = False
self._iface = self.__get_default_iface() self._iface = self.__get_default_iface()
self._iface = "" if self._iface == "lo" else self._iface
# this environment variable can be empty # this environment variable can be empty
self._prefix = os.getenv("SYS_JOB_ID", "") self._prefix = os.getenv("SYS_JOB_ID", "")
......
...@@ -270,6 +270,7 @@ class PSLib(Fleet): ...@@ -270,6 +270,7 @@ class PSLib(Fleet):
self._role_maker._barrier_worker() self._role_maker._barrier_worker()
if self._role_maker.is_first_worker(): if self._role_maker.is_first_worker():
self._fleet_ptr.stop_server() self._fleet_ptr.stop_server()
if self._heter_ptr:
self._heter_ptr.stop_xpu_service() self._heter_ptr.stop_xpu_service()
self._role_maker._barrier_worker() self._role_maker._barrier_worker()
self._role_maker._barrier_all() self._role_maker._barrier_all()
......
...@@ -846,7 +846,7 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -846,7 +846,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
"user_define_dump_filename", "") "user_define_dump_filename", "")
opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "") opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "")
opt_info["dump_param"] = strategy.get("dump_param", []) opt_info["dump_param"] = strategy.get("dump_param", [])
gpus_env = os.getenv("FLAGS_selected_gpus") gpus_env = os.getenv("FLAGS_selected_gpus", "0")
opt_info["worker_places"] = [int(s) for s in gpus_env.split(",")] opt_info["worker_places"] = [int(s) for s in gpus_env.split(",")]
opt_info["use_ps_gpu"] = strategy.get("use_ps_gpu", False) opt_info["use_ps_gpu"] = strategy.get("use_ps_gpu", False)
if server._server.downpour_server_param.downpour_table_param[ if server._server.downpour_server_param.downpour_table_param[
......
...@@ -25,8 +25,8 @@ import errno ...@@ -25,8 +25,8 @@ import errno
import time import time
import logging import logging
import six import six
from . import fs #from . import fs
from .fs import FS, LocalFS, FSFileExistsError, FSFileNotExistsError, ExecuteError, FSTimeOut, FSShellCmdAborted from paddle.distributed.fleet.utils.fs import FS, LocalFS, FSFileExistsError, FSFileNotExistsError, ExecuteError, FSTimeOut, FSShellCmdAborted
from paddle.fluid import core from paddle.fluid import core
import functools import functools
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册