未验证 提交 bbc2add7 编写于 作者: L lilong12 提交者: GitHub

Initialize gloo for low level collective apis (#27672)

* add gloo initializer, test=develop
上级 b7107c65
...@@ -21,10 +21,10 @@ void GlooParallelContext::Init() { ...@@ -21,10 +21,10 @@ void GlooParallelContext::Init() {
auto gloo_ptr = paddle::framework::GlooWrapper::GetInstance(); auto gloo_ptr = paddle::framework::GlooWrapper::GetInstance();
gloo_ptr->SetRank(strategy_.rank); gloo_ptr->SetRank(strategy_.rank);
gloo_ptr->SetSize(strategy_.rank_num); gloo_ptr->SetSize(strategy_.rank_num);
gloo_ptr->SetPrefix(strategy_.prefix);
gloo_ptr->SetIface(strategy_.iface); gloo_ptr->SetIface(strategy_.iface);
gloo_ptr->SetTimeoutSeconds(strategy_.init_seconds, strategy_.run_seconds); gloo_ptr->SetTimeoutSeconds(strategy_.init_seconds, strategy_.run_seconds);
gloo_ptr->SetHdfsStore(strategy_.path, strategy_.fs_name, strategy_.fs_ugi); gloo_ptr->SetHttpStore(strategy_.ip_address, strategy_.ip_port,
strategy_.scope);
gloo_ptr->Init(); gloo_ptr->Init();
} }
#endif #endif
......
...@@ -25,12 +25,11 @@ struct GlooParallelStrategy { ...@@ -25,12 +25,11 @@ struct GlooParallelStrategy {
int rank{0}; int rank{0};
int rank_num{1}; int rank_num{1};
std::string iface; std::string iface;
std::string prefix;
int init_seconds{9999999}; int init_seconds{9999999};
int run_seconds{9999999}; int run_seconds{9999999};
std::string path; std::string ip_address;
std::string fs_name; int ip_port;
std::string fs_ugi; std::string scope{"worker"};
}; };
class GlooParallelContext { class GlooParallelContext {
......
...@@ -62,12 +62,6 @@ void BindGlooContext(py::module *m) { ...@@ -62,12 +62,6 @@ void BindGlooContext(py::module *m) {
[](platform::GlooParallelStrategy &self, const std::string &iface) { [](platform::GlooParallelStrategy &self, const std::string &iface) {
self.iface = iface; self.iface = iface;
}) })
.def_property("prefix",
[](const platform::GlooParallelStrategy &self) {
return self.prefix;
},
[](platform::GlooParallelStrategy &self,
const std::string &prefix) { self.prefix = prefix; })
.def_property("init_seconds", .def_property("init_seconds",
[](const platform::GlooParallelStrategy &self) { [](const platform::GlooParallelStrategy &self) {
return self.init_seconds; return self.init_seconds;
...@@ -83,23 +77,19 @@ void BindGlooContext(py::module *m) { ...@@ -83,23 +77,19 @@ void BindGlooContext(py::module *m) {
self.run_seconds = run_seconds; self.run_seconds = run_seconds;
}) })
.def_property( .def_property(
"path", "ip_address",
[](const platform::GlooParallelStrategy &self) { return self.path; }, [](const platform::GlooParallelStrategy &self) {
[](platform::GlooParallelStrategy &self, const std::string &path) { return self.ip_address;
self.path = path; },
}) [](platform::GlooParallelStrategy &self,
.def_property("fs_name", const std::string &ip_address) { self.ip_address = ip_address; })
[](const platform::GlooParallelStrategy &self) { .def_property("ip_port",
return self.fs_name;
},
[](platform::GlooParallelStrategy &self,
const std::string &fs_name) { self.fs_name = fs_name; })
.def_property("fs_ugi",
[](const platform::GlooParallelStrategy &self) { [](const platform::GlooParallelStrategy &self) {
return self.fs_ugi; return self.ip_port;
}, },
[](platform::GlooParallelStrategy &self, [](platform::GlooParallelStrategy &self, int ip_port) {
const std::string &fs_ugi) { self.fs_ugi = fs_ugi; }); self.ip_port = ip_port;
});
py::class_<platform::GlooParallelContext> gloo_ctx(*m, "GlooParallelContext"); py::class_<platform::GlooParallelContext> gloo_ctx(*m, "GlooParallelContext");
gloo_ctx.def(py::init<const platform::GlooParallelStrategy &>()) gloo_ctx.def(py::init<const platform::GlooParallelStrategy &>())
......
...@@ -19,6 +19,7 @@ import warnings ...@@ -19,6 +19,7 @@ import warnings
from multiprocessing import Process, Manager from multiprocessing import Process, Manager
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.distributed.fleet.base.private_helper_function import wait_server_ready
class Role: class Role:
...@@ -77,7 +78,8 @@ class Gloo(object): ...@@ -77,7 +78,8 @@ class Gloo(object):
self._worker_num = worker_num self._worker_num = worker_num
self._server_num = server_num self._server_num = server_num
self._need_init_all = need_init_all self._need_init_all = need_init_all
self._iface = self.__get_default_iface() self._start_http_server = kwargs.get("start_http_server", False)
self._iface = ""
self._prefix = kwargs.get("store.prefix", "") self._prefix = kwargs.get("store.prefix", "")
if self._rendezvous == Gloo.RENDEZVOUS.HDFS: if self._rendezvous == Gloo.RENDEZVOUS.HDFS:
...@@ -102,7 +104,9 @@ class Gloo(object): ...@@ -102,7 +104,9 @@ class Gloo(object):
if not ip or not port: if not ip or not port:
raise ValueError(self._err_type) raise ValueError(self._err_type)
self._init_http(ip, port, self._prefix) self._init_http(ip, port, self._prefix, self._start_http_server)
ep = ":".join([ip, port])
wait_server_ready([ep])
else: else:
raise ValueError(self._err_type) raise ValueError(self._err_type)
...@@ -163,14 +167,13 @@ class Gloo(object): ...@@ -163,14 +167,13 @@ class Gloo(object):
gloo = init(rank, nodes, "ALL") gloo = init(rank, nodes, "ALL")
self._nodes_comm = gloo self._nodes_comm = gloo
def _init_http(self, ip, port, prefix): def _init_http(self, ip, port, prefix, start_http_server):
def __start_kv_server(http_server_d, size_d): def __start_kv_server(http_server_d, size_d):
from paddle.distributed.fleet.utils.http_server import KVServer from paddle.distributed.fleet.utils.http_server import KVServer
http_server = KVServer(port, size_d) http_server = KVServer(port, size_d)
http_server.start() http_server.start()
wait_seconds = 5 wait_seconds = 5
while http_server_d.get("running", while http_server_d.get("running", False):
False) and not http_server.shoud_stop():
time.sleep(wait_seconds) time.sleep(wait_seconds)
http_server.stop() http_server.stop()
...@@ -203,7 +206,7 @@ class Gloo(object): ...@@ -203,7 +206,7 @@ class Gloo(object):
port = int(port) port = int(port)
if self._role == Role.SERVER and self._role_id == 0: if start_http_server:
init_kv_server() init_kv_server()
if self._role == Role.WORKER: if self._role == Role.WORKER:
...@@ -536,8 +539,8 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -536,8 +539,8 @@ class PaddleCloudRoleMaker(RoleMakerBase):
self._kwargs = kwargs self._kwargs = kwargs
self._role_is_generated = False self._role_is_generated = False
self._server_endpoints = None self._server_endpoints = []
self._worker_endpoints = None self._worker_endpoints = []
self._gloo = Gloo() # gloo instance self._gloo = Gloo() # gloo instance
...@@ -800,12 +803,21 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -800,12 +803,21 @@ class PaddleCloudRoleMaker(RoleMakerBase):
"store.prefix": prefix, "store.prefix": prefix,
} }
elif rendezvous_type == Gloo.RENDEZVOUS.HTTP: elif rendezvous_type == Gloo.RENDEZVOUS.HTTP:
ip = os.getenv("PADDLE_GLOO_HTTP_HOST", "") start_http_server = False
port = os.getenv("PADDLE_GLOO_HTTP_PORT", "") if self._is_collective:
ep_rank_0 = self._worker_endpoints[0]
if self._is_first_worker():
start_http_server = True
else:
ep_rank_0 = self._server_endpoints[0]
if self._server_index() == 0:
start_http_server = True
ip, port = ep_rank_0.split(':')
kwargs = { kwargs = {
"http.host": ip, "http.host": ip,
"http.port": port, "http.port": port,
"store.prefix": prefix, "store.prefix": prefix,
'start_http_server': start_http_server,
} }
else: else:
dfs_path = os.getenv("PADDLE_GLOO_FS_PATH", "") dfs_path = os.getenv("PADDLE_GLOO_FS_PATH", "")
......
...@@ -220,7 +220,7 @@ def launch_collective(args): ...@@ -220,7 +220,7 @@ def launch_collective(args):
gloo_rendezvous_dir = tempfile.mkdtemp() gloo_rendezvous_dir = tempfile.mkdtemp()
# add gloo env # add gloo env
global_envs["PADDLE_WITH_GLOO"] = "1" global_envs["PADDLE_WITH_GLOO"] = "1"
global_envs["PADDLE_GLOO_RENDEZVOUS"] = "2" global_envs["PADDLE_GLOO_RENDEZVOUS"] = "3"
global_envs["PADDLE_GLOO_FS_PATH"] = gloo_rendezvous_dir global_envs["PADDLE_GLOO_FS_PATH"] = gloo_rendezvous_dir
procs = start_local_trainers( procs = start_local_trainers(
...@@ -333,7 +333,7 @@ def launch_ps(args): ...@@ -333,7 +333,7 @@ def launch_ps(args):
gloo_rendezvous_dir = tempfile.mkdtemp() gloo_rendezvous_dir = tempfile.mkdtemp()
# add gloo env # add gloo env
current_env["PADDLE_WITH_GLOO"] = "1" current_env["PADDLE_WITH_GLOO"] = "1"
current_env["PADDLE_GLOO_RENDEZVOUS"] = "2" current_env["PADDLE_GLOO_RENDEZVOUS"] = "3"
current_env["PADDLE_GLOO_FS_PATH"] = gloo_rendezvous_dir current_env["PADDLE_GLOO_FS_PATH"] = gloo_rendezvous_dir
current_env.pop("http_proxy", None) current_env.pop("http_proxy", None)
......
...@@ -181,7 +181,7 @@ class KVServer: ...@@ -181,7 +181,7 @@ class KVServer:
self.listen_thread.join() self.listen_thread.join()
self.http_server.server_close() self.http_server.server_close()
def shoud_stop(self): def should_stop(self):
""" """
return whether the server should stop. return whether the server should stop.
......
...@@ -15,6 +15,9 @@ ...@@ -15,6 +15,9 @@
import os import os
import six import six
import warnings import warnings
from multiprocessing import Process, Manager
import time
import sys
from paddle import compat as cpt from paddle import compat as cpt
...@@ -23,12 +26,23 @@ from paddle.fluid import core ...@@ -23,12 +26,23 @@ from paddle.fluid import core
from paddle.fluid.framework import _set_expected_place from paddle.fluid.framework import _set_expected_place
from paddle.fluid.dygraph import parallel_helper from paddle.fluid.dygraph import parallel_helper
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed.fleet.base.private_helper_function import wait_server_ready
__all__ = ["init_parallel_env"] __all__ = ["init_parallel_env"]
ParallelStrategy = core.ParallelStrategy ParallelStrategy = core.ParallelStrategy
def _start_kv_server(port, http_server_d):
from paddle.distributed.fleet.utils.http_server import KVServer
http_server = KVServer(int(port))
http_server.start()
wait_seconds = 5
while http_server_d.get("running", False):
time.sleep(wait_seconds)
http_server.stop()
def init_parallel_env(): def init_parallel_env():
""" """
Initialize parallel training environment in dynamic graph mode. Initialize parallel training environment in dynamic graph mode.
...@@ -110,7 +124,40 @@ def init_parallel_env(): ...@@ -110,7 +124,40 @@ def init_parallel_env():
_check_var_exists("PADDLE_TRAINERS_NUM") _check_var_exists("PADDLE_TRAINERS_NUM")
_check_var_exists("PADDLE_TRAINER_ENDPOINTS") _check_var_exists("PADDLE_TRAINER_ENDPOINTS")
# 3. init NCCL ParallelStrategy if ParallelEnv().world_size < 2:
return
# 3: init gloo context
ep_rank_0 = ParallelEnv().trainer_endpoints[0].split(":")
ep_rank = ParallelEnv().trainer_endpoints[ParallelEnv().rank].split(":")
manager = Manager()
# glboal dict to store status
http_server_d = manager.dict()
http_server_d["running"] = False
if ParallelEnv().rank == 0:
http_server = Process(
target=_start_kv_server, args=(int(ep_rank_0[1]), http_server_d))
http_server.daemon = True
http_server_d["running"] = True
http_server.start()
wait_server_ready([ParallelEnv().trainer_endpoints[0]])
gloo_strategy = core.GlooParallelStrategy()
gloo_strategy.rank = ParallelEnv().rank
gloo_strategy.rank_num = ParallelEnv().world_size
gloo_strategy.ip_address = ep_rank_0[0]
gloo_strategy.ip_port = int(ep_rank_0[1])
default_init_timeout_seconds = 3600
default_run_timeout_seconds = 9999999
gloo_strategy.init_seconds = default_init_timeout_seconds
gloo_strategy.run_seconds = default_run_timeout_seconds
gloo = core.GlooParallelContext(gloo_strategy)
gloo.init()
if ParallelEnv().rank == 0:
http_server_d["running"] = False
http_server.join()
# 4. init NCCL ParallelStrategy
strategy = ParallelStrategy() strategy = ParallelStrategy()
if parallel_helper._is_parallel_ctx_initialized(): if parallel_helper._is_parallel_ctx_initialized():
warnings.warn("The parallel environment has been initialized.") warnings.warn("The parallel environment has been initialized.")
...@@ -118,8 +165,7 @@ def init_parallel_env(): ...@@ -118,8 +165,7 @@ def init_parallel_env():
strategy.local_rank = ParallelEnv().rank strategy.local_rank = ParallelEnv().rank
strategy.trainer_endpoints = ParallelEnv().trainer_endpoints strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
strategy.current_endpoint = ParallelEnv().current_endpoint strategy.current_endpoint = ParallelEnv().current_endpoint
if strategy.nranks < 2:
return
# NOTE(chenweihang): [ why config global place here? ] # NOTE(chenweihang): [ why config global place here? ]
# the dygraph mode will be set to default mode, # the dygraph mode will be set to default mode,
# users will not call `dygraph.guard` or `enable_dygraph` # users will not call `dygraph.guard` or `enable_dygraph`
......
...@@ -366,7 +366,7 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options): ...@@ -366,7 +366,7 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options):
device = get_device() device = get_device()
if device == 'cpu': if device == 'cpu':
# TODO: not supports cpu parallel now # TODO: not supports cpu parallel now
nprocs = _cpu_num nprocs = _cpu_num()
else: else:
nprocs = core.get_cuda_device_count() nprocs = core.get_cuda_device_count()
......
...@@ -989,8 +989,7 @@ class GeneralRoleMaker(RoleMakerBase): ...@@ -989,8 +989,7 @@ class GeneralRoleMaker(RoleMakerBase):
http_server = KVServer(int(self._http_ip_port[1]), size_d) http_server = KVServer(int(self._http_ip_port[1]), size_d)
http_server.start() http_server.start()
wait_seconds = 5 wait_seconds = 5
while http_server_d.get("running", while http_server_d.get("running", False):
False) and not http_server.shoud_stop():
time.sleep(wait_seconds) time.sleep(wait_seconds)
http_server.stop() http_server.stop()
......
...@@ -173,7 +173,7 @@ class KVServer: ...@@ -173,7 +173,7 @@ class KVServer:
self.listen_thread.join() self.listen_thread.join()
self.http_server.server_close() self.http_server.server_close()
def shoud_stop(self): def should_stop(self):
""" """
return whether the server should stop. return whether the server should stop.
......
...@@ -15,6 +15,12 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding) ...@@ -15,6 +15,12 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_transformer) list(APPEND DIST_TEST_OPS test_parallel_dygraph_transformer)
list(APPEND DIST_TEST_OPS test_listen_and_serv_op) list(APPEND DIST_TEST_OPS test_listen_and_serv_op)
list(APPEND DIST_TEST_OPS test_fleet_graph_execution_meta_optimizer) list(APPEND DIST_TEST_OPS test_fleet_graph_execution_meta_optimizer)
list(APPEND DIST_TEST_OPS test_collective_reduce_api)
list(APPEND DIST_TEST_OPS test_collective_scatter_api)
list(APPEND DIST_TEST_OPS test_collective_barrier_api)
list(APPEND DIST_TEST_OPS test_collective_allreduce_api)
list(APPEND DIST_TEST_OPS test_collective_broadcast_api)
list(APPEND DIST_TEST_OPS test_collective_allgather_api)
set(MIXED_DIST_TEST_OPS ${DIST_TEST_OPS}) set(MIXED_DIST_TEST_OPS ${DIST_TEST_OPS})
#remove distribute unittests. #remove distribute unittests.
list(APPEND MIXED_DIST_TEST_OPS test_dgc_op) list(APPEND MIXED_DIST_TEST_OPS test_dgc_op)
...@@ -62,12 +68,6 @@ if(NOT WITH_GPU OR WIN32) ...@@ -62,12 +68,6 @@ if(NOT WITH_GPU OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_broadcast) LIST(REMOVE_ITEM TEST_OPS test_broadcast)
LIST(REMOVE_ITEM TEST_OPS test_collective_reduce) LIST(REMOVE_ITEM TEST_OPS test_collective_reduce)
LIST(REMOVE_ITEM TEST_OPS test_collective_scatter) LIST(REMOVE_ITEM TEST_OPS test_collective_scatter)
LIST(REMOVE_ITEM TEST_OPS test_collective_reduce_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_scatter_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_barrier_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_allreduce_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_broadcast_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_allgather_api)
LIST(REMOVE_ITEM TEST_OPS test_reducescatter) LIST(REMOVE_ITEM TEST_OPS test_reducescatter)
LIST(REMOVE_ITEM TEST_OPS test_reducescatter_api) LIST(REMOVE_ITEM TEST_OPS test_reducescatter_api)
endif() endif()
......
...@@ -26,6 +26,7 @@ import functools ...@@ -26,6 +26,7 @@ import functools
import pickle import pickle
from contextlib import closing from contextlib import closing
from six import string_types from six import string_types
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.unique_name as nameGen import paddle.fluid.unique_name as nameGen
from paddle.fluid import core from paddle.fluid import core
...@@ -60,38 +61,6 @@ class TestCollectiveAPIRunnerBase(object): ...@@ -60,38 +61,6 @@ class TestCollectiveAPIRunnerBase(object):
else: else:
break break
def initCommunicator(self, program, rank, nranks, wait_port,
current_endpoint, endpoints):
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
if rank == 0 and wait_port:
self.wait_server_ready(other_endpoints)
block = program.global_block()
nccl_id_var = block.create_var(
name=nameGen.generate('nccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
block.append_op(
type='c_gen_nccl_id',
inputs={},
outputs={'Out': nccl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints
})
block.append_op(
type='c_comm_init',
inputs={'X': nccl_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': self.global_ring_id
})
def run_trainer(self, args): def run_trainer(self, args):
train_prog = fluid.Program() train_prog = fluid.Program()
startup_prog = fluid.Program() startup_prog = fluid.Program()
...@@ -100,23 +69,12 @@ class TestCollectiveAPIRunnerBase(object): ...@@ -100,23 +69,12 @@ class TestCollectiveAPIRunnerBase(object):
current_endpoint = args["currentendpoint"] current_endpoint = args["currentendpoint"]
nranks = 2 nranks = 2
result = self.get_model(train_prog, startup_prog, rank) result = self.get_model(train_prog, startup_prog, rank)
paddle.distributed.init_parallel_env()
if args['backend'] == 'nccl': if args['backend'] == 'nccl':
self.initCommunicator(startup_prog, rank, nranks, True,
current_endpoint, endpoints)
device_id = int(os.getenv("FLAGS_selected_gpus", "0")) device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
place = fluid.CUDAPlace( place = fluid.CUDAPlace(
device_id) #if args.use_gpu else fluid.CPUPlace() device_id) #if args.use_gpu else fluid.CPUPlace()
else: else:
strategy = fluid.core.GlooParallelStrategy()
strategy.rank = rank
strategy.rank_num = nranks
strategy.prefix = ""
strategy.iface = "lo"
strategy.init_seconds = 999999
strategy.run_seconds = 999999
strategy.path = "/tmp/tmp%d" % args['path_id']
gloo = fluid.core.GlooParallelContext(strategy)
gloo.init()
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup_prog) exe.run(startup_prog)
...@@ -199,8 +157,8 @@ class TestDistBase(unittest.TestCase): ...@@ -199,8 +157,8 @@ class TestDistBase(unittest.TestCase):
tr_cmd = "%s %s" tr_cmd = "%s %s"
tr0_cmd = tr_cmd % (self._python_interp, model_file) tr0_cmd = tr_cmd % (self._python_interp, model_file)
tr1_cmd = tr_cmd % (self._python_interp, model_file) tr1_cmd = tr_cmd % (self._python_interp, model_file)
tr0_pipe = open("/tmp/tr0_err.log", "wb") tr0_pipe = open("/tmp/tr0_err.log", "w")
tr1_pipe = open("/tmp/tr1_err.log", "wb") tr1_pipe = open("/tmp/tr1_err.log", "w")
#print(tr0_cmd) #print(tr0_cmd)
tr0_proc = subprocess.Popen( tr0_proc = subprocess.Popen(
tr0_cmd.strip().split(), tr0_cmd.strip().split(),
...@@ -221,6 +179,10 @@ class TestDistBase(unittest.TestCase): ...@@ -221,6 +179,10 @@ class TestDistBase(unittest.TestCase):
# close trainer file # close trainer file
tr0_pipe.close() tr0_pipe.close()
tr1_pipe.close() tr1_pipe.close()
with open("/tmp/tr0_err.log", "r") as f:
sys.stderr.write('trainer 0 stderr file: %s\n' % f.read())
with open("/tmp/tr1_err.log", "r") as f:
sys.stderr.write('trainer 1 stderr file: %s\n' % f.read())
return pickle.loads(tr0_out), pickle.loads( return pickle.loads(tr0_out), pickle.loads(
tr1_out), tr0_proc.pid, tr1_proc.pid tr1_out), tr0_proc.pid, tr1_proc.pid
...@@ -247,6 +209,7 @@ class TestDistBase(unittest.TestCase): ...@@ -247,6 +209,7 @@ class TestDistBase(unittest.TestCase):
if check_error_log: if check_error_log:
required_envs["GLOG_v"] = "3" required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1" required_envs["GLOG_logtostderr"] = "1"
required_envs["GLOO_LOG_LEVEL"] = "TRACE"
tr0_out, tr1_out, pid0, pid1 = self._run_cluster(model_file, tr0_out, tr1_out, pid0, pid1 = self._run_cluster(model_file,
required_envs) required_envs)
np.random.seed(pid0) np.random.seed(pid0)
......
...@@ -406,7 +406,7 @@ class TestParallelDyGraphRunnerBase(object): ...@@ -406,7 +406,7 @@ class TestParallelDyGraphRunnerBase(object):
fluid.default_main_program().random_seed = seed fluid.default_main_program().random_seed = seed
np.random.seed(seed) np.random.seed(seed)
import random import random
random.seed = seed random.seed(seed)
model, train_reader, opt = self.get_model() model, train_reader, opt = self.get_model()
nranks = len(args.endpoints.split(",")) if args.endpoints else 1 nranks = len(args.endpoints.split(",")) if args.endpoints else 1
...@@ -456,7 +456,7 @@ class TestParallelDyGraphRunnerBase(object): ...@@ -456,7 +456,7 @@ class TestParallelDyGraphRunnerBase(object):
paddle.static.default_startup_program().random_seed = seed paddle.static.default_startup_program().random_seed = seed
paddle.static.default_main_program().random_seed = seed paddle.static.default_main_program().random_seed = seed
np.random.seed(seed) np.random.seed(seed)
random.seed = seed random.seed(seed)
# get trainer id # get trainer id
args.trainer_id = paddle.distributed.get_rank() args.trainer_id = paddle.distributed.get_rank()
...@@ -499,7 +499,7 @@ class TestParallelDyGraphRunnerBase(object): ...@@ -499,7 +499,7 @@ class TestParallelDyGraphRunnerBase(object):
paddle.static.default_startup_program().random_seed = seed paddle.static.default_startup_program().random_seed = seed
paddle.static.default_main_program().random_seed = seed paddle.static.default_main_program().random_seed = seed
np.random.seed(seed) np.random.seed(seed)
random.seed = seed random.seed(seed)
# get trainer id # get trainer id
args.trainer_id = paddle.distributed.get_rank() args.trainer_id = paddle.distributed.get_rank()
......
...@@ -182,7 +182,7 @@ class TestCloudRoleMaker(unittest.TestCase): ...@@ -182,7 +182,7 @@ class TestCloudRoleMaker(unittest.TestCase):
h.log_message("666") h.log_message("666")
s.get_deleted_size("haha") s.get_deleted_size("haha")
s1 = TmpS() s1 = TmpS()
s1.shoud_stop() s1.should_stop()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -677,7 +677,6 @@ class TestGlooWithCloudRoleMaker(unittest.TestCase): ...@@ -677,7 +677,6 @@ class TestGlooWithCloudRoleMaker(unittest.TestCase):
os.environ["PADDLE_GLOO_HTTP_PORT"] = "" os.environ["PADDLE_GLOO_HTTP_PORT"] = ""
role = role_maker.PaddleCloudRoleMaker() role = role_maker.PaddleCloudRoleMaker()
self.assertRaises(ValueError, role._generate_role)
def test_fs_gloo8(self): def test_fs_gloo8(self):
plats = platform.platform() plats = platform.platform()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册