未验证 提交 1d1e5484 编写于 作者: X Xing-lil 提交者: GitHub

Update gloo in dygraph (#55537)

* update broadcast gloo in dygraph

* update

* update reduce gloo in dygraph

* update reduce gloo in dygraph

* update

* update allreduce allgather

* update all

* update

* update

* update
上级 982e0a9d
......@@ -24,16 +24,12 @@
#include <unistd.h>
#endif
#include <gloo/broadcast.h>
#include <gloo/gather.h>
#include <gloo/reduce.h>
#include <gloo/scatter.h>
#include "paddle/fluid/distributed/collective/common.h"
#include "paddle/fluid/distributed/collective/gloo_send_recv.h"
#include "paddle/fluid/distributed/collective/process_group_gloo.h"
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/core/distributed/comm_context_manager.h"
namespace paddle {
namespace distributed {
......@@ -99,28 +95,6 @@ namespace distributed {
}
#endif
typedef void (*reduce_func)(void*, const void*, const void*, size_t);
template <typename T>
reduce_func get_function(const ReduceOp& r) {
switch (r) {
case ReduceOp::SUM:
return reduce_func(&::gloo::sum<T>);
case ReduceOp::PRODUCT:
return reduce_func(&::gloo::product<T>);
case ReduceOp::MIN:
return reduce_func(&::gloo::min<T>);
case ReduceOp::MAX:
return reduce_func(&::gloo::max<T>);
case ReduceOp::AVG:
VLOG(0) << "Error: Unsupported ReduceOp::AVG.";
exit(-1);
}
VLOG(0) << "Error: Unknown ReduceOp.";
exit(-1);
}
template <typename T>
T* get_data(phi::DenseTensor& tensor) { // NOLINT
return reinterpret_cast<T*>(tensor.data());
......@@ -188,21 +162,19 @@ ProcessGroupGloo::ProcessGroupGloo(
_tag(0),
_store(new GlooStore(store)) {
_context = std::make_shared<gloo::rendezvous::Context>(rank, world_size);
auto prefix_store =
::gloo::rendezvous::PrefixStore(std::to_string(gid), *_store);
_context->connectFullMesh(prefix_store, options->device);
_context->connectFullMesh(*_store, options->device);
}
class BroadcastGlooTask : public ProcessGroupGloo::GlooTask {
public:
BroadcastGlooTask(const std::shared_ptr<gloo::Context>& context,
BroadcastGlooTask(phi::distributed::GlooCommContext* comm_context,
std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
int rank,
int root,
uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, inputs, CommType::BROADCAST),
_context(context),
_comm_context(comm_context),
_root(root),
_inputs(inputs),
_outputs(outputs),
......@@ -211,22 +183,14 @@ class BroadcastGlooTask : public ProcessGroupGloo::GlooTask {
void Run() override { _do_broadcast(_inputs[0], _outputs[0]); }
private:
std::shared_ptr<gloo::Context> _context;
phi::distributed::GlooCommContext* _comm_context;
const int _root;
std::vector<phi::DenseTensor> _inputs{};
std::vector<phi::DenseTensor> _outputs{};
const uint32_t _tag;
void _do_broadcast(phi::DenseTensor& in, phi::DenseTensor& out) { // NOLINT
gloo::BroadcastOptions opts(_context);
const auto& dtype = in.dtype();
if (rank_ == _root) {
GENERATE_FUNC(dtype, set_input, opts, in);
}
GENERATE_FUNC(dtype, set_output, opts, out);
opts.setRoot(_root);
opts.setTag(_tag);
gloo::broadcast(opts);
_comm_context->Broadcast(&(out), in, _root, _tag);
}
};
......@@ -256,22 +220,22 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Broadcast(
auto root = opts.source_rank;
std::unique_ptr<BroadcastGlooTask> task;
auto tag = next_tag();
auto context = get_context();
auto comm_context = this->GetCommContext();
task = std::make_unique<BroadcastGlooTask>(
context, inputs, outputs, rank_, root, tag);
comm_context, inputs, outputs, rank_, root, tag);
task->Run();
return task;
}
class SendGlooTask : public ProcessGroupGloo::GlooTask {
public:
SendGlooTask(const std::shared_ptr<gloo::Context>& context,
SendGlooTask(phi::distributed::GlooCommContext* comm_context,
std::vector<phi::DenseTensor>* inputs,
int rank,
int dst_rank,
uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, *inputs, CommType::SEND),
_context(context),
_comm_context(comm_context),
_inputs(*inputs),
_dst(dst_rank),
_tag(tag) {}
......@@ -279,20 +243,13 @@ class SendGlooTask : public ProcessGroupGloo::GlooTask {
void Run() override { _do_send(_inputs); }
private:
std::shared_ptr<gloo::Context> _context;
phi::distributed::GlooCommContext* _comm_context;
std::vector<phi::DenseTensor> _inputs;
int _dst;
uint32_t _tag;
void _do_send(std::vector<phi::DenseTensor>& in) { // NOLINT
SendRecvOptions opts(_context);
const auto& dtype = in[0].dtype();
GENERATE_FUNC(dtype, set_input, opts, in[0]);
opts.setSrc(_context.get()->rank);
opts.setDst(_dst);
opts.setTag(_tag);
send_recv(&opts);
_comm_context->Send(in[0], _dst, _tag);
}
};
......@@ -306,8 +263,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Send(
std::vector<phi::DenseTensor>& inputs, int dst_rank) {
std::unique_ptr<SendGlooTask> task;
auto tag = next_tag();
auto context = get_context();
task = std::make_unique<SendGlooTask>(context, &inputs, rank_, dst_rank, tag);
auto comm_context = this->GetCommContext();
task = std::make_unique<SendGlooTask>(
comm_context, &inputs, rank_, dst_rank, tag);
task->Run();
return task;
......@@ -315,13 +273,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Send(
class RecvGlooTask : public ProcessGroupGloo::GlooTask {
public:
RecvGlooTask(const std::shared_ptr<gloo::Context>& context,
RecvGlooTask(phi::distributed::GlooCommContext* comm_context,
std::vector<phi::DenseTensor>* outputs,
int rank,
int src_rank,
uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, *outputs, CommType::RECV),
_context(context),
_comm_context(comm_context),
_outputs(*outputs),
_src(src_rank),
_tag(tag) {}
......@@ -329,20 +287,13 @@ class RecvGlooTask : public ProcessGroupGloo::GlooTask {
void Run() override { _do_recv(_outputs); }
private:
std::shared_ptr<gloo::Context> _context;
phi::distributed::GlooCommContext* _comm_context;
std::vector<phi::DenseTensor> _outputs;
const int _src;
const uint32_t _tag;
void _do_recv(std::vector<phi::DenseTensor>& out) { // NOLINT
SendRecvOptions opts(_context);
const auto& dtype = out[0].dtype();
GENERATE_FUNC(dtype, set_output, opts, out[0]);
opts.setSrc(_src);
opts.setDst(_context.get()->rank);
opts.setTag(_tag);
send_recv(&opts);
_comm_context->Recv(&(out[0]), _src, _tag);
}
};
......@@ -356,10 +307,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Recv(
std::vector<phi::DenseTensor>& outputs, int src_rank) {
std::unique_ptr<RecvGlooTask> task;
auto tag = next_tag();
auto context = get_context();
auto comm_context = this->GetCommContext();
task =
std::make_unique<RecvGlooTask>(context, &outputs, rank_, src_rank, tag);
task = std::make_unique<RecvGlooTask>(
comm_context, &outputs, rank_, src_rank, tag);
task->Run();
return task;
}
......@@ -367,13 +318,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Recv(
class AllreduceGlooTask : public ProcessGroupGloo::GlooTask {
public:
AllreduceGlooTask(int rank,
const std::shared_ptr<gloo::Context>& context,
phi::distributed::GlooCommContext* comm_context,
std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
ReduceOp reduce_op,
uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, inputs, CommType::ALLREDUCE),
_context(context),
_comm_context(comm_context),
_inputs(inputs),
_outputs(outputs),
_reduce_op(reduce_op),
......@@ -382,34 +333,16 @@ class AllreduceGlooTask : public ProcessGroupGloo::GlooTask {
void Run() override { _do_allreduce(_inputs, _outputs); }
private:
std::shared_ptr<gloo::Context> _context;
phi::distributed::GlooCommContext* _comm_context;
std::vector<phi::DenseTensor> _inputs;
std::vector<phi::DenseTensor> _outputs;
const ReduceOp _reduce_op;
uint32_t _tag;
gloo::AllreduceOptions::Func _get_function(const phi::DataType type,
const ReduceOp op) {
gloo::AllreduceOptions::Func fn;
GENERATE_FUNC(type, _get_function_impl, fn, op);
return fn;
}
template <typename T>
void _get_function_impl(gloo::AllreduceOptions::Func& fn, // NOLINT
const ReduceOp op) {
fn = get_function<T>(op);
}
void _do_allreduce(std::vector<phi::DenseTensor>& ins, // NOLINT
std::vector<phi::DenseTensor>& outs) { // NOLINT
const auto& dtype = ins[0].dtype();
gloo::AllreduceOptions opts(_context);
GENERATE_FUNC(dtype, set_inputs, opts, ins);
GENERATE_FUNC(dtype, set_outputs, opts, outs);
opts.setReduceFunction(_get_function(dtype, _reduce_op));
opts.setTag(_tag);
gloo::allreduce(opts);
_comm_context->AllReduce(
&(outs[0]), ins[0], static_cast<int>(_reduce_op), _tag);
}
};
......@@ -437,36 +370,33 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllReduce(
bool sync_op) {
auto tag = next_tag();
std::shared_ptr<GlooTask> task;
auto context = get_context();
auto comm_context = this->GetCommContext();
task = std::make_shared<AllreduceGlooTask>(
rank_, context, inputs, outputs, opts.reduce_op, tag);
rank_, comm_context, inputs, outputs, opts.reduce_op, tag);
task->Run();
return task;
}
class BarrierGlooTask : public ProcessGroupGloo::GlooTask {
public:
BarrierGlooTask(int rank, const std::shared_ptr<gloo::Context>& context)
BarrierGlooTask(int rank, phi::distributed::GlooCommContext* comm_context)
: ProcessGroupGloo::GlooTask(
rank, std::vector<phi::DenseTensor>{}, CommType::BARRIER),
_context(context) {}
_comm_context(comm_context) {}
void Run() override { _do_barrier(); }
private:
std::shared_ptr<gloo::Context> _context;
phi::distributed::GlooCommContext* _comm_context;
void _do_barrier() {
gloo::BarrierOptions opts(_context);
gloo::barrier(opts);
}
void _do_barrier() { _comm_context->Barrier(); }
};
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Barrier(
const BarrierOptions& opts) {
std::shared_ptr<BarrierGlooTask> task;
auto context = get_context();
task = std::make_shared<BarrierGlooTask>(rank_, context);
auto comm_context = this->GetCommContext();
task = std::make_shared<BarrierGlooTask>(rank_, comm_context);
task->Run();
return task;
}
......@@ -474,12 +404,12 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Barrier(
class AllgatherGlooTask : public ProcessGroupGloo::GlooTask {
public:
AllgatherGlooTask(int rank,
const std::shared_ptr<gloo::Context>& context,
phi::distributed::GlooCommContext* comm_context,
std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, inputs, CommType::ALLGATHER),
_context(context),
_comm_context(comm_context),
_inputs(inputs),
_outputs(outputs),
_tag(tag) {}
......@@ -487,19 +417,14 @@ class AllgatherGlooTask : public ProcessGroupGloo::GlooTask {
void Run() override { _do_allgather(_inputs, _outputs); }
private:
std::shared_ptr<gloo::Context> _context;
phi::distributed::GlooCommContext* _comm_context;
std::vector<phi::DenseTensor> _inputs;
std::vector<phi::DenseTensor> _outputs;
uint32_t _tag;
void _do_allgather(std::vector<phi::DenseTensor>& in, // NOLINT
std::vector<phi::DenseTensor>& out) { // NOLINT
const auto& dtype = in[0].dtype();
gloo::AllgatherOptions opts(_context);
GENERATE_FUNC(dtype, set_input, opts, in[0]);
GENERATE_FUNC(dtype, set_output, opts, out[0]);
opts.setTag(_tag);
gloo::allgather(opts);
_comm_context->AllGather(&(out[0]), in[0], _tag);
}
};
......@@ -526,9 +451,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
bool sync_op) {
std::shared_ptr<AllgatherGlooTask> task;
auto tag = next_tag();
auto context = get_context();
auto comm_context = this->GetCommContext();
task = std::make_shared<AllgatherGlooTask>(
rank_, context, in_tensors, out_tensors, tag);
rank_, comm_context, in_tensors, out_tensors, tag);
task->Run();
return task;
}
......@@ -536,14 +461,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
class ReduceGlooTask : public ProcessGroupGloo::GlooTask {
public:
ReduceGlooTask(int rank,
const std::shared_ptr<gloo::Context>& context,
phi::distributed::GlooCommContext* comm_context,
std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
ReduceOp reduce_op,
int dst,
uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, inputs, CommType::REDUCE),
_context(context),
_comm_context(comm_context),
_inputs(inputs),
_outputs(outputs),
_reduce_op(reduce_op),
......@@ -553,37 +478,18 @@ class ReduceGlooTask : public ProcessGroupGloo::GlooTask {
void Run() override { _do_reduce(_inputs, _outputs, _dst); }
private:
std::shared_ptr<gloo::Context> _context;
phi::distributed::GlooCommContext* _comm_context;
std::vector<phi::DenseTensor> _inputs;
std::vector<phi::DenseTensor> _outputs;
const ReduceOp _reduce_op;
int _dst;
uint32_t _tag;
gloo::ReduceOptions::Func _get_function(const phi::DataType type,
const ReduceOp op) {
gloo::ReduceOptions::Func fn;
GENERATE_FUNC(type, _get_function_impl, fn, op);
return fn;
}
template <typename T>
void _get_function_impl(gloo::ReduceOptions::Func& fn, // NOLINT
const ReduceOp op) {
fn = get_function<T>(op);
}
void _do_reduce(std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
int dst) {
const auto& dtype = inputs[0].dtype();
gloo::ReduceOptions opts(_context);
GENERATE_FUNC(dtype, set_input, opts, inputs[0]);
GENERATE_FUNC(dtype, set_output, opts, outputs[0]);
opts.setReduceFunction(_get_function(dtype, _reduce_op));
opts.setTag(_tag);
opts.setRoot(dst);
gloo::reduce(opts);
_comm_context->Reduce(
&(outputs[0]), inputs[0], static_cast<int>(_reduce_op), _dst, _tag);
}
};
......@@ -595,11 +501,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Reduce(
) {
std::shared_ptr<ReduceGlooTask> task;
auto tag = next_tag();
auto context = get_context();
auto comm_context = this->GetCommContext();
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
task = std::make_shared<ReduceGlooTask>(rank_,
context,
comm_context,
in_wrapper,
out_wrapper,
opts.reduce_op,
......@@ -619,14 +525,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Reduce(
class ScatterGlooTask : public ProcessGroupGloo::GlooTask {
public:
ScatterGlooTask(int rank,
const std::shared_ptr<gloo::Context>& context,
phi::distributed::GlooCommContext* comm_context,
std::vector<phi::DenseTensor>& inputs, // NOLINT
std::vector<phi::DenseTensor>& outputs, // NOLINT
int src,
int size,
uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, inputs, CommType::SCATTER),
_context(context),
_comm_context(comm_context),
_inputs(inputs),
_outputs(outputs),
_src(src),
......@@ -636,7 +542,7 @@ class ScatterGlooTask : public ProcessGroupGloo::GlooTask {
void Run() override { _do_scatter(_inputs, _outputs, _src); }
private:
std::shared_ptr<gloo::Context> _context;
phi::distributed::GlooCommContext* _comm_context;
std::vector<phi::DenseTensor> _inputs;
std::vector<phi::DenseTensor> _outputs;
int _src;
......@@ -646,15 +552,7 @@ class ScatterGlooTask : public ProcessGroupGloo::GlooTask {
void _do_scatter(std::vector<phi::DenseTensor>& in, // NOLINT
std::vector<phi::DenseTensor>& out, // NOLINT
int src) {
const auto& dtype = in[0].dtype();
gloo::ScatterOptions opts(_context);
if (rank_ == src) {
GENERATE_FUNC(dtype, set_inputs_for_scatter, opts, in[0], _size);
}
GENERATE_FUNC(dtype, set_output, opts, out[0]);
opts.setRoot(src);
opts.setTag(_tag);
gloo::scatter(opts);
_comm_context->Scatter(&(out[0]), in[0], _src, _size, _tag);
}
};
......@@ -665,11 +563,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Scatter(
bool sync_op) {
std::shared_ptr<ScatterGlooTask> task;
auto tag = next_tag();
auto context = get_context();
auto comm_context = this->GetCommContext();
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
task = std::make_shared<ScatterGlooTask>(
rank_, context, in_wrapper, out_wrapper, opts.root_rank, size_, tag);
rank_, comm_context, in_wrapper, out_wrapper, opts.root_rank, size_, tag);
task->Run();
return task;
}
......@@ -684,13 +582,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Scatter(
class GatherGlooTask : public ProcessGroupGloo::GlooTask {
public:
GatherGlooTask(int rank,
const std::shared_ptr<gloo::Context>& context,
phi::distributed::GlooCommContext* comm_context,
const phi::DenseTensor& input, // NOLINT
phi::DenseTensor* output, // NOLINT
int src,
uint32_t tag)
: ProcessGroupGloo::GlooTask(rank, {input}, CommType::GATHER),
_context(context),
_comm_context(comm_context),
_input(input),
_output(*output),
_src(src),
......@@ -699,7 +597,7 @@ class GatherGlooTask : public ProcessGroupGloo::GlooTask {
void Run() override { _do_gather(_input, _output, _src); }
private:
std::shared_ptr<gloo::Context> _context;
phi::distributed::GlooCommContext* _comm_context;
phi::DenseTensor _input;
phi::DenseTensor _output;
int _src;
......@@ -708,16 +606,7 @@ class GatherGlooTask : public ProcessGroupGloo::GlooTask {
void _do_gather(phi::DenseTensor& in, // NOLINT
phi::DenseTensor& out, // NOLINT
int src) {
const auto& dtype = in.dtype();
gloo::GatherOptions opts(_context);
if (rank_ == src) {
GENERATE_FUNC(dtype, set_output, opts, out);
}
GENERATE_FUNC(dtype, set_input, opts, in);
opts.setRoot(src);
opts.setTag(_tag);
gloo::gather(opts);
_comm_context->Gather(&(out), in, src, _tag);
}
};
......@@ -733,9 +622,9 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::Gather(
platform::errors::InvalidArgument("Gloo cannot use use_calc_stream."));
std::shared_ptr<GatherGlooTask> task;
auto tag = next_tag();
auto context = get_context();
auto comm_context = this->GetCommContext();
task = std::make_shared<GatherGlooTask>(
rank_, context, in_tensor, out_tensor, opts.root_rank, tag);
rank_, comm_context, in_tensor, out_tensor, opts.root_rank, tag);
task->Run();
return task;
}
......@@ -804,11 +693,24 @@ std::shared_ptr<ProcessGroupGloo> ProcessGroupGloo::CreateProcessGroupGloo(
} else {
opts->device = ProcessGroupGloo::createDefaultDevice();
}
phi::distributed::CommContextManager::CreateGlooCommContext(
store, gid, rank, size);
auto process_group =
std::make_shared<ProcessGroupGloo>(store, rank, size, gid, opts);
ProcessGroupIdMap::GetInstance().emplace(gid, process_group);
return process_group;
}
phi::distributed::GlooCommContext* ProcessGroupGloo::GetCommContext() {
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
auto comm_context = static_cast<phi::distributed::GlooCommContext*>(
comm_context_manager.Get(this->gid_));
PADDLE_ENFORCE_NE(comm_context,
nullptr,
phi::errors::Unavailable("GlooCommContext is nullptr"));
return comm_context;
}
} // namespace distributed
} // namespace paddle
......@@ -20,6 +20,7 @@
#include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/distributed/collective/process_group_without_stream.h"
#include "paddle/phi/core/distributed/gloo_comm_context.h"
#include "paddle/phi/core/distributed/store/store.h"
#include "paddle/phi/core/distributed/store/tcp_store.h"
......@@ -225,6 +226,8 @@ class ProcessGroupGloo : public ProcessGroupWithoutStream {
return GetDeviceContext(place);
}
phi::distributed::GlooCommContext* GetCommContext();
// Helper functions for Gloo.
static std::shared_ptr<::gloo::transport::Device> createDeviceForHostname(
const std::string& hostname);
......
......@@ -24,7 +24,7 @@ namespace paddle {
namespace distributed {
// TODO(shenliang03): To support AVG for reduce
enum class ReduceOp : std::uint8_t { SUM = 0, AVG, MAX, MIN, PRODUCT };
enum class ReduceOp : std::uint8_t { SUM = 0, MAX, MIN, PRODUCT, AVG };
struct AllreduceOptions {
ReduceOp reduce_op = ReduceOp::SUM;
......
......@@ -17,8 +17,11 @@
#include <gloo/allgather.h>
#include <gloo/allreduce.h>
#include <gloo/barrier.h>
#include <gloo/broadcast.h>
#include <gloo/gather.h>
#include <gloo/reduce.h>
#include <gloo/scatter.h>
#include <gloo/types.h>
#include "paddle/phi/common/data_type.h"
......@@ -41,7 +44,8 @@ GlooCommContext::GlooCommContext(
void GlooCommContext::Broadcast(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int root) {
int root,
uint32_t tag) {
// gloo only uses CPU now
CommStaticCheck::SameShape(*out_tensor,
in_tensor,
......@@ -56,15 +60,18 @@ void GlooCommContext::Broadcast(phi::DenseTensor* out_tensor,
GENERATE_FUNC(dtype, SetInput, &opts, in_tensor);
}
opts.setRoot(root);
opts.setTag(tag);
gloo::broadcast(opts);
}
void GlooCommContext::AllGather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor) {
const phi::DenseTensor& in_tensor,
uint32_t tag) {
// gloo only uses CPU now
gloo::AllgatherOptions opts(gloo_context_);
const auto& dtype = in_tensor.dtype();
opts.setTag(tag);
GENERATE_FUNC(dtype, SetInput, &opts, in_tensor);
GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor);
gloo::allgather(opts);
......@@ -72,8 +79,10 @@ void GlooCommContext::AllGather(phi::DenseTensor* out_tensor,
void GlooCommContext::AllReduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int reduce_type) {
int reduce_type,
uint32_t tag) {
gloo::AllreduceOptions opts(gloo_context_);
opts.setTag(tag);
const auto& dtype = in_tensor.dtype();
GENERATE_FUNC(dtype, SetInput, &opts, in_tensor);
GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor);
......@@ -84,9 +93,11 @@ void GlooCommContext::AllReduce(phi::DenseTensor* out_tensor,
void GlooCommContext::Reduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int reduce_type,
int root) {
int root,
uint32_t tag) {
gloo::ReduceOptions opts(gloo_context_);
opts.setRoot(root);
opts.setTag(tag);
const auto& dtype = in_tensor.dtype();
GENERATE_FUNC(dtype, SetInput, &opts, in_tensor);
GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor);
......@@ -94,5 +105,65 @@ void GlooCommContext::Reduce(phi::DenseTensor* out_tensor,
gloo::reduce(opts);
}
void GlooCommContext::Gather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int src,
uint32_t tag) {
gloo::GatherOptions opts(gloo_context_);
const auto& dtype = in_tensor.dtype();
opts.setTag(tag);
opts.setRoot(src);
GENERATE_FUNC(dtype, SetInput, &opts, in_tensor);
if (rank_ == src) {
GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor);
}
gloo::gather(opts);
}
void GlooCommContext::Scatter(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int src,
int size,
uint32_t tag) {
gloo::ScatterOptions opts(gloo_context_);
const auto& dtype = in_tensor.dtype();
if (rank_ == src) {
GENERATE_FUNC(dtype, SetInputForScatter, &opts, in_tensor, size);
}
GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor);
opts.setRoot(src);
opts.setTag(tag);
gloo::scatter(opts);
}
void GlooCommContext::Barrier() {
gloo::BarrierOptions opts(gloo_context_);
gloo::barrier(opts);
}
void GlooCommContext::Send(const phi::DenseTensor& in_tensor,
int dst,
uint32_t tag) {
SendRecvOptions opts(gloo_context_);
const auto& dtype = in_tensor.dtype();
GENERATE_FUNC(dtype, SetInput, &opts, in_tensor);
opts.setSrc(gloo_context_.get()->rank);
opts.setDst(dst);
opts.setTag(tag);
send_recv(&opts);
}
void GlooCommContext::Recv(phi::DenseTensor* out_tensor,
int src,
uint32_t tag) {
SendRecvOptions opts(gloo_context_);
const auto& dtype = out_tensor->dtype();
GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor);
opts.setSrc(src);
opts.setDst(gloo_context_.get()->rank);
opts.setTag(tag);
send_recv(&opts);
}
} // namespace distributed
} // namespace phi
......@@ -35,17 +35,38 @@ class GlooCommContext final : public CommContext {
void Broadcast(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int root);
int root,
uint32_t tag = 0);
void AllReduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int reduce_type);
int reduce_type,
uint32_t tag = 0);
void Reduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int reduce_type,
int root);
int root,
uint32_t tag = 0);
void AllGather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor);
const phi::DenseTensor& in_tensor,
uint32_t tag = 0);
void Gather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int src,
uint32_t tag = 0);
void Scatter(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int src,
int size,
uint32_t tag = 0);
void Barrier();
void Send(const phi::DenseTensor& in_tensor, int dst, uint32_t tag = 0);
void Recv(phi::DenseTensor* out_tensor, int src, uint32_t tag = 0);
private:
DISABLE_COPY_AND_ASSIGN(GlooCommContext);
......
......@@ -91,5 +91,20 @@ std::shared_ptr<gloo::transport::Device> CreateGlooDevice() {
}
}
void send_recv(SendRecvOptions* opts) {
const auto& context = opts->context;
gloo::transport::UnboundBuffer* in = opts->in.get();
gloo::transport::UnboundBuffer* out = opts->out.get();
const auto slot = gloo::Slot::build(kSendRecvSlotPrefix, opts->tag);
if (context->rank == opts->src) {
in->send(opts->dst, slot);
in->waitSend(opts->timeout);
} else if (context->rank == opts->dst) {
out->recv(opts->src, slot);
out->waitRecv(opts->timeout);
}
}
} // namespace distributed
} // namespace phi
......@@ -14,6 +14,7 @@
#pragma once
#include <gloo/allreduce.h>
#include <gloo/math.h>
#include <gloo/transport/tcp/device.h>
#include <gloo/types.h>
......@@ -103,6 +104,19 @@ void SetInput(P* opts, const phi::DenseTensor& tensor) {
tensor.numel());
}
template <typename T, typename P>
void SetInputForScatter(P* opts, const phi::DenseTensor& tensor, int nranks) {
std::vector<T*> ret;
ret.reserve(nranks);
T* raw_pointer = reinterpret_cast<T*>(const_cast<void*>(tensor.data()));
size_t offset = 0;
for (int i = 0; i < nranks; i++) {
ret.push_back(raw_pointer + offset);
offset += tensor.numel() / nranks;
}
opts->setInputs(ret, tensor.numel() / nranks);
}
template <typename T, typename P>
void SetReduceFunc(P* opts, int reduce_type) {
// gloo only support mutable data input
......@@ -136,5 +150,55 @@ void SetReduceFunc(P* opts, int reduce_type) {
// env preparation
std::shared_ptr<gloo::transport::Device> CreateGlooDevice();
constexpr uint8_t kSendRecvSlotPrefix = 0x08;
class SendRecvOptions {
public:
explicit SendRecvOptions(const std::shared_ptr<gloo::Context>& context)
: context(context), timeout(context->getTimeout()) {}
template <typename T>
void setInput(T* ptr, size_t elements) {
this->in = context->createUnboundBuffer(ptr, elements * sizeof(T));
}
template <typename T>
void setOutput(T* ptr, size_t elements) {
this->out = context->createUnboundBuffer(ptr, elements * sizeof(T));
}
void setSrc(int src) { this->src = src; }
void setDst(int dst) { this->dst = dst; }
void setTag(uint32_t tag) { this->tag = tag; }
void setTimeout(std::chrono::milliseconds timeout) {
this->timeout = timeout;
}
protected:
std::shared_ptr<gloo::Context> context;
std::unique_ptr<gloo::transport::UnboundBuffer> in;
std::unique_ptr<gloo::transport::UnboundBuffer> out;
// Rank of process to send_recv from.
int src = -1;
// Rank of process to send_recv to.
int dst = -1;
// Tag for this operation.
// Must be unique across operations executing in parallel.
uint32_t tag = 0;
// End-to-end timeout for this operation.
std::chrono::milliseconds timeout;
friend void send_recv(SendRecvOptions*);
};
void send_recv(SendRecvOptions* opts);
} // namespace distributed
} // namespace phi
......@@ -14,7 +14,7 @@
import os
import test_collective_api_base as test_base
import legacy_test.test_collective_api_base as test_base
import paddle
import paddle.distributed as dist
......
......@@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
from legacy_test.test_collective_api_base import (
TestCollectiveAPIRunnerBase,
runtime_main,
)
import paddle
import paddle.distributed as dist
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册