未验证 提交 1be70bc5 编写于 作者: W Wen Sun 提交者: GitHub

Refactor `ProcessGroup` to support comm context migration & clang compilation (#49451)

* refactor: use base class

* fix: incorrect deps

* fix: add missing header

* refactor: update class structures

* fix: bkcl typo

* fix: remove redundant def
上级 5949f2d7
...@@ -2,14 +2,11 @@ cc_library( ...@@ -2,14 +2,11 @@ cc_library(
process_group process_group
SRCS process_group.cc SRCS process_group.cc
DEPS dense_tensor) DEPS dense_tensor)
cc_library(
process_group_stream
SRCS process_group_stream.cc
DEPS dense_tensor)
cc_library( cc_library(
eager_reducer eager_reducer
SRCS reducer.cc SRCS reducer.cc
DEPS eager_api process_group process_group_stream phi_api string_helper) DEPS eager_api process_group phi_api string_helper)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
cc_library( cc_library(
...@@ -23,7 +20,6 @@ if(WITH_NCCL OR WITH_RCCL) ...@@ -23,7 +20,6 @@ if(WITH_NCCL OR WITH_RCCL)
process_group_nccl process_group_nccl
SRCS process_group_nccl.cc nccl_tools.cc common.cc check.cc SRCS process_group_nccl.cc nccl_tools.cc common.cc check.cc
DEPS process_group DEPS process_group
process_group_stream
place place
enforce enforce
collective_helper collective_helper
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
#include "paddle/fluid/distributed/collective/check.h" #include "paddle/fluid/distributed/collective/check.h"
#include "paddle/fluid/distributed/collective/nccl_tools.h" #include "paddle/fluid/distributed/collective/nccl_tools.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h" #include "paddle/phi/core/errors.h"
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
......
...@@ -14,13 +14,16 @@ ...@@ -14,13 +14,16 @@
#include "paddle/fluid/distributed/collective/nccl_tools.h" #include "paddle/fluid/distributed/collective/nccl_tools.h"
#include "paddle/fluid/platform/enforce.h" #include <unordered_map>
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
ncclRedOp_t ToNCCLRedType(ReduceOp reduction) { ncclRedOp_t ToNCCLRedType(ReduceOp reduction) {
static const std::map<ReduceOp, ncclRedOp_t> red_type = { static const std::unordered_map<ReduceOp, ncclRedOp_t> red_type = {
{ReduceOp::MIN, ncclMin}, {ReduceOp::MIN, ncclMin},
{ReduceOp::MAX, ncclMax}, {ReduceOp::MAX, ncclMax},
{ReduceOp::SUM, ncclSum}, {ReduceOp::SUM, ncclSum},
...@@ -29,7 +32,7 @@ ncclRedOp_t ToNCCLRedType(ReduceOp reduction) { ...@@ -29,7 +32,7 @@ ncclRedOp_t ToNCCLRedType(ReduceOp reduction) {
auto it = red_type.find(reduction); auto it = red_type.find(reduction);
PADDLE_ENFORCE_EQ(it != red_type.end(), PADDLE_ENFORCE_EQ(it != red_type.end(),
true, true,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"Invalid nccl reduction. Must be ncclMin | ncclMax | " "Invalid nccl reduction. Must be ncclMin | ncclMax | "
"ncclProd | ncclSum")); "ncclProd | ncclSum"));
return it->second; return it->second;
......
...@@ -14,20 +14,15 @@ ...@@ -14,20 +14,15 @@
#pragma once #pragma once
#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#endif
#include <string> #include <string>
#include "paddle/fluid/distributed/collective/types.h" #include "paddle/fluid/distributed/collective/types.h"
#ifdef PADDLE_WITH_RCCL #ifdef PADDLE_WITH_RCCL
#include <hip/hip_runtime.h>
#include "paddle/phi/backends/dynload/rccl.h" #include "paddle/phi/backends/dynload/rccl.h"
#else #else
#include <cuda_runtime.h>
#include "paddle/phi/backends/dynload/nccl.h" #include "paddle/phi/backends/dynload/nccl.h"
#endif #endif
......
...@@ -17,44 +17,20 @@ ...@@ -17,44 +17,20 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
ProcessGroup::Task::Task(int rank, CommType comm_type, bool sync_op)
: rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {}
ProcessGroup::Task::~Task() = default;
bool ProcessGroup::Task::IsCompleted() { bool ProcessGroup::Task::IsCompleted() {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
return is_completed_; return is_completed_;
} }
bool ProcessGroup::Task::Wait(std::chrono::milliseconds timeout) {
return false;
}
void ProcessGroup::Task::Synchronize() {}
void ProcessGroup::Task::UpdateWaitChain(const phi::DeviceContext& ctx) {}
ProcessGroup::ProcessGroup(int rank, int size, int gid) ProcessGroup::ProcessGroup(int rank, int size, int gid)
: rank_(rank), size_(size), gid_(gid) { : rank_(rank), size_(size), gid_(gid) {
if (gid != IGNORE_ID) { if (gid != kIgnoreId) {
auto map = ProcessGroupMapFromGid::getInstance(); auto map = ProcessGroupMapFromGid::getInstance();
map->insert(gid_, this); map->insert(gid_, this);
} }
} }
// TODO(sunyilun): methods below will be removed later // TODO(sunyilun): methods below will be removed later
ProcessGroup::Task::Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type)
: rank_(rank), comm_type_(comm_type) {}
ProcessGroup::Task::Task(int rank,
const std::vector<phi::DenseTensor>& inputs,
CommType comm_type,
bool sync_op)
: rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {}
ProcessGroupIdMap& ProcessGroupIdMap::GetInstance() { ProcessGroupIdMap& ProcessGroupIdMap::GetInstance() {
static ProcessGroupIdMap instance; static ProcessGroupIdMap instance;
return instance; return instance;
......
...@@ -17,21 +17,22 @@ ...@@ -17,21 +17,22 @@
#include <chrono> #include <chrono>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/distributed/collective/types.h" #include "paddle/fluid/distributed/collective/types.h"
#include "paddle/fluid/eager/api/utils/tensor_utils.h" #include "paddle/fluid/eager/api/utils/tensor_utils.h" // NOTE: this header is required somewhere
#include "paddle/fluid/framework/tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/phi/core/device_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
constexpr auto kWaitTimeout = std::chrono::milliseconds(0); constexpr auto kWaitTimeout = std::chrono::milliseconds(0);
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
constexpr int IGNORE_ID = -1; constexpr int kIgnoreId = -1;
using Tensor = paddle::experimental::Tensor;
enum class CommType : std::uint8_t { enum class CommType : std::uint8_t {
BROADCAST = 0, BROADCAST = 0,
...@@ -53,23 +54,28 @@ class ProcessGroup { ...@@ -53,23 +54,28 @@ class ProcessGroup {
public: public:
class Task { class Task {
public: public:
Task(int rank, CommType comm_type, bool sync_op); Task(int rank, CommType comm_type, bool sync_op)
: rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {}
virtual ~Task() = default;
virtual ~Task();
virtual bool IsCompleted(); virtual bool IsCompleted();
virtual bool Wait(std::chrono::milliseconds timeout = kWaitTimeout); virtual bool Wait(std::chrono::milliseconds timeout = kWaitTimeout) {
virtual void Synchronize(); return false;
virtual void UpdateWaitChain(const phi::DeviceContext& ctx); }
virtual void Synchronize() {}
virtual void UpdateWaitChain(const phi::DeviceContext& ctx) {}
bool IsSync() const { return sync_op_; } bool IsSync() const { return sync_op_; }
// TODO(sunyilun): methods below will be removed later // TODO(sunyilun): methods below will be removed later
Task(int rank, Task(int rank,
const std::vector<phi::DenseTensor>& inputs, const std::vector<phi::DenseTensor>& inputs,
CommType comm_type); CommType comm_type)
: rank_(rank), comm_type_(comm_type) {}
Task(int rank, Task(int rank,
const std::vector<phi::DenseTensor>& inputs, const std::vector<phi::DenseTensor>& inputs,
CommType comm_type, CommType comm_type,
bool sync_op); bool sync_op)
: rank_(rank), comm_type_(comm_type), sync_op_(sync_op) {}
protected: protected:
const int rank_; const int rank_;
...@@ -92,20 +98,26 @@ class ProcessGroup { ...@@ -92,20 +98,26 @@ class ProcessGroup {
virtual std::string GetBackendName() const = 0; virtual std::string GetBackendName() const = 0;
virtual phi::DeviceContext* GetDeviceContext(const Place& place) const { virtual phi::DeviceContext* GetDeviceContext(const Place& place) const {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"ProcessGroup%s does not support get device_context.", "ProcessGroup%s does not support get device_context.",
GetBackendName())); GetBackendName()));
} }
virtual phi::DeviceContext* GetDeviceContext(const Place& place,
bool use_calc_stream) const {
PADDLE_THROW(phi::errors::Unimplemented(
"ProcessGroup%s does not support get device_context.",
GetBackendName()));
}
// without stream APIs
virtual std::shared_ptr<ProcessGroup::Task> AllGather( virtual std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor, phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
bool sync_op) { bool sync_op) {
return AllGather(out_tensor, PADDLE_THROW(phi::errors::Unimplemented(
in_tensor, "ProcessGroup%s does not support all_gather with sync_op flag.",
/*offset*/ 0, GetBackendName()));
/*numel*/ -1, // -1 indicates the whole tensor
sync_op);
} }
virtual std::shared_ptr<ProcessGroup::Task> AllGather( virtual std::shared_ptr<ProcessGroup::Task> AllGather(
...@@ -114,7 +126,7 @@ class ProcessGroup { ...@@ -114,7 +126,7 @@ class ProcessGroup {
int64_t offset, int64_t offset,
int64_t numel, int64_t numel,
bool sync_op) { bool sync_op) {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"ProcessGroup%s does not support all_gather with sync_op flag.", "ProcessGroup%s does not support all_gather with sync_op flag.",
GetBackendName())); GetBackendName()));
} }
...@@ -124,7 +136,7 @@ class ProcessGroup { ...@@ -124,7 +136,7 @@ class ProcessGroup {
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts, const AllreduceOptions& opts,
bool sync_op) { bool sync_op) {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"ProcessGroup%s does not support all_reduce with sync_op flag.", "ProcessGroup%s does not support all_reduce with sync_op flag.",
GetBackendName())); GetBackendName()));
} }
...@@ -135,14 +147,14 @@ class ProcessGroup { ...@@ -135,14 +147,14 @@ class ProcessGroup {
const std::vector<int64_t>& out_size_each_rank, const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank, const std::vector<int64_t>& in_size_each_rank,
bool sync_op) { bool sync_op) {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"ProcessGroup%s does not support all_to_all with sync_op flag.", "ProcessGroup%s does not support all_to_all with sync_op flag.",
GetBackendName())); GetBackendName()));
} }
virtual std::shared_ptr<ProcessGroup::Task> Barrier( virtual std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) { const BarrierOptions& = BarrierOptions()) {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"ProcessGroup%s does not support barrier.", GetBackendName())); "ProcessGroup%s does not support barrier.", GetBackendName()));
} }
...@@ -151,7 +163,7 @@ class ProcessGroup { ...@@ -151,7 +163,7 @@ class ProcessGroup {
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts, const BroadcastOptions& opts,
bool sync_op) { bool sync_op) {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"ProcessGroup%s does not support broadcast with sync_op flag", "ProcessGroup%s does not support broadcast with sync_op flag",
GetBackendName())); GetBackendName()));
} }
...@@ -161,7 +173,7 @@ class ProcessGroup { ...@@ -161,7 +173,7 @@ class ProcessGroup {
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
const ReduceOptions& opts, const ReduceOptions& opts,
bool sync_op) { bool sync_op) {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"ProcessGroup%s does not support reduce with sync_op flag.", "ProcessGroup%s does not support reduce with sync_op flag.",
GetBackendName())); GetBackendName()));
} }
...@@ -171,7 +183,7 @@ class ProcessGroup { ...@@ -171,7 +183,7 @@ class ProcessGroup {
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts, const ReduceScatterOptions& opts,
bool sync_op) { bool sync_op) {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"ProcessGroup%s does not support reduce_scatter with sync_op flag.", "ProcessGroup%s does not support reduce_scatter with sync_op flag.",
GetBackendName())); GetBackendName()));
} }
...@@ -181,7 +193,7 @@ class ProcessGroup { ...@@ -181,7 +193,7 @@ class ProcessGroup {
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
const ScatterOptions& opts, const ScatterOptions& opts,
bool sync_op) { bool sync_op) {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"ProcessGroup%s does not support scatter with sync_op flag.", "ProcessGroup%s does not support scatter with sync_op flag.",
GetBackendName())); GetBackendName()));
} }
...@@ -189,11 +201,9 @@ class ProcessGroup { ...@@ -189,11 +201,9 @@ class ProcessGroup {
virtual std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor, virtual std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank, int src_rank,
bool sync_op) { bool sync_op) {
return Recv(tensor, PADDLE_THROW(phi::errors::Unimplemented(
src_rank, "ProcessGroup%s does not support recv with sync_op flag.",
/*offset*/ 0, GetBackendName()));
/*numel*/ -1, // -1 indicates the whole tensor
sync_op);
} }
virtual std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor, virtual std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
...@@ -201,18 +211,16 @@ class ProcessGroup { ...@@ -201,18 +211,16 @@ class ProcessGroup {
int64_t offset, int64_t offset,
int64_t numel, int64_t numel,
bool sync_op) { bool sync_op) {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"ProcessGroup%s does not support recv with sync_op flag.", "ProcessGroup%s does not support recv with sync_op flag.",
GetBackendName())); GetBackendName()));
} }
virtual std::shared_ptr<ProcessGroup::Task> Send( virtual std::shared_ptr<ProcessGroup::Task> Send(
const phi::DenseTensor& tensor, int dst_rank, bool sync_op) { const phi::DenseTensor& tensor, int dst_rank, bool sync_op) {
return Send(tensor, PADDLE_THROW(phi::errors::Unimplemented(
dst_rank, "ProcessGroup%s does not support send with sync_op flag.",
/*offset*/ 0, GetBackendName()));
/*numel*/ -1, // -1 indicates the whole tensor
sync_op);
} }
virtual std::shared_ptr<ProcessGroup::Task> Send( virtual std::shared_ptr<ProcessGroup::Task> Send(
...@@ -221,17 +229,162 @@ class ProcessGroup { ...@@ -221,17 +229,162 @@ class ProcessGroup {
int64_t offset, int64_t offset,
int64_t numel, int64_t numel,
bool sync_op) { bool sync_op) {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"ProcessGroup%s does not support send with sync_op flag.", "ProcessGroup%s does not support send with sync_op flag.",
GetBackendName())); GetBackendName()));
} }
// stream APIs
virtual std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(
phi::errors::Unimplemented("ProcessGroup%s does not support all_gather "
"with sync_op and use_calc_stream flag.",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int64_t offset,
int64_t numel,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(
phi::errors::Unimplemented("ProcessGroup%s does not support all_gather "
"with sync_op and use_calc_stream flag.",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(
phi::errors::Unimplemented("ProcessGroup%s does not support all_reduce "
"with sync_op and use_calc_stream flag.",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> AllToAll(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(
phi::errors::Unimplemented("ProcessGroup%s does not support all_to_all "
"with sync_op and use_calc_stream flag.",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(
phi::errors::Unimplemented("ProcessGroup%s does not support broadcast "
"with sync_op and use_calc_stream flag.",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Reduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(
phi::errors::Unimplemented("ProcessGroup%s does not support reduce "
"with sync_op and use_calc_stream flag.",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> ReduceScatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(phi::errors::Unimplemented(
"ProcessGroup%s does not support reduce_scatter "
"with sync_op and use_calc_stream flag.",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Scatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(
phi::errors::Unimplemented("ProcessGroup%s does not support scatter "
"with sync_op and use_calc_stream flag.",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(
phi::errors::Unimplemented("ProcessGroup%s does not support recv with "
"sync_op and use_calc_stream flag.",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t numel,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(
phi::errors::Unimplemented("ProcessGroup%s does not support recv "
"with sync_op and use_calc_stream flag.",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Send(
const phi::DenseTensor& tensor,
int dst_rank,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(
phi::errors::Unimplemented("ProcessGroup%s does not support send "
"with sync_op and use_calc_stream flag.",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Send(
const phi::DenseTensor& tensor,
int dst_rank,
int64_t offset,
int64_t numel,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(
phi::errors::Unimplemented("ProcessGroup%s does not support send "
"with sync_op and use_calc_stream flag.",
GetBackendName()));
}
// legacy APIs
// TODO(liyurui): This API will be moved later // TODO(liyurui): This API will be moved later
virtual std::shared_ptr<ProcessGroup::Task> AllReduce( virtual std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT
const AllreduceOptions& = AllreduceOptions()) { const AllreduceOptions& = AllreduceOptions()) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
"ProcessGroup%s does not support allreduce", GetBackendName())); "ProcessGroup%s does not support allreduce", GetBackendName()));
} }
...@@ -240,7 +393,7 @@ class ProcessGroup { ...@@ -240,7 +393,7 @@ class ProcessGroup {
std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT
const AllreduceOptions&, const AllreduceOptions&,
bool) { bool) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
"ProcessGroup%s does not support allreduce with sync_op flag", "ProcessGroup%s does not support allreduce with sync_op flag",
GetBackendName())); GetBackendName()));
} }
...@@ -250,7 +403,7 @@ class ProcessGroup { ...@@ -250,7 +403,7 @@ class ProcessGroup {
std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT std::vector<phi::DenseTensor>& /* input tensors */, // NOLINT
std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT
const BroadcastOptions& = BroadcastOptions()) { const BroadcastOptions& = BroadcastOptions()) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
"ProcessGroup%s does not support broadcast", GetBackendName())); "ProcessGroup%s does not support broadcast", GetBackendName()));
} }
...@@ -259,27 +412,27 @@ class ProcessGroup { ...@@ -259,27 +412,27 @@ class ProcessGroup {
std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT std::vector<phi::DenseTensor>& /* output tensors */, // NOLINT
const BroadcastOptions&, const BroadcastOptions&,
bool) { bool) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
"ProcessGroup%s does not support broadcast with sync_op flag", "ProcessGroup%s does not support broadcast with sync_op flag",
GetBackendName())); GetBackendName()));
} }
virtual std::shared_ptr<ProcessGroup::Task> Send( virtual std::shared_ptr<ProcessGroup::Task> Send(
std::vector<phi::DenseTensor>&, int) { // NOLINT std::vector<phi::DenseTensor>&, int) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
"ProcessGroup%s does not support send", GetBackendName())); "ProcessGroup%s does not support send", GetBackendName()));
} }
virtual std::shared_ptr<ProcessGroup::Task> Recv( virtual std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>&, int) { // NOLINT std::vector<phi::DenseTensor>&, int) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
"ProcessGroup%s does not support recv", GetBackendName())); "ProcessGroup%s does not support recv", GetBackendName()));
} }
virtual std::shared_ptr<ProcessGroup::Task> AllGather( virtual std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>&, // NOLINT std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&) { // NOLINT std::vector<phi::DenseTensor>&) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
"ProcessGroup%s does not support all_gather", GetBackendName())); "ProcessGroup%s does not support all_gather", GetBackendName()));
} }
...@@ -287,7 +440,7 @@ class ProcessGroup { ...@@ -287,7 +440,7 @@ class ProcessGroup {
std::vector<phi::DenseTensor>&, // NOLINT std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT std::vector<phi::DenseTensor>&, // NOLINT
bool) { bool) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
"ProcessGroup%s does not support all_gather with sync_op flag", "ProcessGroup%s does not support all_gather with sync_op flag",
GetBackendName())); GetBackendName()));
} }
...@@ -295,7 +448,7 @@ class ProcessGroup { ...@@ -295,7 +448,7 @@ class ProcessGroup {
virtual std::shared_ptr<ProcessGroup::Task> AllToAll( virtual std::shared_ptr<ProcessGroup::Task> AllToAll(
std::vector<phi::DenseTensor>&, // NOLINT std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&) { // NOLINT std::vector<phi::DenseTensor>&) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
"ProcessGroup%s does not support AllToAll", GetBackendName())); "ProcessGroup%s does not support AllToAll", GetBackendName()));
} }
...@@ -303,7 +456,7 @@ class ProcessGroup { ...@@ -303,7 +456,7 @@ class ProcessGroup {
std::vector<phi::DenseTensor>&, // NOLINT std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT std::vector<phi::DenseTensor>&, // NOLINT
const ReduceOptions& opts) { const ReduceOptions& opts) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
"ProcessGroup%s does not support reduce", GetBackendName())); "ProcessGroup%s does not support reduce", GetBackendName()));
} }
...@@ -311,7 +464,7 @@ class ProcessGroup { ...@@ -311,7 +464,7 @@ class ProcessGroup {
std::vector<phi::DenseTensor>&, // NOLINT std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT std::vector<phi::DenseTensor>&, // NOLINT
const ScatterOptions&) { const ScatterOptions&) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
"ProcessGroup%s does not support scatter", GetBackendName())); "ProcessGroup%s does not support scatter", GetBackendName()));
} }
...@@ -330,28 +483,11 @@ class ProcessGroupIdMap ...@@ -330,28 +483,11 @@ class ProcessGroupIdMap
// TODO(dev): The following method will be removed soon. // TODO(dev): The following method will be removed soon.
class ProcessGroupMapFromGid { class ProcessGroupMapFromGid {
public: public:
bool has(int gid) { bool has(int gid) { return map_.find(gid) != map_.end(); }
auto it = map_.find(gid);
return it != map_.end();
}
void insert(int gid, ProcessGroup* pg) { void insert(int gid, ProcessGroup* pg) { map_[gid] = pg; }
// TODO(sandyhouse): address ut and uncomment the following codes
// PADDLE_ENFORCE_EQ(has(gid), false,
// platform::errors::PreconditionNotMet(
// "The process group with id %d doesnot exist.",
// gid));
map_[gid] = pg;
}
ProcessGroup* get(int gid) { ProcessGroup* get(int gid) { return map_.find(gid)->second; }
// TODO(sandyhouse): address ut and uncomment the following codes
// PADDLE_ENFORCE_EQ(has(gid), true,
// platform::errors::PreconditionNotMet(
// "The process group with id %d doesnot exist.",
// gid));
return map_.find(gid)->second;
}
static std::shared_ptr<ProcessGroupMapFromGid> getInstance() { static std::shared_ptr<ProcessGroupMapFromGid> getInstance() {
static auto s_instance = std::make_shared<ProcessGroupMapFromGid>(); static auto s_instance = std::make_shared<ProcessGroupMapFromGid>();
......
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
#include "paddle/fluid/distributed/collective/common.h" #include "paddle/fluid/distributed/collective/common.h"
#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" #include "paddle/fluid/platform/device/xpu/bkcl_helper.h"
#include "paddle/fluid/platform/device/xpu/xpu_info.h" #include "paddle/fluid/platform/device/xpu/xpu_info.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/phi/core/device_context.h"
#include "paddle/fluid/platform/place.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h" #include "paddle/phi/core/errors.h"
namespace paddle { namespace paddle {
...@@ -76,7 +76,7 @@ ProcessGroupBKCL::ProcessGroupBKCL(const std::shared_ptr<Store>& store, ...@@ -76,7 +76,7 @@ ProcessGroupBKCL::ProcessGroupBKCL(const std::shared_ptr<Store>& store,
int rank, int rank,
int size, int size,
int gid) int gid)
: ProcessGroupStream(rank, size, gid), store_(store) {} : ProcessGroupWithStream(rank, size, gid), store_(store) {}
void ProcessGroupBKCL::GroupStart() { void ProcessGroupBKCL::GroupStart() {
PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_start()); PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_start());
......
...@@ -15,17 +15,16 @@ ...@@ -15,17 +15,16 @@
#pragma once #pragma once
#include <chrono> #include <chrono>
#include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/distributed/collective/process_group_stream.h" #include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/distributed/collective/process_group_with_stream.h"
#include "paddle/fluid/distributed/store/store.h" #include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h" #include "paddle/fluid/platform/device/xpu/xpu_header.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h" #include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "paddle/fluid/platform/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/device_context.h" #include "paddle/phi/core/device_context.h"
#if defined(PADDLE_WITH_XPU) #if defined(PADDLE_WITH_XPU)
...@@ -37,12 +36,12 @@ constexpr const char* BKCL_BACKEND_NAME = "BKCL"; ...@@ -37,12 +36,12 @@ constexpr const char* BKCL_BACKEND_NAME = "BKCL";
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
using Place = paddle::platform::Place; using Place = phi::Place;
// BKCL funcs use separate communication stream by default // BKCL funcs use separate communication stream by default
class ProcessGroupBKCL : public ProcessGroupStream { class ProcessGroupBKCL : public ProcessGroupWithStream {
public: public:
class BKCLTask final : public ProcessGroupStream::TaskStream, class BKCLTask final : public ProcessGroupWithStream::TaskStream,
public std::enable_shared_from_this<BKCLTask> { public std::enable_shared_from_this<BKCLTask> {
public: public:
BKCLTask(const Place& place, BKCLTask(const Place& place,
...@@ -161,12 +160,12 @@ class ProcessGroupBKCL : public ProcessGroupStream { ...@@ -161,12 +160,12 @@ class ProcessGroupBKCL : public ProcessGroupStream {
bool sync_op, bool sync_op,
bool use_calc_stream); bool use_calc_stream);
void BroadcastUniqueBKCLID(BKCLUniqueId* bkcl_id); // NOLINT void BroadcastUniqueBKCLID(BKCLUniqueId* bkcl_id);
void CreateBKCLEnvCache(const Place& place, const std::string& place_key); void CreateBKCLEnvCache(const Place& place, const std::string& place_key);
template <typename Fn> template <typename Fn>
std::shared_ptr<ProcessGroupStream::Task> Collective( std::shared_ptr<ProcessGroup::Task> Collective(
phi::DenseTensor* out_tensor, phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
Fn fn, Fn fn,
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
DECLARE_bool(xccl_blocking_wait); DECLARE_bool(xccl_blocking_wait);
...@@ -102,7 +103,9 @@ ProcessGroupCustom::ProcessGroupCustom(const std::shared_ptr<Store>& store, ...@@ -102,7 +103,9 @@ ProcessGroupCustom::ProcessGroupCustom(const std::shared_ptr<Store>& store,
int rank, int rank,
int size, int size,
int gid) int gid)
: ProcessGroup(rank, size, gid), store_(store), device_type_(device_type) {} : ProcessGroupWithoutStream(rank, size, gid),
store_(store),
device_type_(device_type) {}
void ProcessGroupCustom::BroadcastUniqueCustomID( void ProcessGroupCustom::BroadcastUniqueCustomID(
std::vector<phi::ccl::CCLRootId>& ccl_ids) { // NOLINT std::vector<phi::ccl::CCLRootId>& ccl_ids) { // NOLINT
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "paddle/fluid/distributed/collective/custom_ccl_tools.h" #include "paddle/fluid/distributed/collective/custom_ccl_tools.h"
#include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/distributed/collective/process_group_without_stream.h"
#include "paddle/fluid/distributed/store/store.h" #include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/platform/device/npu/npu_stream.h" #include "paddle/fluid/platform/device/npu/npu_stream.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
...@@ -34,7 +35,7 @@ namespace paddle { ...@@ -34,7 +35,7 @@ namespace paddle {
namespace distributed { namespace distributed {
using Place = paddle::platform::Place; using Place = paddle::platform::Place;
using CustomDeviceContext = paddle::platform::CustomDeviceContext; using CustomDeviceContext = paddle::platform::CustomDeviceContext;
class ProcessGroupCustom : public ProcessGroup { class ProcessGroupCustom : public ProcessGroupWithoutStream {
public: public:
class CustomTask : public ProcessGroup::Task, class CustomTask : public ProcessGroup::Task,
public std::enable_shared_from_this<CustomTask> { public std::enable_shared_from_this<CustomTask> {
......
...@@ -182,7 +182,7 @@ ProcessGroupGloo::ProcessGroupGloo( ...@@ -182,7 +182,7 @@ ProcessGroupGloo::ProcessGroupGloo(
int world_size, int world_size,
int gid, int gid,
const std::shared_ptr<GlooOptions> options) const std::shared_ptr<GlooOptions> options)
: ProcessGroup(rank, world_size, gid), : ProcessGroupWithoutStream(rank, world_size, gid),
_tag(0), _tag(0),
_store(new GlooStore(store)) { _store(new GlooStore(store)) {
_context = std::make_shared<gloo::rendezvous::Context>(rank, world_size); _context = std::make_shared<gloo::rendezvous::Context>(rank, world_size);
...@@ -400,15 +400,6 @@ class AllgatherGlooTask : public ProcessGroupGloo::GlooTask { ...@@ -400,15 +400,6 @@ class AllgatherGlooTask : public ProcessGroupGloo::GlooTask {
} }
}; };
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op) {
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return AllGather(in_wrapper, out_wrapper, true);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather( std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllGather(
phi::DenseTensor* out_tensor, phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
......
...@@ -19,18 +19,18 @@ ...@@ -19,18 +19,18 @@
#include <mutex> #include <mutex>
#include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/distributed/collective/process_group_without_stream.h"
#include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/distributed/store/tcp_store.h"
#ifdef PADDLE_WITH_GLOO #ifdef PADDLE_WITH_GLOO
#include "paddle/fluid/framework/fleet/gloo_wrapper.h" #include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif #endif
#include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/distributed/store/tcp_store.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
class ProcessGroupGloo : public ProcessGroup { class ProcessGroupGloo : public ProcessGroupWithoutStream {
public: public:
class GlooTask : public ProcessGroup::Task, class GlooTask : public ProcessGroup::Task,
public std::enable_shared_from_this<GlooTask> { public std::enable_shared_from_this<GlooTask> {
...@@ -120,11 +120,6 @@ class ProcessGroupGloo : public ProcessGroup { ...@@ -120,11 +120,6 @@ class ProcessGroupGloo : public ProcessGroup {
int64_t /*numel*/, // for compatibility, no use now int64_t /*numel*/, // for compatibility, no use now
bool sync_op) override; bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> AllReduce( std::shared_ptr<ProcessGroup::Task> AllReduce(
phi::DenseTensor* out_tensor, phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
......
...@@ -174,7 +174,9 @@ std::shared_ptr<ProcessGroupMPI> ProcessGroupMPI::CreateProcessGroupMPI( ...@@ -174,7 +174,9 @@ std::shared_ptr<ProcessGroupMPI> ProcessGroupMPI::CreateProcessGroupMPI(
} }
ProcessGroupMPI::ProcessGroupMPI(int rank, int size, MPI_Comm pg_comm, int gid) ProcessGroupMPI::ProcessGroupMPI(int rank, int size, MPI_Comm pg_comm, int gid)
: ProcessGroup(rank, size, gid), stop_(false), pg_comm(pg_comm) { : ProcessGroupWithoutStream(rank, size, gid),
stop_(false),
pg_comm(pg_comm) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
pg_comm == MPI_COMM_NULL, pg_comm == MPI_COMM_NULL,
false, false,
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <mutex> #include <mutex>
#include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/distributed/collective/process_group_without_stream.h"
#include "paddle/fluid/distributed/collective/types.h" #include "paddle/fluid/distributed/collective/types.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
...@@ -57,7 +58,7 @@ struct TaskEntry { ...@@ -57,7 +58,7 @@ struct TaskEntry {
std::function<void(std::unique_ptr<TaskEntry>&)> run_; std::function<void(std::unique_ptr<TaskEntry>&)> run_;
}; };
class ProcessGroupMPI : public ProcessGroup { class ProcessGroupMPI : public ProcessGroupWithoutStream {
public: public:
class MPITask : public ProcessGroup::Task { class MPITask : public ProcessGroup::Task {
public: public:
......
...@@ -18,9 +18,11 @@ ...@@ -18,9 +18,11 @@
#include "paddle/fluid/distributed/collective/common.h" #include "paddle/fluid/distributed/collective/common.h"
#include "paddle/fluid/distributed/collective/nccl_tools.h" #include "paddle/fluid/distributed/collective/nccl_tools.h"
#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/core/enforce.h"
DECLARE_bool(nccl_blocking_wait); DECLARE_bool(nccl_blocking_wait);
DECLARE_bool(use_stream_safe_cuda_allocator); DECLARE_bool(use_stream_safe_cuda_allocator);
...@@ -88,7 +90,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr<Store>& store, ...@@ -88,7 +90,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(const std::shared_ptr<Store>& store,
int rank, int rank,
int size, int size,
int gid) int gid)
: ProcessGroupStream(rank, size, gid), store_(store) {} : ProcessGroupWithStream(rank, size, gid), store_(store) {}
void ProcessGroupNCCL::GroupStart() { void ProcessGroupNCCL::GroupStart() {
NCCL_CHECK(phi::dynload::ncclGroupStart()); NCCL_CHECK(phi::dynload::ncclGroupStart());
...@@ -112,7 +114,7 @@ phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext( ...@@ -112,7 +114,7 @@ phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext(
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
iter, iter,
place_to_comm_ctx_.end(), place_to_comm_ctx_.end(),
platform::errors::NotFound( phi::errors::NotFound(
"Cannot find the device context in this process group.")); "Cannot find the device context in this process group."));
return iter->second.get(); return iter->second.get();
} }
...@@ -124,7 +126,7 @@ ncclComm_t ProcessGroupNCCL::NCCLComm(const Place& place) const { ...@@ -124,7 +126,7 @@ ncclComm_t ProcessGroupNCCL::NCCLComm(const Place& place) const {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
iter, iter,
place_to_comm_ctx_.end(), place_to_comm_ctx_.end(),
platform::errors::NotFound( phi::errors::NotFound(
"Cannot find the NCCL commmunicator in this process group.")); "Cannot find the NCCL commmunicator in this process group."));
return iter->second->nccl_comm(); return iter->second->nccl_comm();
} }
...@@ -207,7 +209,7 @@ void CheckSizeOnEachRank(const phi::DDim& tensor_dim, ...@@ -207,7 +209,7 @@ void CheckSizeOnEachRank(const phi::DDim& tensor_dim,
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
length_size_on_each_rank, length_size_on_each_rank,
world_size, world_size,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The length of size_on_each_rank must be equal to world_size.")); "The length of size_on_each_rank must be equal to world_size."));
int64_t sum_size_on_each_rank = int64_t sum_size_on_each_rank =
...@@ -215,7 +217,7 @@ void CheckSizeOnEachRank(const phi::DDim& tensor_dim, ...@@ -215,7 +217,7 @@ void CheckSizeOnEachRank(const phi::DDim& tensor_dim,
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
sum_size_on_each_rank, sum_size_on_each_rank,
tensor_dim[0], tensor_dim[0],
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The sum of size_on_each_rank must be equal to tensor's dim[0].")); "The sum of size_on_each_rank must be equal to tensor's dim[0]."));
} }
...@@ -289,7 +291,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier( ...@@ -289,7 +291,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Barrier(
const BarrierOptions& opts) { const BarrierOptions& opts) {
PADDLE_ENFORCE_GE(opts.device_id, PADDLE_ENFORCE_GE(opts.device_id,
0, 0,
platform::errors::PreconditionNotMet( phi::errors::PreconditionNotMet(
"The barrier device id must greater or equal than 0.")); "The barrier device id must greater or equal than 0."));
platform::CUDAPlace place(opts.device_id); platform::CUDAPlace place(opts.device_id);
auto allocator = std::unique_ptr<phi::Allocator>( auto allocator = std::unique_ptr<phi::Allocator>(
...@@ -681,7 +683,7 @@ void ProcessGroupNCCL::CreateNCCLManagerCache( ...@@ -681,7 +683,7 @@ void ProcessGroupNCCL::CreateNCCLManagerCache(
const std::string& places_key, const std::vector<Place>& places) { const std::string& places_key, const std::vector<Place>& places) {
PADDLE_ENFORCE_EQ(places_key.empty(), PADDLE_ENFORCE_EQ(places_key.empty(),
false, false,
platform::errors::PreconditionNotMet( phi::errors::PreconditionNotMet(
"Not able to create/get the NCCL Communicator since " "Not able to create/get the NCCL Communicator since "
"the GPU place are not known")); "the GPU place are not known"));
...@@ -837,7 +839,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce( ...@@ -837,7 +839,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), CheckTensorsInCudaPlace(in_tensors),
true, true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); phi::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective( return Collective(
in_tensors, in_tensors,
out_tensors, out_tensors,
...@@ -864,7 +866,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast( ...@@ -864,7 +866,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), CheckTensorsInCudaPlace(in_tensors),
true, true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); phi::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective( return Collective(
in_tensors, in_tensors,
...@@ -892,25 +894,25 @@ void CheckTensorsInDifferentDevices( ...@@ -892,25 +894,25 @@ void CheckTensorsInDifferentDevices(
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
tensors.size() == 0, tensors.size() == 0,
false, false,
platform::errors::InvalidArgument("Tensor list must be nonempty.")); phi::errors::InvalidArgument("Tensor list must be nonempty."));
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
tensors.size(), tensors.size(),
num_devices, num_devices,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"Tensor list mustn't be larger than the number of available GPUs.")); "Tensor list mustn't be larger than the number of available GPUs."));
std::set<Place> used_devices; std::set<Place> used_devices;
for (const auto& t : tensors) { for (const auto& t : tensors) {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(t.place()), PADDLE_ENFORCE_EQ(
true, platform::is_gpu_place(t.place()),
platform::errors::InvalidArgument( true,
"Tensors must be CUDA and dense tensor.")); phi::errors::InvalidArgument("Tensors must be CUDA and dense tensor."));
const auto inserted = used_devices.insert(t.place()).second; const auto inserted = used_devices.insert(t.place()).second;
PADDLE_ENFORCE_EQ(inserted, PADDLE_ENFORCE_EQ(inserted,
true, true,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"Tensors must be on distinct GPU devices.")); "Tensors must be on distinct GPU devices."));
} }
} }
...@@ -965,11 +967,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather( ...@@ -965,11 +967,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), CheckTensorsInCudaPlace(in_tensors),
true, true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); phi::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors), CheckTensorsInCudaPlace(out_tensors),
true, true,
platform::errors::InvalidArgument("All outputs should be in CudaPlace.")); phi::errors::InvalidArgument("All outputs should be in CudaPlace."));
return Collective( return Collective(
in_tensors, in_tensors,
out_tensors, out_tensors,
...@@ -1019,7 +1021,7 @@ void* GetPointerByOffset(void* raw_pointer, ...@@ -1019,7 +1021,7 @@ void* GetPointerByOffset(void* raw_pointer,
return reinterpret_cast<void*>(reinterpret_cast<uint16_t*>(raw_pointer) + return reinterpret_cast<void*>(reinterpret_cast<uint16_t*>(raw_pointer) +
offset); offset);
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"Datatype %s in NCCL is not supported.", type)); "Datatype %s in NCCL is not supported.", type));
} }
return nullptr; return nullptr;
...@@ -1031,11 +1033,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll( ...@@ -1031,11 +1033,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), CheckTensorsInCudaPlace(in_tensors),
true, true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); phi::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors), CheckTensorsInCudaPlace(out_tensors),
true, true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); phi::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective( return Collective(
in_tensors, in_tensors,
out_tensors, out_tensors,
...@@ -1074,7 +1076,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce( ...@@ -1074,7 +1076,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), CheckTensorsInCudaPlace(in_tensors),
true, true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); phi::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective( return Collective(
in_tensors, in_tensors,
out_tensors, out_tensors,
...@@ -1102,11 +1104,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter( ...@@ -1102,11 +1104,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors), CheckTensorsInCudaPlace(in_tensors),
true, true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); phi::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors), CheckTensorsInCudaPlace(out_tensors),
true, true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); phi::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective( return Collective(
in_tensors, in_tensors,
out_tensors, out_tensors,
......
...@@ -20,11 +20,10 @@ ...@@ -20,11 +20,10 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/distributed/collective/process_group_stream.h" #include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/distributed/collective/process_group_with_stream.h"
#include "paddle/fluid/distributed/store/store.h" #include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/device_event.h" #include "paddle/fluid/platform/device_event.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/device_context.h" #include "paddle/phi/core/device_context.h"
...@@ -41,11 +40,11 @@ ...@@ -41,11 +40,11 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
using Place = paddle::platform::Place; using Place = phi::Place;
class ProcessGroupNCCL final : public ProcessGroupStream { class ProcessGroupNCCL final : public ProcessGroupWithStream {
public: public:
class NCCLTask final : public ProcessGroupStream::TaskStream, class NCCLTask final : public ProcessGroupWithStream::TaskStream,
public std::enable_shared_from_this<NCCLTask> { public std::enable_shared_from_this<NCCLTask> {
public: public:
NCCLTask(const Place& place, NCCLTask(const Place& place,
...@@ -86,11 +85,11 @@ class ProcessGroupNCCL final : public ProcessGroupStream { ...@@ -86,11 +85,11 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
std::string GetBackendName() const override { return "NCCL"; } std::string GetBackendName() const override { return "NCCL"; }
phi::DeviceContext* GetDeviceContext(const Place& place) const override;
phi::DeviceContext* GetDeviceContext(const Place& place, phi::DeviceContext* GetDeviceContext(const Place& place,
bool use_calc_stream) const override; bool use_calc_stream) const override;
phi::DeviceContext* GetDeviceContext(const Place& place) const override;
std::shared_ptr<ProcessGroup::Task> AllGather( std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor, phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/collective/process_group_stream.h"
namespace paddle {
namespace distributed {
ProcessGroupStream::ProcessGroupStream(int rank, int size, int gid)
: ProcessGroup(rank, size, gid) {}
phi::DeviceContext* ProcessGroupStream::GetDeviceContext(
const Place& place, bool use_calc_stream) const {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support get device_context.", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int64_t offset,
int64_t numel,
bool sync_op) {
return AllGather(out_tensor,
in_tensor,
offset,
numel,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op,
bool use_calc_stream) {
return AllGather(out_tensor,
in_tensor,
/*offset*/ 0,
/*numel*/ -1, // -1 indicates the whole tensor
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
int64_t offset,
int64_t numel,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support all_gather.", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op) {
return AllReduce(out_tensor,
in_tensor,
opts,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support all_reduce.", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAll(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank,
bool sync_op) {
return AllToAll(out_tensor,
in_tensor,
out_size_each_rank,
in_size_each_rank,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllToAll(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support all_to_all.", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op) {
return Broadcast(out_tensor,
in_tensor,
opts,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Broadcast(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support broadcast.", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Reduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op) {
return Reduce(out_tensor,
in_tensor,
opts,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Reduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support reduce.", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::ReduceScatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts,
bool sync_op) {
return ReduceScatter(out_tensor,
in_tensor,
opts,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::ReduceScatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support reduce_scatter.", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Scatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ScatterOptions& opts,
bool sync_op) {
return Scatter(out_tensor,
in_tensor,
opts,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Scatter(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support scatter.", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv(
phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t numel,
bool sync_op) {
return Recv(tensor,
src_rank,
offset,
numel,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv(
phi::DenseTensor* tensor,
int src_rank,
bool sync_op,
bool use_calc_stream) {
return Recv(tensor,
src_rank,
/*offset*/ 0,
/*numel*/ -1, // -1 indicates sending the whole tensor
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Recv(
phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
int64_t numel,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support recv.", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send(
const phi::DenseTensor& tensor,
int dst_rank,
int64_t offset,
int64_t numel,
bool sync_op) {
return Send(tensor,
dst_rank,
offset,
numel,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send(
const phi::DenseTensor& tensor,
int dst_rank,
bool sync_op,
bool use_calc_stream) {
return Send(tensor,
dst_rank,
/*offset*/ 0,
/*numel*/ -1, // -1 indicates receiving the whole tensor
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send(
const phi::DenseTensor& tensor,
int dst_rank,
int64_t offset,
int64_t numel,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::Unimplemented(
"ProcessGroup%s does not support send.", GetBackendName()));
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -15,15 +15,17 @@ ...@@ -15,15 +15,17 @@
#pragma once #pragma once
#include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
// NOTE(liyurui): Notice that some backends use `stream` as an abstract // NOTE: Notice that some backends use `stream` as an abstract conception of
// conception of hardward resource. We provide this base class allowing users to // hardward resource. We provide this base class allowing users to put
// put communications on calculation stream. In some scenorios, we found this // communications on calculation stream. In some scenorios, we found this will
// will save the time of switching streams. // save the time of switching streams.
class ProcessGroupStream : public ProcessGroup { class ProcessGroupWithStream : public ProcessGroup {
public: public:
class TaskStream : public ProcessGroup::Task { class TaskStream : public ProcessGroup::Task {
public: public:
...@@ -55,148 +57,201 @@ class ProcessGroupStream : public ProcessGroup { ...@@ -55,148 +57,201 @@ class ProcessGroupStream : public ProcessGroup {
}; };
public: public:
ProcessGroupStream(int rank, int size, int gid); ProcessGroupWithStream(int rank, int size, int gid)
virtual ~ProcessGroupStream() = default; : ProcessGroup(rank, size, gid) {}
using ProcessGroup::GetDeviceContext;
virtual ~ProcessGroupWithStream() = default;
virtual phi::DeviceContext* GetDeviceContext(const Place& place,
bool use_calc_stream) const; // methods from base class
using ProcessGroup::AllGather;
using ProcessGroup::AllReduce;
using ProcessGroup::AllToAll;
using ProcessGroup::Broadcast;
using ProcessGroup::Recv;
using ProcessGroup::Reduce;
using ProcessGroup::ReduceScatter;
using ProcessGroup::Scatter;
using ProcessGroup::Send;
std::shared_ptr<ProcessGroup::Task> AllGather( std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor, phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
int64_t offset, bool sync_op) override {
int64_t numel, return AllGather(out_tensor,
bool sync_op) override; in_tensor,
/*offset*/ 0,
virtual std::shared_ptr<ProcessGroup::Task> AllGather( /*numel*/ -1, // -1 indicates the whole tensor
phi::DenseTensor* out_tensor, sync_op);
const phi::DenseTensor& in_tensor, }
bool sync_op,
bool use_calc_stream);
virtual std::shared_ptr<ProcessGroup::Task> AllGather( std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor, phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
int64_t offset, int64_t offset,
int64_t numel, int64_t numel,
bool sync_op, bool sync_op) override {
bool use_calc_stream); return AllGather(out_tensor,
in_tensor,
offset,
numel,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> AllReduce( std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor, phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts, bool sync_op,
bool sync_op) override; bool use_calc_stream) override {
return AllGather(out_tensor,
in_tensor,
/*offset*/ 0,
/*numel*/ -1, // -1 indicates the whole tensor
sync_op,
use_calc_stream);
}
virtual std::shared_ptr<ProcessGroup::Task> AllReduce( std::shared_ptr<ProcessGroup::Task> AllReduce(
phi::DenseTensor* out_tensor, phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
const AllreduceOptions& opts, const AllreduceOptions& opts,
bool sync_op, bool sync_op) override {
bool use_calc_stream); return AllReduce(out_tensor,
in_tensor,
opts,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> AllToAll( std::shared_ptr<ProcessGroup::Task> AllToAll(
phi::DenseTensor* out_tensor, phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
const std::vector<int64_t>& out_size_each_rank, const std::vector<int64_t>& out_size_each_rank,
const std::vector<int64_t>& in_size_each_rank, const std::vector<int64_t>& in_size_each_rank,
bool sync_op) override; bool sync_op) override {
return AllToAll(out_tensor,
virtual std::shared_ptr<ProcessGroup::Task> AllToAll( in_tensor,
phi::DenseTensor* out_tensor, out_size_each_rank,
const phi::DenseTensor& in_tensor, in_size_each_rank,
const std::vector<int64_t>& out_size_each_rank, sync_op,
const std::vector<int64_t>& in_size_each_rank, /*use_calc_stream*/ false);
bool sync_op, }
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Broadcast( std::shared_ptr<ProcessGroup::Task> Broadcast(
phi::DenseTensor* out_tensor, phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
const BroadcastOptions& opts, const BroadcastOptions& opts,
bool sync_op) override; bool sync_op) override {
return Broadcast(out_tensor,
virtual std::shared_ptr<ProcessGroup::Task> Broadcast( in_tensor,
phi::DenseTensor* out_tensor, opts,
const phi::DenseTensor& in_tensor, sync_op,
const BroadcastOptions& opts, /*use_calc_stream*/ false);
bool sync_op, }
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Reduce(phi::DenseTensor* out_tensor, std::shared_ptr<ProcessGroup::Task> Reduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
const ReduceOptions& opts, const ReduceOptions& opts,
bool sync_op) override; bool sync_op) override {
return Reduce(out_tensor,
virtual std::shared_ptr<ProcessGroup::Task> Reduce( in_tensor,
phi::DenseTensor* out_tensor, opts,
const phi::DenseTensor& in_tensor, sync_op,
const ReduceOptions& opts, /*use_calc_stream*/ false);
bool sync_op, }
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> ReduceScatter( std::shared_ptr<ProcessGroup::Task> ReduceScatter(
phi::DenseTensor* out_tensor, phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts, const ReduceScatterOptions& opts,
bool sync_op) override; bool sync_op) override {
return ReduceScatter(out_tensor,
virtual std::shared_ptr<ProcessGroup::Task> ReduceScatter( in_tensor,
phi::DenseTensor* out_tensor, opts,
const phi::DenseTensor& in_tensor, sync_op,
const ReduceScatterOptions& opts, /*use_calc_stream*/ false);
bool sync_op, }
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> Scatter(phi::DenseTensor* out_tensor, std::shared_ptr<ProcessGroup::Task> Scatter(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor, const phi::DenseTensor& in_tensor,
const ScatterOptions& opts, const ScatterOptions& opts,
bool sync_op) override; bool sync_op) override {
return Scatter(out_tensor,
in_tensor,
opts,
sync_op,
/*use_calc_stream*/ false);
}
virtual std::shared_ptr<ProcessGroup::Task> Scatter( std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
phi::DenseTensor* out_tensor, int src_rank,
const phi::DenseTensor& in_tensor, bool sync_op) override {
const ScatterOptions& opts, return Recv(tensor,
bool sync_op, src_rank,
bool use_calc_stream); /*offset*/ 0,
/*numel*/ -1, // -1 indicates the whole tensor
sync_op);
}
std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor, std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank, int src_rank,
int64_t offset, int64_t offset,
int64_t numel, int64_t numel,
bool sync_op) override; bool sync_op) override {
return Recv(tensor,
src_rank,
offset,
numel,
sync_op,
/*use_calc_stream*/ false);
}
virtual std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor, std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank, int src_rank,
bool sync_op, bool sync_op,
bool use_calc_stream); bool use_calc_stream) override {
return Recv(tensor,
src_rank,
/*offset*/ 0,
/*numel*/ -1, // -1 indicates sending the whole tensor
sync_op,
use_calc_stream);
}
virtual std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor, std::shared_ptr<ProcessGroup::Task> Send(const phi::DenseTensor& tensor,
int src_rank, int dst_rank,
int64_t offset, bool sync_op) override {
int64_t numel, return Send(tensor,
bool sync_op, dst_rank,
bool use_calc_stream); /*offset*/ 0,
/*numel*/ -1, // -1 indicates the whole tensor
sync_op);
}
std::shared_ptr<ProcessGroup::Task> Send(const phi::DenseTensor& tensor, std::shared_ptr<ProcessGroup::Task> Send(const phi::DenseTensor& tensor,
int dst_rank, int dst_rank,
int64_t offset, int64_t offset,
int64_t numel, int64_t numel,
bool sync_op) override; bool sync_op) override {
return Send(tensor,
dst_rank,
offset,
numel,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> Send(const phi::DenseTensor& tensor, std::shared_ptr<ProcessGroup::Task> Send(const phi::DenseTensor& tensor,
int dst_rank, int dst_rank,
bool sync_op, bool sync_op,
bool use_calc_stream); bool use_calc_stream) override {
return Send(tensor,
virtual std::shared_ptr<ProcessGroup::Task> Send( dst_rank,
const phi::DenseTensor& tensor, /*offset*/ 0,
int dst_rank, /*numel*/ -1, // -1 indicates receiving the whole tensor
int64_t offset, sync_op,
int64_t numel, use_calc_stream);
bool sync_op, }
bool use_calc_stream);
}; };
} // namespace distributed } // namespace distributed
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
namespace distributed {
class ProcessGroupWithoutStream : public ProcessGroup {
public:
ProcessGroupWithoutStream(int rank, int size, int gid)
: ProcessGroup(rank, size, gid) {}
virtual ~ProcessGroupWithoutStream() = default;
// methods from base class
using ProcessGroup::AllGather;
using ProcessGroup::Recv;
using ProcessGroup::Send;
std::shared_ptr<ProcessGroup::Task> AllGather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
bool sync_op) override {
return AllGather(out_tensor,
in_tensor,
/*offset*/ 0,
/*numel*/ -1, // -1 indicates the whole tensor
sync_op);
}
std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank,
bool sync_op) override {
return Recv(tensor,
src_rank,
/*offset*/ 0,
/*numel*/ -1, // -1 indicates the whole tensor
sync_op);
}
std::shared_ptr<ProcessGroup::Task> Send(const phi::DenseTensor& tensor,
int dst_rank,
bool sync_op) override {
return Send(tensor,
dst_rank,
/*offset*/ 0,
/*numel*/ -1, // -1 indicates the whole tensor
sync_op);
}
};
} // namespace distributed
} // namespace paddle
...@@ -22,7 +22,6 @@ limitations under the License. */ ...@@ -22,7 +22,6 @@ limitations under the License. */
#endif #endif
#include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/process_group.h"
#include "paddle/fluid/distributed/collective/process_group_stream.h"
#include "paddle/fluid/distributed/collective/reducer.h" #include "paddle/fluid/distributed/collective/reducer.h"
#include "paddle/fluid/distributed/collective/types.h" #include "paddle/fluid/distributed/collective/types.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
...@@ -733,15 +732,11 @@ void BindDistributed(py::module *m) { ...@@ -733,15 +732,11 @@ void BindDistributed(py::module *m) {
py::arg("in"), py::arg("in"),
py::arg("out"), py::arg("out"),
py::arg("src"), py::arg("src"),
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>())
auto ProcessGroupStream =
py::class_<distributed::ProcessGroupStream,
std::shared_ptr<distributed::ProcessGroupStream>>(
*m, "ProcessGroupStream", ProcessGroup)
.def( .def(
"all_gather_on_calc_stream", "all_gather_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroup &self,
py::handle py_out_tensor_list, py::handle py_out_tensor_list,
py::handle py_in_tensor) { py::handle py_in_tensor) {
auto out_tensor_list = auto out_tensor_list =
...@@ -770,7 +765,7 @@ void BindDistributed(py::module *m) { ...@@ -770,7 +765,7 @@ void BindDistributed(py::module *m) {
.def( .def(
"all_gather_into_tensor_on_calc_stream", "all_gather_into_tensor_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroup &self,
py::handle py_out_tensor, py::handle py_out_tensor,
py::handle py_in_tensor) { py::handle py_in_tensor) {
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
...@@ -794,7 +789,7 @@ void BindDistributed(py::module *m) { ...@@ -794,7 +789,7 @@ void BindDistributed(py::module *m) {
.def( .def(
"all_gather_partial_on_calc_stream", "all_gather_partial_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroup &self,
py::handle py_out_tensor, py::handle py_out_tensor,
py::handle py_in_tensor, py::handle py_in_tensor,
int nranks, int nranks,
...@@ -828,7 +823,7 @@ void BindDistributed(py::module *m) { ...@@ -828,7 +823,7 @@ void BindDistributed(py::module *m) {
.def( .def(
"all_reduce_on_calc_stream", "all_reduce_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroup &self,
py::handle py_tensor, py::handle py_tensor,
distributed::ReduceOp op) { distributed::ReduceOp op) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
...@@ -849,7 +844,7 @@ void BindDistributed(py::module *m) { ...@@ -849,7 +844,7 @@ void BindDistributed(py::module *m) {
.def( .def(
"all_to_all_on_calc_stream", "all_to_all_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroup &self,
py::handle py_out_tensor_list, py::handle py_out_tensor_list,
py::handle py_in_tensor_list) { py::handle py_in_tensor_list) {
auto out_tensor_list = auto out_tensor_list =
...@@ -886,7 +881,7 @@ void BindDistributed(py::module *m) { ...@@ -886,7 +881,7 @@ void BindDistributed(py::module *m) {
.def( .def(
"all_to_all_tensor_on_calc_stream", "all_to_all_tensor_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroup &self,
py::handle py_out_tensor, py::handle py_out_tensor,
py::handle py_in_tensor) { py::handle py_in_tensor) {
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
...@@ -914,7 +909,7 @@ void BindDistributed(py::module *m) { ...@@ -914,7 +909,7 @@ void BindDistributed(py::module *m) {
.def( .def(
"all_to_all_single_on_calc_stream", "all_to_all_single_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroup &self,
py::handle py_out_tensor, py::handle py_out_tensor,
py::handle py_in_tensor, py::handle py_in_tensor,
const std::vector<int64_t> &out_sizes, const std::vector<int64_t> &out_sizes,
...@@ -944,7 +939,7 @@ void BindDistributed(py::module *m) { ...@@ -944,7 +939,7 @@ void BindDistributed(py::module *m) {
.def( .def(
"broadcast_on_calc_stream", "broadcast_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroup &self,
py::handle py_tensor, py::handle py_tensor,
int src) { int src) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
...@@ -965,7 +960,7 @@ void BindDistributed(py::module *m) { ...@@ -965,7 +960,7 @@ void BindDistributed(py::module *m) {
.def( .def(
"reduce_on_calc_stream", "reduce_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroup &self,
py::handle py_tensor, py::handle py_tensor,
int dst, int dst,
distributed::ReduceOp op) { distributed::ReduceOp op) {
...@@ -988,7 +983,7 @@ void BindDistributed(py::module *m) { ...@@ -988,7 +983,7 @@ void BindDistributed(py::module *m) {
.def( .def(
"reduce_scatter_on_calc_stream", "reduce_scatter_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroup &self,
py::handle py_out_tensor, py::handle py_out_tensor,
py::handle py_in_tensor_list, py::handle py_in_tensor_list,
distributed::ReduceOp op) { distributed::ReduceOp op) {
...@@ -1018,7 +1013,7 @@ void BindDistributed(py::module *m) { ...@@ -1018,7 +1013,7 @@ void BindDistributed(py::module *m) {
.def( .def(
"reduce_scatter_tensor_on_calc_stream", "reduce_scatter_tensor_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroup &self,
py::handle py_out_tensor, py::handle py_out_tensor,
py::handle py_in_tensor, py::handle py_in_tensor,
distributed::ReduceOp op) { distributed::ReduceOp op) {
...@@ -1046,7 +1041,7 @@ void BindDistributed(py::module *m) { ...@@ -1046,7 +1041,7 @@ void BindDistributed(py::module *m) {
.def( .def(
"scatter_on_calc_stream", "scatter_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroup &self,
py::handle py_out_tensor, py::handle py_out_tensor,
py::handle py_in_tensor_list, py::handle py_in_tensor_list,
int src) { int src) {
...@@ -1076,7 +1071,7 @@ void BindDistributed(py::module *m) { ...@@ -1076,7 +1071,7 @@ void BindDistributed(py::module *m) {
.def( .def(
"scatter_tensor_on_calc_stream", "scatter_tensor_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroup &self,
py::handle py_out_tensor, py::handle py_out_tensor,
py::handle py_in_tensor, py::handle py_in_tensor,
int src) { int src) {
...@@ -1104,7 +1099,7 @@ void BindDistributed(py::module *m) { ...@@ -1104,7 +1099,7 @@ void BindDistributed(py::module *m) {
.def( .def(
"send_on_calc_stream", "send_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroup &self,
py::handle py_tensor, py::handle py_tensor,
int dst) { int dst) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
...@@ -1122,7 +1117,7 @@ void BindDistributed(py::module *m) { ...@@ -1122,7 +1117,7 @@ void BindDistributed(py::module *m) {
.def( .def(
"send_partial_on_calc_stream", "send_partial_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroup &self,
py::handle py_tensor, py::handle py_tensor,
int dst_rank, int dst_rank,
int nranks, int nranks,
...@@ -1151,7 +1146,7 @@ void BindDistributed(py::module *m) { ...@@ -1151,7 +1146,7 @@ void BindDistributed(py::module *m) {
.def( .def(
"recv_on_calc_stream", "recv_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroup &self,
py::handle py_tensor, py::handle py_tensor,
int src) { int src) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
...@@ -1169,7 +1164,7 @@ void BindDistributed(py::module *m) { ...@@ -1169,7 +1164,7 @@ void BindDistributed(py::module *m) {
.def( .def(
"recv_partial_on_calc_stream", "recv_partial_on_calc_stream",
[](distributed::ProcessGroupStream &self, [](distributed::ProcessGroup &self,
py::handle py_tensor, py::handle py_tensor,
int src_rank, int src_rank,
int nranks, int nranks,
...@@ -1199,7 +1194,7 @@ void BindDistributed(py::module *m) { ...@@ -1199,7 +1194,7 @@ void BindDistributed(py::module *m) {
#if defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL) #if defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)
py::class_<distributed::ProcessGroupNCCL, py::class_<distributed::ProcessGroupNCCL,
std::shared_ptr<distributed::ProcessGroupNCCL>>( std::shared_ptr<distributed::ProcessGroupNCCL>>(
*m, "ProcessGroupNCCL", ProcessGroupStream) *m, "ProcessGroupNCCL", ProcessGroup)
.def_static("create", .def_static("create",
distributed::ProcessGroupNCCL::CreateProcessGroupNCCL, distributed::ProcessGroupNCCL::CreateProcessGroupNCCL,
py::arg("store"), py::arg("store"),
...@@ -1250,7 +1245,7 @@ void BindDistributed(py::module *m) { ...@@ -1250,7 +1245,7 @@ void BindDistributed(py::module *m) {
auto processGroupBKCL = auto processGroupBKCL =
py::class_<distributed::ProcessGroupBKCL, py::class_<distributed::ProcessGroupBKCL,
std::shared_ptr<distributed::ProcessGroupBKCL>>( std::shared_ptr<distributed::ProcessGroupBKCL>>(
*m, "ProcessGroupBKCL", ProcessGroupStream) *m, "ProcessGroupBKCL", ProcessGroup)
.def_static("create", .def_static("create",
distributed::ProcessGroupBKCL::CreateProcessGroupBKCL, distributed::ProcessGroupBKCL::CreateProcessGroupBKCL,
py::arg("store"), py::arg("store"),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册