diff --git a/paddle/fluid/distributed/collective/process_group_gloo.cc b/paddle/fluid/distributed/collective/process_group_gloo.cc index 8a87008484d99340030defb6a454502d973863fb..05ae7ed745623ef45e40ffaa2fa751f67e60f1ce 100644 --- a/paddle/fluid/distributed/collective/process_group_gloo.cc +++ b/paddle/fluid/distributed/collective/process_group_gloo.cc @@ -24,16 +24,12 @@ #include #endif -#include -#include #include -#include #include "paddle/fluid/distributed/collective/common.h" -#include "paddle/fluid/distributed/collective/gloo_send_recv.h" #include "paddle/fluid/distributed/collective/process_group_gloo.h" -#include "paddle/fluid/framework/fleet/gloo_wrapper.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" namespace paddle { namespace distributed { @@ -99,28 +95,6 @@ namespace distributed { } #endif -typedef void (*reduce_func)(void*, const void*, const void*, size_t); - -template -reduce_func get_function(const ReduceOp& r) { - switch (r) { - case ReduceOp::SUM: - return reduce_func(&::gloo::sum); - case ReduceOp::PRODUCT: - return reduce_func(&::gloo::product); - case ReduceOp::MIN: - return reduce_func(&::gloo::min); - case ReduceOp::MAX: - return reduce_func(&::gloo::max); - case ReduceOp::AVG: - VLOG(0) << "Error: Unsupported ReduceOp::AVG."; - exit(-1); - } - - VLOG(0) << "Error: Unknown ReduceOp."; - exit(-1); -} - template T* get_data(phi::DenseTensor& tensor) { // NOLINT return reinterpret_cast(tensor.data()); @@ -188,21 +162,19 @@ ProcessGroupGloo::ProcessGroupGloo( _tag(0), _store(new GlooStore(store)) { _context = std::make_shared(rank, world_size); - auto prefix_store = - ::gloo::rendezvous::PrefixStore(std::to_string(gid), *_store); - _context->connectFullMesh(prefix_store, options->device); + _context->connectFullMesh(*_store, options->device); } class BroadcastGlooTask : public ProcessGroupGloo::GlooTask { public: - BroadcastGlooTask(const std::shared_ptr& context, + BroadcastGlooTask(phi::distributed::GlooCommContext* comm_context, std::vector& inputs, // NOLINT std::vector& outputs, // NOLINT int rank, int root, uint32_t tag) : ProcessGroupGloo::GlooTask(rank, inputs, CommType::BROADCAST), - _context(context), + _comm_context(comm_context), _root(root), _inputs(inputs), _outputs(outputs), @@ -211,22 +183,14 @@ class BroadcastGlooTask : public ProcessGroupGloo::GlooTask { void Run() override { _do_broadcast(_inputs[0], _outputs[0]); } private: - std::shared_ptr _context; + phi::distributed::GlooCommContext* _comm_context; const int _root; std::vector _inputs{}; std::vector _outputs{}; const uint32_t _tag; void _do_broadcast(phi::DenseTensor& in, phi::DenseTensor& out) { // NOLINT - gloo::BroadcastOptions opts(_context); - const auto& dtype = in.dtype(); - if (rank_ == _root) { - GENERATE_FUNC(dtype, set_input, opts, in); - } - GENERATE_FUNC(dtype, set_output, opts, out); - opts.setRoot(_root); - opts.setTag(_tag); - gloo::broadcast(opts); + _comm_context->Broadcast(&(out), in, _root, _tag); } }; @@ -256,22 +220,22 @@ std::shared_ptr ProcessGroupGloo::Broadcast( auto root = opts.source_rank; std::unique_ptr task; auto tag = next_tag(); - auto context = get_context(); + auto comm_context = this->GetCommContext(); task = std::make_unique( - context, inputs, outputs, rank_, root, tag); + comm_context, inputs, outputs, rank_, root, tag); task->Run(); return task; } class SendGlooTask : public ProcessGroupGloo::GlooTask { public: - SendGlooTask(const std::shared_ptr& context, + SendGlooTask(phi::distributed::GlooCommContext* comm_context, std::vector* inputs, int rank, int dst_rank, uint32_t tag) : ProcessGroupGloo::GlooTask(rank, *inputs, CommType::SEND), - _context(context), + _comm_context(comm_context), _inputs(*inputs), _dst(dst_rank), _tag(tag) {} @@ -279,20 +243,13 @@ class SendGlooTask : public ProcessGroupGloo::GlooTask { void Run() override { _do_send(_inputs); } private: - std::shared_ptr _context; + phi::distributed::GlooCommContext* _comm_context; std::vector _inputs; int _dst; uint32_t _tag; void _do_send(std::vector& in) { // NOLINT - SendRecvOptions opts(_context); - const auto& dtype = in[0].dtype(); - GENERATE_FUNC(dtype, set_input, opts, in[0]); - - opts.setSrc(_context.get()->rank); - opts.setDst(_dst); - opts.setTag(_tag); - send_recv(&opts); + _comm_context->Send(in[0], _dst, _tag); } }; @@ -306,8 +263,9 @@ std::shared_ptr ProcessGroupGloo::Send( std::vector& inputs, int dst_rank) { std::unique_ptr task; auto tag = next_tag(); - auto context = get_context(); - task = std::make_unique(context, &inputs, rank_, dst_rank, tag); + auto comm_context = this->GetCommContext(); + task = std::make_unique( + comm_context, &inputs, rank_, dst_rank, tag); task->Run(); return task; @@ -315,13 +273,13 @@ std::shared_ptr ProcessGroupGloo::Send( class RecvGlooTask : public ProcessGroupGloo::GlooTask { public: - RecvGlooTask(const std::shared_ptr& context, + RecvGlooTask(phi::distributed::GlooCommContext* comm_context, std::vector* outputs, int rank, int src_rank, uint32_t tag) : ProcessGroupGloo::GlooTask(rank, *outputs, CommType::RECV), - _context(context), + _comm_context(comm_context), _outputs(*outputs), _src(src_rank), _tag(tag) {} @@ -329,20 +287,13 @@ class RecvGlooTask : public ProcessGroupGloo::GlooTask { void Run() override { _do_recv(_outputs); } private: - std::shared_ptr _context; + phi::distributed::GlooCommContext* _comm_context; std::vector _outputs; const int _src; const uint32_t _tag; void _do_recv(std::vector& out) { // NOLINT - SendRecvOptions opts(_context); - const auto& dtype = out[0].dtype(); - GENERATE_FUNC(dtype, set_output, opts, out[0]); - - opts.setSrc(_src); - opts.setDst(_context.get()->rank); - opts.setTag(_tag); - send_recv(&opts); + _comm_context->Recv(&(out[0]), _src, _tag); } }; @@ -356,10 +307,10 @@ std::shared_ptr ProcessGroupGloo::Recv( std::vector& outputs, int src_rank) { std::unique_ptr task; auto tag = next_tag(); - auto context = get_context(); + auto comm_context = this->GetCommContext(); - task = - std::make_unique(context, &outputs, rank_, src_rank, tag); + task = std::make_unique( + comm_context, &outputs, rank_, src_rank, tag); task->Run(); return task; } @@ -367,13 +318,13 @@ std::shared_ptr ProcessGroupGloo::Recv( class AllreduceGlooTask : public ProcessGroupGloo::GlooTask { public: AllreduceGlooTask(int rank, - const std::shared_ptr& context, + phi::distributed::GlooCommContext* comm_context, std::vector& inputs, // NOLINT std::vector& outputs, // NOLINT ReduceOp reduce_op, uint32_t tag) : ProcessGroupGloo::GlooTask(rank, inputs, CommType::ALLREDUCE), - _context(context), + _comm_context(comm_context), _inputs(inputs), _outputs(outputs), _reduce_op(reduce_op), @@ -382,34 +333,16 @@ class AllreduceGlooTask : public ProcessGroupGloo::GlooTask { void Run() override { _do_allreduce(_inputs, _outputs); } private: - std::shared_ptr _context; + phi::distributed::GlooCommContext* _comm_context; std::vector _inputs; std::vector _outputs; const ReduceOp _reduce_op; uint32_t _tag; - gloo::AllreduceOptions::Func _get_function(const phi::DataType type, - const ReduceOp op) { - gloo::AllreduceOptions::Func fn; - GENERATE_FUNC(type, _get_function_impl, fn, op); - return fn; - } - - template - void _get_function_impl(gloo::AllreduceOptions::Func& fn, // NOLINT - const ReduceOp op) { - fn = get_function(op); - } - void _do_allreduce(std::vector& ins, // NOLINT std::vector& outs) { // NOLINT - const auto& dtype = ins[0].dtype(); - gloo::AllreduceOptions opts(_context); - GENERATE_FUNC(dtype, set_inputs, opts, ins); - GENERATE_FUNC(dtype, set_outputs, opts, outs); - opts.setReduceFunction(_get_function(dtype, _reduce_op)); - opts.setTag(_tag); - gloo::allreduce(opts); + _comm_context->AllReduce( + &(outs[0]), ins[0], static_cast(_reduce_op), _tag); } }; @@ -437,36 +370,33 @@ std::shared_ptr ProcessGroupGloo::AllReduce( bool sync_op) { auto tag = next_tag(); std::shared_ptr task; - auto context = get_context(); + auto comm_context = this->GetCommContext(); task = std::make_shared( - rank_, context, inputs, outputs, opts.reduce_op, tag); + rank_, comm_context, inputs, outputs, opts.reduce_op, tag); task->Run(); return task; } class BarrierGlooTask : public ProcessGroupGloo::GlooTask { public: - BarrierGlooTask(int rank, const std::shared_ptr& context) + BarrierGlooTask(int rank, phi::distributed::GlooCommContext* comm_context) : ProcessGroupGloo::GlooTask( rank, std::vector{}, CommType::BARRIER), - _context(context) {} + _comm_context(comm_context) {} void Run() override { _do_barrier(); } private: - std::shared_ptr _context; + phi::distributed::GlooCommContext* _comm_context; - void _do_barrier() { - gloo::BarrierOptions opts(_context); - gloo::barrier(opts); - } + void _do_barrier() { _comm_context->Barrier(); } }; std::shared_ptr ProcessGroupGloo::Barrier( const BarrierOptions& opts) { std::shared_ptr task; - auto context = get_context(); - task = std::make_shared(rank_, context); + auto comm_context = this->GetCommContext(); + task = std::make_shared(rank_, comm_context); task->Run(); return task; } @@ -474,12 +404,12 @@ std::shared_ptr ProcessGroupGloo::Barrier( class AllgatherGlooTask : public ProcessGroupGloo::GlooTask { public: AllgatherGlooTask(int rank, - const std::shared_ptr& context, + phi::distributed::GlooCommContext* comm_context, std::vector& inputs, // NOLINT std::vector& outputs, // NOLINT uint32_t tag) : ProcessGroupGloo::GlooTask(rank, inputs, CommType::ALLGATHER), - _context(context), + _comm_context(comm_context), _inputs(inputs), _outputs(outputs), _tag(tag) {} @@ -487,19 +417,14 @@ class AllgatherGlooTask : public ProcessGroupGloo::GlooTask { void Run() override { _do_allgather(_inputs, _outputs); } private: - std::shared_ptr _context; + phi::distributed::GlooCommContext* _comm_context; std::vector _inputs; std::vector _outputs; uint32_t _tag; void _do_allgather(std::vector& in, // NOLINT std::vector& out) { // NOLINT - const auto& dtype = in[0].dtype(); - gloo::AllgatherOptions opts(_context); - GENERATE_FUNC(dtype, set_input, opts, in[0]); - GENERATE_FUNC(dtype, set_output, opts, out[0]); - opts.setTag(_tag); - gloo::allgather(opts); + _comm_context->AllGather(&(out[0]), in[0], _tag); } }; @@ -526,9 +451,9 @@ std::shared_ptr ProcessGroupGloo::AllGather( bool sync_op) { std::shared_ptr task; auto tag = next_tag(); - auto context = get_context(); + auto comm_context = this->GetCommContext(); task = std::make_shared( - rank_, context, in_tensors, out_tensors, tag); + rank_, comm_context, in_tensors, out_tensors, tag); task->Run(); return task; } @@ -536,14 +461,14 @@ std::shared_ptr ProcessGroupGloo::AllGather( class ReduceGlooTask : public ProcessGroupGloo::GlooTask { public: ReduceGlooTask(int rank, - const std::shared_ptr& context, + phi::distributed::GlooCommContext* comm_context, std::vector& inputs, // NOLINT std::vector& outputs, // NOLINT ReduceOp reduce_op, int dst, uint32_t tag) : ProcessGroupGloo::GlooTask(rank, inputs, CommType::REDUCE), - _context(context), + _comm_context(comm_context), _inputs(inputs), _outputs(outputs), _reduce_op(reduce_op), @@ -553,37 +478,18 @@ class ReduceGlooTask : public ProcessGroupGloo::GlooTask { void Run() override { _do_reduce(_inputs, _outputs, _dst); } private: - std::shared_ptr _context; + phi::distributed::GlooCommContext* _comm_context; std::vector _inputs; std::vector _outputs; const ReduceOp _reduce_op; int _dst; uint32_t _tag; - gloo::ReduceOptions::Func _get_function(const phi::DataType type, - const ReduceOp op) { - gloo::ReduceOptions::Func fn; - GENERATE_FUNC(type, _get_function_impl, fn, op); - return fn; - } - - template - void _get_function_impl(gloo::ReduceOptions::Func& fn, // NOLINT - const ReduceOp op) { - fn = get_function(op); - } - void _do_reduce(std::vector& inputs, // NOLINT std::vector& outputs, // NOLINT int dst) { - const auto& dtype = inputs[0].dtype(); - gloo::ReduceOptions opts(_context); - GENERATE_FUNC(dtype, set_input, opts, inputs[0]); - GENERATE_FUNC(dtype, set_output, opts, outputs[0]); - opts.setReduceFunction(_get_function(dtype, _reduce_op)); - opts.setTag(_tag); - opts.setRoot(dst); - gloo::reduce(opts); + _comm_context->Reduce( + &(outputs[0]), inputs[0], static_cast(_reduce_op), _dst, _tag); } }; @@ -595,11 +501,11 @@ std::shared_ptr ProcessGroupGloo::Reduce( ) { std::shared_ptr task; auto tag = next_tag(); - auto context = get_context(); + auto comm_context = this->GetCommContext(); std::vector in_wrapper{in_tensor}; std::vector out_wrapper{*out_tensor}; task = std::make_shared(rank_, - context, + comm_context, in_wrapper, out_wrapper, opts.reduce_op, @@ -619,14 +525,14 @@ std::shared_ptr ProcessGroupGloo::Reduce( class ScatterGlooTask : public ProcessGroupGloo::GlooTask { public: ScatterGlooTask(int rank, - const std::shared_ptr& context, + phi::distributed::GlooCommContext* comm_context, std::vector& inputs, // NOLINT std::vector& outputs, // NOLINT int src, int size, uint32_t tag) : ProcessGroupGloo::GlooTask(rank, inputs, CommType::SCATTER), - _context(context), + _comm_context(comm_context), _inputs(inputs), _outputs(outputs), _src(src), @@ -636,7 +542,7 @@ class ScatterGlooTask : public ProcessGroupGloo::GlooTask { void Run() override { _do_scatter(_inputs, _outputs, _src); } private: - std::shared_ptr _context; + phi::distributed::GlooCommContext* _comm_context; std::vector _inputs; std::vector _outputs; int _src; @@ -646,15 +552,7 @@ class ScatterGlooTask : public ProcessGroupGloo::GlooTask { void _do_scatter(std::vector& in, // NOLINT std::vector& out, // NOLINT int src) { - const auto& dtype = in[0].dtype(); - gloo::ScatterOptions opts(_context); - if (rank_ == src) { - GENERATE_FUNC(dtype, set_inputs_for_scatter, opts, in[0], _size); - } - GENERATE_FUNC(dtype, set_output, opts, out[0]); - opts.setRoot(src); - opts.setTag(_tag); - gloo::scatter(opts); + _comm_context->Scatter(&(out[0]), in[0], _src, _size, _tag); } }; @@ -665,11 +563,11 @@ std::shared_ptr ProcessGroupGloo::Scatter( bool sync_op) { std::shared_ptr task; auto tag = next_tag(); - auto context = get_context(); + auto comm_context = this->GetCommContext(); std::vector in_wrapper{in_tensor}; std::vector out_wrapper{*out_tensor}; task = std::make_shared( - rank_, context, in_wrapper, out_wrapper, opts.root_rank, size_, tag); + rank_, comm_context, in_wrapper, out_wrapper, opts.root_rank, size_, tag); task->Run(); return task; } @@ -684,13 +582,13 @@ std::shared_ptr ProcessGroupGloo::Scatter( class GatherGlooTask : public ProcessGroupGloo::GlooTask { public: GatherGlooTask(int rank, - const std::shared_ptr& context, + phi::distributed::GlooCommContext* comm_context, const phi::DenseTensor& input, // NOLINT phi::DenseTensor* output, // NOLINT int src, uint32_t tag) : ProcessGroupGloo::GlooTask(rank, {input}, CommType::GATHER), - _context(context), + _comm_context(comm_context), _input(input), _output(*output), _src(src), @@ -699,7 +597,7 @@ class GatherGlooTask : public ProcessGroupGloo::GlooTask { void Run() override { _do_gather(_input, _output, _src); } private: - std::shared_ptr _context; + phi::distributed::GlooCommContext* _comm_context; phi::DenseTensor _input; phi::DenseTensor _output; int _src; @@ -708,16 +606,7 @@ class GatherGlooTask : public ProcessGroupGloo::GlooTask { void _do_gather(phi::DenseTensor& in, // NOLINT phi::DenseTensor& out, // NOLINT int src) { - const auto& dtype = in.dtype(); - gloo::GatherOptions opts(_context); - if (rank_ == src) { - GENERATE_FUNC(dtype, set_output, opts, out); - } - GENERATE_FUNC(dtype, set_input, opts, in); - - opts.setRoot(src); - opts.setTag(_tag); - gloo::gather(opts); + _comm_context->Gather(&(out), in, src, _tag); } }; @@ -733,9 +622,9 @@ std::shared_ptr ProcessGroupGloo::Gather( platform::errors::InvalidArgument("Gloo cannot use use_calc_stream.")); std::shared_ptr task; auto tag = next_tag(); - auto context = get_context(); + auto comm_context = this->GetCommContext(); task = std::make_shared( - rank_, context, in_tensor, out_tensor, opts.root_rank, tag); + rank_, comm_context, in_tensor, out_tensor, opts.root_rank, tag); task->Run(); return task; } @@ -804,11 +693,24 @@ std::shared_ptr ProcessGroupGloo::CreateProcessGroupGloo( } else { opts->device = ProcessGroupGloo::createDefaultDevice(); } + phi::distributed::CommContextManager::CreateGlooCommContext( + store, gid, rank, size); auto process_group = std::make_shared(store, rank, size, gid, opts); ProcessGroupIdMap::GetInstance().emplace(gid, process_group); return process_group; } +phi::distributed::GlooCommContext* ProcessGroupGloo::GetCommContext() { + const auto& comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + auto comm_context = static_cast( + comm_context_manager.Get(this->gid_)); + PADDLE_ENFORCE_NE(comm_context, + nullptr, + phi::errors::Unavailable("GlooCommContext is nullptr")); + return comm_context; +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/process_group_gloo.h b/paddle/fluid/distributed/collective/process_group_gloo.h index c45b3e74d84938563efa96431775929f95a19ceb..29407c907dc87870b3490071cb8d39a5c014710f 100644 --- a/paddle/fluid/distributed/collective/process_group_gloo.h +++ b/paddle/fluid/distributed/collective/process_group_gloo.h @@ -20,6 +20,7 @@ #include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/process_group_without_stream.h" +#include "paddle/phi/core/distributed/gloo_comm_context.h" #include "paddle/phi/core/distributed/store/store.h" #include "paddle/phi/core/distributed/store/tcp_store.h" @@ -225,6 +226,8 @@ class ProcessGroupGloo : public ProcessGroupWithoutStream { return GetDeviceContext(place); } + phi::distributed::GlooCommContext* GetCommContext(); + // Helper functions for Gloo. static std::shared_ptr<::gloo::transport::Device> createDeviceForHostname( const std::string& hostname); diff --git a/paddle/fluid/distributed/collective/types.h b/paddle/fluid/distributed/collective/types.h index 3bafa53727c7217c3cd0a0512ca3a679cb7ca134..433f645db994b243ea309efdc711e167abe1d001 100644 --- a/paddle/fluid/distributed/collective/types.h +++ b/paddle/fluid/distributed/collective/types.h @@ -24,7 +24,7 @@ namespace paddle { namespace distributed { // TODO(shenliang03): To support AVG for reduce -enum class ReduceOp : std::uint8_t { SUM = 0, AVG, MAX, MIN, PRODUCT }; +enum class ReduceOp : std::uint8_t { SUM = 0, MAX, MIN, PRODUCT, AVG }; struct AllreduceOptions { ReduceOp reduce_op = ReduceOp::SUM; diff --git a/paddle/phi/core/distributed/gloo_comm_context.cc b/paddle/phi/core/distributed/gloo_comm_context.cc index 7c956185ef430587a2732218b2eec7ee82d38889..098bc851bf11c3732ec1729221a825148d3b104f 100644 --- a/paddle/phi/core/distributed/gloo_comm_context.cc +++ b/paddle/phi/core/distributed/gloo_comm_context.cc @@ -17,8 +17,11 @@ #include #include +#include #include +#include #include +#include #include #include "paddle/phi/common/data_type.h" @@ -41,7 +44,8 @@ GlooCommContext::GlooCommContext( void GlooCommContext::Broadcast(phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, - int root) { + int root, + uint32_t tag) { // gloo only uses CPU now CommStaticCheck::SameShape(*out_tensor, in_tensor, @@ -56,15 +60,18 @@ void GlooCommContext::Broadcast(phi::DenseTensor* out_tensor, GENERATE_FUNC(dtype, SetInput, &opts, in_tensor); } opts.setRoot(root); + opts.setTag(tag); gloo::broadcast(opts); } void GlooCommContext::AllGather(phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor) { + const phi::DenseTensor& in_tensor, + uint32_t tag) { // gloo only uses CPU now gloo::AllgatherOptions opts(gloo_context_); const auto& dtype = in_tensor.dtype(); + opts.setTag(tag); GENERATE_FUNC(dtype, SetInput, &opts, in_tensor); GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor); gloo::allgather(opts); @@ -72,8 +79,10 @@ void GlooCommContext::AllGather(phi::DenseTensor* out_tensor, void GlooCommContext::AllReduce(phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, - int reduce_type) { + int reduce_type, + uint32_t tag) { gloo::AllreduceOptions opts(gloo_context_); + opts.setTag(tag); const auto& dtype = in_tensor.dtype(); GENERATE_FUNC(dtype, SetInput, &opts, in_tensor); GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor); @@ -84,9 +93,11 @@ void GlooCommContext::AllReduce(phi::DenseTensor* out_tensor, void GlooCommContext::Reduce(phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, int reduce_type, - int root) { + int root, + uint32_t tag) { gloo::ReduceOptions opts(gloo_context_); opts.setRoot(root); + opts.setTag(tag); const auto& dtype = in_tensor.dtype(); GENERATE_FUNC(dtype, SetInput, &opts, in_tensor); GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor); @@ -94,5 +105,65 @@ void GlooCommContext::Reduce(phi::DenseTensor* out_tensor, gloo::reduce(opts); } +void GlooCommContext::Gather(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + int src, + uint32_t tag) { + gloo::GatherOptions opts(gloo_context_); + const auto& dtype = in_tensor.dtype(); + opts.setTag(tag); + opts.setRoot(src); + GENERATE_FUNC(dtype, SetInput, &opts, in_tensor); + if (rank_ == src) { + GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor); + } + gloo::gather(opts); +} + +void GlooCommContext::Scatter(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + int src, + int size, + uint32_t tag) { + gloo::ScatterOptions opts(gloo_context_); + const auto& dtype = in_tensor.dtype(); + if (rank_ == src) { + GENERATE_FUNC(dtype, SetInputForScatter, &opts, in_tensor, size); + } + GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor); + opts.setRoot(src); + opts.setTag(tag); + gloo::scatter(opts); +} + +void GlooCommContext::Barrier() { + gloo::BarrierOptions opts(gloo_context_); + gloo::barrier(opts); +} + +void GlooCommContext::Send(const phi::DenseTensor& in_tensor, + int dst, + uint32_t tag) { + SendRecvOptions opts(gloo_context_); + const auto& dtype = in_tensor.dtype(); + GENERATE_FUNC(dtype, SetInput, &opts, in_tensor); + opts.setSrc(gloo_context_.get()->rank); + opts.setDst(dst); + opts.setTag(tag); + send_recv(&opts); +} + +void GlooCommContext::Recv(phi::DenseTensor* out_tensor, + int src, + uint32_t tag) { + SendRecvOptions opts(gloo_context_); + const auto& dtype = out_tensor->dtype(); + GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor); + opts.setSrc(src); + opts.setDst(gloo_context_.get()->rank); + opts.setTag(tag); + send_recv(&opts); +} + } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/gloo_comm_context.h b/paddle/phi/core/distributed/gloo_comm_context.h index b8db0431c250429606373fbcfb364a4938c5b92e..50e996c93000933f338deb64aad4e2949a3f45de 100644 --- a/paddle/phi/core/distributed/gloo_comm_context.h +++ b/paddle/phi/core/distributed/gloo_comm_context.h @@ -35,17 +35,38 @@ class GlooCommContext final : public CommContext { void Broadcast(phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, - int root); + int root, + uint32_t tag = 0); void AllReduce(phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, - int reduce_type); + int reduce_type, + uint32_t tag = 0); void Reduce(phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, int reduce_type, - int root); + int root, + uint32_t tag = 0); void AllGather(phi::DenseTensor* out_tensor, - const phi::DenseTensor& in_tensor); + const phi::DenseTensor& in_tensor, + uint32_t tag = 0); + + void Gather(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + int src, + uint32_t tag = 0); + + void Scatter(phi::DenseTensor* out_tensor, + const phi::DenseTensor& in_tensor, + int src, + int size, + uint32_t tag = 0); + + void Barrier(); + + void Send(const phi::DenseTensor& in_tensor, int dst, uint32_t tag = 0); + + void Recv(phi::DenseTensor* out_tensor, int src, uint32_t tag = 0); private: DISABLE_COPY_AND_ASSIGN(GlooCommContext); diff --git a/paddle/phi/core/distributed/gloo_utils.cc b/paddle/phi/core/distributed/gloo_utils.cc index 4d451b930a74774e7d6f507643ea39ccfbeab68d..312681384a1996d5716975cd8845fed0193ef9e6 100644 --- a/paddle/phi/core/distributed/gloo_utils.cc +++ b/paddle/phi/core/distributed/gloo_utils.cc @@ -91,5 +91,20 @@ std::shared_ptr CreateGlooDevice() { } } +void send_recv(SendRecvOptions* opts) { + const auto& context = opts->context; + gloo::transport::UnboundBuffer* in = opts->in.get(); + gloo::transport::UnboundBuffer* out = opts->out.get(); + const auto slot = gloo::Slot::build(kSendRecvSlotPrefix, opts->tag); + + if (context->rank == opts->src) { + in->send(opts->dst, slot); + in->waitSend(opts->timeout); + } else if (context->rank == opts->dst) { + out->recv(opts->src, slot); + out->waitRecv(opts->timeout); + } +} + } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/gloo_utils.h b/paddle/phi/core/distributed/gloo_utils.h index 57e029c17ac9b6cb9e1133c1f31888c269f3351e..89455ea83b0335c57c06c654cb67a1acc996f7f5 100644 --- a/paddle/phi/core/distributed/gloo_utils.h +++ b/paddle/phi/core/distributed/gloo_utils.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include #include @@ -103,6 +104,19 @@ void SetInput(P* opts, const phi::DenseTensor& tensor) { tensor.numel()); } +template +void SetInputForScatter(P* opts, const phi::DenseTensor& tensor, int nranks) { + std::vector ret; + ret.reserve(nranks); + T* raw_pointer = reinterpret_cast(const_cast(tensor.data())); + size_t offset = 0; + for (int i = 0; i < nranks; i++) { + ret.push_back(raw_pointer + offset); + offset += tensor.numel() / nranks; + } + opts->setInputs(ret, tensor.numel() / nranks); +} + template void SetReduceFunc(P* opts, int reduce_type) { // gloo only support mutable data input @@ -136,5 +150,55 @@ void SetReduceFunc(P* opts, int reduce_type) { // env preparation std::shared_ptr CreateGlooDevice(); +constexpr uint8_t kSendRecvSlotPrefix = 0x08; + +class SendRecvOptions { + public: + explicit SendRecvOptions(const std::shared_ptr& context) + : context(context), timeout(context->getTimeout()) {} + + template + void setInput(T* ptr, size_t elements) { + this->in = context->createUnboundBuffer(ptr, elements * sizeof(T)); + } + + template + void setOutput(T* ptr, size_t elements) { + this->out = context->createUnboundBuffer(ptr, elements * sizeof(T)); + } + + void setSrc(int src) { this->src = src; } + + void setDst(int dst) { this->dst = dst; } + + void setTag(uint32_t tag) { this->tag = tag; } + + void setTimeout(std::chrono::milliseconds timeout) { + this->timeout = timeout; + } + + protected: + std::shared_ptr context; + std::unique_ptr in; + std::unique_ptr out; + + // Rank of process to send_recv from. + int src = -1; + + // Rank of process to send_recv to. + int dst = -1; + + // Tag for this operation. + // Must be unique across operations executing in parallel. + uint32_t tag = 0; + + // End-to-end timeout for this operation. + std::chrono::milliseconds timeout; + + friend void send_recv(SendRecvOptions*); +}; + +void send_recv(SendRecvOptions* opts); + } // namespace distributed } // namespace phi diff --git a/test/collective/collective_allgather_api.py b/test/collective/collective_allgather_api.py index 3a7ed15ab11fe5b209b91e6bab89e656b18b1d6d..761703cb49497d43d00b184935538475a5423be9 100644 --- a/test/collective/collective_allgather_api.py +++ b/test/collective/collective_allgather_api.py @@ -14,7 +14,7 @@ import os -import test_collective_api_base as test_base +import legacy_test.test_collective_api_base as test_base import paddle import paddle.distributed as dist diff --git a/test/collective/collective_allreduce_api.py b/test/collective/collective_allreduce_api.py index a93718643ed3cf70365e0f7dd58902e8711ba575..b1f2770205518c1547680fc94b5db8beeaa5a9b1 100644 --- a/test/collective/collective_allreduce_api.py +++ b/test/collective/collective_allreduce_api.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main +from legacy_test.test_collective_api_base import ( + TestCollectiveAPIRunnerBase, + runtime_main, +) import paddle import paddle.distributed as dist