未验证 提交 e8f146a9 编写于 作者: B Bo Liu 提交者: GitHub

[CPU] Enable barrier op upon gloo (#34671)

上级 17188e8d
...@@ -27,6 +27,21 @@ void GlooParallelContext::Init() { ...@@ -27,6 +27,21 @@ void GlooParallelContext::Init() {
strategy_.scope); strategy_.scope);
gloo_ptr->Init(); gloo_ptr->Init();
} }
void GlooParallelContext::Barrier() {
auto gloo_ptr = paddle::framework::GlooWrapper::GetInstance();
PADDLE_ENFORCE_EQ(gloo_ptr->IsInitialized(), true,
paddle::platform::errors::Unavailable(
"Gloo context is not initialized."));
gloo_ptr->Barrier();
}
void GlooParallelContext::ReleaseContext() {
auto gloo_ptr = paddle::framework::GlooWrapper::GetInstance();
if (gloo_ptr->IsInitialized() == true) {
gloo_ptr.reset();
}
}
#endif #endif
} // namespace platform } // namespace platform
......
...@@ -41,6 +41,10 @@ class GlooParallelContext { ...@@ -41,6 +41,10 @@ class GlooParallelContext {
virtual void Init(); virtual void Init();
virtual void Barrier();
virtual void ReleaseContext();
protected: protected:
GlooParallelStrategy strategy_; GlooParallelStrategy strategy_;
}; };
......
...@@ -93,7 +93,11 @@ void BindGlooContext(py::module *m) { ...@@ -93,7 +93,11 @@ void BindGlooContext(py::module *m) {
py::class_<platform::GlooParallelContext> gloo_ctx(*m, "GlooParallelContext"); py::class_<platform::GlooParallelContext> gloo_ctx(*m, "GlooParallelContext");
gloo_ctx.def(py::init<const platform::GlooParallelStrategy &>()) gloo_ctx.def(py::init<const platform::GlooParallelStrategy &>())
.def("init", [](platform::GlooParallelContext &self) { self.Init(); }); .def("init", [](platform::GlooParallelContext &self) { self.Init(); })
.def("barrier",
[](platform::GlooParallelContext &self) { self.Barrier(); })
.def("release",
[](platform::GlooParallelContext &self) { self.ReleaseContext(); });
#endif #endif
} }
......
...@@ -18,6 +18,10 @@ from .parallel import init_parallel_env # noqa: F401 ...@@ -18,6 +18,10 @@ from .parallel import init_parallel_env # noqa: F401
from .parallel import get_rank # noqa: F401 from .parallel import get_rank # noqa: F401
from .parallel import get_world_size # noqa: F401 from .parallel import get_world_size # noqa: F401
from .parallel_with_gloo import gloo_init_parallel_env
from .parallel_with_gloo import gloo_barrier
from .parallel_with_gloo import gloo_release
from paddle.distributed.fleet.dataset import InMemoryDataset # noqa: F401 from paddle.distributed.fleet.dataset import InMemoryDataset # noqa: F401
from paddle.distributed.fleet.dataset import QueueDataset # noqa: F401 from paddle.distributed.fleet.dataset import QueueDataset # noqa: F401
...@@ -60,6 +64,9 @@ __all__ = [ #noqa ...@@ -60,6 +64,9 @@ __all__ = [ #noqa
"ParallelEnv", "ParallelEnv",
"new_group", "new_group",
"init_parallel_env", "init_parallel_env",
"gloo_init_parallel_env",
"gloo_barrier",
"gloo_release",
"QueueDataset", "QueueDataset",
"split", "split",
"CountFilterEntry", "CountFilterEntry",
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except jin compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import time
import warnings
from multiprocessing import Process, Manager
# deprecated module import
from paddle.fluid import core
from paddle.distributed.fleet.base.private_helper_function import wait_server_ready
__all__ = []
_global_gloo_ctx = None
def _start_kv_server(port, http_server_d, size):
from paddle.distributed.fleet.utils.http_server import KVServer
http_server = KVServer(int(port), size=size)
http_server.start()
wait_seconds = 3
while http_server_d.get("running", False) or not http_server.should_stop():
time.sleep(wait_seconds)
http_server.stop()
def gloo_init_parallel_env(rank_id, rank_num, server_endpoint):
"""
Initialize parallel environment with gloo for cpu only.
Args:
- rank_id(int, required) - the index of current rank;
- rank_num (int, required) - the number of ranks in this parallel env;
- server_endpoint (str, required) - endpoint of server to init gloo context in ip:port format;
Returns:
None
Examples:
.. code-block:: python
import paddle
import multiprocessing
from contextlib import closing
import socket
port_set = set()
def find_free_port():
def _free_port():
with closing(socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as s:
s.bind(('', 0))
return s.getsockname()[1]
while True:
port = _free_port()
if port not in port_set:
port_set.add(port)
return port
def test_gloo_init(id, rank_num, server_endpoint):
paddle.distributed.gloo_init_parallel_env(
id, rank_num, server_endpoint)
def test_gloo_init_with_multiprocess(num_of_ranks):
jobs = []
server_endpoint = "127.0.0.1:%s" % (find_free_port())
for id in range(num_of_ranks):
p = multiprocessing.Process(
target=test_gloo_init,
args=(id, num_of_ranks, server_endpoint))
jobs.append(p)
p.start()
for proc in jobs:
proc.join()
if __name__ == '__main__':
# Arg: number of ranks (processes)
test_gloo_init_with_multiprocess(2)
"""
assert (rank_num < 2) is False, \
"rank_num should greater than or equal to 2 for parallel environment initialzation."
# init gloo context
manager = Manager()
# global dict to store status
http_server_status = manager.dict()
http_server_status["running"] = False
if rank_id == 0:
# The scope for worker used by http server is '_worker'
size = {'_worker': rank_num}
http_server_proc = Process(
target=_start_kv_server,
args=(int(server_endpoint.split(":")[1]), http_server_status, size))
http_server_proc.daemon = True
http_server_status["running"] = True
http_server_proc.start()
# all processes in this parallel environment should wait until server is ready
wait_server_ready([server_endpoint])
gloo_strategy = core.GlooParallelStrategy()
gloo_strategy.rank = rank_id
gloo_strategy.rank_num = rank_num
gloo_strategy.ip_address = server_endpoint.split(":")[0]
gloo_strategy.ip_port = int(server_endpoint.split(":")[1])
# default_init_timeout_seconds
gloo_strategy.init_seconds = 3600
# default_run_timeout_seconds
gloo_strategy.run_seconds = 9999999
global _global_gloo_ctx
_global_gloo_ctx = core.GlooParallelContext(gloo_strategy)
_global_gloo_ctx.init()
if rank_id == 0:
http_server_status["running"] = False
http_server_proc.join()
def gloo_barrier():
"""
Call barrier function with initialized gloo context.
Args:
None
Returns:
None
Examples:
.. code-block:: python
import paddle
import multiprocessing
from contextlib import closing
import socket
port_set = set()
def find_free_port():
def _free_port():
with closing(socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as s:
s.bind(('', 0))
return s.getsockname()[1]
while True:
port = _free_port()
if port not in port_set:
port_set.add(port)
return port
def test_gloo_barrier(id, rank_num, server_endpoint):
paddle.distributed.gloo_init_parallel_env(
id, rank_num, server_endpoint)
paddle.distributed.gloo_barrier()
def test_gloo_barrier_with_multiprocess(num_of_ranks):
jobs = []
server_endpoint = "127.0.0.1:%s" % (find_free_port())
for id in range(num_of_ranks):
p = multiprocessing.Process(
target=test_gloo_barrier,
args=(id, num_of_ranks, server_endpoint))
jobs.append(p)
p.start()
for proc in jobs:
proc.join()
if __name__ == '__main__':
# Arg: number of ranks (processes)
test_gloo_barrier_with_multiprocess(2)
"""
assert _global_gloo_ctx is not None, "gloo context is not initialzed."
_global_gloo_ctx.barrier()
def gloo_release():
"""
Release the parallel environment initialized by gloo
Args:
None
Returns:
None
Examples:
.. code-block:: python
import paddle
import multiprocessing
from contextlib import closing
import socket
port_set = set()
def find_free_port():
def _free_port():
with closing(socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as s:
s.bind(('', 0))
return s.getsockname()[1]
while True:
port = _free_port()
if port not in port_set:
port_set.add(port)
return port
def test_gloo_release(id, rank_num, server_endpoint):
paddle.distributed.gloo_init_parallel_env(
id, rank_num, server_endpoint)
paddle.distributed.gloo_barrier()
paddle.distributed.gloo_release()
def test_gloo_release_with_multiprocess(num_of_ranks):
jobs = []
server_endpoint = "127.0.0.1:%s" % (find_free_port())
for id in range(num_of_ranks):
p = multiprocessing.Process(
target=test_gloo_release,
args=(id, num_of_ranks, server_endpoint))
jobs.append(p)
p.start()
for proc in jobs:
proc.join()
if __name__ == '__main__':
# Arg: number of ranks (processes)
test_gloo_release_with_multiprocess(2)
"""
if _global_gloo_ctx is not None:
_global_gloo_ctx.release()
...@@ -145,6 +145,7 @@ if(NOT WITH_DISTRIBUTE OR WIN32) ...@@ -145,6 +145,7 @@ if(NOT WITH_DISTRIBUTE OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_fleet_ps) LIST(REMOVE_ITEM TEST_OPS test_fleet_ps)
LIST(REMOVE_ITEM TEST_OPS test_fleet_rolemaker_2) LIST(REMOVE_ITEM TEST_OPS test_fleet_rolemaker_2)
LIST(REMOVE_ITEM TEST_OPS test_fleet_utils) LIST(REMOVE_ITEM TEST_OPS test_fleet_utils)
LIST(REMOVE_ITEM TEST_OPS test_collective_cpu_barrier_with_gloo)
# TODO: Fix these unittests failed on Windows # TODO: Fix these unittests failed on Windows
list(REMOVE_ITEM TEST_OPS test_fake_init_op) list(REMOVE_ITEM TEST_OPS test_fake_init_op)
...@@ -740,6 +741,7 @@ endif() ...@@ -740,6 +741,7 @@ endif()
if (WITH_DISTRIBUTE AND NOT WIN32) if (WITH_DISTRIBUTE AND NOT WIN32)
set_tests_properties(test_fleet_utils PROPERTIES TIMEOUT 120) set_tests_properties(test_fleet_utils PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_cpu_barrier_with_gloo PROPERTIES TIMEOUT 40)
endif() endif()
if (WITH_DISTRIBUTE) if (WITH_DISTRIBUTE)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import os
import sys
import time
import multiprocessing
from contextlib import closing
import socket
import paddle
import paddle.fluid as fluid
port_set = set()
paddle.enable_static()
class CollectiveCPUBarrierWithGlooTest(unittest.TestCase):
def find_free_port(self):
def _free_port():
with closing(socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as s:
s.bind(('', 0))
return s.getsockname()[1]
while True:
port = _free_port()
if port not in port_set:
port_set.add(port)
return port
def barrier_func(self, id, rank_num, server_endpoint, out_dict, sleep_time):
try:
paddle.distributed.gloo_init_parallel_env(id, rank_num,
server_endpoint)
# 1st barrier
# Run barrier to synchronize processes after starting
paddle.distributed.gloo_barrier()
# 2nd barrier
# Let rank 0 sleep for one second and check that all processes
# saw that artificial delay through the barrier
start = time.time()
if (id == 0):
time.sleep(sleep_time)
paddle.distributed.gloo_barrier()
end = time.time()
out_dict[id] = end - start
# Release
paddle.distributed.gloo_release()
except:
out_dict[id] = 0
def barrier_op(self, id, rank_num, server_endpoint, out_dict, sleep_time):
try:
main_prog = fluid.Program()
startup_prog = fluid.Program()
paddle.distributed.gloo_init_parallel_env(id, rank_num,
server_endpoint)
place = fluid.CPUPlace()
with fluid.program_guard(main_prog, startup_prog):
paddle.distributed.barrier()
exe = fluid.Executor(place)
# Run barrier to synchronize processes after starting
exe.run(main_prog)
# Let rank 0 sleep for one second and check that all processes
# saw that artificial delay through the barrier
start = time.time()
if (id == 0):
time.sleep(sleep_time)
exe.run(main_prog)
end = time.time()
out_dict[id] = end - start
# Release
paddle.distributed.gloo_release()
except:
out_dict[id] = 0
def test_barrier_func_with_multiprocess(self):
num_of_ranks = 4
sleep_time = 1
# create endpoints
ep_str = "127.0.0.1:%s" % (self.find_free_port())
# call barrier op inside each process
manager = multiprocessing.Manager()
procs_out_dict = manager.dict()
jobs = []
for id in range(num_of_ranks):
p = multiprocessing.Process(
target=self.barrier_func,
args=(id, num_of_ranks, ep_str, procs_out_dict, sleep_time))
jobs.append(p)
p.start()
for proc in jobs:
proc.join()
for _, v in procs_out_dict.items():
self.assertTrue(v > sleep_time)
def test_barrier_op_with_multiprocess(self):
num_of_ranks = 4
sleep_time = 1
# create endpoints
ep_str = "127.0.0.1:%s" % (self.find_free_port())
# call barrier op inside each process
manager = multiprocessing.Manager()
procs_out_dict = manager.dict()
jobs = []
for id in range(num_of_ranks):
p = multiprocessing.Process(
target=self.barrier_op,
args=(id, num_of_ranks, ep_str, procs_out_dict, sleep_time))
jobs.append(p)
p.start()
for proc in jobs:
proc.join()
for _, v in procs_out_dict.items():
self.assertTrue(v > sleep_time)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册