未验证 提交 1d1e5484 编写于 作者: X Xing-lil 提交者: GitHub

Update gloo in dygraph (#55537)

* update broadcast gloo in dygraph

* update

* update reduce gloo in dygraph

* update reduce gloo in dygraph

* update

* update allreduce allgather

* update all

* update

* update

* update
上级 982e0a9d
......@@ -20,6 +20,7 @@
#include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/distributed/collective/process_group_without_stream.h"
#include "paddle/phi/core/distributed/gloo_comm_context.h"
#include "paddle/phi/core/distributed/store/store.h"
#include "paddle/phi/core/distributed/store/tcp_store.h"
......@@ -225,6 +226,8 @@ class ProcessGroupGloo : public ProcessGroupWithoutStream {
return GetDeviceContext(place);
}
phi::distributed::GlooCommContext* GetCommContext();
// Helper functions for Gloo.
static std::shared_ptr<::gloo::transport::Device> createDeviceForHostname(
const std::string& hostname);
......
......@@ -24,7 +24,7 @@ namespace paddle {
namespace distributed {
// TODO(shenliang03): To support AVG for reduce
enum class ReduceOp : std::uint8_t { SUM = 0, AVG, MAX, MIN, PRODUCT };
enum class ReduceOp : std::uint8_t { SUM = 0, MAX, MIN, PRODUCT, AVG };
struct AllreduceOptions {
ReduceOp reduce_op = ReduceOp::SUM;
......
......@@ -17,8 +17,11 @@
#include <gloo/allgather.h>
#include <gloo/allreduce.h>
#include <gloo/barrier.h>
#include <gloo/broadcast.h>
#include <gloo/gather.h>
#include <gloo/reduce.h>
#include <gloo/scatter.h>
#include <gloo/types.h>
#include "paddle/phi/common/data_type.h"
......@@ -41,7 +44,8 @@ GlooCommContext::GlooCommContext(
void GlooCommContext::Broadcast(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int root) {
int root,
uint32_t tag) {
// gloo only uses CPU now
CommStaticCheck::SameShape(*out_tensor,
in_tensor,
......@@ -56,15 +60,18 @@ void GlooCommContext::Broadcast(phi::DenseTensor* out_tensor,
GENERATE_FUNC(dtype, SetInput, &opts, in_tensor);
}
opts.setRoot(root);
opts.setTag(tag);
gloo::broadcast(opts);
}
void GlooCommContext::AllGather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor) {
const phi::DenseTensor& in_tensor,
uint32_t tag) {
// gloo only uses CPU now
gloo::AllgatherOptions opts(gloo_context_);
const auto& dtype = in_tensor.dtype();
opts.setTag(tag);
GENERATE_FUNC(dtype, SetInput, &opts, in_tensor);
GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor);
gloo::allgather(opts);
......@@ -72,8 +79,10 @@ void GlooCommContext::AllGather(phi::DenseTensor* out_tensor,
void GlooCommContext::AllReduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int reduce_type) {
int reduce_type,
uint32_t tag) {
gloo::AllreduceOptions opts(gloo_context_);
opts.setTag(tag);
const auto& dtype = in_tensor.dtype();
GENERATE_FUNC(dtype, SetInput, &opts, in_tensor);
GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor);
......@@ -84,9 +93,11 @@ void GlooCommContext::AllReduce(phi::DenseTensor* out_tensor,
void GlooCommContext::Reduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int reduce_type,
int root) {
int root,
uint32_t tag) {
gloo::ReduceOptions opts(gloo_context_);
opts.setRoot(root);
opts.setTag(tag);
const auto& dtype = in_tensor.dtype();
GENERATE_FUNC(dtype, SetInput, &opts, in_tensor);
GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor);
......@@ -94,5 +105,65 @@ void GlooCommContext::Reduce(phi::DenseTensor* out_tensor,
gloo::reduce(opts);
}
void GlooCommContext::Gather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int src,
uint32_t tag) {
gloo::GatherOptions opts(gloo_context_);
const auto& dtype = in_tensor.dtype();
opts.setTag(tag);
opts.setRoot(src);
GENERATE_FUNC(dtype, SetInput, &opts, in_tensor);
if (rank_ == src) {
GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor);
}
gloo::gather(opts);
}
void GlooCommContext::Scatter(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int src,
int size,
uint32_t tag) {
gloo::ScatterOptions opts(gloo_context_);
const auto& dtype = in_tensor.dtype();
if (rank_ == src) {
GENERATE_FUNC(dtype, SetInputForScatter, &opts, in_tensor, size);
}
GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor);
opts.setRoot(src);
opts.setTag(tag);
gloo::scatter(opts);
}
void GlooCommContext::Barrier() {
gloo::BarrierOptions opts(gloo_context_);
gloo::barrier(opts);
}
void GlooCommContext::Send(const phi::DenseTensor& in_tensor,
int dst,
uint32_t tag) {
SendRecvOptions opts(gloo_context_);
const auto& dtype = in_tensor.dtype();
GENERATE_FUNC(dtype, SetInput, &opts, in_tensor);
opts.setSrc(gloo_context_.get()->rank);
opts.setDst(dst);
opts.setTag(tag);
send_recv(&opts);
}
void GlooCommContext::Recv(phi::DenseTensor* out_tensor,
int src,
uint32_t tag) {
SendRecvOptions opts(gloo_context_);
const auto& dtype = out_tensor->dtype();
GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor);
opts.setSrc(src);
opts.setDst(gloo_context_.get()->rank);
opts.setTag(tag);
send_recv(&opts);
}
} // namespace distributed
} // namespace phi
......@@ -35,17 +35,38 @@ class GlooCommContext final : public CommContext {
void Broadcast(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int root);
int root,
uint32_t tag = 0);
void AllReduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int reduce_type);
int reduce_type,
uint32_t tag = 0);
void Reduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int reduce_type,
int root);
int root,
uint32_t tag = 0);
void AllGather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor);
const phi::DenseTensor& in_tensor,
uint32_t tag = 0);
void Gather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int src,
uint32_t tag = 0);
void Scatter(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int src,
int size,
uint32_t tag = 0);
void Barrier();
void Send(const phi::DenseTensor& in_tensor, int dst, uint32_t tag = 0);
void Recv(phi::DenseTensor* out_tensor, int src, uint32_t tag = 0);
private:
DISABLE_COPY_AND_ASSIGN(GlooCommContext);
......
......@@ -91,5 +91,20 @@ std::shared_ptr<gloo::transport::Device> CreateGlooDevice() {
}
}
void send_recv(SendRecvOptions* opts) {
const auto& context = opts->context;
gloo::transport::UnboundBuffer* in = opts->in.get();
gloo::transport::UnboundBuffer* out = opts->out.get();
const auto slot = gloo::Slot::build(kSendRecvSlotPrefix, opts->tag);
if (context->rank == opts->src) {
in->send(opts->dst, slot);
in->waitSend(opts->timeout);
} else if (context->rank == opts->dst) {
out->recv(opts->src, slot);
out->waitRecv(opts->timeout);
}
}
} // namespace distributed
} // namespace phi
......@@ -14,6 +14,7 @@
#pragma once
#include <gloo/allreduce.h>
#include <gloo/math.h>
#include <gloo/transport/tcp/device.h>
#include <gloo/types.h>
......@@ -103,6 +104,19 @@ void SetInput(P* opts, const phi::DenseTensor& tensor) {
tensor.numel());
}
template <typename T, typename P>
void SetInputForScatter(P* opts, const phi::DenseTensor& tensor, int nranks) {
std::vector<T*> ret;
ret.reserve(nranks);
T* raw_pointer = reinterpret_cast<T*>(const_cast<void*>(tensor.data()));
size_t offset = 0;
for (int i = 0; i < nranks; i++) {
ret.push_back(raw_pointer + offset);
offset += tensor.numel() / nranks;
}
opts->setInputs(ret, tensor.numel() / nranks);
}
template <typename T, typename P>
void SetReduceFunc(P* opts, int reduce_type) {
// gloo only support mutable data input
......@@ -136,5 +150,55 @@ void SetReduceFunc(P* opts, int reduce_type) {
// env preparation
std::shared_ptr<gloo::transport::Device> CreateGlooDevice();
constexpr uint8_t kSendRecvSlotPrefix = 0x08;
class SendRecvOptions {
public:
explicit SendRecvOptions(const std::shared_ptr<gloo::Context>& context)
: context(context), timeout(context->getTimeout()) {}
template <typename T>
void setInput(T* ptr, size_t elements) {
this->in = context->createUnboundBuffer(ptr, elements * sizeof(T));
}
template <typename T>
void setOutput(T* ptr, size_t elements) {
this->out = context->createUnboundBuffer(ptr, elements * sizeof(T));
}
void setSrc(int src) { this->src = src; }
void setDst(int dst) { this->dst = dst; }
void setTag(uint32_t tag) { this->tag = tag; }
void setTimeout(std::chrono::milliseconds timeout) {
this->timeout = timeout;
}
protected:
std::shared_ptr<gloo::Context> context;
std::unique_ptr<gloo::transport::UnboundBuffer> in;
std::unique_ptr<gloo::transport::UnboundBuffer> out;
// Rank of process to send_recv from.
int src = -1;
// Rank of process to send_recv to.
int dst = -1;
// Tag for this operation.
// Must be unique across operations executing in parallel.
uint32_t tag = 0;
// End-to-end timeout for this operation.
std::chrono::milliseconds timeout;
friend void send_recv(SendRecvOptions*);
};
void send_recv(SendRecvOptions* opts);
} // namespace distributed
} // namespace phi
......@@ -14,7 +14,7 @@
import os
import test_collective_api_base as test_base
import legacy_test.test_collective_api_base as test_base
import paddle
import paddle.distributed as dist
......
......@@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from test_collective_api_base import TestCollectiveAPIRunnerBase, runtime_main
from legacy_test.test_collective_api_base import (
TestCollectiveAPIRunnerBase,
runtime_main,
)
import paddle
import paddle.distributed as dist
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册