未验证 提交 5f6376b7 编写于 作者: Y yuehuayingxueluo 提交者: GitHub

Add Gloo Gather Function (#52334)

* add gloo gather

* add gloo_tools

* fix CI bug

* use gloo gather

* remove redundant code

* fix process_group_gloo.py

* rename send_recv

* fix conflict

* fix conflict

* fix codestyle

* fix CI bug

* add PADDLE_ENFORCE_NE
上级 e6e62342
......@@ -11,7 +11,7 @@ cc_library(
if(WITH_DISTRIBUTE)
cc_library(
process_group_gloo
SRCS process_group_gloo.cc send_recv.cc
SRCS process_group_gloo.cc gloo_send_recv.cc
DEPS phi_api eager_api gloo_wrapper tcp_store)
endif()
......
......@@ -18,7 +18,7 @@
#include "gloo/common/logging.h"
#include "gloo/math.h"
#include "gloo/types.h"
#include "paddle/fluid/distributed/collective/send_recv.h"
#include "paddle/fluid/distributed/collective/gloo_send_recv.h"
namespace paddle {
namespace distributed {
......
......@@ -25,12 +25,13 @@
#endif
#include <gloo/broadcast.h>
#include <gloo/gather.h>
#include <gloo/reduce.h>
#include <gloo/scatter.h>
#include "paddle/fluid/distributed/collective/common.h"
#include "paddle/fluid/distributed/collective/gloo_send_recv.h"
#include "paddle/fluid/distributed/collective/process_group_gloo.h"
#include "paddle/fluid/distributed/collective/send_recv.h"
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -680,6 +681,65 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Scatter(
return Scatter(&out_tensors[0], in_tensors[0], opts, true);
}
class GatherGlooTask : public ProcessGroupGloo::GlooTask {
public:
GatherGlooTask(int rank,
const std::shared_ptr<gloo::Context>& context,
const phi::DenseTensor& input, // NOLINT
phi::DenseTensor* output, // NOLINT
int src,
uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, {input}, CommType::GATHER),
_context(context),
_input(input),
_output(*output),
_src(src),
_tag(tag) {}
void Run() override { _do_gather(_input, _output, _src); }
private:
std::shared_ptr<gloo::Context> _context;
phi::DenseTensor _input;
phi::DenseTensor _output;
int _src;
uint32_t _tag;
void _do_gather(phi::DenseTensor& in, // NOLINT
phi::DenseTensor& out, // NOLINT
int src) {
const auto& dtype = in.dtype();
gloo::GatherOptions opts(_context);
if (rank_ == src) {
GENERATE_FUNC(dtype, set_output, opts, out);
}
GENERATE_FUNC(dtype, set_input, opts, in);
opts.setRoot(src);
opts.setTag(_tag);
gloo::gather(opts);
}
};
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Gather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const GatherOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_ENFORCE_NE(
use_calc_stream,
true,
platform::errors::InvalidArgument("Gloo cannot use use_calc_stream."));
std::shared_ptr<GatherGlooTask> task;
auto tag = next_tag();
auto context = get_context();
task = std::make_shared<GatherGlooTask>(
rank_, context, in_tensor, out_tensor, opts.root_rank, tag);
task->Run();
return task;
}
std::shared_ptr<::gloo::transport::Device>
ProcessGroupGloo::createDeviceForInterface(const std::string& ifname) {
::gloo::transport::tcp::attr attr;
......
......@@ -150,6 +150,12 @@ class ProcessGroupGloo : public ProcessGroupWithoutStream {
const ScatterOptions& opts,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Gather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const GatherOptions& opts,
bool sync_op,
bool use_calc_stream) override;
// TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroup::Task> Broadcast(
std::vector<phi::DenseTensor>& inputs,
......@@ -210,6 +216,15 @@ class ProcessGroupGloo : public ProcessGroupWithoutStream {
return platform::DeviceContextPool::Instance().Get(place);
}
phi::DeviceContext* GetDeviceContext(const Place& place,
bool use_calc_stream) const override {
PADDLE_ENFORCE_NE(
use_calc_stream,
true,
platform::errors::InvalidArgument("Gloo cannot use use_calc_stream."));
return GetDeviceContext(place);
}
// Helper functions for Gloo.
static std::shared_ptr<::gloo::transport::Device> createDeviceForHostname(
const std::string& hostname);
......
......@@ -499,7 +499,6 @@ void BindDistributed(py::module *m) {
py::arg("src"),
py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>())
.def(
"scatter_tensor",
[](distributed::ProcessGroup &self,
......@@ -547,11 +546,12 @@ void BindDistributed(py::module *m) {
auto *dev_ctx =
self.GetDeviceContext(in_tensor.place(), use_calc_stream);
distributed::GatherOptions gather_ops{dst};
distributed::GatherOptions gather_opts{dst};
auto task = self.Gather(
out_dense, in_dense, gather_ops, sync_op, use_calc_stream);
out_dense, in_dense, gather_opts, sync_op, use_calc_stream);
SplitTensor(*dev_ctx, *out_dense, &out_tensor_list);
if (!use_calc_stream) {
if (!use_calc_stream &&
dev_ctx->GetPlace() != platform::CPUPlace()) {
// calculate stream will wait comm stream
task->UpdateWaitChain(*dev_ctx);
}
......@@ -561,7 +561,7 @@ void BindDistributed(py::module *m) {
py::arg("out"),
py::arg("dst"),
py::arg("sync_op"),
py::arg("use_calc_stream"),
py::arg("use_calc_stream") = false,
py::call_guard<py::gil_scoped_release>())
.def(
"barrier",
......
......@@ -113,15 +113,15 @@ class TestProcessGroupFp32(unittest.TestCase):
send_recv_result_1 = paddle.assign(tensor_x)
send_recv_result_2 = paddle.assign(tensor_y_2)
if pg.rank() == 0:
task = pg.send(tensor_x, 1, True)
else:
task = pg.send(tensor_x, pg.size() - 1, True)
elif pg.rank() == pg.size() - 1:
task = pg.recv(tensor_y_1, 0, True)
assert np.array_equal(send_recv_result_1, tensor_y_1)
if pg.rank() == 0:
task = pg.recv(tensor_x, 1, True)
task = pg.recv(tensor_x, pg.size() - 1, True)
assert np.array_equal(send_recv_result_2, tensor_x)
else:
elif pg.rank() == pg.size() - 1:
task = pg.send(tensor_y_2, 0, True)
print("test send_recv api ok")
......@@ -204,6 +204,30 @@ class TestProcessGroupFp32(unittest.TestCase):
assert np.array_equal(tensor_y, out2)
print("test scatter api ok\n")
# test Gather
def test_gather(root):
tensor_x = [
paddle.zeros(self.shape).astype(self.dtype)
for _ in range(pg.size())
]
tensor_y = [
paddle.to_tensor(
np.random.random(self.shape).astype(self.dtype)
)
for _ in range(pg.size())
]
if pg.rank() == root:
task = pg.gather(tensor_y[root], tensor_x, root, True)
task.wait()
assert np.array_equal(tensor_x, tensor_y)
else:
task = pg.gather(tensor_y[pg.rank()], tensor_x, root, True)
task.wait()
test_gather(0)
test_gather(pg.size() - 1)
print("test gather api ok\n")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册