diff --git a/paddle/fluid/distributed/collective/ProcessGroup.h b/paddle/fluid/distributed/collective/ProcessGroup.h index fca395c5f2bf71b8446ecf15c1290f7a2f44436c..52e09792d5d80a47ded9c1f692079774d3483dcc 100644 --- a/paddle/fluid/distributed/collective/ProcessGroup.h +++ b/paddle/fluid/distributed/collective/ProcessGroup.h @@ -113,6 +113,19 @@ class ProcessGroup { "ProcessGroup%s does not support receive", GetBackendName())); } + virtual std::shared_ptr Send_Partial(phi::DenseTensor&, + int, int, + int) { // NOLINT + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support send", GetBackendName())); + } + + virtual std::shared_ptr Recv_Partial( + phi::DenseTensor& tensors, int, int, int) { // NOLINT + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support receive", GetBackendName())); + } + virtual std::shared_ptr AllGather( std::vector&, // NOLINT std::vector&) { // NOLINT diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc index 86cc5b5db7cd7486f61dae937c579064f298db17..f1b66864b29309818cce5167c0aa1e63afa5db86 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc @@ -428,6 +428,53 @@ std::shared_ptr ProcessGroupNCCL::Recv( return task; } +std::shared_ptr ProcessGroupNCCL::Send_Partial( + phi::DenseTensor& tensors, int dst_rank, int offset, int length) { + // CheckTensorsInDifferentDevices(tensors, static_cast(GetSize())); + + phi::DenseTensor flatten_tensor; + flatten_tensor.ShareDataWith(tensors).Resize({tensors.numel()}); + + phi::DenseTensor shared_input = flatten_tensor.Slice(offset, offset + length); + + std::vector shared_tensors; + shared_tensors.push_back(shared_input); + + auto task = PointToPoint(shared_tensors, + [&](phi::DenseTensor& input, ncclComm_t comm, + const gpuStream_t& stream, int dst_rank) { + return platform::dynload::ncclSend( + input.data(), input.numel(), + platform::ToNCCLDataType(input.dtype()), + dst_rank, comm, stream); + }, + dst_rank, CommType::SEND); + return task; +} + +std::shared_ptr ProcessGroupNCCL::Recv_Partial( + phi::DenseTensor& tensors, int src_rank, int offset, int length) { + // phi::DenseTensor shared_input = tensors.Slice(offset, offset+length); + + phi::DenseTensor flatten_tensor; + flatten_tensor.ShareDataWith(tensors).Resize({tensors.numel()}); + phi::DenseTensor shared_input = flatten_tensor.Slice(offset, offset + length); + + std::vector shared_tensors; + shared_tensors.push_back(shared_input); + + auto task = PointToPoint(shared_tensors, + [&](phi::DenseTensor& output, ncclComm_t comm, + const gpuStream_t& stream, int src_rank) { + return platform::dynload::ncclRecv( + output.data(), output.numel(), + platform::ToNCCLDataType(output.dtype()), + src_rank, comm, stream); + }, + src_rank, CommType::RECV); + return task; +} + std::shared_ptr ProcessGroupNCCL::AllGather( std::vector& in_tensors, std::vector& out_tensors) { diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h index 4b6c3f4031354d8d6530206d034595d07c361cda..82ced6e135ac93745dbfdb241697c87b60a730cc 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h @@ -102,6 +102,14 @@ class ProcessGroupNCCL : public ProcessGroup { std::shared_ptr Recv( std::vector& tensors, int src_rank) override; + std::shared_ptr Send_Partial(phi::DenseTensor& tensors, + int dst_rank, int offset, + int length) override; + + std::shared_ptr Recv_Partial(phi::DenseTensor& tensors, + int src_rank, int offset, + int length) override; + std::shared_ptr AllGather( std::vector& in_tensors, std::vector& out_tensors) override; diff --git a/paddle/fluid/operators/collective/global_gather_op.cu.cc b/paddle/fluid/operators/collective/global_gather_op.cu.cc index 6684470e881cb0daf56d740f1f820e881c632df1..c256063090cc81dddbc667017d75341d3d00ddfd 100644 --- a/paddle/fluid/operators/collective/global_gather_op.cu.cc +++ b/paddle/fluid/operators/collective/global_gather_op.cu.cc @@ -22,10 +22,10 @@ limitations under the License. */ namespace paddle { namespace operators { + template -class GlobalGatherOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { +struct GlobalGatherFunctor { + void operator()(const framework::ExecutionContext& ctx) { #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if NCCL_VERSION_CODE >= 2703 auto x = ctx.Input("X"); @@ -137,6 +137,132 @@ class GlobalGatherOpCUDAKernel : public framework::OpKernel { } }; +template +struct GlobalGatherProcessGroupFunctor { + void operator()(const framework::ExecutionContext& ctx) { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#if NCCL_VERSION_CODE >= 2703 + auto x = ctx.Input("X"); + auto local_count = ctx.Input("local_count"); + auto global_count = ctx.Input("global_count"); + auto local_count_type = + framework::TransToProtoVarType(local_count->dtype()); + auto global_count_type = + framework::TransToProtoVarType(global_count->dtype()); + if (local_count_type != framework::proto::VarType::INT64) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Please use int64 type in local_count.")); + } + if (global_count_type != framework::proto::VarType::INT64) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Please use int64 type in global_count.")); + } + auto out = ctx.Output("Out"); + const int64_t* cpu_local_count_data; + const int64_t* cpu_global_count_data; + auto local_count_len = 0; + + framework::Tensor cpu_local_count; + if (platform::is_cpu_place(local_count->place())) { + cpu_local_count_data = local_count->data(); + local_count_len = local_count->numel(); + } else { + framework::TensorCopySync(*local_count, platform::CPUPlace(), + &cpu_local_count); + cpu_local_count_data = cpu_local_count.data(); + local_count_len = cpu_local_count.numel(); + } + + framework::Tensor cpu_global_count; + if (platform::is_cpu_place(global_count->place())) { + cpu_global_count_data = global_count->data(); + } else { + framework::TensorCopySync(*global_count, platform::CPUPlace(), + &cpu_global_count); + cpu_global_count_data = cpu_global_count.data(); + } + + int ring_id = ctx.Attr("ring_id"); + PADDLE_ENFORCE_GE( + ring_id, 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for global gather op must be non-negative.", + ring_id)); + auto place = ctx.GetPlace(); + + auto map = distributed::ProcessGroupMapFromGid::getInstance(); + distributed::ProcessGroup* pg = map->get(ring_id); + + int nranks = pg->GetSize(); + auto in_feat = x->dims()[1]; + auto n_expert = local_count->dims()[0] / nranks; + + auto fwd_count = 0; + + for (auto i = 0; i < local_count_len; ++i) { + fwd_count += cpu_local_count_data[i]; + } + framework::DDim out_dims = phi::make_ddim({fwd_count, in_feat}); + int64_t* expert_ptr = new int64_t[n_expert * nranks]; + expert_ptr[0] = 0; + auto tot_experts = n_expert * nranks; + for (auto i = 1; i < tot_experts; ++i) { + expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1]; + } + auto send_ptr = 0; + out->mutable_data(out_dims, place); + + for (auto i = 0; i < n_expert; ++i) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto j = 0; j < nranks; ++j) { + int idx = i + j * n_expert; + if (cpu_global_count_data[idx]) { + phi::DenseTensor tmp = *x; + pg->Send_Partial(tmp, j, send_ptr * in_feat, + cpu_global_count_data[idx] * in_feat); + send_ptr += cpu_global_count_data[idx]; + } + if (cpu_local_count_data[idx]) { + pg->Recv_Partial(*out, j, expert_ptr[idx] * in_feat, + cpu_local_count_data[idx] * in_feat); + } + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + } + +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); +#else + PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); +#endif + +#else + PADDLE_THROW( + platform::errors::Unavailable("NCCL version >= 2.7.3 is needed.")); +#endif +#else + PADDLE_THROW( + platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); +#endif + } +}; + +template +class GlobalGatherOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const int rid = ctx.Attr("ring_id"); + auto map = distributed::ProcessGroupMapFromGid::getInstance(); + if (map->has(rid)) { + GlobalGatherProcessGroupFunctor functor_; + functor_(ctx); + } else { + GlobalGatherFunctor functor_; + functor_(ctx); + } + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/collective/global_gather_op.h b/paddle/fluid/operators/collective/global_gather_op.h index 3ff2df9e48f3d93ac8589745378e435d31ecf6bc..47212b1d15581c5f0bccfde01a52021c2c485915 100644 --- a/paddle/fluid/operators/collective/global_gather_op.h +++ b/paddle/fluid/operators/collective/global_gather_op.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/fluid/distributed/collective/ProcessGroup.h" #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" @@ -33,5 +34,15 @@ class GlobalGatherOpCPUKernel : public framework::OpKernel { } }; +template +struct GlobalGatherFunctor { + void operator()(const framework::ExecutionContext& ctx); +}; + +template +struct GlobalGatherProcessGroupFunctor { + void operator()(const framework::ExecutionContext& ctx); +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/collective/global_scatter_op.cu.cc b/paddle/fluid/operators/collective/global_scatter_op.cu.cc index cd3c3a3229ca068b634e0111e85195e5c7935e34..df8d675ec9d7154c74603ca2edae14f5958e054a 100644 --- a/paddle/fluid/operators/collective/global_scatter_op.cu.cc +++ b/paddle/fluid/operators/collective/global_scatter_op.cu.cc @@ -22,10 +22,10 @@ limitations under the License. */ namespace paddle { namespace operators { + template -class GlobalScatterOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { +struct GlobalScatterFunctor { + void operator()(const framework::ExecutionContext& ctx) { #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if NCCL_VERSION_CODE >= 2703 auto x = ctx.Input("X"); @@ -137,6 +137,130 @@ class GlobalScatterOpCUDAKernel : public framework::OpKernel { } }; +template +struct GlobalScatterProcessGroupFunctor { + void operator()(const framework::ExecutionContext& ctx) { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#if NCCL_VERSION_CODE >= 2703 + auto x = ctx.Input("X"); + auto local_count = ctx.Input("local_count"); + auto global_count = ctx.Input("global_count"); + auto local_count_type = + framework::TransToProtoVarType(local_count->dtype()); + auto global_count_type = + framework::TransToProtoVarType(global_count->dtype()); + if (local_count_type != framework::proto::VarType::INT64) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Please use int64 type in local_count.")); + } + if (global_count_type != framework::proto::VarType::INT64) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Please use int64 type in global_count.")); + } + auto out = ctx.Output("Out"); + const int64_t* cpu_local_count_data; + const int64_t* cpu_global_count_data; + framework::Tensor cpu_local_count; + if (platform::is_cpu_place(local_count->place())) { + cpu_local_count_data = local_count->data(); + } else { + framework::TensorCopySync(*local_count, platform::CPUPlace(), + &cpu_local_count); + cpu_local_count_data = cpu_local_count.data(); + } + auto global_count_len = 0; + framework::Tensor cpu_global_count; + if (platform::is_cpu_place(global_count->place())) { + cpu_global_count_data = global_count->data(); + global_count_len = global_count->numel(); + } else { + framework::TensorCopySync(*global_count, platform::CPUPlace(), + &cpu_global_count); + cpu_global_count_data = cpu_global_count.data(); + global_count_len = cpu_global_count.numel(); + } + + int ring_id = ctx.Attr("ring_id"); + PADDLE_ENFORCE_GE( + ring_id, 0, + platform::errors::InvalidArgument( + "The ring_id (%d) for global scatter op must be non-negative.", + ring_id)); + + auto place = ctx.GetPlace(); + + auto map = distributed::ProcessGroupMapFromGid::getInstance(); + distributed::ProcessGroup* pg = map->get(ring_id); + int nranks = pg->GetSize(); + auto in_feat = x->dims()[1]; + auto n_expert = local_count->dims()[0] / nranks; + int64_t fwd_count = 0; + + for (auto i = 0; i < global_count_len; ++i) { + fwd_count += cpu_global_count_data[i]; + } + framework::DDim out_dims = phi::make_ddim({fwd_count, in_feat}); + int64_t* expert_ptr = new int64_t[n_expert * nranks]; + expert_ptr[0] = 0; + auto tot_experts = n_expert * nranks; + for (auto i = 1; i < tot_experts; ++i) { + expert_ptr[i] = expert_ptr[i - 1] + cpu_local_count_data[i - 1]; + } + + auto recv_ptr = 0; + out->mutable_data(out_dims, place); + + for (auto i = 0; i < n_expert; ++i) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto j = 0; j < nranks; ++j) { + int idx = i + j * n_expert; + if (cpu_local_count_data[idx]) { + phi::DenseTensor tmp = *x; + pg->Send_Partial(tmp, j, expert_ptr[idx] * in_feat, + cpu_local_count_data[idx] * in_feat); + } + if (cpu_global_count_data[idx]) { + pg->Recv_Partial(*out, j, recv_ptr * in_feat, + cpu_global_count_data[idx] * in_feat); + recv_ptr += cpu_global_count_data[idx]; + } + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + } + +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); +#else + PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); +#endif + +#else + PADDLE_THROW( + platform::errors::Unavailable("NCCL version >= 2.7.3 is needed.")); +#endif +#else + PADDLE_THROW( + platform::errors::Unavailable("PaddlePaddle should compile with GPU.")); +#endif + } +}; + +template +class GlobalScatterOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const int rid = ctx.Attr("ring_id"); + auto map = distributed::ProcessGroupMapFromGid::getInstance(); + if (map->has(rid)) { + GlobalScatterProcessGroupFunctor functor_; + functor_(ctx); + } else { + GlobalScatterFunctor functor_; + functor_(ctx); + } + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/collective/global_scatter_op.h b/paddle/fluid/operators/collective/global_scatter_op.h index 52b486aef25c2bda43abef1880ad9c974c9b6f15..aa567a284a6f7314a76afe574e2c5c6719344342 100644 --- a/paddle/fluid/operators/collective/global_scatter_op.h +++ b/paddle/fluid/operators/collective/global_scatter_op.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/fluid/distributed/collective/ProcessGroup.h" #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" @@ -33,5 +34,15 @@ class GlobalScatterOpCPUKernel : public framework::OpKernel { } }; +template +struct GlobalScatterFunctor { + void operator()(const framework::ExecutionContext& ctx); +}; + +template +struct GlobalScatterProcessGroupFunctor { + void operator()(const framework::ExecutionContext& ctx); +}; + } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 2918e8501c3d08bcafe4764ea9411cf2dc023ba3..402e65f76d5b1716a12a815c22c34f6098943401 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1182,8 +1182,8 @@ endif() if((WITH_ROCM OR WITH_GPU) AND NOT WIN32) set_tests_properties(test_collective_allgather_api PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_alltoall_api PROPERTIES TIMEOUT 120) - set_tests_properties(test_collective_global_gather PROPERTIES TIMEOUT 120) - set_tests_properties(test_collective_global_scatter PROPERTIES TIMEOUT 120) + set_tests_properties(test_collective_global_gather PROPERTIES TIMEOUT 200) + set_tests_properties(test_collective_global_scatter PROPERTIES TIMEOUT 200) set_tests_properties(test_collective_sendrecv_api PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_broadcast_api PROPERTIES TIMEOUT 120) set_tests_properties(test_collective_allreduce_api PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/test_collective_api_base.py b/python/paddle/fluid/tests/unittests/test_collective_api_base.py index 00294bf6071b329526768415453a3f80f8039ffa..dbd982947265fd1cc7e10108da612e465588ac45 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_api_base.py +++ b/python/paddle/fluid/tests/unittests/test_collective_api_base.py @@ -191,7 +191,8 @@ class TestDistBase(unittest.TestCase): path_id="0", static_mode="1", check_error_log=False, - need_envs={}): + need_envs={}, + eager_mode=True): if backend == "nccl" or backend == "bkcl": with_gloo = '0' else: @@ -216,6 +217,12 @@ class TestDistBase(unittest.TestCase): required_envs["GLOG_v"] = "3" required_envs["GLOG_logtostderr"] = "1" required_envs["GLOO_LOG_LEVEL"] = "TRACE" + + if eager_mode: + required_envs["FLAGS_enable_eager_mode"] = "%d" % 0 + else: + required_envs["FLAGS_enable_eager_mode"] = "%d" % 1 + tr0_out, tr1_out, pid0, pid1 = self._run_cluster(model_file, required_envs) np.random.seed(pid0) diff --git a/python/paddle/fluid/tests/unittests/test_collective_global_gather.py b/python/paddle/fluid/tests/unittests/test_collective_global_gather.py index c9dee529c21a1650f1eca5893f4dc966b62bf907..6809f3970f6833beb52c625c129699f5cc899c8d 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_global_gather.py +++ b/python/paddle/fluid/tests/unittests/test_collective_global_gather.py @@ -35,7 +35,16 @@ class TestCollectiveGlobalGatherAPI(TestDistBase): "collective_global_gather_dygraph.py", "global_gather", "nccl", - static_mode="0") + static_mode="0", + eager_mode=False) + + def test_global_gather_nccl_dygraph_eager(self): + self.check_with_place( + "collective_global_gather_dygraph.py", + "global_gather", + "nccl", + static_mode="0", + eager_mode=True) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_collective_global_scatter.py b/python/paddle/fluid/tests/unittests/test_collective_global_scatter.py index 2b4555de2744d6f1263e841ae92869abc8bd9ee0..1485bafa387f59f4031101d997de4945c2ffa5ca 100644 --- a/python/paddle/fluid/tests/unittests/test_collective_global_scatter.py +++ b/python/paddle/fluid/tests/unittests/test_collective_global_scatter.py @@ -35,7 +35,16 @@ class TestCollectiveSelectScatterAPI(TestDistBase): "collective_global_scatter_dygraph.py", "global_scatter", "nccl", - static_mode="0") + static_mode="0", + eager_mode=False) + + def test_global_scatter_nccl_dygraph_eager(self): + self.check_with_place( + "collective_global_scatter_dygraph.py", + "global_scatter", + "nccl", + static_mode="0", + eager_mode=True) if __name__ == '__main__':