未验证 提交 8cc2e28c 编写于 作者: S ShenLiang 提交者: GitHub

[Bug Fix]Fix global_scatter/global_gather in ProcessGroup (#43027)

* fix alltoall

* rename utest
上级 9eb18c75
...@@ -113,6 +113,19 @@ class ProcessGroup { ...@@ -113,6 +113,19 @@ class ProcessGroup {
"ProcessGroup%s does not support receive", GetBackendName())); "ProcessGroup%s does not support receive", GetBackendName()));
} }
virtual std::shared_ptr<ProcessGroup::Task> Send_Partial(phi::DenseTensor&,
int, int,
int) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support send", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Recv_Partial(
phi::DenseTensor& tensors, int, int, int) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support receive", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> AllGather( virtual std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>&, // NOLINT std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&) { // NOLINT std::vector<phi::DenseTensor>&) { // NOLINT
......
...@@ -428,6 +428,53 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv( ...@@ -428,6 +428,53 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
return task; return task;
} }
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send_Partial(
phi::DenseTensor& tensors, int dst_rank, int offset, int length) {
// CheckTensorsInDifferentDevices(tensors, static_cast<size_t>(GetSize()));
phi::DenseTensor flatten_tensor;
flatten_tensor.ShareDataWith(tensors).Resize({tensors.numel()});
phi::DenseTensor shared_input = flatten_tensor.Slice(offset, offset + length);
std::vector<phi::DenseTensor> shared_tensors;
shared_tensors.push_back(shared_input);
auto task = PointToPoint(shared_tensors,
[&](phi::DenseTensor& input, ncclComm_t comm,
const gpuStream_t& stream, int dst_rank) {
return platform::dynload::ncclSend(
input.data(), input.numel(),
platform::ToNCCLDataType(input.dtype()),
dst_rank, comm, stream);
},
dst_rank, CommType::SEND);
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv_Partial(
phi::DenseTensor& tensors, int src_rank, int offset, int length) {
// phi::DenseTensor shared_input = tensors.Slice(offset, offset+length);
phi::DenseTensor flatten_tensor;
flatten_tensor.ShareDataWith(tensors).Resize({tensors.numel()});
phi::DenseTensor shared_input = flatten_tensor.Slice(offset, offset + length);
std::vector<phi::DenseTensor> shared_tensors;
shared_tensors.push_back(shared_input);
auto task = PointToPoint(shared_tensors,
[&](phi::DenseTensor& output, ncclComm_t comm,
const gpuStream_t& stream, int src_rank) {
return platform::dynload::ncclRecv(
output.data(), output.numel(),
platform::ToNCCLDataType(output.dtype()),
src_rank, comm, stream);
},
src_rank, CommType::RECV);
return task;
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather( std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) { std::vector<phi::DenseTensor>& out_tensors) {
......
...@@ -102,6 +102,14 @@ class ProcessGroupNCCL : public ProcessGroup { ...@@ -102,6 +102,14 @@ class ProcessGroupNCCL : public ProcessGroup {
std::shared_ptr<ProcessGroup::Task> Recv( std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>& tensors, int src_rank) override; std::vector<phi::DenseTensor>& tensors, int src_rank) override;
std::shared_ptr<ProcessGroup::Task> Send_Partial(phi::DenseTensor& tensors,
int dst_rank, int offset,
int length) override;
std::shared_ptr<ProcessGroup::Task> Recv_Partial(phi::DenseTensor& tensors,
int src_rank, int offset,
int length) override;
std::shared_ptr<ProcessGroup::Task> AllGather( std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors, std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override; std::vector<phi::DenseTensor>& out_tensors) override;
......
...@@ -22,10 +22,10 @@ limitations under the License. */ ...@@ -22,10 +22,10 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
class GlobalGatherOpCUDAKernel : public framework::OpKernel<T> { struct GlobalGatherFunctor<phi::GPUContext, T> {
public: void operator()(const framework::ExecutionContext& ctx) {
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if NCCL_VERSION_CODE >= 2703 #if NCCL_VERSION_CODE >= 2703
auto x = ctx.Input<framework::LoDTensor>("X"); auto x = ctx.Input<framework::LoDTensor>("X");
...@@ -137,6 +137,132 @@ class GlobalGatherOpCUDAKernel : public framework::OpKernel<T> { ...@@ -137,6 +137,132 @@ class GlobalGatherOpCUDAKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T>
struct GlobalGatherProcessGroupFunctor<phi::GPUContext, T> {
void operator()(const framework::ExecutionContext& ctx) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if NCCL_VERSION_CODE >= 2703
auto x = ctx.Input<framework::LoDTensor>("X");
auto local_count = ctx.Input<framework::LoDTensor>("local_count");
auto global_count = ctx.Input<framework::LoDTensor>("global_count");
auto local_count_type =
framework::TransToProtoVarType(local_count->dtype());
auto global_count_type =
framework::TransToProtoVarType(global_count->dtype());
if (local_count_type != framework::proto::VarType::INT64) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Please use int64 type in local_count."));
}
if (global_count_type != framework::proto::VarType::INT64) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Please use int64 type in global_count."));
}
auto out = ctx.Output<framework::LoDTensor>("Out");
const int64_t* cpu_local_count_data;
const int64_t* cpu_global_count_data;
auto local_count_len = 0;
framework::Tensor cpu_local_count;
if (platform::is_cpu_place(local_count->place())) {
cpu_local_count_data = local_count->data<int64_t>();
local_count_len = local_count->numel();
} else {
framework::TensorCopySync(*local_count, platform::CPUPlace(),
&cpu_local_count);
cpu_local_count_data = cpu_local_count.data<int64_t>();
local_count_len = cpu_local_count.numel();
}
framework::Tensor cpu_global_count;
if (platform::is_cpu_place(global_count->place())) {
cpu_global_count_data = global_count->data<int64_t>();
} else {
framework::TensorCopySync(*global_count, platform::CPUPlace(),
&cpu_global_count);
cpu_global_count_data = cpu_global_count.data<int64_t>();
}
int ring_id = ctx.Attr<int>("ring_id");
PADDLE_ENFORCE_GE(
ring_id, 0,
platform::errors::InvalidArgument(
"The ring_id (%d) for global gather op must be non-negative.",
ring_id));
auto place = ctx.GetPlace();
auto map = distributed::ProcessGroupMapFromGid::getInstance();
distributed::ProcessGroup* pg = map->get(ring_id);
int nranks = pg->GetSize();
auto in_feat = x->dims()[1];
auto n_expert = local_count->dims()[0] / nranks;
auto fwd_count = 0;
for (auto i = 0; i < local_count_len; ++i) {
fwd_count += cpu_local_count_data[i];
}
framework::DDim out_dims = phi::make_ddim({fwd_count, in_feat});
int64_t* expert_ptr = new int64_t[n_expert * nranks];
expert_ptr[0] = 0;
auto tot_experts = n_expert * nranks;
for (auto i = 1; i < tot_experts; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1];
}
auto send_ptr = 0;
out->mutable_data<T>(out_dims, place);
for (auto i = 0; i < n_expert; ++i) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
for (auto j = 0; j < nranks; ++j) {
int idx = i + j * n_expert;
if (cpu_global_count_data[idx]) {
phi::DenseTensor tmp = *x;
pg->Send_Partial(tmp, j, send_ptr * in_feat,
cpu_global_count_data[idx] * in_feat);
send_ptr += cpu_global_count_data[idx];
}
if (cpu_local_count_data[idx]) {
pg->Recv_Partial(*out, j, expert_ptr[idx] * in_feat,
cpu_local_count_data[idx] * in_feat);
}
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
}
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
#else
PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
#endif
#else
PADDLE_THROW(
platform::errors::Unavailable("NCCL version >= 2.7.3 is needed."));
#endif
#else
PADDLE_THROW(
platform::errors::Unavailable("PaddlePaddle should compile with GPU."));
#endif
}
};
template <typename T>
class GlobalGatherOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const int rid = ctx.Attr<int>("ring_id");
auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) {
GlobalGatherProcessGroupFunctor<phi::GPUContext, T> functor_;
functor_(ctx);
} else {
GlobalGatherFunctor<phi::GPUContext, T> functor_;
functor_(ctx);
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -33,5 +34,15 @@ class GlobalGatherOpCPUKernel : public framework::OpKernel<T> { ...@@ -33,5 +34,15 @@ class GlobalGatherOpCPUKernel : public framework::OpKernel<T> {
} }
}; };
template <typename Context, typename T>
struct GlobalGatherFunctor {
void operator()(const framework::ExecutionContext& ctx);
};
template <typename Context, typename T>
struct GlobalGatherProcessGroupFunctor {
void operator()(const framework::ExecutionContext& ctx);
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -22,10 +22,10 @@ limitations under the License. */ ...@@ -22,10 +22,10 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> { struct GlobalScatterFunctor<phi::GPUContext, T> {
public: void operator()(const framework::ExecutionContext& ctx) {
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if NCCL_VERSION_CODE >= 2703 #if NCCL_VERSION_CODE >= 2703
auto x = ctx.Input<framework::LoDTensor>("X"); auto x = ctx.Input<framework::LoDTensor>("X");
...@@ -137,6 +137,130 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> { ...@@ -137,6 +137,130 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T>
struct GlobalScatterProcessGroupFunctor<phi::GPUContext, T> {
void operator()(const framework::ExecutionContext& ctx) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if NCCL_VERSION_CODE >= 2703
auto x = ctx.Input<framework::LoDTensor>("X");
auto local_count = ctx.Input<framework::LoDTensor>("local_count");
auto global_count = ctx.Input<framework::LoDTensor>("global_count");
auto local_count_type =
framework::TransToProtoVarType(local_count->dtype());
auto global_count_type =
framework::TransToProtoVarType(global_count->dtype());
if (local_count_type != framework::proto::VarType::INT64) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Please use int64 type in local_count."));
}
if (global_count_type != framework::proto::VarType::INT64) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Please use int64 type in global_count."));
}
auto out = ctx.Output<framework::LoDTensor>("Out");
const int64_t* cpu_local_count_data;
const int64_t* cpu_global_count_data;
framework::Tensor cpu_local_count;
if (platform::is_cpu_place(local_count->place())) {
cpu_local_count_data = local_count->data<int64_t>();
} else {
framework::TensorCopySync(*local_count, platform::CPUPlace(),
&cpu_local_count);
cpu_local_count_data = cpu_local_count.data<int64_t>();
}
auto global_count_len = 0;
framework::Tensor cpu_global_count;
if (platform::is_cpu_place(global_count->place())) {
cpu_global_count_data = global_count->data<int64_t>();
global_count_len = global_count->numel();
} else {
framework::TensorCopySync(*global_count, platform::CPUPlace(),
&cpu_global_count);
cpu_global_count_data = cpu_global_count.data<int64_t>();
global_count_len = cpu_global_count.numel();
}
int ring_id = ctx.Attr<int>("ring_id");
PADDLE_ENFORCE_GE(
ring_id, 0,
platform::errors::InvalidArgument(
"The ring_id (%d) for global scatter op must be non-negative.",
ring_id));
auto place = ctx.GetPlace();
auto map = distributed::ProcessGroupMapFromGid::getInstance();
distributed::ProcessGroup* pg = map->get(ring_id);
int nranks = pg->GetSize();
auto in_feat = x->dims()[1];
auto n_expert = local_count->dims()[0] / nranks;
int64_t fwd_count = 0;
for (auto i = 0; i < global_count_len; ++i) {
fwd_count += cpu_global_count_data[i];
}
framework::DDim out_dims = phi::make_ddim({fwd_count, in_feat});
int64_t* expert_ptr = new int64_t[n_expert * nranks];
expert_ptr[0] = 0;
auto tot_experts = n_expert * nranks;
for (auto i = 1; i < tot_experts; ++i) {
expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1];
}
auto recv_ptr = 0;
out->mutable_data<T>(out_dims, place);
for (auto i = 0; i < n_expert; ++i) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
for (auto j = 0; j < nranks; ++j) {
int idx = i + j * n_expert;
if (cpu_local_count_data[idx]) {
phi::DenseTensor tmp = *x;
pg->Send_Partial(tmp, j, expert_ptr[idx] * in_feat,
cpu_local_count_data[idx] * in_feat);
}
if (cpu_global_count_data[idx]) {
pg->Recv_Partial(*out, j, recv_ptr * in_feat,
cpu_global_count_data[idx] * in_feat);
recv_ptr += cpu_global_count_data[idx];
}
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
}
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
#else
PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
#endif
#else
PADDLE_THROW(
platform::errors::Unavailable("NCCL version >= 2.7.3 is needed."));
#endif
#else
PADDLE_THROW(
platform::errors::Unavailable("PaddlePaddle should compile with GPU."));
#endif
}
};
template <typename T>
class GlobalScatterOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const int rid = ctx.Attr<int>("ring_id");
auto map = distributed::ProcessGroupMapFromGid::getInstance();
if (map->has(rid)) {
GlobalScatterProcessGroupFunctor<phi::GPUContext, T> functor_;
functor_(ctx);
} else {
GlobalScatterFunctor<phi::GPUContext, T> functor_;
functor_(ctx);
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -33,5 +34,15 @@ class GlobalScatterOpCPUKernel : public framework::OpKernel<T> { ...@@ -33,5 +34,15 @@ class GlobalScatterOpCPUKernel : public framework::OpKernel<T> {
} }
}; };
template <typename Context, typename T>
struct GlobalScatterFunctor {
void operator()(const framework::ExecutionContext& ctx);
};
template <typename Context, typename T>
struct GlobalScatterProcessGroupFunctor {
void operator()(const framework::ExecutionContext& ctx);
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -1182,8 +1182,8 @@ endif() ...@@ -1182,8 +1182,8 @@ endif()
if((WITH_ROCM OR WITH_GPU) AND NOT WIN32) if((WITH_ROCM OR WITH_GPU) AND NOT WIN32)
set_tests_properties(test_collective_allgather_api PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_allgather_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_alltoall_api PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_alltoall_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_global_gather PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_global_gather PROPERTIES TIMEOUT 200)
set_tests_properties(test_collective_global_scatter PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_global_scatter PROPERTIES TIMEOUT 200)
set_tests_properties(test_collective_sendrecv_api PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_sendrecv_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_broadcast_api PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_broadcast_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_allreduce_api PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_allreduce_api PROPERTIES TIMEOUT 120)
......
...@@ -191,7 +191,8 @@ class TestDistBase(unittest.TestCase): ...@@ -191,7 +191,8 @@ class TestDistBase(unittest.TestCase):
path_id="0", path_id="0",
static_mode="1", static_mode="1",
check_error_log=False, check_error_log=False,
need_envs={}): need_envs={},
eager_mode=True):
if backend == "nccl" or backend == "bkcl": if backend == "nccl" or backend == "bkcl":
with_gloo = '0' with_gloo = '0'
else: else:
...@@ -216,6 +217,12 @@ class TestDistBase(unittest.TestCase): ...@@ -216,6 +217,12 @@ class TestDistBase(unittest.TestCase):
required_envs["GLOG_v"] = "3" required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1" required_envs["GLOG_logtostderr"] = "1"
required_envs["GLOO_LOG_LEVEL"] = "TRACE" required_envs["GLOO_LOG_LEVEL"] = "TRACE"
if eager_mode:
required_envs["FLAGS_enable_eager_mode"] = "%d" % 0
else:
required_envs["FLAGS_enable_eager_mode"] = "%d" % 1
tr0_out, tr1_out, pid0, pid1 = self._run_cluster(model_file, tr0_out, tr1_out, pid0, pid1 = self._run_cluster(model_file,
required_envs) required_envs)
np.random.seed(pid0) np.random.seed(pid0)
......
...@@ -35,7 +35,16 @@ class TestCollectiveGlobalGatherAPI(TestDistBase): ...@@ -35,7 +35,16 @@ class TestCollectiveGlobalGatherAPI(TestDistBase):
"collective_global_gather_dygraph.py", "collective_global_gather_dygraph.py",
"global_gather", "global_gather",
"nccl", "nccl",
static_mode="0") static_mode="0",
eager_mode=False)
def test_global_gather_nccl_dygraph_eager(self):
self.check_with_place(
"collective_global_gather_dygraph.py",
"global_gather",
"nccl",
static_mode="0",
eager_mode=True)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -35,7 +35,16 @@ class TestCollectiveSelectScatterAPI(TestDistBase): ...@@ -35,7 +35,16 @@ class TestCollectiveSelectScatterAPI(TestDistBase):
"collective_global_scatter_dygraph.py", "collective_global_scatter_dygraph.py",
"global_scatter", "global_scatter",
"nccl", "nccl",
static_mode="0") static_mode="0",
eager_mode=False)
def test_global_scatter_nccl_dygraph_eager(self):
self.check_with_place(
"collective_global_scatter_dygraph.py",
"global_scatter",
"nccl",
static_mode="0",
eager_mode=True)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册