未验证 提交 99626502 编写于 作者: T tangwei12 提交者: GitHub

【paddle.fleet】gloo and util (#27213)

* fix worker endpoints

* fix gloo wrapper for hdfs

* GPU fleetrun support gloo

* parameterserver fleetrun support gloo

* fix get server endpoint
上级 a5ef246c
......@@ -19,6 +19,8 @@ limitations under the License. */
namespace gloo {
namespace rendezvous {
constexpr int kNodeSize = 136;
HdfsStore::HdfsStore(const std::string& path) {
path_ = path;
wait_sleep_ms_ = 10000;
......@@ -213,12 +215,14 @@ void ParallelConnectContext::connectFullMesh(
storeKey << rank;
store.set(storeKey.str(), allBytes);
auto total_add_size = kNodeSize * (size - 1);
std::vector<std::shared_ptr<std::thread>> connect_threads(thread_num_);
// Connect every pair
for (uint32_t i = 0; i < connect_threads.size(); ++i) {
connect_threads[i].reset(new std::thread(
[&store, &transportContext, this](size_t thread_idx,
size_t thread_num) -> void {
[&store, &transportContext, total_add_size, this](
size_t thread_idx, size_t thread_num) -> void {
for (int i = thread_idx; i < size; i += thread_num) {
if (i == rank) {
continue;
......@@ -226,8 +230,23 @@ void ParallelConnectContext::connectFullMesh(
// Wait for address of other side of this pair to become available
std::string key = std::to_string(i);
store.wait({key}, getTimeout());
std::vector<char> allAddrs;
auto max_retry_times = 5;
// Connect to other side of this pair
auto allAddrs = store.get(key);
while (max_retry_times > 0) {
allAddrs = store.get(key);
VLOG(3) << "store get all address size: " << allAddrs.size()
<< " except: " << total_add_size;
if (allAddrs.size() == static_cast<size_t>(total_add_size)) {
break;
}
--max_retry_times;
}
auto addr = extractAddress(allAddrs, i);
transportContext->getPair(i)->connect(addr);
}
......
......@@ -39,6 +39,7 @@ server_num = fleet.server_num
server_index = fleet.server_index
server_endpoints = fleet.server_endpoints
is_server = fleet.is_server
set_util = fleet.set_util
util = fleet.util
barrier_worker = fleet.barrier_worker
init_worker = fleet.init_worker
......
......@@ -180,6 +180,8 @@ class Fleet(object):
raise ValueError(
"`role_maker` should be subclass of `RoleMakerBase`, but got {}".
format(type(role_maker)))
self._role_maker.generate_role()
self.strategy_compiler = StrategyCompiler()
if paddle.fluid.framework.in_dygraph_mode():
if parallel_helper._is_parallel_ctx_initialized():
......@@ -187,7 +189,6 @@ class Fleet(object):
"The dygraph parallel environment has been initialized.")
else:
paddle.distributed.init_parallel_env()
return None
def is_first_worker(self):
"""
......@@ -275,13 +276,10 @@ class Fleet(object):
fleet.worker_endpoints()
"""
'''
if to_string:
return ",".join(self._role_maker.get_trainer_endpoints())
else:
return self._role_maker.get_trainer_endpoints()
'''
return ["127.0.0.1:1001", "127.0.0.1:1002"]
def server_num(self):
"""
......@@ -355,7 +353,9 @@ class Fleet(object):
return self._role_maker.is_server(
) or self._role_maker._is_heter_worker()
@property
def set_util(self, util):
self._util = util
def util(self):
"""
Utility functions that can be used under certain runtime
......@@ -376,16 +376,6 @@ class Fleet(object):
"""
return self._util
@util.setter
def util(self, util):
"""
Set Utility functions for userd-defined runtime
Returns:
None
"""
self._util = util
def barrier_worker(self):
"""
barrier all workers
......@@ -393,7 +383,7 @@ class Fleet(object):
Returns:
None
"""
self._role_maker.barrier_worker()
self._role_maker._barrier("worker")
@is_non_distributed_check
@inited_runtime_handler
......
......@@ -57,34 +57,7 @@ class UtilBase(object):
), "fs_client must be the instance of paddle.distributed.fleet.utils.FS"
self.fs_client = fs_client
def __check_comm_world(self, comm_world="worker"):
if not self.role_maker._role_is_generated:
self.role_maker.generate_role()
_comm_world = None
comm_world_upper = comm_world.upper()
if comm_world_upper == "WORKER":
if not self.role_maker.is_worker():
print(
"warning: current role is not worker in collective_func(comm_world=\"worker\")"
)
_comm_world = self.role_maker._node_type_comm
elif comm_world_upper == "SERVER":
if not self.role_maker.is_server():
print(
"warning: current role is not server in collective_func(comm_world=\"server\")"
)
_comm_world = self.role_maker._node_type_comm
elif comm_world_upper == "ALL":
_comm_world = self.role_maker._all_comm
else:
raise ValueError(
"not support comm_world, please choose one from [worker, server, all]"
)
return _comm_world
def all_reduce(self, input, mode, comm_world="worker"):
def all_reduce(self, input, mode="sum", comm_world="worker"):
"""
All reduce `input` between specified collection. This is a distributed API.
......@@ -130,8 +103,7 @@ class UtilBase(object):
if __name__ == "__main__":
train()
"""
_comm_world = self.__check_comm_world(comm_world)
return self.role_maker._all_reduce(_comm_world, input, mode)
return self.role_maker._all_reduce(input, mode, comm_world)
def barrier(self, comm_world="worker"):
"""
......@@ -170,8 +142,7 @@ class UtilBase(object):
if __name__ == "__main__":
train()
"""
_comm_world = self.__check_comm_world(comm_world)
self.role_maker._barrier(_comm_world)
self.role_maker._barrier(comm_world)
def all_gather(self, input, comm_world="worker"):
"""
......@@ -219,8 +190,8 @@ class UtilBase(object):
if __name__ == "__main__":
train()
"""
_comm_world = self.__check_comm_world(comm_world)
return self.role_maker._all_gather(_comm_world, input)
return self.role_maker._all_gather(input, comm_world)
def _broadcast(self):
pass
......
......@@ -55,7 +55,10 @@ launch a process on each of the given gpu card or cpu machine.
"""
from __future__ import print_function
import shutil
import sys
import tempfile
from sys import version
import subprocess
import os
......@@ -213,12 +216,20 @@ def launch_collective(args):
cluster, pod = get_cluster_from_args(args, gpus)
logger.debug("get cluster from args:{}".format(cluster))
global_envs = copy.copy(os.environ.copy())
gloo_rendezvous_dir = tempfile.mkdtemp()
# add gloo env
global_envs["PADDLE_WITH_GLOO"] = "1"
global_envs["PADDLE_GLOO_RENDEZVOUS"] = "2"
global_envs["PADDLE_GLOO_FS_PATH"] = gloo_rendezvous_dir
procs = start_local_trainers(
cluster,
pod,
training_script=args.training_script,
training_script_args=args.training_script_args,
log_dir=args.log_dir)
log_dir=args.log_dir,
envs=global_envs)
while True:
alive = watch_local_trainers(procs, cluster.trainers_nranks())
......@@ -230,6 +241,9 @@ def launch_collective(args):
time.sleep(3)
if os.path.exists(gloo_rendezvous_dir):
shutil.rmtree(gloo_rendezvous_dir)
def launch_ps(args):
ports = None
......@@ -315,6 +329,13 @@ def launch_ps(args):
default_env = os.environ.copy()
current_env = copy.copy(default_env)
gloo_rendezvous_dir = tempfile.mkdtemp()
# add gloo env
current_env["PADDLE_WITH_GLOO"] = "1"
current_env["PADDLE_GLOO_RENDEZVOUS"] = "2"
current_env["PADDLE_GLOO_FS_PATH"] = gloo_rendezvous_dir
current_env.pop("http_proxy", None)
current_env.pop("https_proxy", None)
procs = []
......@@ -419,6 +440,9 @@ def launch_ps(args):
procs[i].proc.terminate()
print("all parameter server are killed", file=sys.stderr)
if os.path.exists(gloo_rendezvous_dir):
shutil.rmtree(gloo_rendezvous_dir)
def launch():
args = _parse_args()
......
......@@ -398,8 +398,14 @@ def start_local_trainers(cluster,
pod,
training_script,
training_script_args,
log_dir=None):
current_env = copy.copy(os.environ.copy())
log_dir=None,
envs=None):
if envs is None:
current_env = copy.copy(os.environ.copy())
else:
current_env = copy.copy(envs)
#paddle broadcast ncclUniqueId use socket, and
#proxy maybe make trainers unreachable, so delete them.
#if we set them to "", grpc will log error message "bad uri"
......
......@@ -27,7 +27,7 @@ class TestFleetBase(unittest.TestCase):
os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001"
os.environ["PADDLE_TRAINERS_NUM"] = "2"
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \
"127.0.0.1:36001,127.0.0.2:36001"
"127.0.0.1:36001,127.0.0.2:36001"
def test_init(self):
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
......@@ -88,7 +88,7 @@ class TestFleetBase(unittest.TestCase):
def test_util(self):
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
self.assertEqual(fleet.util, None)
self.assertEqual(fleet.util(), None)
def test_barrier_worker(self):
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
......@@ -99,20 +99,17 @@ class TestFleetBase(unittest.TestCase):
def test_init_worker(self):
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
if fleet.is_worker():
fleet.init_worker()
def test_run_server(self):
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
if fleet.is_worker():
fleet.run_worker()
with self.assertRaises(ValueError):
if fleet.is_worker():
fleet.init_worker()
def test_stop_worker(self):
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
if fleet.is_worker():
fleet.stop_worker()
with self.assertRaises(ValueError):
if fleet.is_worker():
fleet.stop_worker()
def test_distributed_optimizer(self):
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
......
......@@ -15,7 +15,11 @@
from __future__ import print_function
import os
import platform
import shutil
import tempfile
import unittest
import paddle
import paddle.distributed.fleet.base.role_maker as role_maker
......@@ -42,9 +46,9 @@ class TestRoleMakerBase(unittest.TestCase):
self.assertTrue(len(pserver_endpoints) == 0)
print(role.to_string())
self.assertTrue(role._all_gather(role._node_type_comm, 1) is None)
self.assertTrue(role._all_reduce(role._node_type_comm, 1) is None)
role._barrier(role._node_type_comm)
self.assertTrue(role._all_gather(1, "worker") is None)
self.assertTrue(role._all_reduce(1, "sum", "worker") is None)
role._barrier("worker")
class TestCloudRoleMaker(unittest.TestCase):
......@@ -72,8 +76,8 @@ class TestCloudRoleMaker(unittest.TestCase):
print("warning: no netifaces, skip test_tr_rolemaker")
return
ro = role_maker.PaddleCloudRoleMaker(
is_collective=False, init_gloo=False)
ro = role_maker.PaddleCloudRoleMaker(is_collective=False)
self.assertTrue(ro.is_worker())
self.assertFalse(ro.is_server())
self.assertEqual(ro.worker_num(), 2)
......@@ -108,8 +112,9 @@ class TestCloudRoleMaker(unittest.TestCase):
self.assertEqual(ro.server_num(), 2)
pserver_endpoints = ro.get_pserver_endpoints()
self.assertEqual(pserver_endpoints[0], '127.0.0.1:36001')
self.assertTrue(ro._all_gather(ro._all_comm, 1) is None)
self.assertTrue(ro._all_reduce(ro._all_comm, 1) is None)
self.assertEqual(ro._all_gather(1, "worker"), 1)
self.assertEqual(ro._all_reduce(1, "sum", "worker"), 1)
def test_traing_role(self):
"""Test training role."""
......@@ -142,7 +147,7 @@ class TestUserDefinedRoleMaker(unittest.TestCase):
ro = role_maker.UserDefinedRoleMaker(
is_collective=False,
init_gloo=False,
server_endpoints="127.0.0.1:36001,127.0.0.1:36001",
server_endpoints=["127.0.0.1:36001", "127.0.0.1:36001"],
role=role_maker.Role.SERVER,
current_id=0,
worker_num=2)
......@@ -161,14 +166,274 @@ class TestUserDefinedRoleMaker(unittest.TestCase):
ro = role_maker.UserDefinedRoleMaker(
is_collective=False,
init_gloo=False,
server_endpoints="127.0.0.1:36001,127.0.0.1:36001",
server_endpoints=["127.0.0.1:36001", "127.0.0.1:36001"],
role=role_maker.Role.WORKER,
current_id=0,
worker_num=2)
self.assertIn("127.0.0.1:36001", ro.get_pserver_endpoints())
self.assertTrue(ro.is_worker())
self.assertEqual(ro.role_id(), 0)
class TestGlooWithCloudRoleMaker(unittest.TestCase):
def setUp(self):
os.environ["PADDLE_TRAINERS_NUM"] = "1"
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36001"
os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001"
os.environ["POD_IP"] = "127.0.0.1"
os.environ["PADDLE_TRAINER_ID"] = "0"
def case(self, role, comm_world):
role._barrier(comm_world)
gather = role._all_gather(1, comm_world)
self.assertEqual(gather[0], 1)
all_reduce = role._all_reduce(1, "sum", comm_world)
self.assertEqual(1, all_reduce)
def mkdir(self):
tmp = tempfile.mkdtemp()
return tmp
def clean(self, tmp):
shutil.rmtree(tmp)
def test_hdfs_gloo(self):
plats = platform.platform()
if 'Linux' not in plats:
print("skip gloo UT on MacOS/Win")
return
tmp = self.mkdir()
os.environ["TRAINING_ROLE"] = "TRAINER"
os.environ["SYS_JOB_ID"] = "gloo_for_cluster"
os.environ["PADDLE_WITH_GLOO"] = "1"
os.environ["PADDLE_GLOO_RENDEZVOUS"] = "1"
os.environ["PADDLE_GLOO_FS_NAME"] = "NULL"
os.environ["PADDLE_GLOO_FS_UGI"] = "NULL"
os.environ["PADDLE_GLOO_FS_PATH"] = tmp
role = role_maker.PaddleCloudRoleMaker()
role.generate_role()
self.case(role, "worker")
self.clean(tmp)
def test_fs_gloo(self):
plats = platform.platform()
if 'Linux' not in plats:
print("skip gloo UT on MacOS/Win")
return
tmp = self.mkdir()
os.environ["TRAINING_ROLE"] = "TRAINER"
os.environ["SYS_JOB_ID"] = "gloo_for_cluster"
os.environ["PADDLE_WITH_GLOO"] = "1"
os.environ["PADDLE_GLOO_RENDEZVOUS"] = "2"
os.environ["PADDLE_GLOO_FS_PATH"] = tmp
role = role_maker.PaddleCloudRoleMaker()
role.generate_role()
self.case(role, "worker")
self.clean(tmp)
def test_fs_gloo2(self):
plats = platform.platform()
if 'Linux' not in plats:
print("skip gloo UT on MacOS/Win")
return
tmp = self.mkdir()
os.environ["TRAINING_ROLE"] = "PSERVER"
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36001"
os.environ["POD_IP"] = "127.0.0.1"
os.environ["PADDLE_PORT"] = "36001"
os.environ["SYS_JOB_ID"] = "gloo_for_cluster"
os.environ["PADDLE_WITH_GLOO"] = "1"
os.environ["PADDLE_GLOO_RENDEZVOUS"] = "2"
os.environ["PADDLE_GLOO_FS_PATH"] = tmp
role = role_maker.PaddleCloudRoleMaker()
role.generate_role()
self.case(role, "server")
self.clean(tmp)
def test_fs_gloo3(self):
plats = platform.platform()
if 'Linux' not in plats:
print("skip gloo UT on MacOS/Win")
return
tmp = self.mkdir()
os.environ["TRAINING_ROLE"] = "PSERVER"
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36001"
os.environ["POD_IP"] = "127.0.0.1"
os.environ["PADDLE_PORT"] = "36001"
os.environ["SYS_JOB_ID"] = "gloo_for_cluster"
os.environ["PADDLE_WITH_GLOO"] = "1"
os.environ["PADDLE_GLOO_RENDEZVOUS"] = "1"
os.environ["PADDLE_GLOO_FS_NAME"] = "NULL"
os.environ["PADDLE_GLOO_FS_UGI"] = "NULL"
os.environ["PADDLE_GLOO_FS_PATH"] = tmp
role = role_maker.PaddleCloudRoleMaker()
role.generate_role()
self.case(role, "server")
self.clean(tmp)
def test_fs_gloo4(self):
plats = platform.platform()
if 'Linux' not in plats:
print("skip gloo UT on MacOS/Win")
return
os.environ["TRAINING_ROLE"] = "PSERVER"
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36001"
os.environ["POD_IP"] = "127.0.0.1"
os.environ["PADDLE_PORT"] = "36001"
os.environ["SYS_JOB_ID"] = "gloo_for_cluster"
os.environ["PADDLE_WITH_GLOO"] = "1"
os.environ["PADDLE_GLOO_RENDEZVOUS"] = "3"
os.environ["PADDLE_GLOO_HTTP_HOST"] = "127.0.0.1"
os.environ["PADDLE_GLOO_HTTP_PORT"] = "30019"
role = role_maker.PaddleCloudRoleMaker()
role.generate_role()
import time
time.sleep(3)
def test_fs_gloo5(self):
plats = platform.platform()
if 'Linux' not in plats:
print("skip gloo UT on MacOS/Win")
return
tmp = self.mkdir()
os.environ["TRAINING_ROLE"] = "PSERVER"
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36001"
os.environ["POD_IP"] = "127.0.0.1"
os.environ["PADDLE_PORT"] = "36001"
os.environ["PADDLE_TRAINERS_NUM"] = "0"
os.environ["SYS_JOB_ID"] = "gloo_for_cluster"
os.environ["PADDLE_WITH_GLOO"] = "2"
os.environ["PADDLE_GLOO_RENDEZVOUS"] = "2"
os.environ["PADDLE_GLOO_FS_PATH"] = tmp
role = role_maker.PaddleCloudRoleMaker()
role.generate_role()
self.case(role, "server")
self.case(role, "all")
self.clean(tmp)
def test_fs_gloo6(self):
plats = platform.platform()
if 'Linux' not in plats:
print("skip gloo UT on MacOS/Win")
return
tmp = self.mkdir()
os.environ["TRAINING_ROLE"] = "PSERVER"
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36001"
os.environ["POD_IP"] = "127.0.0.1"
os.environ["PADDLE_PORT"] = "36001"
os.environ["PADDLE_TRAINERS_NUM"] = "0"
os.environ["SYS_JOB_ID"] = "gloo_for_cluster"
os.environ["PADDLE_WITH_GLOO"] = "2"
os.environ["PADDLE_GLOO_RENDEZVOUS"] = "1"
os.environ["PADDLE_GLOO_FS_NAME"] = "NULL"
os.environ["PADDLE_GLOO_FS_UGI"] = "NULL"
os.environ["PADDLE_GLOO_FS_PATH"] = tmp
role = role_maker.PaddleCloudRoleMaker()
role.generate_role()
self.case(role, "server")
self.case(role, "all")
self.clean(tmp)
def test_fs_gloo7(self):
plats = platform.platform()
if 'Linux' not in plats:
print("skip gloo UT on MacOS/Win")
return
os.environ["TRAINING_ROLE"] = "PSERVER"
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36001"
os.environ["POD_IP"] = "127.0.0.1"
os.environ["PADDLE_PORT"] = "36001"
os.environ["PADDLE_TRAINERS_NUM"] = "0"
os.environ["SYS_JOB_ID"] = "gloo_for_cluster"
os.environ["PADDLE_WITH_GLOO"] = "1"
os.environ["PADDLE_GLOO_RENDEZVOUS"] = "5"
role = role_maker.PaddleCloudRoleMaker()
self.assertRaises(ValueError, role.generate_role)
def test_fs_gloo8(self):
plats = platform.platform()
if 'Linux' not in plats:
print("skip gloo UT on MacOS/Win")
return
tmp = self.mkdir()
os.environ["TRAINING_ROLE"] = "PSERVER"
os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36001"
os.environ["POD_IP"] = "127.0.0.1"
os.environ["PADDLE_PORT"] = "36001"
os.environ["PADDLE_TRAINERS_NUM"] = "0"
os.environ["SYS_JOB_ID"] = "gloo_for_cluster"
os.environ["PADDLE_WITH_GLOO"] = "2"
os.environ["PADDLE_GLOO_RENDEZVOUS"] = "1"
os.environ["PADDLE_GLOO_FS_NAME"] = "NULL"
os.environ["PADDLE_GLOO_FS_UGI"] = "NULL"
os.environ["PADDLE_GLOO_FS_PATH"] = tmp
def net():
x = paddle.fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = paddle.fluid.layers.fc(input=x, size=1, act=None)
y = paddle.fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = paddle.fluid.layers.square_error_cost(
input=y_predict, label=y)
avg_cost = paddle.fluid.layers.mean(cost)
return avg_cost
from paddle.distributed import fleet
role = role_maker.PaddleCloudRoleMaker()
fleet.init(role)
avg_cost = net()
strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.a_sync = False
optimizer = paddle.optimizer.SGD(0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(avg_cost)
comm_world = "server"
fleet.util().barrier(comm_world)
gather = fleet.util().all_gather(1, comm_world)
self.assertEqual(gather[0], 1)
all_reduce = fleet.util().all_reduce(1, "sum", comm_world)
self.assertEqual(1, all_reduce)
self.clean(tmp)
if __name__ == "__main__":
unittest.main()
......@@ -59,7 +59,7 @@ class TestFleetUtil(unittest.TestCase):
import paddle.distributed.fleet.base.role_maker as role_maker
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
default_util = fleet.util
default_util = fleet.util()
self.assertEqual(default_util, None)
def test_set_user_defined_util(self):
......@@ -76,8 +76,8 @@ class TestFleetUtil(unittest.TestCase):
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
my_util = UserDefinedUtil()
fleet.util = my_util
user_id = fleet.util.get_user_id()
fleet.set_util(my_util)
user_id = fleet.util().get_user_id()
self.assertEqual(user_id, 10)
def test_fs(self):
......@@ -88,97 +88,6 @@ class TestFleetUtil(unittest.TestCase):
self.assertFalse(fs.need_upload_download())
fleet_util._set_file_system(fs)
def test_barrier(self):
try:
import netifaces
except:
print("warning: no netifaces, skip test_barrier")
return
gloo = fluid.core.Gloo()
gloo.set_rank(0)
gloo.set_size(1)
gloo.set_prefix("123")
gloo.set_iface("lo")
gloo.set_hdfs_store("./tmp_test_fleet_barrier", "", "")
gloo.init()
role = role_maker.UserDefinedRoleMaker(
is_collective=False,
init_gloo=False,
current_id=0,
role=role_maker.Role.SERVER,
worker_endpoints=["127.0.0.1:6003"],
server_endpoints=["127.0.0.1:6001"])
role._node_type_comm = gloo
role._role_is_generated = True
fleet_util._set_role_maker(role)
fleet_util.barrier("worker")
def test_all_reduce(self):
try:
import netifaces
except:
print("warning: no netifaces, skip test_all_reduce")
return
gloo = fluid.core.Gloo()
gloo.set_rank(0)
gloo.set_size(1)
gloo.set_prefix("123")
gloo.set_iface("lo")
gloo.set_hdfs_store("./tmp_test_fleet_reduce", "", "")
gloo.init()
role = role_maker.UserDefinedRoleMaker(
is_collective=False,
init_gloo=False,
current_id=0,
role=role_maker.Role.WORKER,
worker_endpoints=["127.0.0.1:6003"],
server_endpoints=["127.0.0.1:6001"])
role._node_type_comm = gloo
role._role_is_generated = True
fleet_util._set_role_maker(role)
output = fleet_util.all_reduce(1, "sum", comm_world="server")
print(output)
# self.assertEqual(output, 1)
def test_all_gather(self):
try:
import netifaces
except:
print("warning: no netifaces, skip test_all_gather")
return
gloo = fluid.core.Gloo()
gloo.set_rank(0)
gloo.set_size(1)
gloo.set_prefix("123")
gloo.set_iface("lo")
gloo.set_hdfs_store("./tmp_test_fleet_reduce", "", "")
gloo.init()
role = role_maker.UserDefinedRoleMaker(
is_collective=False,
init_gloo=False,
current_id=0,
role=role_maker.Role.SERVER,
worker_endpoints=["127.0.0.1:6003"],
server_endpoints=["127.0.0.1:6001"])
role._node_type_comm = gloo
role._all_comm = gloo
role._role_is_generated = True
fleet_util._set_role_maker(role)
output = fleet_util.all_gather(1, comm_world="all")
print(output)
# self.assertTrue(len(output) == 1 and output[0] == 1)
self.assertRaises(Exception, fleet_util.all_gather, 1, "test")
def download_files(self):
path = download(self.proto_data_url, self.module_name,
self.proto_data_md5)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册