diff --git a/paddle/fluid/distributed/collective/ProcessGroup.h b/paddle/fluid/distributed/collective/ProcessGroup.h index 09d140c5416538fc6830a60e05a703df2b9611d4..ca1cf7dd48ba707a04e4cbaae56125187d1ecf8a 100644 --- a/paddle/fluid/distributed/collective/ProcessGroup.h +++ b/paddle/fluid/distributed/collective/ProcessGroup.h @@ -125,6 +125,16 @@ class ProcessGroup { "ProcessGroup%s does not support broadcast", GetBackendName())); } + virtual std::shared_ptr Broadcast( + std::vector& /* input tensors */, // NOLINT + std::vector& /* output tensors */, // NOLINT + const BroadcastOptions&, + bool) { + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support broadcast with sync_op flag", + GetBackendName())); + } + virtual std::shared_ptr Barrier( const BarrierOptions& = BarrierOptions()) { PADDLE_THROW(platform::errors::InvalidArgument( @@ -160,14 +170,14 @@ class ProcessGroup { virtual std::shared_ptr Send_Partial( phi::DenseTensor&, // NOLINT int, - int, - int) { + int64_t, + int64_t) { PADDLE_THROW(platform::errors::InvalidArgument( "ProcessGroup%s does not support send_partial", GetBackendName())); } virtual std::shared_ptr Send_Partial( - phi::DenseTensor&, int, int, int, bool) { // NOLINT + phi::DenseTensor&, int, int64_t, int64_t, bool) { // NOLINT PADDLE_THROW(platform::errors::InvalidArgument( "ProcessGroup%s does not support send_partial with sync_op flag", GetBackendName())); @@ -176,14 +186,14 @@ class ProcessGroup { virtual std::shared_ptr Recv_Partial( phi::DenseTensor&, // NOLINT int, - int, - int) { + int64_t, + int64_t) { PADDLE_THROW(platform::errors::InvalidArgument( "ProcessGroup%s does not support recv_partial", GetBackendName())); } virtual std::shared_ptr Recv_Partial( - phi::DenseTensor&, int, int, int, bool) { // NOLINT + phi::DenseTensor&, int, int64_t, int64_t, bool) { // NOLINT PADDLE_THROW(platform::errors::InvalidArgument( "ProcessGroup%s does not support recv_partial with sync_op flag", GetBackendName())); @@ -208,8 +218,8 @@ class ProcessGroup { virtual std::shared_ptr AllGather_Partial( std::vector& in_tensors, // NOLINT std::vector& out_tensors, // NOLINT - int offset, - int length) { // NOLINT + int64_t offset, + int64_t length) { PADDLE_THROW(platform::errors::InvalidArgument( "ProcessGroup%s does not support AllGather_Partial", GetBackendName())); } @@ -217,9 +227,9 @@ class ProcessGroup { virtual std::shared_ptr AllGather_Partial( std::vector& in_tensors, // NOLINT std::vector& out_tensors, // NOLINT - int offset, - int length, - bool) { // NOLINT + int64_t offset, + int64_t length, + bool) { PADDLE_THROW(platform::errors::InvalidArgument( "ProcessGroup%s does not support AllGather_Partial", GetBackendName())); } @@ -231,6 +241,14 @@ class ProcessGroup { "ProcessGroup%s does not support AllToAll", GetBackendName())); } + virtual std::shared_ptr AllToAll( + std::vector&, // NOLINT + std::vector&, // NOLINT + bool) { + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support alltoall", GetBackendName())); + } + virtual std::shared_ptr AllToAll_Single( std::vector&, // NOLINT std::vector&, // NOLINT @@ -240,26 +258,66 @@ class ProcessGroup { "ProcessGroup%s does not support AllToAll_Single", GetBackendName())); } + virtual std::shared_ptr AllToAllSingle( + std::vector&, // NOLINT + std::vector&, // NOLINT + std::vector&, + std::vector&, + bool) { + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support alltoall_single", GetBackendName())); + } + virtual std::shared_ptr Reduce( std::vector&, // NOLINT std::vector&, // NOLINT const ReduceOptions& opts) { PADDLE_THROW(platform::errors::InvalidArgument( - "ProcessGroup%s does not support Reduce", GetBackendName())); + "ProcessGroup%s does not support reduce", GetBackendName())); + } + + virtual std::shared_ptr Reduce( + std::vector& /* input tensors */, // NOLINT + std::vector& /* output tensors */, // NOLINT + const ReduceOptions&, + bool) { + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support reduce with sync_op flag", + GetBackendName())); } virtual std::shared_ptr Scatter( std::vector&, // NOLINT std::vector&, // NOLINT - const ScatterOptions&) { // NOLINT + const ScatterOptions&) { PADDLE_THROW(platform::errors::InvalidArgument( - "ProcessGroup%s does not support Scatter", GetBackendName())); + "ProcessGroup%s does not support scatter", GetBackendName())); + } + + virtual std::shared_ptr Scatter( + std::vector&, // NOLINT + std::vector&, // NOLINT + const ScatterOptions&, + bool) { + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support scatter with sync_op flag", + GetBackendName())); + } + + virtual std::shared_ptr ReduceScatter( + std::vector&, // NOLINT + std::vector&, // NOLINT + const ReduceScatterOptions&, + bool) { + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support reduce_scatter with sync_op flag", + GetBackendName())); } virtual std::shared_ptr _ReduceScatterBase( - phi::DenseTensor&, // NOLINT - phi::DenseTensor&, // NOLINT - const ReduceScatterOptions&) { // NOLINT + phi::DenseTensor&, // NOLINT + phi::DenseTensor&, // NOLINT + const ReduceScatterOptions&) { PADDLE_THROW(platform::errors::InvalidArgument( "ProcessGroup%s does not support ReduceScatter", GetBackendName())); } diff --git a/paddle/fluid/distributed/collective/ProcessGroupCustom.cc b/paddle/fluid/distributed/collective/ProcessGroupCustom.cc index ad9356b368ea264dafde34393c945cf522d63210..f18765a05f619051f041923314d1e5703c3f0e44 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupCustom.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupCustom.cc @@ -267,8 +267,8 @@ void* XcclGetPointerByOffset(void* raw_pointer, std::shared_ptr ProcessGroupCustom::AllGather_Partial( std::vector& in_tensors, std::vector& out_tensors, - int offset, - int length) { + int64_t offset, + int64_t length) { PADDLE_ENFORCE_EQ( CheckTensorsInCustomPlace(in_tensors, device_type_), true, diff --git a/paddle/fluid/distributed/collective/ProcessGroupCustom.h b/paddle/fluid/distributed/collective/ProcessGroupCustom.h index ccce66603afe69b96fdd11f3e575373284966cc9..ce3532bbb6f0e2a8534638d3f20f7cf57c042cc3 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupCustom.h +++ b/paddle/fluid/distributed/collective/ProcessGroupCustom.h @@ -80,8 +80,8 @@ class ProcessGroupCustom : public ProcessGroup { std::shared_ptr AllGather_Partial( std::vector& in_tensors, std::vector& out_tensors, - int offset, - int length) override; + int64_t offset, + int64_t length) override; std::shared_ptr AllReduce( std::vector& in_tensors, @@ -117,8 +117,8 @@ class ProcessGroupCustom : public ProcessGroup { std::set used_place_ids_; private: - void BcastCustomId(std::vector& ccl_ids, - int root, // NOLINT + void BcastCustomId(std::vector& ccl_ids, // NOLINT + int root, int server_fd); void BroadcastUniqueCustomID( diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc index 75f061f693b9ba4aa335eed70124214935788d30..2e18dfcc3ba1208f47a4ceeb2529826d46b44c34 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.cc @@ -628,6 +628,40 @@ std::shared_ptr ProcessGroupNCCL::Broadcast( CommType::BROADCAST); } +std::shared_ptr ProcessGroupNCCL::Broadcast( + std::vector& in_tensors, + std::vector& out_tensors, + const BroadcastOptions& opts, + bool sync_op, + bool use_calc_stream) { + PADDLE_ENFORCE_EQ( + CheckTensorsInCudaPlace(in_tensors), + true, + platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); + + return Collective( + in_tensors, + out_tensors, + [&](phi::DenseTensor& input, + phi::DenseTensor& output, + ncclComm_t comm, + const gpuStream_t& stream) { + const auto root = + opts.source_rank * in_tensors.size() + opts.source_root; + return platform::dynload::ncclBroadcast( + input.data(), + output.data(), + input.numel(), + platform::ToNCCLDataType(input.type()), + root, + comm, + stream); + }, + CommType::BROADCAST, + sync_op, + use_calc_stream); +} + std::shared_ptr ProcessGroupNCCL::Barrier( const BarrierOptions& opts) { // Only support single card single process @@ -782,7 +816,7 @@ std::shared_ptr ProcessGroupNCCL::Recv( } std::shared_ptr ProcessGroupNCCL::Send_Partial( - phi::DenseTensor& tensors, int dst_rank, int offset, int length) { + phi::DenseTensor& tensors, int dst_rank, int64_t offset, int64_t length) { // CheckTensorsInDifferentDevices(tensors, static_cast(GetSize())); phi::DenseTensor flatten_tensor; @@ -813,8 +847,8 @@ std::shared_ptr ProcessGroupNCCL::Send_Partial( std::shared_ptr ProcessGroupNCCL::Send_Partial( phi::DenseTensor& tensors, int dst_rank, - int offset, - int length, + int64_t offset, + int64_t length, bool sync_op, bool use_calc_stream) { phi::DenseTensor flatten_tensor; @@ -845,7 +879,7 @@ std::shared_ptr ProcessGroupNCCL::Send_Partial( } std::shared_ptr ProcessGroupNCCL::Recv_Partial( - phi::DenseTensor& tensors, int src_rank, int offset, int length) { + phi::DenseTensor& tensors, int src_rank, int64_t offset, int64_t length) { // phi::DenseTensor shared_input = tensors.Slice(offset, offset+length); phi::DenseTensor flatten_tensor; @@ -876,8 +910,8 @@ std::shared_ptr ProcessGroupNCCL::Recv_Partial( std::shared_ptr ProcessGroupNCCL::Recv_Partial( phi::DenseTensor& tensors, int src_rank, - int offset, - int length, + int64_t offset, + int64_t length, bool sync_op, bool use_calc_stream) { phi::DenseTensor flatten_tensor; @@ -1009,8 +1043,8 @@ void* GetPointerByOffset(void* raw_pointer, std::shared_ptr ProcessGroupNCCL::AllGather_Partial( std::vector& in_tensors, std::vector& out_tensors, - int offset, - int length) { + int64_t offset, + int64_t length) { PADDLE_ENFORCE_EQ( CheckTensorsInCudaPlace(in_tensors), true, @@ -1040,8 +1074,8 @@ std::shared_ptr ProcessGroupNCCL::AllGather_Partial( std::shared_ptr ProcessGroupNCCL::AllGather_Partial( std::vector& in_tensors, std::vector& out_tensors, - int offset, - int length, + int64_t offset, + int64_t length, bool sync_op, bool use_calc_stream) { PADDLE_ENFORCE_EQ( @@ -1114,6 +1148,52 @@ std::shared_ptr ProcessGroupNCCL::AllToAll( CommType::ALLTOALL); } +std::shared_ptr ProcessGroupNCCL::AllToAll( + std::vector& in_tensors, + std::vector& out_tensors, + bool sync_op, + bool use_calc_stream) { + PADDLE_ENFORCE_EQ( + CheckTensorsInCudaPlace(in_tensors), + true, + platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); + PADDLE_ENFORCE_EQ( + CheckTensorsInCudaPlace(out_tensors), + true, + platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); + return Collective( + in_tensors, + out_tensors, + [&](phi::DenseTensor& input, + phi::DenseTensor& output, + ncclComm_t comm, + const gpuStream_t& stream) { + size_t offset = 0; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto i = 0; i < size_; i++) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( + GetPointerByOffset(input.data(), offset, input.dtype()), + input.numel() / size_, + platform::ToNCCLDataType(input.dtype()), + i, + comm, + stream)); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( + GetPointerByOffset(output.data(), offset, input.dtype()), + input.numel() / size_, + platform::ToNCCLDataType(input.dtype()), + i, + comm, + stream)); + offset += input.numel() / size_; + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + }, + CommType::ALLTOALL, + sync_op, + use_calc_stream); +} + std::shared_ptr ProcessGroupNCCL::AllToAll_Single( std::vector& in_tensors, std::vector& out_tensors, @@ -1176,6 +1256,72 @@ std::shared_ptr ProcessGroupNCCL::AllToAll_Single( CommType::ALLTOALL_SINGLE); } +std::shared_ptr ProcessGroupNCCL::AllToAllSingle( + std::vector& in_tensors, + std::vector& out_tensors, + std::vector& in_sizes, + std::vector& out_sizes, + bool sync_op, + bool use_calc_stream) { + PADDLE_ENFORCE_EQ( + CheckTensorsInCudaPlace(in_tensors), + true, + platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); + PADDLE_ENFORCE_EQ( + CheckTensorsInCudaPlace(out_tensors), + true, + platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); + return Collective( + in_tensors, + out_tensors, + [&](phi::DenseTensor& input, + phi::DenseTensor& output, + ncclComm_t comm, + const gpuStream_t& stream) { + PADDLE_ENFORCE_EQ(input.dtype() == output.dtype(), + true, + platform::errors::InvalidArgument( + "The dtypes of input and output must be equal.")); + + std::vector in_dims = phi::vectorize(input.dims()); + std::vector out_dims = phi::vectorize(output.dims()); + CheckSplitSizes(&in_sizes, in_dims); + CheckSplitSizes(&out_sizes, out_dims); + + size_t in_offset = 0, out_offset = 0; + size_t in_length = 0, out_length = 0; + size_t in_row_size = input.numel() / in_dims[0]; + size_t out_row_size = output.numel() / out_dims[0]; + + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto i = 0; i < size_; i++) { + in_length = in_sizes[i] * in_row_size; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( + GetPointerByOffset(input.data(), in_offset, input.dtype()), + in_length, + platform::ToNCCLDataType(input.dtype()), + i, + comm, + stream)); + in_offset += in_length; + + out_length = out_sizes[i] * out_row_size; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( + GetPointerByOffset(output.data(), out_offset, input.dtype()), + out_length, + platform::ToNCCLDataType(input.dtype()), + i, + comm, + stream)); + out_offset += out_length; + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + }, + CommType::ALLTOALL_SINGLE, + sync_op, + use_calc_stream); +} + std::shared_ptr ProcessGroupNCCL::Reduce( std::vector& in_tensors, std::vector& out_tensors, @@ -1204,6 +1350,70 @@ std::shared_ptr ProcessGroupNCCL::Reduce( CommType::REDUCE); } +std::shared_ptr ProcessGroupNCCL::Reduce( + std::vector& in_tensors, + std::vector& out_tensors, + const ReduceOptions& opts, + bool sync_op, + bool use_calc_stream) { + PADDLE_ENFORCE_EQ( + CheckTensorsInCudaPlace(in_tensors), + true, + platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); + return Collective( + in_tensors, + out_tensors, + [&](const phi::DenseTensor& input, + phi::DenseTensor& output, + ncclComm_t comm, + const gpuStream_t& stream) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduce( + input.data(), + output.data(), + input.numel(), + platform::ToNCCLDataType(input.dtype()), + ToNCCLRedType(opts.reduce_op), + opts.root_rank, + comm, + stream)); + }, + CommType::REDUCE, + sync_op, + use_calc_stream); +} + +std::shared_ptr ProcessGroupNCCL::ReduceScatter( + std::vector& in_tensors, + std::vector& out_tensors, + const ReduceScatterOptions& opts, + bool sync_op, + bool use_calc_stream) { + return Collective( + in_tensors, + out_tensors, + [&](phi::DenseTensor& input, + phi::DenseTensor& output, + ncclComm_t comm, + const gpuStream_t& stream) { + if (FLAGS_use_stream_safe_cuda_allocator) { + platform::CUDADeviceGuard cuda_guard; + cuda_guard.SetDevice(output.place()); + memory::RecordStream(output.Holder(), stream); + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduceScatter( + input.data(), + output.data(), + output.numel(), + platform::ToNCCLDataType(input.dtype()), + ToNCCLRedType(opts.reduce_op), + comm, + stream)); + }, + CommType::REDUCE_SCATTER, + sync_op, + use_calc_stream); +} + std::shared_ptr ProcessGroupNCCL::Scatter( std::vector& in_tensors, std::vector& out_tensors, @@ -1257,6 +1467,68 @@ std::shared_ptr ProcessGroupNCCL::Scatter( CommType::SCATTER); } +std::shared_ptr ProcessGroupNCCL::Scatter( + std::vector& in_tensors, + std::vector& out_tensors, + const ScatterOptions& opts, + bool sync_op, + bool use_calc_stream) { + PADDLE_ENFORCE_EQ( + CheckTensorsInCudaPlace(in_tensors), + true, + platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); + PADDLE_ENFORCE_EQ( + CheckTensorsInCudaPlace(out_tensors), + true, + platform::errors::InvalidArgument("All inputs should be in CudaPlace.")); + return Collective( + in_tensors, + out_tensors, + [&](phi::DenseTensor& input, + phi::DenseTensor& output, + ncclComm_t comm, + const gpuStream_t& stream) { + PADDLE_ENFORCE_EQ( + output.numel(), + input.numel() / size_, + platform::errors::InvalidArgument( + "Input and output tensors should have the same shape.")); + size_t offset = 0; + if (rank_ == opts.root_rank) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart()); + for (auto i = 0; i < size_; i++) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend( + GetPointerByOffset(input.data(), offset, input.dtype()), + input.numel() / size_, + platform::ToNCCLDataType(input.dtype()), + i, + comm, + stream)); + offset += input.numel() / size_; + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( + output.data(), + input.numel() / size_, + platform::ToNCCLDataType(input.dtype()), + opts.root_rank, + comm, + stream)); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd()); + } else { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv( + output.data(), + input.numel() / size_, + platform::ToNCCLDataType(input.dtype()), + opts.root_rank, + comm, + stream)); + } + }, + CommType::SCATTER, + sync_op, + use_calc_stream); +} + std::shared_ptr ProcessGroupNCCL::_ReduceScatterBase( phi::DenseTensor& out_tensor, phi::DenseTensor& in_tensor, diff --git a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h index 6d15f6ebdeff05a5dd5c00f12838debaa9772626..6427e9e3e2ab1c8a06224cd88b1a60a4cf067c61 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupNCCL.h +++ b/paddle/fluid/distributed/collective/ProcessGroupNCCL.h @@ -119,6 +119,13 @@ class ProcessGroupNCCL : public ProcessGroupStream { std::vector& out_tensors, const BroadcastOptions& = BroadcastOptions()) override; + std::shared_ptr Broadcast( + std::vector& in_tensors, + std::vector& out_tensors, + const BroadcastOptions& opts, + bool sync_op, + bool use_calc_stream) override; + std::shared_ptr Barrier( const BarrierOptions& = BarrierOptions()) override; @@ -142,27 +149,27 @@ class ProcessGroupNCCL : public ProcessGroupStream { std::shared_ptr Send_Partial(phi::DenseTensor& tensors, int dst_rank, - int offset, - int length) override; + int64_t offset, + int64_t length) override; std::shared_ptr Send_Partial( phi::DenseTensor& tensors, int dst_rank, - int offset, - int length, + int64_t offset, + int64_t length, bool sync_op, bool use_calc_stream) override; std::shared_ptr Recv_Partial(phi::DenseTensor& tensors, int src_rank, - int offset, - int length) override; + int64_t offset, + int64_t length) override; std::shared_ptr Recv_Partial( phi::DenseTensor& tensors, int src_rank, - int offset, - int length, + int64_t offset, + int64_t length, bool sync_op, bool use_calc_stream) override; @@ -179,20 +186,26 @@ class ProcessGroupNCCL : public ProcessGroupStream { std::shared_ptr AllGather_Partial( std::vector& in_tensors, std::vector& out_tensors, - int offset, - int length) override; + int64_t offset, + int64_t length) override; std::shared_ptr AllGather_Partial( std::vector& in_tensors, std::vector& out_tensors, - int offset, - int length, + int64_t offset, + int64_t length, bool sync_op, bool use_calc_stream) override; std::shared_ptr AllToAll( - std::vector& in, - std::vector& out) override; + std::vector& in_tensors, + std::vector& out_tensors) override; + + std::shared_ptr AllToAll( + std::vector& in_tensors, + std::vector& out_tensors, + bool sync_op, + bool use_calc_stream) override; std::shared_ptr AllToAll_Single( std::vector& in, @@ -200,15 +213,44 @@ class ProcessGroupNCCL : public ProcessGroupStream { std::vector& in_sizes, std::vector& out_sizes) override; + std::shared_ptr AllToAllSingle( + std::vector& in_tensors, + std::vector& out_tensors, + std::vector& in_sizes, + std::vector& out_sizes, + bool sync_op, + bool use_calc_stream) override; + std::shared_ptr Reduce( std::vector& tensors, std::vector& out_tensors, const ReduceOptions& opts) override; + std::shared_ptr Reduce( + std::vector& in_tensors, + std::vector& out_tensors, + const ReduceOptions& opts, + bool sync_op, + bool use_calc_stream) override; + + std::shared_ptr ReduceScatter( + std::vector& in_tensors, + std::vector& out_tensors, + const ReduceScatterOptions& opts, + bool sync_op, + bool use_calc_stream) override; + std::shared_ptr Scatter( std::vector& in_tensors, std::vector& out_tensors, - const ScatterOptions&) override; + const ScatterOptions& opts) override; + + std::shared_ptr Scatter( + std::vector& in_tensors, + std::vector& out_tensors, + const ScatterOptions& opts, + bool sync_op, + bool use_calc_stream) override; std::shared_ptr _ReduceScatterBase( phi::DenseTensor&, // NOLINT diff --git a/paddle/fluid/distributed/collective/ProcessGroupStream.cc b/paddle/fluid/distributed/collective/ProcessGroupStream.cc index 222fe03b60bf19d3ce41a24d0e69873a8b668b15..b2cfae088b2271ab455430a7b06bd9714d31a1f4 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupStream.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupStream.cc @@ -70,6 +70,138 @@ std::shared_ptr ProcessGroupStream::AllReduce( "ProcessGroup%s does not support do all_reduce", GetBackendName())); } +std::shared_ptr ProcessGroupStream::AllToAll( + std::vector& in_tensors, + std::vector& out_tensors, + bool sync_op) { + return AllToAll(in_tensors, + out_tensors, + sync_op, + /*use_calc_stream*/ false); +} + +std::shared_ptr ProcessGroupStream::AllToAll( + std::vector& in_tensors, + std::vector& out_tensors, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support do alltoall", GetBackendName())); +} + +std::shared_ptr ProcessGroupStream::AllToAllSingle( + std::vector& in_tensors, + std::vector& out_tensors, + std::vector& in_sizes, + std::vector& out_sizes, + bool sync_op) { + return AllToAllSingle(in_tensors, + out_tensors, + in_sizes, + out_sizes, + sync_op, + /*use_calc_stream*/ false); +} + +std::shared_ptr ProcessGroupStream::AllToAllSingle( + std::vector& in_tensors, + std::vector& out_tensors, + std::vector& in_sizes, + std::vector& out_sizes, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support do alltoall_single", GetBackendName())); +} + +std::shared_ptr ProcessGroupStream::Broadcast( + std::vector& in_tensors, + std::vector& out_tensors, + const BroadcastOptions& opts, + bool sync_op) { + return Broadcast(in_tensors, + out_tensors, + opts, + sync_op, + /*use_calc_stream*/ false); +} + +std::shared_ptr ProcessGroupStream::Broadcast( + std::vector& in_tensors, + std::vector& out_tensors, + const BroadcastOptions& opts, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support do broadcast", GetBackendName())); +} + +std::shared_ptr ProcessGroupStream::Reduce( + std::vector& in_tensors, + std::vector& out_tensors, + const ReduceOptions& opts, + bool sync_op) { + return Reduce(in_tensors, + out_tensors, + opts, + sync_op, + /*use_calc_stream*/ false); +} + +std::shared_ptr ProcessGroupStream::Reduce( + std::vector& in_tensors, + std::vector& out_tensors, + const ReduceOptions& opts, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support do reduce", GetBackendName())); +} + +std::shared_ptr ProcessGroupStream::ReduceScatter( + std::vector& in_tensors, + std::vector& out_tensors, + const ReduceScatterOptions& opts, + bool sync_op) { + return ReduceScatter(in_tensors, + out_tensors, + opts, + sync_op, + /*use_calc_stream*/ false); +} + +std::shared_ptr ProcessGroupStream::ReduceScatter( + std::vector& in_tensors, + std::vector& out_tensors, + const ReduceScatterOptions& opts, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support do reduce_scatter", GetBackendName())); +} + +std::shared_ptr ProcessGroupStream::Scatter( + std::vector& in_tensors, + std::vector& out_tensors, + const ScatterOptions& opts, + bool sync_op) { + return Scatter(in_tensors, + out_tensors, + opts, + sync_op, + /*use_calc_stream*/ false); +} + +std::shared_ptr ProcessGroupStream::Scatter( + std::vector& in_tensors, + std::vector& out_tensors, + const ScatterOptions& opts, + bool sync_op, + bool use_calc_stream) { + PADDLE_THROW(platform::errors::InvalidArgument( + "ProcessGroup%s does not support do scatter", GetBackendName())); +} + std::shared_ptr ProcessGroupStream::Send( std::vector& tensors, int dst_rank, bool sync_op) { return Send(tensors, @@ -90,8 +222,8 @@ std::shared_ptr ProcessGroupStream::Send( std::shared_ptr ProcessGroupStream::Send_Partial( phi::DenseTensor& tensors, int dst_rank, - int offset, - int length, + int64_t offset, + int64_t length, bool sync_op) { return Send_Partial(tensors, dst_rank, @@ -104,8 +236,8 @@ std::shared_ptr ProcessGroupStream::Send_Partial( std::shared_ptr ProcessGroupStream::Send_Partial( phi::DenseTensor& tensors, int dst_rank, - int offset, - int length, + int64_t offset, + int64_t length, bool sync_op, bool use_calc_stream) { PADDLE_THROW(platform::errors::InvalidArgument( @@ -132,8 +264,8 @@ std::shared_ptr ProcessGroupStream::Recv( std::shared_ptr ProcessGroupStream::Recv_Partial( phi::DenseTensor& tensors, int src_rank, - int offset, - int length, + int64_t offset, + int64_t length, bool sync_op) { return Recv_Partial(tensors, src_rank, @@ -146,8 +278,8 @@ std::shared_ptr ProcessGroupStream::Recv_Partial( std::shared_ptr ProcessGroupStream::Recv_Partial( phi::DenseTensor& tensors, int src_rank, - int offset, - int length, + int64_t offset, + int64_t length, bool sync_op, bool use_calc_stream) { PADDLE_THROW(platform::errors::InvalidArgument( @@ -157,8 +289,8 @@ std::shared_ptr ProcessGroupStream::Recv_Partial( std::shared_ptr ProcessGroupStream::AllGather_Partial( std::vector& in_tensors, std::vector& out_tensors, - int offset, - int length, + int64_t offset, + int64_t length, bool sync_op) { return AllGather_Partial(in_tensors, out_tensors, @@ -171,8 +303,8 @@ std::shared_ptr ProcessGroupStream::AllGather_Partial( std::shared_ptr ProcessGroupStream::AllGather_Partial( std::vector& in_tensors, std::vector& out_tensors, - int offset, - int length, + int64_t offset, + int64_t length, bool sync_op, bool use_calc_stream) { PADDLE_THROW(platform::errors::InvalidArgument( diff --git a/paddle/fluid/distributed/collective/ProcessGroupStream.h b/paddle/fluid/distributed/collective/ProcessGroupStream.h index 1162c3e050925ab2e525b2ef25b0c94defe02733..2f0aa139104e929a24740f5827ee263648df18fe 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupStream.h +++ b/paddle/fluid/distributed/collective/ProcessGroupStream.h @@ -81,6 +81,84 @@ class ProcessGroupStream : public ProcessGroup { bool sync_op, bool use_calc_stream); + std::shared_ptr AllToAll( + std::vector& in_tensors, // NOLINT + std::vector& out_tensors, // NOLINT + bool sync_op) override; + + virtual std::shared_ptr AllToAll( + std::vector& in_tensors, // NOLINT + std::vector& out_tensors, // NOLINT + bool sync_op, + bool use_calc_stream); + + std::shared_ptr AllToAllSingle( + std::vector& in_tensors, // NOLINT + std::vector& out_tensors, // NOLINT + std::vector& in_sizes, // NOLINT + std::vector& out_sizes, // NOLINT + bool sync_op) override; + + virtual std::shared_ptr AllToAllSingle( + std::vector& in_tensors, // NOLINT + std::vector& out_tensors, // NOLINT + std::vector& in_sizes, // NOLINT + std::vector& out_sizes, // NOLINT + bool sync_op, + bool use_calc_stream); + + std::shared_ptr Broadcast( + std::vector& in_tensors, // NOLINT + std::vector& out_tensors, // NOLINT + const BroadcastOptions& opts, + bool sync_op) override; + + virtual std::shared_ptr Broadcast( + std::vector& in_tensors, // NOLINT + std::vector& out_tensors, // NOLINT + const BroadcastOptions& opts, + bool sync_op, + bool use_calc_stream); + + std::shared_ptr Reduce( + std::vector& in_tensors, // NOLINT + std::vector& out_tensors, // NOLINT + const ReduceOptions& opts, + bool sync_op) override; + + virtual std::shared_ptr Reduce( + std::vector& in_tensors, // NOLINT + std::vector& out_tensors, // NOLINT + const ReduceOptions& opts, + bool sync_op, + bool use_calc_stream); + + std::shared_ptr ReduceScatter( + std::vector& in_tensors, // NOLINT + std::vector& out_tensors, // NOLINT + const ReduceScatterOptions& opts, + bool sync_op) override; + + virtual std::shared_ptr ReduceScatter( + std::vector& in_tensors, // NOLINT + std::vector& out_tensors, // NOLINT + const ReduceScatterOptions& opts, + bool sync_op, + bool use_calc_stream); + + std::shared_ptr Scatter( + std::vector& in_tensors, // NOLINT + std::vector& out_tensors, // NOLINT + const ScatterOptions& opts, + bool sync_op) override; + + virtual std::shared_ptr Scatter( + std::vector& in_tensors, // NOLINT + std::vector& out_tensors, // NOLINT + const ScatterOptions& opts, + bool sync_op, + bool use_calc_stream); + std::shared_ptr Send( std::vector& tensors, // NOLINT int dst_rank, @@ -95,15 +173,15 @@ class ProcessGroupStream : public ProcessGroup { std::shared_ptr Send_Partial( phi::DenseTensor& tensors, // NOLINT int dst_rank, - int offset, - int length, + int64_t offset, + int64_t length, bool sync_op) override; virtual std::shared_ptr Send_Partial( phi::DenseTensor& tensors, // NOLINT int dst_rank, - int offset, - int length, + int64_t offset, + int64_t length, bool sync_op, bool use_calc_stream); @@ -121,30 +199,30 @@ class ProcessGroupStream : public ProcessGroup { std::shared_ptr Recv_Partial( phi::DenseTensor& tensors, // NOLINT int src_rank, - int offset, - int length, + int64_t offset, + int64_t length, bool sync_op) override; virtual std::shared_ptr Recv_Partial( phi::DenseTensor& tensors, // NOLINT int src_rank, - int offset, - int length, + int64_t offset, + int64_t length, bool sync_op, bool use_calc_stream); std::shared_ptr AllGather_Partial( std::vector& in_tensors, std::vector& out_tensors, - int offset, - int length, + int64_t offset, + int64_t length, bool sync_op) override; virtual std::shared_ptr AllGather_Partial( std::vector& in_tensors, // NOLINT std::vector& out_tensors, // NOLINT - int offset, - int length, + int64_t offset, + int64_t length, bool sync_op, bool use_calc_stream); }; diff --git a/paddle/fluid/distributed/collective/Utils.h b/paddle/fluid/distributed/collective/Utils.h index 79146febdf80928a7c6cbfe68ca812da882da905..c06c0345163ed7c6d68e7256bc84ee07c183507e 100644 --- a/paddle/fluid/distributed/collective/Utils.h +++ b/paddle/fluid/distributed/collective/Utils.h @@ -14,14 +14,26 @@ #pragma once -#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/backends/device_guard.h" #include "paddle/phi/backends/device_manager.h" +#include "paddle/phi/kernels/funcs/concat_and_split_functor.h" namespace paddle { namespace distributed { +template +struct ConcatDenseTensor { + void operator()(const DeviceContext *context, + const std::vector &in, + phi::DenseTensor *out, + int axis = 0) { + phi::funcs::ConcatFunctor concat_functor; + concat_functor(*context, in, axis, out); + } +}; + template struct SplitDenseTensor { void operator()(const DeviceContext *context, @@ -33,17 +45,36 @@ struct SplitDenseTensor { for (auto *p_tensor : *out) { shape_refer.emplace_back(p_tensor); } - operators::math::SplitFunctor split_functor_; - split_functor_(*context, in, shape_refer, axis, out); + phi::funcs::SplitFunctor split_functor; + split_functor(*context, in, shape_refer, axis, out); } }; #ifdef PADDLE_WITH_CUSTOM_DEVICE +template +struct ConcatDenseTensor { + void operator()(const platform::CustomDeviceContext *context, + const std::vector &in, + phi::DenseTensor *out, + int axis = 0) { + auto *out_data = out->data(); + auto *device = phi::DeviceManager::GetDeviceWithPlace(context->GetPlace()); + size_t offset = 0; + for (const auto &tensor : in) { + const auto *in_data = tensor.data(); + auto sz = tensor.numel() * sizeof(T); + device->MemoryCopyD2D(out_data + offset, in_data, sz, nullptr); + offset += sz; + } + } +}; + template struct SplitDenseTensor { void operator()(const platform::CustomDeviceContext *context, const phi::DenseTensor &in, - std::vector *out) { + std::vector *out, + int axis = 0) { auto *in_data = in.data(); auto *device = phi::DeviceManager::GetDeviceWithPlace(context->GetPlace()); size_t offset = 0; @@ -57,42 +88,119 @@ struct SplitDenseTensor { }; #endif +template +void ConcatDenseTensorWithType(const DeviceContext *dev_ctx, + const std::vector &t_list, + phi::DenseTensor *p_out, + phi::DataType type) { + switch (type) { + case phi::DataType::BOOL: + ConcatDenseTensor()(dev_ctx, t_list, p_out); + break; + case phi::DataType::UINT8: + ConcatDenseTensor()(dev_ctx, t_list, p_out); + break; + case phi::DataType::INT8: + ConcatDenseTensor()(dev_ctx, t_list, p_out); + break; + case phi::DataType::INT32: + ConcatDenseTensor()(dev_ctx, t_list, p_out); + break; + case phi::DataType::INT64: + ConcatDenseTensor()(dev_ctx, t_list, p_out); + break; + case phi::DataType::FLOAT16: + ConcatDenseTensor()( + dev_ctx, t_list, p_out); + break; + case phi::DataType::FLOAT32: + ConcatDenseTensor()(dev_ctx, t_list, p_out); + break; + case phi::DataType::FLOAT64: + ConcatDenseTensor()(dev_ctx, t_list, p_out); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Data type (%s) is not supported when it concats tensors.", type)); + } +} + template void SplitDenseTensorWithType(const DeviceContext *dev_ctx, - const phi::DenseTensor &p_dense, + const phi::DenseTensor &t_in, std::vector *p_list, phi::DataType type) { switch (type) { case phi::DataType::BOOL: - SplitDenseTensor()(dev_ctx, p_dense, p_list); + SplitDenseTensor()(dev_ctx, t_in, p_list); break; case phi::DataType::UINT8: - SplitDenseTensor()(dev_ctx, p_dense, p_list); + SplitDenseTensor()(dev_ctx, t_in, p_list); break; case phi::DataType::INT8: - SplitDenseTensor()(dev_ctx, p_dense, p_list); + SplitDenseTensor()(dev_ctx, t_in, p_list); break; case phi::DataType::INT32: - SplitDenseTensor()(dev_ctx, p_dense, p_list); + SplitDenseTensor()(dev_ctx, t_in, p_list); break; case phi::DataType::INT64: - SplitDenseTensor()(dev_ctx, p_dense, p_list); + SplitDenseTensor()(dev_ctx, t_in, p_list); break; case phi::DataType::FLOAT16: SplitDenseTensor()( - dev_ctx, p_dense, p_list); + dev_ctx, t_in, p_list); break; case phi::DataType::FLOAT32: - SplitDenseTensor()(dev_ctx, p_dense, p_list); + SplitDenseTensor()(dev_ctx, t_in, p_list); break; case phi::DataType::FLOAT64: - SplitDenseTensor()(dev_ctx, p_dense, p_list); + SplitDenseTensor()(dev_ctx, t_in, p_list); break; default: PADDLE_THROW(platform::errors::Unimplemented( - "Data type (%s) is not supported when it splits tensors for " - "allgather.", - type)); + "Data type (%s) is not supported when it splits tensors.", type)); + } +} + +void ConcatTensor(const phi::DeviceContext *dev_ctx, + const std::vector &tensor_list, + const experimental::Tensor *tensor) { + auto *dense_tensor = + std::dynamic_pointer_cast(tensor->impl()).get(); + + const auto &place = dev_ctx->GetPlace(); + if (platform::is_gpu_place(place)) { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + ConcatDenseTensorWithType(static_cast(dev_ctx), + tensor_list, + dense_tensor, + tensor->dtype()); +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "Paddle can't concat tensor since it's not support GPU, please " + "recompile or reinstall Paddle with GPU support.")); +#endif + } else if (platform::is_custom_place(place)) { +#ifdef PADDLE_WITH_CUSTOM_DEVICE + ConcatDenseTensorWithType( + static_cast(dev_ctx), + tensor_list, + dense_tensor, + tensor->dtype()); +#else + PADDLE_THROW(platform::errors::PermissionDenied( + "Paddle can't concat tensor since it's not compiled with " + "CUSTOM_DEVICE, please recompile or reinstall Paddle with " + "CUSTOM_DEVICE support.")); +#endif + } else if (platform::is_cpu_place(place)) { + ConcatDenseTensorWithType(static_cast(dev_ctx), + tensor_list, + dense_tensor, + tensor->dtype()); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Concat tensor not supported on place (%s)", place)); } } @@ -115,8 +223,8 @@ void SplitTensor(const phi::DeviceContext *dev_ctx, tensor.dtype()); #else PADDLE_THROW(platform::errors::PermissionDenied( - "Paddle can't split tensor since it's not support NCCL/RCCL, please " - "recompile or reinstall Paddle with NCCL/RCCL support.")); + "Paddle can't split tensor since it's not support GPU, please " + "recompile or reinstall Paddle with GPU support.")); #endif } else if (platform::is_custom_place(place)) { #ifdef PADDLE_WITH_CUSTOM_DEVICE diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index 39986c604c03630177934e0ff3be5d6c7fb04ccd..6aa8e19c99c61cc131a1e8610e47e6880a528a5c 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -176,6 +176,24 @@ void BindDistributed(py::module *m) { py::arg("source_rank"), py::call_guard()) + .def( + "broadcast", + [](distributed::ProcessGroup &self, + py::handle py_tensor, + int src, + bool sync_op) { + auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); + distributed::BroadcastOptions opts{src}; + auto dense = + std::dynamic_pointer_cast(tensor.impl()); + std::vector tensors = {*dense}; + return self.Broadcast(tensors, tensors, opts, sync_op); + }, + py::arg("tensor"), + py::arg("src"), + py::arg("sync_op"), + py::call_guard()) + .def( "barrier", [](distributed::ProcessGroup &self, std::vector place_ids) { @@ -228,9 +246,9 @@ void BindDistributed(py::module *m) { auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto dense = std::dynamic_pointer_cast(tensor.impl()); - int numel = (*dense).numel(); - int send_numel = numel / nranks; - int offset = send_numel * rank_id; + int64_t numel = (*dense).numel(); + int64_t send_numel = numel / nranks; + int64_t offset = send_numel * rank_id; return self.Send_Partial(*dense, dst_rank, offset, send_numel); }, py::arg("tensor"), @@ -250,9 +268,9 @@ void BindDistributed(py::module *m) { auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto dense = std::dynamic_pointer_cast(tensor.impl()); - int numel = (*dense).numel(); - int send_numel = numel / nranks; - int offset = send_numel * rank_id; + int64_t numel = (*dense).numel(); + int64_t send_numel = numel / nranks; + int64_t offset = send_numel * rank_id; return self.Send_Partial( *dense, dst_rank, offset, send_numel, sync_op); }, @@ -305,9 +323,9 @@ void BindDistributed(py::module *m) { auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto dense = std::dynamic_pointer_cast(tensor.impl()); - int numel = (*dense).numel(); - int recv_numel = numel / nranks; - int offset = recv_numel * rank_id; + int64_t numel = (*dense).numel(); + int64_t recv_numel = numel / nranks; + int64_t offset = recv_numel * rank_id; return self.Recv_Partial(*dense, src_rank, offset, recv_numel); }, py::arg("tensor"), @@ -327,9 +345,9 @@ void BindDistributed(py::module *m) { auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto dense = std::dynamic_pointer_cast(tensor.impl()); - int numel = (*dense).numel(); - int recv_numel = numel / nranks; - int offset = recv_numel * rank_id; + int64_t numel = (*dense).numel(); + int64_t recv_numel = numel / nranks; + int64_t offset = recv_numel * rank_id; return self.Recv_Partial( *dense, src_rank, offset, recv_numel, sync_op); }, @@ -388,7 +406,7 @@ void BindDistributed(py::module *m) { py::call_guard()) .def( - "allgather_base", + "allgather_into_tensor", [](distributed::ProcessGroup &self, py::handle py_in_tensor, py::handle py_out_tensor, @@ -425,9 +443,9 @@ void BindDistributed(py::module *m) { out_tensor.impl()); std::vector in_tensors = {*in_dense}; std::vector out_tensors = {*out_dense}; - int numel = (*in_dense).numel(); - int send_numel = numel / nranks; - int offset = send_numel * rank_id; + int64_t numel = (*in_dense).numel(); + int64_t send_numel = numel / nranks; + int64_t offset = send_numel * rank_id; return self.AllGather_Partial( in_tensors, out_tensors, offset, send_numel); }, @@ -456,6 +474,61 @@ void BindDistributed(py::module *m) { py::arg("out"), py::call_guard()) + .def( + "alltoall", + [](distributed::ProcessGroup &self, + py::handle py_in_tensor_list, + py::handle py_out_tensor_list, + bool sync_op) { + auto in_tensor_list = + CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0); + Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0); + auto in_dense = std::dynamic_pointer_cast( + concat_in_tensor.impl()); + std::vector in_wrapper = {*in_dense}; + + auto out_tensor_list = + CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0); + Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0); + auto out_dense = std::dynamic_pointer_cast( + concat_out_tensor.impl()); + std::vector out_wrapper = {*out_dense}; + + // in_tensor_list should not be empty + const auto *dev_ctx = + self.GetDeviceContext(in_tensor_list.back().place()); + auto task = self.AllToAll(in_wrapper, out_wrapper, sync_op); + distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list); + return task; + }, + py::arg("in"), + py::arg("out"), + py::arg("sync_op"), + py::call_guard()) + + .def( + "alltoall_tensor", + [](distributed::ProcessGroup &self, + py::handle py_in_tensor, + py::handle py_out_tensor, + bool sync_op) { + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto in_dense = std::dynamic_pointer_cast( + in_tensor.impl()); + std::vector in_wrapper = {*in_dense}; + + auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); + auto out_dense = std::dynamic_pointer_cast( + out_tensor.impl()); + std::vector out_wrapper = {*out_dense}; + + return self.AllToAll(in_wrapper, out_wrapper, sync_op); + }, + py::arg("in"), + py::arg("out"), + py::arg("sync_op"), + py::call_guard()) + .def( "alltoall_single", [](distributed::ProcessGroup &self, @@ -480,6 +553,34 @@ void BindDistributed(py::module *m) { py::arg("out_sizes"), py::call_guard()) + .def( + "alltoall_single", + [](distributed::ProcessGroup &self, + py::handle py_in_tensor, + py::handle py_out_tensor, + std::vector &in_sizes, + std::vector &out_sizes, + bool sync_op) { + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto in_dense = std::dynamic_pointer_cast( + in_tensor.impl()); + std::vector in_wrapper = {*in_dense}; + + auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); + auto out_dense = std::dynamic_pointer_cast( + out_tensor.impl()); + std::vector out_wrapper = {*out_dense}; + + return self.AllToAllSingle( + in_wrapper, out_wrapper, in_sizes, out_sizes, sync_op); + }, + py::arg("in"), + py::arg("out"), + py::arg("in_sizes"), + py::arg("out_sizes"), + py::arg("sync_op"), + py::call_guard()) + .def( "reduce", [](distributed::ProcessGroup &self, @@ -499,6 +600,83 @@ void BindDistributed(py::module *m) { py::arg("dst"), py::arg("op") = distributed::ReduceOp::SUM, py::call_guard()) + + .def( + "reduce", + [](distributed::ProcessGroup &self, + py::handle py_in_tensor, + int dst, + distributed::ReduceOp op, + bool sync_op) { + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + distributed::ReduceOptions opts{op, dst}; + auto dense = std::dynamic_pointer_cast( + in_tensor.impl()); + std::vector tensors = {*dense}; + return self.Reduce(tensors, tensors, opts, sync_op); + }, + py::arg("tensor"), + py::arg("dst"), + py::arg("op"), + py::arg("sync_op"), + py::call_guard()) + + .def( + "reduce_scatter", + [](distributed::ProcessGroup &self, + py::handle py_in_tensor_list, + py::handle py_out_tensor, + distributed::ReduceOp op, + bool sync_op) { + auto in_tensor_list = + CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0); + Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0); + auto in_dense = std::dynamic_pointer_cast( + concat_in_tensor.impl()); + std::vector in_wrapper = {*in_dense}; + + auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); + auto out_dense = std::dynamic_pointer_cast( + out_tensor.impl()); + std::vector out_wrapper = {*out_dense}; + + distributed::ReduceScatterOptions opts{op}; + return self.ReduceScatter( + in_wrapper, out_wrapper, opts, sync_op); + }, + py::arg("in"), + py::arg("out"), + py::arg("op"), + py::arg("sync_op"), + py::call_guard()) + + .def( + "reduce_scatter_tensor", + [](distributed::ProcessGroup &self, + py::handle py_in_tensor, + py::handle py_out_tensor, + distributed::ReduceOp op, + bool sync_op) { + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto in_dense = std::dynamic_pointer_cast( + in_tensor.impl()); + std::vector in_wrapper = {*in_dense}; + + auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); + auto out_dense = std::dynamic_pointer_cast( + out_tensor.impl()); + std::vector out_wrapper = {*out_dense}; + + distributed::ReduceScatterOptions opts{op}; + return self.ReduceScatter( + in_wrapper, out_wrapper, opts, sync_op); + }, + py::arg("in"), + py::arg("out"), + py::arg("op"), + py::arg("sync_op"), + py::call_guard()) + .def( "scatter", [](distributed::ProcessGroup &self, @@ -521,6 +699,61 @@ void BindDistributed(py::module *m) { py::arg("out"), py::arg("src"), py::call_guard()) + + .def( + "scatter", + [](distributed::ProcessGroup &self, + py::handle py_in_tensor_list, + py::handle py_out_tensor, + int src, + bool sync_op) { + auto in_tensor_list = + CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0); + Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0); + auto in_dense = std::dynamic_pointer_cast( + concat_in_tensor.impl()); + std::vector in_wrapper = {*in_dense}; + + auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); + auto out_dense = std::dynamic_pointer_cast( + out_tensor.impl()); + std::vector out_wrapper = {*out_dense}; + + distributed::ScatterOptions opts{src}; + return self.Scatter(in_wrapper, out_wrapper, opts, sync_op); + }, + py::arg("in"), + py::arg("out"), + py::arg("src"), + py::arg("sync_op"), + py::call_guard()) + + .def( + "scatter_tensor", + [](distributed::ProcessGroup &self, + py::handle py_in_tensor, + py::handle py_out_tensor, + int src, + bool sync_op) { + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto in_dense = std::dynamic_pointer_cast( + in_tensor.impl()); + std::vector in_wrapper = {*in_dense}; + + auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); + auto out_dense = std::dynamic_pointer_cast( + out_tensor.impl()); + std::vector out_wrapper = {*out_dense}; + + distributed::ScatterOptions opts{src}; + return self.Scatter(in_wrapper, out_wrapper, opts, sync_op); + }, + py::arg("in"), + py::arg("out"), + py::arg("src"), + py::arg("sync_op"), + py::call_guard()) + .def( "_reduce_scatter_base", [](distributed::ProcessGroup &self, @@ -577,7 +810,7 @@ void BindDistributed(py::module *m) { py::call_guard()) .def( - "allgather_base_on_calc_stream", + "allgather_into_tensor_on_calc_stream", [](distributed::ProcessGroupStream &self, py::handle py_in_tensor, py::handle py_out_tensor) { @@ -600,6 +833,37 @@ void BindDistributed(py::module *m) { py::arg("out"), py::call_guard()) + .def( + "all_gather_partial_on_calc_stream", + [](distributed::ProcessGroupStream &self, + py::handle py_in_tensor, + py::handle py_out_tensor, + int nranks, + int rank_id) { + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); + auto in_dense = std::dynamic_pointer_cast( + in_tensor.impl()); + auto out_dense = std::dynamic_pointer_cast( + out_tensor.impl()); + std::vector in_tensors = {*in_dense}; + std::vector out_tensors = {*out_dense}; + int64_t numel = (*in_dense).numel(); + int64_t send_numel = numel / nranks; + int64_t offset = send_numel * rank_id; + return self.AllGather_Partial(in_tensors, + out_tensors, + offset, + send_numel, + /*sync_op*/ true, + /*use_calc_stream*/ true); + }, + py::arg("in"), + py::arg("out"), + py::arg("num"), + py::arg("id"), + py::call_guard()) + .def( "allreduce_on_calc_stream", [](distributed::ProcessGroupStream &self, @@ -622,34 +886,248 @@ void BindDistributed(py::module *m) { py::call_guard()) .def( - "all_gather_partial_on_calc_stream", + "alltoall_on_calc_stream", + [](distributed::ProcessGroupStream &self, + py::handle py_in_tensor_list, + py::handle py_out_tensor_list) { + auto in_tensor_list = + CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0); + Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0); + auto in_dense = std::dynamic_pointer_cast( + concat_in_tensor.impl()); + std::vector in_wrapper = {*in_dense}; + + auto out_tensor_list = + CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0); + Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0); + auto out_dense = std::dynamic_pointer_cast( + concat_out_tensor.impl()); + std::vector out_wrapper = {*out_dense}; + + // in_tensor_list must not be empty + const auto *dev_ctx = self.GetDeviceContext( + in_tensor_list.back().place(), /*use_calc_stream*/ true); + auto task = self.AllToAll(in_wrapper, + out_wrapper, + /*sync_op*/ true, + /*use_calc_stream*/ true); + distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list); + return task; + }, + py::arg("in"), + py::arg("out"), + py::call_guard()) + + .def( + "alltoall_tensor_on_calc_stream", + [](distributed::ProcessGroupStream &self, + py::handle py_in_tensor, + py::handle py_out_tensor) { + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto in_dense = std::dynamic_pointer_cast( + in_tensor.impl()); + std::vector in_wrapper = {*in_dense}; + + auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); + auto out_dense = std::dynamic_pointer_cast( + out_tensor.impl()); + std::vector out_wrapper = {*out_dense}; + + return self.AllToAll(in_wrapper, + out_wrapper, + /*sync_op*/ true, + /*use_calc_stream*/ true); + }, + py::arg("in"), + py::arg("out"), + py::call_guard()) + + .def( + "alltoall_single_on_calc_stream", [](distributed::ProcessGroupStream &self, py::handle py_in_tensor, py::handle py_out_tensor, - int nranks, - int rank_id) { + std::vector &in_sizes, + std::vector &out_sizes) { auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto in_dense = std::dynamic_pointer_cast( + in_tensor.impl()); + std::vector in_wrapper = {*in_dense}; + auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); + auto out_dense = std::dynamic_pointer_cast( + out_tensor.impl()); + std::vector out_wrapper = {*out_dense}; + + return self.AllToAllSingle(in_wrapper, + out_wrapper, + in_sizes, + out_sizes, + /*sync_op*/ true, + /*use_calc_stream*/ true); + }, + py::arg("in"), + py::arg("out"), + py::arg("in_sizes"), + py::arg("out_sizes"), + py::call_guard()) + + .def( + "broadcast_on_calc_stream", + [](distributed::ProcessGroupStream &self, + py::handle py_tensor, + int src) { + auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); + distributed::BroadcastOptions opts{src}; + auto dense = + std::dynamic_pointer_cast(tensor.impl()); + std::vector tensors = {*dense}; + return self.Broadcast(tensors, + tensors, + opts, + /*sync_op*/ true, + /*use_calc_stream*/ true); + }, + py::arg("tensor"), + py::arg("src"), + py::call_guard()) + + .def( + "reduce_on_calc_stream", + [](distributed::ProcessGroupStream &self, + py::handle py_in_tensor, + int dst, + distributed::ReduceOp op) { + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + distributed::ReduceOptions opts{op, dst}; + auto dense = std::dynamic_pointer_cast( + in_tensor.impl()); + std::vector tensors = {*dense}; + return self.Reduce(tensors, + tensors, + opts, + /*sync_op*/ true, + /*use_calc_stream*/ true); + }, + py::arg("tensor"), + py::arg("dst"), + py::arg("op"), + py::call_guard()) + + .def( + "reduce_scatter_on_calc_stream", + [](distributed::ProcessGroupStream &self, + py::handle py_in_tensor_list, + py::handle py_out_tensor, + distributed::ReduceOp op) { + auto in_tensor_list = + CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0); + Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0); + auto in_dense = std::dynamic_pointer_cast( + concat_in_tensor.impl()); + std::vector in_wrapper = {*in_dense}; + + auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); + auto out_dense = std::dynamic_pointer_cast( + out_tensor.impl()); + std::vector out_wrapper = {*out_dense}; + + distributed::ReduceScatterOptions opts{op}; + return self.ReduceScatter(in_wrapper, + out_wrapper, + opts, + /*sync_op*/ true, + /*use_calc_stream*/ true); + }, + py::arg("in"), + py::arg("out"), + py::arg("op"), + py::call_guard()) + + .def( + "reduce_scatter_tensor_on_calc_stream", + [](distributed::ProcessGroupStream &self, + py::handle py_in_tensor, + py::handle py_out_tensor, + distributed::ReduceOp op) { + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); auto in_dense = std::dynamic_pointer_cast( in_tensor.impl()); + std::vector in_wrapper = {*in_dense}; + + auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); auto out_dense = std::dynamic_pointer_cast( out_tensor.impl()); - std::vector in_tensors = {*in_dense}; - std::vector out_tensors = {*out_dense}; - int numel = (*in_dense).numel(); - int send_numel = numel / nranks; - int offset = send_numel * rank_id; - return self.AllGather_Partial(in_tensors, - out_tensors, - offset, - send_numel, - /*sync_op*/ true, - /*use_calc_stream*/ true); + std::vector out_wrapper = {*out_dense}; + + distributed::ReduceScatterOptions opts{op}; + return self.ReduceScatter(in_wrapper, + out_wrapper, + opts, + /*sync_op*/ true, + /*use_calc_stream*/ true); }, py::arg("in"), py::arg("out"), - py::arg("num"), - py::arg("id"), + py::arg("op"), + py::call_guard()) + + .def( + "scatter_on_calc_stream", + [](distributed::ProcessGroupStream &self, + py::handle py_in_tensor_list, + py::handle py_out_tensor, + int src) { + auto in_tensor_list = + CastPyArg2VectorOfTensor(py_in_tensor_list.ptr(), 0); + Tensor concat_in_tensor = paddle::concat(in_tensor_list, 0); + auto in_dense = std::dynamic_pointer_cast( + concat_in_tensor.impl()); + std::vector in_wrapper = {*in_dense}; + + auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); + auto out_dense = std::dynamic_pointer_cast( + out_tensor.impl()); + std::vector out_wrapper = {*out_dense}; + + distributed::ScatterOptions opts{src}; + return self.Scatter(in_wrapper, + out_wrapper, + opts, + /*sync_op*/ true, + /*use_calc_stream*/ true); + }, + py::arg("in"), + py::arg("out"), + py::arg("src"), + py::call_guard()) + + .def( + "scatter_tensor_on_calc_stream", + [](distributed::ProcessGroupStream &self, + py::handle py_in_tensor, + py::handle py_out_tensor, + int src) { + auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0); + auto in_dense = std::dynamic_pointer_cast( + in_tensor.impl()); + std::vector in_wrapper = {*in_dense}; + + auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0); + auto out_dense = std::dynamic_pointer_cast( + out_tensor.impl()); + std::vector out_wrapper = {*out_dense}; + + distributed::ScatterOptions opts{src}; + return self.Scatter(in_wrapper, + out_wrapper, + opts, + /*sync_op*/ true, + /*use_calc_stream*/ true); + }, + py::arg("in"), + py::arg("out"), + py::arg("src"), py::call_guard()) .def( @@ -680,9 +1158,9 @@ void BindDistributed(py::module *m) { auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto dense = std::dynamic_pointer_cast(tensor.impl()); - int numel = (*dense).numel(); - int send_numel = numel / nranks; - int offset = send_numel * rank_id; + int64_t numel = (*dense).numel(); + int64_t send_numel = numel / nranks; + int64_t offset = send_numel * rank_id; return self.Send_Partial(*dense, dst_rank, offset, @@ -724,9 +1202,9 @@ void BindDistributed(py::module *m) { auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0); auto dense = std::dynamic_pointer_cast(tensor.impl()); - int numel = (*dense).numel(); - int recv_numel = numel / nranks; - int offset = recv_numel * rank_id; + int64_t numel = (*dense).numel(); + int64_t recv_numel = numel / nranks; + int64_t offset = recv_numel * rank_id; return self.Recv_Partial(*dense, src_rank, offset, diff --git a/python/paddle/distributed/communication/stream/__init__.py b/python/paddle/distributed/communication/stream/__init__.py index deab1f97ea28e101f4d5f95fec5b499582876445..a1844decf94781ea53923d4b5bec9eaac7218e80 100644 --- a/python/paddle/distributed/communication/stream/__init__.py +++ b/python/paddle/distributed/communication/stream/__init__.py @@ -14,7 +14,16 @@ from .all_gather import all_gather from .all_reduce import all_reduce -from .send import send +from .alltoall import alltoall +from .alltoall_single import alltoall_single +from .broadcast import broadcast +from .reduce import reduce +from .reduce_scatter import _reduce_scatter_base, reduce_scatter from .recv import recv +from .scatter import scatter +from .send import send -__all__ = ["all_gather", "all_reduce", "send", "recv"] +__all__ = [ + "_reduce_scatter_base", "all_reduce", "alltoall", "alltoall_single", + "broadcast", "reduce", "reduce_scatter", "recv", "scatter", "send" +] diff --git a/python/paddle/distributed/communication/stream/all_gather.py b/python/paddle/distributed/communication/stream/all_gather.py index dca2957309068f3e7b948dd7229e8601c4959da6..9eb961cda171d42bb38c84ae16067dd65e1e5e69 100644 --- a/python/paddle/distributed/communication/stream/all_gather.py +++ b/python/paddle/distributed/communication/stream/all_gather.py @@ -34,17 +34,18 @@ def _check_tensor_list_shape(tensor_list, shape, nranks=1): 'The tensor_list for all_gather is not correctly-sized.') -def _all_gather_base_in_dygraph(out_tensor, in_tensor, group, sync_op, - use_calc_stream): +def _all_gather_into_tensor_in_dygraph(out_tensor, in_tensor, group, sync_op, + use_calc_stream): group = collective._get_default_group() if group is None else group _check_tensor_shape(out_tensor, in_tensor.shape, group.nranks) if use_calc_stream: - return group.process_group.allgather_base_on_calc_stream( + return group.process_group.allgather_into_tensor_on_calc_stream( in_tensor, out_tensor) - task = group.process_group.allgather_base(in_tensor, out_tensor, sync_op) + task = group.process_group.allgather_into_tensor(in_tensor, out_tensor, + sync_op) if sync_op: task.wait() @@ -83,7 +84,7 @@ def all_gather(tensor_or_tensor_list, tensor_or_tensor_list (Union[Tensor, List[Tensor]]): The output. If it is a tensor, it should be correctly-sized. If it is a list, it should be empty or contain correctly-sized tensors. tensor (Tensor): The input tensor on each rank. The result will overwrite this tenor after communication. Support - float16, float32, float64, int32 or int64 as the input data type. + float16, float32, float64, int32, int64, int8, uint8 or bool as the input data type. group (Group, optional): Communicate in which group. If none is given, use the global group as default. sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default. use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This @@ -125,8 +126,9 @@ def all_gather(tensor_or_tensor_list, if framework.in_dygraph_mode(): if paddle.is_tensor(tensor_or_tensor_list): - return _all_gather_base_in_dygraph(tensor_or_tensor_list, tensor, - group, sync_op, use_calc_stream) + return _all_gather_into_tensor_in_dygraph(tensor_or_tensor_list, + tensor, group, sync_op, + use_calc_stream) else: return _all_gather_in_dygraph(tensor_or_tensor_list, tensor, group, sync_op, use_calc_stream) diff --git a/python/paddle/distributed/communication/stream/all_reduce.py b/python/paddle/distributed/communication/stream/all_reduce.py index 965a6ae89008a3c362afa7a171572df0d8f19427..0ba161a078ab89e82631aedd623d7934f160d3df 100644 --- a/python/paddle/distributed/communication/stream/all_reduce.py +++ b/python/paddle/distributed/communication/stream/all_reduce.py @@ -70,7 +70,7 @@ def all_reduce(tensor, Args: tensor (Tensor): The input tensor on each rank. The result will overwrite this tenor after communication. Support - float16, float32, float64, int32 or int64 as the input data type. + float16, float32, float64, int32, int64, int8, uint8 or bool as the input data type. op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD, optional): The reduction used. If none is given, use ReduceOp.SUM as default. group (Group, optional): Communicate in which group. If none is given, use the global group as default. sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default. diff --git a/python/paddle/distributed/communication/stream/alltoall.py b/python/paddle/distributed/communication/stream/alltoall.py new file mode 100644 index 0000000000000000000000000000000000000000..b216906d0456888285405f170b48a090b03a61a9 --- /dev/null +++ b/python/paddle/distributed/communication/stream/alltoall.py @@ -0,0 +1,157 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.fluid.framework as framework +from paddle.distributed import collective + + +def _check_tensor_shape(tensor, shape, nranks=1): + if tensor.shape != shape: + raise RuntimeError('The tensor for alltoall is not correctly-sized.') + + +def _check_tensor_list_shape(tensor_list, shape, nranks=1): + if len(tensor_list) != nranks: + raise RuntimeError( + 'The tensor_list for alltoall is not correctly-sized.') + for tensor in tensor_list: + if tensor.shape != shape: + raise RuntimeError( + 'The tensor_list for alltoall is not correctly-sized.') + + +def _alltoall_tensor_in_dygraph(out_tensor, in_tensor, group, sync_op, + use_calc_stream): + group = collective._get_default_group() if group is None else group + + _check_tensor_shape(out_tensor, in_tensor.shape, group.nranks) + + if use_calc_stream: + return group.process_group.alltoall_tensor_on_calc_stream( + in_tensor, out_tensor) + + task = group.process_group.alltoall_tensor(in_tensor, out_tensor, sync_op) + if sync_op: + task.wait() + + return task + + +def _alltoall_in_dygraph(out_tensor_list, in_tensor_list, group, sync_op, + use_calc_stream): + group = collective._get_default_group() if group is None else group + + if len(in_tensor_list) == 0: + raise RuntimeError("The input tensor_list should not be empty.") + + if len(out_tensor_list) == 0: + out_tensor_list += [ + paddle.empty_like(tensor) for tensor in in_tensor_list + ] + else: + _check_tensor_list_shape(out_tensor_list, in_tensor_list[0].shape, + group.nranks) + + if use_calc_stream: + return group.process_group.alltoall_on_calc_stream( + in_tensor_list, out_tensor_list) + + task = group.process_group.alltoall(in_tensor_list, out_tensor_list, + sync_op) + if sync_op: + task.wait() + + return task + + +def alltoall(out_tensor_or_tensor_list, + in_tensor_or_tensor_list, + group=None, + sync_op=True, + use_calc_stream=False): + """ + + Scatter a tensor (or a tensor list) across devices and gather outputs to another tensor (or a tensor list, respectively). + + Args: + out_tensor_or_tensor_list (Union[Tensor, List[Tensor]]): The output. If it is a tensor, it should be correctly-sized. + If it is a list, it should be empty or contain correctly-sized tensors. Its data type should be the same as the input. + in_tensor_or_tensor_list (Union[Tensor, List[Tensor]]): The input to scatter (must be specified on the source rank). + If it is a tensor, it should be correctly-sized. If it is a list, it should contain correctly-sized tensors. Support + float16, float32, float64, int32, int64, int8, uint8 or bool as the input data type. + group (Group, optional): Communicate in which group. If none is given, use the global group as default. + sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default. + use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This + option is designed for high performance demand, be careful to turn it on except you are clearly know its meaning. + + Returns: + Return a task object. + + Warning: + This API only supports the dygraph mode now. + + Examples: + .. code-block:: python + + # required: distributed + import paddle + import paddle.distributed as dist + + dist.init_parallel_env() + out_tensor_list = [] + if dist.get_rank() == 0: + data1 = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]) + data2 = paddle.to_tensor([[7, 8, 9], [10, 11, 12]]) + else: + data1 = paddle.to_tensor([[13, 14, 15], [16, 17, 18]]) + data2 = paddle.to_tensor([[19, 20, 21], [22, 23, 24]]) + task = dist.stream.alltoall(out_tensor_list, [data1, data2], sync_op=False) + task.wait() + print(out_tensor_list) + # [[[1, 2, 3], [4, 5, 6]], [[13, 14, 15], [16, 17, 18]]] (2 GPUs, out for rank 0) + # [[[7, 8, 9], [10, 11, 12]], [[19, 20, 21], [22, 23, 24]]] (2 GPUs, out for rank 1) + """ + if group is not None and not group.is_member(): + raise RuntimeError( + "The group should not be None and all ranks which invoke this operation should be the member of this group." + ) + + if not sync_op and use_calc_stream: + raise RuntimeError( + "use_calc_stream can only be true in sync op behavior.") + + if out_tensor_or_tensor_list is None: + raise RuntimeError("The output should be specified.") + if in_tensor_or_tensor_list is None: + raise RuntimeError("The input should be specified.") + + if framework.in_dygraph_mode(): + out_is_tensor = paddle.is_tensor(out_tensor_or_tensor_list) + in_is_tensor = paddle.is_tensor(in_tensor_or_tensor_list) + if out_is_tensor and in_is_tensor: + return _alltoall_tensor_in_dygraph(out_tensor_or_tensor_list, + in_tensor_or_tensor_list, group, + sync_op, use_calc_stream) + elif not out_is_tensor and not in_is_tensor: + return _alltoall_in_dygraph(out_tensor_or_tensor_list, + in_tensor_or_tensor_list, group, + sync_op, use_calc_stream) + else: + raise RuntimeError( + "The output and input should be both tensor or tensor list.") + + raise RuntimeError( + "paddle.distributed.stream.alltoall is only supported in dygraph mode now." + ) diff --git a/python/paddle/distributed/communication/stream/alltoall_single.py b/python/paddle/distributed/communication/stream/alltoall_single.py new file mode 100644 index 0000000000000000000000000000000000000000..b2187cc06e343984ae11d005a0ffb5bc27bb8d6f --- /dev/null +++ b/python/paddle/distributed/communication/stream/alltoall_single.py @@ -0,0 +1,128 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid.framework as framework +from paddle.distributed import collective + + +def _alltoall_single_in_dygraph(out_tensor, in_tensor, out_split_sizes, + in_split_sizes, group, sync_op, + use_calc_stream): + group = collective._get_default_group() if group is None else group + + if out_split_sizes is None: + out_split_sizes = [] + if in_split_sizes is None: + in_split_sizes = [] + + if use_calc_stream: + return group.process_group.alltoall_single_on_calc_stream( + in_tensor, out_tensor, in_split_sizes, out_split_sizes) + + task = group.process_group.alltoall_single(in_tensor, out_tensor, + in_split_sizes, out_split_sizes, + sync_op) + if sync_op: + task.wait() + + return task + + +def alltoall_single(out_tensor, + in_tensor, + out_split_sizes=None, + in_split_sizes=None, + group=None, + sync_op=True, + use_calc_stream=False): + """ + + Split and Scatter the splitted input tensor to the out tensor across devices. + + Args: + out_tensor(Tensor): The output tensor. Its data type should be the same as the input. + in_tensor (Tensor): The input tensor. Its data type should be float16, float32, float64, int32, int64, int8, uint8 or bool. + out_split_sizes (List[int], optional): Split sizes of out_tensor for dim[0]. If not given, dim[0] of out_tensor must be divisible + by group size and out_tensor will be gathered averagely from all participators. If none is given, use a empty list as default. + in_split_sizes (List[int], optional): Split sizes of in_tensor for dim[0]. If not given, dim[0] of in_tensor must be divisible + by group size and in_tensor will be scattered averagely to all participators. If none is given, use a empty list as default. + group (Group, optional): Communicate in which group. If none is given, use the global group as default. + sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default. + use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This + option is designed for high performance demand, be careful to turn it on except you are clearly know its meaning. + + Returns: + Return a task object. + + Warning: + This API only supports the dygraph mode now. + + Examples: + .. code-block:: python + + # required: distributed + import paddle + import paddle.distributed as dist + + dist.init_parallel_env() + local_rank = dist.get_rank() + + # case 1 + output = paddle.empty([2], dtype="int64") + if local_rank == 0: + data = paddle.to_tensor([0, 1]) + else: + data = paddle.to_tensor([2, 3]) + task = dist.stream.alltoall_single(output, data, sync_op=False) + task.wait() + out = output.numpy() + # [0, 2] (2 GPUs, out for rank 0) + # [1, 3] (2 GPUs, out for rank 1) + + # case 2 + size = dist.get_world_size() + output = paddle.empty([(local_rank + 1) * size, size], dtype='float32') + if local_rank == 0: + data = paddle.to_tensor([[0., 0.], [0., 0.], [0., 0.]]) + else: + data = paddle.to_tensor([[1., 1.], [1., 1.], [1., 1.]]) + out_split_sizes = [local_rank + 1 for i in range(size)] + in_split_sizes = [i + 1 for i in range(size)] + task = dist.stream.alltoall_single(output, + data, + out_split_sizes, + in_split_sizes, + sync_op=False) + task.wait() + out = output.numpy() + # [[0., 0.], [1., 1.]] (2 GPUs, out for rank 0) + # [[0., 0.], [0., 0.], [1., 1.], [1., 1.]] (2 GPUs, out for rank 1) + """ + if group is not None and not group.is_member(): + raise RuntimeError( + "The group should not be None and all ranks which invoke this operation should be the member of this group." + ) + + if not sync_op and use_calc_stream: + raise RuntimeError( + "use_calc_stream can only be true in sync op behavior.") + + if framework.in_dygraph_mode(): + return _alltoall_single_in_dygraph(out_tensor, in_tensor, + out_split_sizes, in_split_sizes, + group, sync_op, use_calc_stream) + + raise RuntimeError( + "paddle.distributed.stream.alltoall_single is only supported in dygraph mode now." + ) diff --git a/python/paddle/distributed/communication/stream/broadcast.py b/python/paddle/distributed/communication/stream/broadcast.py new file mode 100644 index 0000000000000000000000000000000000000000..06bde316937a9d92325969324c249225991d10e7 --- /dev/null +++ b/python/paddle/distributed/communication/stream/broadcast.py @@ -0,0 +1,83 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid.framework as framework +from paddle.distributed import collective + + +def _broadcast_in_dygraph(tensor, src, group, sync_op, use_calc_stream): + group = collective._get_default_group() if group is None else group + if use_calc_stream: + return group.process_group.broadcast_on_calc_stream(tensor, src) + + task = group.process_group.broadcast(tensor, src, sync_op) + if sync_op: + task.wait() + + return task + + +def broadcast(tensor, src=0, group=None, sync_op=True, use_calc_stream=False): + """ + + Broadcast a tensor to all devices. + + Args: + tensor (Tensor): The tensor to broadcast. Support float16, float32, float64, int32, int64, int8, uint8 or bool as its data type. + src (int, optional): Rank of the source device. If none is given, use `0` as default. + group (Group, optional): Communicate in which group. If none is given, use the global group as default. + sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default. + use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This + option is designed for high performance demand, be careful to turn it on except you are clearly know its meaning. + + Returns: + Return a task object. + + Warning: + This API only supports the dygraph mode now. + + Examples: + .. code-block:: python + + # required: distributed + import paddle + import paddle.distributed as dist + + dist.init_parallel_env() + local_rank = dist.get_rank() + if local_rank == 0: + data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]]) + else: + data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]]) + task = dist.stream.broadcast(data, src=1, sync_op=False) + task.wait() + out = data.numpy() + # [[1, 2, 3], [1, 2, 3]] (2 GPUs) + """ + if group is not None and not group.is_member(): + raise RuntimeError( + "The group should not be None and all ranks which invoke this operation should be the member of this group." + ) + + if not sync_op and use_calc_stream: + raise RuntimeError( + "use_calc_stream can only be True in sync op behavior.") + + if framework.in_dygraph_mode(): + return _broadcast_in_dygraph(tensor, src, group, sync_op, + use_calc_stream) + + raise RuntimeError( + "paddle.distributed.stream.broadcast is only supported in dygraph mode now." + ) diff --git a/python/paddle/distributed/communication/stream/recv.py b/python/paddle/distributed/communication/stream/recv.py index 8b77fa51b0de079e760446609023fdfbc57066f3..25a8173788473aa79f9f32ddae9945d69156fb80 100644 --- a/python/paddle/distributed/communication/stream/recv.py +++ b/python/paddle/distributed/communication/stream/recv.py @@ -64,7 +64,7 @@ def recv(tensor, src=0, group=None, sync_op=True, use_calc_stream=False): task = dist.stream.recv(data, src=0, sync_op=False) task.wait() out = data.numpy() - # [[4, 5, 6], [4, 5, 6] + # [[4, 5, 6], [4, 5, 6]] (2 GPUs) """ if group is not None and not group.is_member(): raise RuntimeError( diff --git a/python/paddle/distributed/communication/stream/reduce.py b/python/paddle/distributed/communication/stream/reduce.py new file mode 100644 index 0000000000000000000000000000000000000000..b0f7f5c884743d3bd1332e9679c45ff52c868b77 --- /dev/null +++ b/python/paddle/distributed/communication/stream/reduce.py @@ -0,0 +1,93 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid.framework as framework +from paddle.distributed.communication.group import _get_global_group +from paddle.distributed.communication.reduce import _get_reduce_op, ReduceOp + + +def _reduce_in_dygraph(tensor, dst, op, group, sync_op, use_calc_stream): + op_type = _get_reduce_op(op, "reduce") + group = _get_global_group() if group is None else group + if use_calc_stream: + return group.process_group.reduce_on_calc_stream(tensor, dst, op_type) + + task = group.process_group.reduce(tensor, dst, op_type, sync_op) + if sync_op: + task.wait() + + return task + + +def reduce(tensor, + dst=0, + op=ReduceOp.SUM, + group=None, + sync_op=True, + use_calc_stream=False): + """ + + Perform specific reduction (for example, sum, max) on a tensor across devices and send to the destintion device. + + Args: + tensor (Tensor): The input tensor on each rank. The result will overwrite this tenor after communication. Support + float16, float32, float64, int32, int64, int8, uint8 or bool as the input data type. + dst (int, optional): Rank of the destination device. If none is given, use `0` as default. + op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD, optional): The reduction used. If none is given, use ReduceOp.SUM as default. + group (Group, optional): Communicate in which group. If none is given, use the global group as default. + sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default. + use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This + option is designed for high performance demand, be careful to turn it on except you are clearly know its meaning. + + Returns: + Return a task object. + + Warning: + This API only supports the dygraph mode now. + + Examples: + .. code-block:: python + + # required: distributed + import paddle + import paddle.distributed as dist + + dist.init_parallel_env() + local_rank = dist.get_rank() + if local_rank == 0: + data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]]) + else: + data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]]) + task = dist.stream.reduce(data, dst=0, sync_op=False) + task.wait() + out = data.numpy() + # [[5, 7, 9], [5, 7, 9]] (2 GPUs, out for rank 0) + # [[1, 2, 3], [1, 2, 3]] (2 GPUs, out for rank 1) + """ + if group is not None and not group.is_member(): + raise RuntimeError( + "The group should not be None and all ranks which invoke this operation should be the member of this group." + ) + + if not sync_op and use_calc_stream: + raise RuntimeError( + "use_calc_stream can only be true in sync op behavior.") + + if framework.in_dygraph_mode(): + return _reduce_in_dygraph(tensor, dst, op, group, sync_op, + use_calc_stream) + + raise RuntimeError( + "paddle.distributed.stream.reduce is only supported in dygraph mode now." + ) diff --git a/python/paddle/distributed/communication/stream/reduce_scatter.py b/python/paddle/distributed/communication/stream/reduce_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..a4aeae6312a302b49f6562e6d8bf9b909d7f387f --- /dev/null +++ b/python/paddle/distributed/communication/stream/reduce_scatter.py @@ -0,0 +1,216 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.distributed as dist +import paddle.fluid.framework as framework +from paddle.distributed.communication.group import _get_global_group +from paddle.distributed.communication.reduce import _get_reduce_op, ReduceOp + + +def _check_tensor_shape(tensor, shape, nranks=1): + expect_shape = list(shape) + expect_shape[0] //= nranks + if list(tensor.shape) != expect_shape: + raise RuntimeError( + "The in_tensor for reduce_scatter is not correctly-sized.") + + +def _check_tensor_list_shape(tensor_list, shape, nranks=1): + if len(tensor_list) != nranks: + raise RuntimeError( + f"The tensor_list for reduce_scatter is not correctly-sized.") + for tensor in tensor_list: + if tensor.shape != shape: + raise RuntimeError( + f"The tensor_list for reduce_scatter is not correctly-sized.") + + +def _reduce_scatter_tensor_in_dygraph(out_tensor, + in_tensor, + op, + group, + sync_op, + use_calc_stream, + caller="reduce_scatter"): + op_type = _get_reduce_op(op, caller) + group = _get_global_group() if group is None else group + + _check_tensor_shape(out_tensor, in_tensor.shape, group.nranks) + + if use_calc_stream: + return group.process_group.reduce_scatter_tensor_on_calc_stream( + in_tensor, out_tensor, op_type) + + task = group.process_group.reduce_scatter_tensor(in_tensor, out_tensor, + op_type, sync_op) + if sync_op: + task.wait() + + return task + + +def _reduce_scatter_in_dygraph(tensor, tensor_list, op, group, sync_op, + use_calc_stream): + op_type = _get_reduce_op(op, "reduce_scatter") + group = _get_global_group() if group is None else group + + _check_tensor_list_shape(tensor_list, tensor.shape, group.nranks) + + if use_calc_stream: + return group.process_group.reduce_scatter_on_calc_stream( + tensor_list, tensor, op_type) + + task = group.process_group.reduce_scatter(tensor_list, tensor, op_type, + sync_op) + if sync_op: + task.wait() + + return task + + +def reduce_scatter(tensor, + tensor_or_tensor_list, + op=ReduceOp.SUM, + group=None, + sync_op=True, + use_calc_stream=False): + """ + + Reduce, then scatter a tensor (or a tensor list) across devices. + + Args: + tensor (Tensor): The output tensor on each rank. The result will overwrite this tenor after communication. Support + float16, float32, float64, int32, int64, int8, uint8 or bool as the input data type. + tensor_list (List[Tensor]]): The input to scatter. + If it is a tensor, it should be correctly-sized. If it is a list, it should contain correctly-sized tensors. + op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD, optional): The reduction used. If none is given, use ReduceOp.SUM as default. + group (Group, optional): Communicate in which group. If none is given, use the global group as default. + sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default. + use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This + option is designed for high performance demand, be careful to turn it on except you are clearly know its meaning. + + Returns: + Return a task object. + + Warning: + This API only supports the dygraph mode now. + + Examples: + .. code-block:: python + + # required: distributed + import paddle + import paddle.distributed as dist + + dist.init_parallel_env() + if dist.get_rank() == 0: + data1 = paddle.to_tensor([0, 1]) + data2 = paddle.to_tensor([2, 3]) + else: + data1 = paddle.to_tensor([4, 5]) + data2 = paddle.to_tensor([6, 7]) + dist.stream.reduce_scatter(data1, [data1, data2]) + out = data1.numpy() + # [4, 6] (2 GPUs, out for rank 0) + # [8, 10] (2 GPUs, out for rank 1) + """ + if group is not None and not group.is_member(): + raise RuntimeError( + "The group should not be None and all ranks which invoke this operation should be the member of this group." + ) + + if not sync_op and use_calc_stream: + raise RuntimeError( + "use_calc_stream can only be true in sync op behavior.") + + if framework.in_dygraph_mode(): + if paddle.is_tensor(tensor_or_tensor_list): + return _reduce_scatter_tensor_in_dygraph(tensor, + tensor_or_tensor_list, op, + group, sync_op, + use_calc_stream) + else: + return _reduce_scatter_in_dygraph(tensor, tensor_or_tensor_list, op, + group, sync_op, use_calc_stream) + + raise RuntimeError( + "paddle.distributed.stream.reduce_scatter is only supported in dygraph mode now." + ) + + +def _reduce_scatter_base(out_tensor, + in_tensor, + op=ReduceOp.SUM, + group=None, + sync_op=True, + use_calc_stream=False): + """ + + Reduce, then scatter a flattened tensor across devices. + + Args: + out_tensor (Tensor): The output tensor on each rank. The result will overwrite this tenor after communication. Support + float16, float32, float64, int32 or int64 as the input data type. + in_tensor (Tensor): The input tensor to reduce and scatter. + op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD, optional): The reduction used. If none is given, use ReduceOp.SUM as default. + group (Group, optional): Communicate in which group. If none is given, use the global group as default. + sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default. + use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This + option is designed for high performance demand, be careful to turn it on except you are clearly know its meaning. + + Returns: + Return a task object. + + Warning: + This API will be deprecated in the future, and only supports the dygraph mode now. + + Examples: + .. code-block:: python + + # required: distributed + import paddle + import paddle.distributed as dist + + dist.init_parallel_env() + if dist.get_rank() == 0: + data1 = paddle.to_tensor([7, 8, 9]) + data2 = paddle.to_tensor([10, 11, 12]) + dist.stream.scatter(data1, src=1) + else: + data1 = paddle.to_tensor([1, 2, 3]) + data2 = paddle.to_tensor([4, 5, 6]) + dist.stream.scatter(data1, [data1, data2], src=1) + out = data1.numpy() + # [1, 2, 3] (2 GPUs, out for rank 0) + # [4, 5, 6] (2 GPUs, out for rank 1) + """ + if group is not None and not group.is_member(): + raise RuntimeError( + "The group should not be None and all ranks which invoke this operation should be the member of this group." + ) + + if not sync_op and use_calc_stream: + raise RuntimeError( + "use_calc_stream can only be true in sync op behavior.") + + if framework.in_dygraph_mode(): + return _reduce_scatter_tensor_in_dygraph(out_tensor, in_tensor, op, + group, sync_op, + use_calc_stream, + "_reduce_scatter_base") + + raise RuntimeError( + "paddle.distributed.stream._reduce_scatter_base is only supported in dygraph mode now." + ) diff --git a/python/paddle/distributed/communication/stream/scatter.py b/python/paddle/distributed/communication/stream/scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..3a3fb00534783897bfc9c75856f1a54dd1969773 --- /dev/null +++ b/python/paddle/distributed/communication/stream/scatter.py @@ -0,0 +1,162 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.distributed as dist +import paddle.fluid.framework as framework +from paddle.distributed import collective + + +def _check_tensor_shape(tensor, shape, nranks=1): + expect_shape = list(shape) + expect_shape[0] //= nranks + if list(tensor.shape) != expect_shape: + raise RuntimeError("The in_tensor for scatter is not correctly-sized.") + + +def _check_tensor_list_shape(tensor_list, shape, nranks=1): + if len(tensor_list) != nranks: + raise RuntimeError( + f"The tensor_list for scatter is not correctly-sized.") + for tensor in tensor_list: + if tensor.shape != shape: + raise RuntimeError( + f"The tensor_list for scatter is not correctly-sized.") + + +def _scatter_tensor_in_dygraph(out_tensor, in_tensor, src, group, sync_op, + use_calc_stream): + group = collective._get_default_group() if group is None else group + + src_rank = group.get_group_rank(src) + if src_rank == -1: + raise RuntimeError("Src rank out of group.") + + nranks = group.nranks + rank = dist.get_rank() + if rank == src_rank: + _check_tensor_shape(out_tensor, in_tensor.shape, nranks) + + if use_calc_stream: + return group.process_group.scatter_tensor_on_calc_stream( + in_tensor, out_tensor, src) + + task = group.process_group.scatter_tensor(in_tensor, out_tensor, src, + sync_op) + if sync_op: + task.wait() + + return task + + +def _scatter_in_dygraph(tensor, tensor_list, src, group, sync_op, + use_calc_stream): + group = collective._get_default_group() if group is None else group + + src_rank = group.get_group_rank(src) + if src_rank == -1: + raise RuntimeError("Src rank out of group.") + + nranks = group.nranks + rank = dist.get_rank() + if rank == src_rank: + if len(tensor_list) == 0: + raise RuntimeError( + "The tensor_list should not be empty on src rank.") + _check_tensor_list_shape(tensor_list, tensor.shape, nranks) + else: + tensor_list = [tensor for _ in range(nranks)] + + if use_calc_stream: + return group.process_group.scatter_on_calc_stream( + tensor_list, tensor, src) + + task = group.process_group.scatter(tensor_list, tensor, src, sync_op) + if sync_op: + task.wait() + + return task + + +def scatter(tensor, + tensor_or_tensor_list=None, + src=0, + group=None, + sync_op=True, + use_calc_stream=False): + """ + + Scatter a tensor (or a tensor list) across devices. + + Args: + tensor (Tensor): The output tensor on each rank. The result will overwrite this tenor after communication. Support + float16, float32, float64, int32, int64, int8, uint8 or bool as the input data type. + tensor_or_tensor_list (Union[Tensor, List[Tensor]]): The input to scatter (default is `None`, must be specified on the source rank). + If it is a tensor, it should be correctly-sized. If it is a list, it should contain correctly-sized tensors. + src (int, optional): Rank of the source device. If none is given, use `0` as default. + group (Group, optional): Communicate in which group. If none is given, use the global group as default. + sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default. + use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This + option is designed for high performance demand, be careful to turn it on except you are clearly know its meaning. + + Returns: + Return a task object. + + Warning: + This API only supports the dygraph mode now. + + Examples: + .. code-block:: python + + # required: distributed + import paddle + import paddle.distributed as dist + + dist.init_parallel_env() + if dist.get_rank() == 0: + data1 = paddle.to_tensor([7, 8, 9]) + data2 = paddle.to_tensor([10, 11, 12]) + dist.stream.scatter(data1, src=1) + else: + data1 = paddle.to_tensor([1, 2, 3]) + data2 = paddle.to_tensor([4, 5, 6]) + dist.stream.scatter(data1, [data1, data2], src=1) + out = data1.numpy() + # [1, 2, 3] (2 GPUs, out for rank 0) + # [4, 5, 6] (2 GPUs, out for rank 1) + """ + if group is not None and not group.is_member(): + raise RuntimeError( + "The group should not be None and all ranks which invoke this operation should be the member of this group." + ) + + if not sync_op and use_calc_stream: + raise RuntimeError( + "use_calc_stream can only be true in sync op behavior.") + + if tensor_or_tensor_list is None: + raise RuntimeError("The input should be specified.") + + if framework.in_dygraph_mode(): + if paddle.is_tensor(tensor_or_tensor_list): + return _scatter_tensor_in_dygraph(tensor, tensor_or_tensor_list, + src, group, sync_op, + use_calc_stream) + else: + return _scatter_in_dygraph(tensor, tensor_or_tensor_list, src, + group, sync_op, use_calc_stream) + + raise RuntimeError( + "paddle.distributed.stream.scatter is only supported in dygraph mode now." + ) diff --git a/python/paddle/distributed/communication/stream/send.py b/python/paddle/distributed/communication/stream/send.py index d46f0f768cc03ccd513c411c67b68f5ac4f9f046..41ec2c0141b1227933a4df5c523455f2e02a8e9d 100644 --- a/python/paddle/distributed/communication/stream/send.py +++ b/python/paddle/distributed/communication/stream/send.py @@ -64,7 +64,7 @@ def send(tensor, dst=0, group=None, sync_op=True, use_calc_stream=False): task = dist.stream.recv(data, src=0, sync_op=False) task.wait() out = data.numpy() - # [[4, 5, 6], [4, 5, 6] + # [[4, 5, 6], [4, 5, 6]] (2 GPUs) """ if group is not None and not group.is_member(): raise RuntimeError( diff --git a/python/paddle/fluid/tests/unittests/collective/CMakeLists.txt b/python/paddle/fluid/tests/unittests/collective/CMakeLists.txt index 6631b7f46e0d0804c25b039f5e309a7cb7cddf58..a4f42cdb6ed218c11506c349aaddd112ca4fe207 100644 --- a/python/paddle/fluid/tests/unittests/collective/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/collective/CMakeLists.txt @@ -282,6 +282,54 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX)) set_tests_properties(test_communication_stream_allreduce_api PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") endif() +if((WITH_GPU OR WITH_ROCM) AND (LINUX)) + py_test_modules( + test_communication_stream_alltoall_api MODULES + test_communication_stream_alltoall_api ENVS + "PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;http_proxy=;https_proxy=") + set_tests_properties(test_communication_stream_alltoall_api + PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") +endif() +if((WITH_GPU OR WITH_ROCM) AND (LINUX)) + py_test_modules( + test_communication_stream_alltoall_single_api MODULES + test_communication_stream_alltoall_single_api ENVS + "PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;http_proxy=;https_proxy=") + set_tests_properties(test_communication_stream_alltoall_single_api + PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") +endif() +if((WITH_GPU OR WITH_ROCM) AND (LINUX)) + py_test_modules( + test_communication_stream_broadcast_api MODULES + test_communication_stream_broadcast_api ENVS + "PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;http_proxy=;https_proxy=") + set_tests_properties(test_communication_stream_broadcast_api + PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") +endif() +if((WITH_GPU OR WITH_ROCM) AND (LINUX)) + py_test_modules( + test_communication_stream_reduce_api MODULES + test_communication_stream_reduce_api ENVS + "PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;http_proxy=;https_proxy=") + set_tests_properties(test_communication_stream_reduce_api + PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") +endif() +if((WITH_GPU OR WITH_ROCM) AND (LINUX)) + py_test_modules( + test_communication_stream_reduce_scatter_api MODULES + test_communication_stream_reduce_scatter_api ENVS + "PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;http_proxy=;https_proxy=") + set_tests_properties(test_communication_stream_reduce_scatter_api + PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") +endif() +if((WITH_GPU OR WITH_ROCM) AND (LINUX)) + py_test_modules( + test_communication_stream_scatter_api MODULES + test_communication_stream_scatter_api ENVS + "PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;http_proxy=;https_proxy=") + set_tests_properties(test_communication_stream_scatter_api + PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST") +endif() if((WITH_GPU OR WITH_ROCM) AND (LINUX)) py_test_modules( test_communication_stream_sendrecv_api MODULES diff --git a/python/paddle/fluid/tests/unittests/collective/communication_stream_alltoall_api_dygraph.py b/python/paddle/fluid/tests/unittests/collective/communication_stream_alltoall_api_dygraph.py new file mode 100644 index 0000000000000000000000000000000000000000..8e65ea8d8aee5e70a52ff4b57a127c50834fc4c6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/communication_stream_alltoall_api_dygraph.py @@ -0,0 +1,113 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.distributed as dist +import test_communication_api_base as test_base +import test_collective_api_base as test_collective_base + + +class StreamAllToAllTestCase(): + + def __init__(self): + self._sync_op = eval(os.getenv("sync_op")) + self._use_calc_stream = eval(os.getenv("use_calc_stream")) + self._backend = os.getenv("backend") + self._shape = eval(os.getenv("shape")) + self._dtype = os.getenv("dtype") + self._seeds = eval(os.getenv("seeds")) + if self._backend not in ["nccl", "gloo"]: + raise NotImplementedError( + "Only support nccl and gloo as the backend for now.") + os.environ["PADDLE_DISTRI_BACKEND"] = self._backend + + def run_test_case(self): + dist.init_parallel_env() + + test_data_list = [] + for seed in self._seeds: + test_data_list.append( + test_collective_base.create_test_data(shape=self._shape, + dtype=self._dtype, + seed=seed)) + + nranks = len(test_data_list) + data1 = test_data_list[0] + data2 = test_data_list[1] + result1 = np.vstack( + [data1[0:data1.shape[0] // 2, :], data2[0:data2.shape[0] // 2, :]]) + result2 = np.vstack( + [data1[data1.shape[0] // 2:, :], data2[data2.shape[0] // 2:, :]]) + + rank = dist.get_rank() + tensor = paddle.to_tensor(test_data_list[rank]) + t1, t2 = paddle.split(tensor, nranks, axis=0) + + # case 1: pass an empty tensor list + empty_tensor_list = [] + task = dist.stream.alltoall(empty_tensor_list, [t1, t2], + sync_op=self._sync_op, + use_calc_stream=self._use_calc_stream) + if not self._sync_op: + task.wait() + result_tensor_list = np.vstack(empty_tensor_list) + if rank == 0: + assert np.allclose(result_tensor_list, + result1, + rtol=1e-05, + atol=1e-05) + else: + assert np.allclose(result_tensor_list, + result2, + rtol=1e-05, + atol=1e-05) + + # case 2: pass a pre-sized tensor list + full_tensor_list = [paddle.empty_like(t1) for _ in test_data_list] + task = dist.stream.alltoall(full_tensor_list, [t1, t2], + sync_op=self._sync_op, + use_calc_stream=self._use_calc_stream) + if not self._sync_op: + task.wait() + result_tensor_list = np.vstack(full_tensor_list) + if rank == 0: + assert np.allclose(result_tensor_list, + result1, + rtol=1e-05, + atol=1e-05) + else: + assert np.allclose(result_tensor_list, + result2, + rtol=1e-05, + atol=1e-05) + + # case 3: pass a pre-sized tensor + out_tensor = paddle.empty_like(tensor) + task = dist.stream.alltoall(out_tensor, + tensor, + sync_op=self._sync_op, + use_calc_stream=self._use_calc_stream) + if not self._sync_op: + task.wait() + if rank == 0: + assert np.allclose(out_tensor, result1, rtol=1e-05, atol=1e-05) + else: + assert np.allclose(out_tensor, result2, rtol=1e-05, atol=1e-05) + + +if __name__ == "__main__": + StreamAllToAllTestCase().run_test_case() diff --git a/python/paddle/fluid/tests/unittests/collective/communication_stream_alltoall_single_api_dygraph.py b/python/paddle/fluid/tests/unittests/collective/communication_stream_alltoall_single_api_dygraph.py new file mode 100644 index 0000000000000000000000000000000000000000..9bdfe124b0b492f52541b6a93e167f37e1269e9d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/communication_stream_alltoall_single_api_dygraph.py @@ -0,0 +1,74 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.distributed as dist +import test_communication_api_base as test_base +import test_collective_api_base as test_collective_base + + +class StreamAllToAllSingleTestCase(): + + def __init__(self): + self._sync_op = eval(os.getenv("sync_op")) + self._use_calc_stream = eval(os.getenv("use_calc_stream")) + self._backend = os.getenv("backend") + self._shape = eval(os.getenv("shape")) + self._dtype = os.getenv("dtype") + self._seeds = eval(os.getenv("seeds")) + if self._backend not in ["nccl", "gloo"]: + raise NotImplementedError( + "Only support nccl and gloo as the backend for now.") + os.environ["PADDLE_DISTRI_BACKEND"] = self._backend + + def run_test_case(self): + dist.init_parallel_env() + + test_data_list = [] + for seed in self._seeds: + test_data_list.append( + test_collective_base.create_test_data(shape=self._shape, + dtype=self._dtype, + seed=seed)) + + nranks = len(test_data_list) + data1 = paddle.to_tensor(test_data_list[0]) + data2 = paddle.to_tensor(test_data_list[1]) + result1 = np.vstack( + (data1[0:data1.shape[0] // 2, :], data2[0:data2.shape[0] // 2, :])) + result2 = np.vstack( + (data1[data1.shape[0] // 2:, :], data2[data2.shape[0] // 2:, :])) + + rank = dist.get_rank() + tensor = paddle.to_tensor(test_data_list[rank]) + + out_tensor = paddle.empty_like(tensor) + task = dist.stream.alltoall_single( + out_tensor, + tensor, + sync_op=self._sync_op, + use_calc_stream=self._use_calc_stream) + if not self._sync_op: + task.wait() + if rank == 0: + assert np.allclose(out_tensor, result1, rtol=1e-05, atol=1e-05) + else: + assert np.allclose(out_tensor, result2, rtol=1e-05, atol=1e-05) + + +if __name__ == "__main__": + StreamAllToAllSingleTestCase().run_test_case() diff --git a/python/paddle/fluid/tests/unittests/collective/communication_stream_broadcast_api_dygraph.py b/python/paddle/fluid/tests/unittests/collective/communication_stream_broadcast_api_dygraph.py new file mode 100644 index 0000000000000000000000000000000000000000..487dfd6ae68942c751a8ad6a5806e08ec10622f8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/communication_stream_broadcast_api_dygraph.py @@ -0,0 +1,54 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +import paddle +import paddle.distributed as dist +import test_collective_api_base as test_collective_base + + +class StreamBroadcastTestCase(): + + def __init__(self): + self._sync_op = eval(os.getenv("sync_op")) + self._use_calc_stream = eval(os.getenv("use_calc_stream")) + self._backend = os.getenv("backend") + self._shape = eval(os.getenv("shape")) + self._dtype = os.getenv("dtype") + self._seeds = eval(os.getenv("seeds")) + if self._backend not in ["nccl", "gloo"]: + raise NotImplementedError( + "Only support nccl and gloo as the backend for now.") + os.environ["PADDLE_DISTRI_BACKEND"] = self._backend + + def run_test_case(self): + dist.init_parallel_env() + + src_rank = 1 + result = test_collective_base.create_test_data( + shape=self._shape, dtype=self._dtype, seed=self._seeds[src_rank]) + tensor = paddle.to_tensor(result) + task = dist.stream.broadcast(tensor, + src=src_rank, + sync_op=self._sync_op, + use_calc_stream=self._use_calc_stream) + if not self._sync_op: + task.wait() + + assert np.allclose(tensor, result, rtol=1e-05, atol=1e-05) + + +if __name__ == "__main__": + StreamBroadcastTestCase().run_test_case() diff --git a/python/paddle/fluid/tests/unittests/collective/communication_stream_reduce_api_dygraph.py b/python/paddle/fluid/tests/unittests/collective/communication_stream_reduce_api_dygraph.py new file mode 100644 index 0000000000000000000000000000000000000000..a487eac566ab5e29398f6738f915b3133486a628 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/communication_stream_reduce_api_dygraph.py @@ -0,0 +1,66 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +import paddle +import paddle.distributed as dist +import test_collective_api_base as test_collective_base + + +class StreamReduceTestCase(): + + def __init__(self): + self._sync_op = eval(os.getenv("sync_op")) + self._use_calc_stream = eval(os.getenv("use_calc_stream")) + self._backend = os.getenv("backend") + self._shape = eval(os.getenv("shape")) + self._dtype = os.getenv("dtype") + self._seeds = eval(os.getenv("seeds")) + if self._backend not in ["nccl", "gloo"]: + raise NotImplementedError( + "Only support nccl and gloo as the backend for now.") + os.environ["PADDLE_DISTRI_BACKEND"] = self._backend + + def run_test_case(self): + dist.init_parallel_env() + + test_data_list = [] + for seed in self._seeds: + test_data_list.append( + test_collective_base.create_test_data(shape=self._shape, + dtype=self._dtype, + seed=seed)) + + rank = dist.get_rank() + tensor = paddle.to_tensor(test_data_list[rank]) + task = dist.stream.reduce(tensor, + dst=1, + sync_op=self._sync_op, + use_calc_stream=self._use_calc_stream) + if not self._sync_op: + task.wait() + + result = sum(test_data_list) + if rank == 1: + assert np.allclose(tensor, result, rtol=1e-05, atol=1e-05) + else: + assert np.allclose(tensor, + test_data_list[rank], + rtol=1e-05, + atol=1e-05) + + +if __name__ == "__main__": + StreamReduceTestCase().run_test_case() diff --git a/python/paddle/fluid/tests/unittests/collective/communication_stream_reduce_scatter_api_dygraph.py b/python/paddle/fluid/tests/unittests/collective/communication_stream_reduce_scatter_api_dygraph.py new file mode 100644 index 0000000000000000000000000000000000000000..8f66d67e0d58c614a30ce8551f1022757fa058e5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/communication_stream_reduce_scatter_api_dygraph.py @@ -0,0 +1,94 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +import paddle +import paddle.distributed as dist +import test_collective_api_base as test_collective_base + + +class StreamReduceScatterTestCase(): + + def __init__(self): + self._sync_op = eval(os.getenv("sync_op")) + self._use_calc_stream = eval(os.getenv("use_calc_stream")) + self._backend = os.getenv("backend") + self._shape = eval(os.getenv("shape")) + self._dtype = os.getenv("dtype") + self._seeds = eval(os.getenv("seeds")) + if self._backend not in ["nccl", "gloo"]: + raise NotImplementedError( + "Only support nccl and gloo as the backend for now.") + os.environ["PADDLE_DISTRI_BACKEND"] = self._backend + + def run_test_case(self): + dist.init_parallel_env() + + test_data_list = [] + for seed in self._seeds: + test_data_list.append( + test_collective_base.create_test_data(shape=self._shape, + dtype=self._dtype, + seed=seed)) + reduce_result = sum(test_data_list) + result1 = reduce_result[0:reduce_result.shape[0] // 2] + result2 = reduce_result[reduce_result.shape[0] // 2:] + + rank = dist.get_rank() + tensor = paddle.to_tensor(test_data_list[rank]) + + # case 1: pass a pre-sized tensor list + t1, t2 = paddle.split(tensor, 2, axis=0) + result_tensor = paddle.empty_like(t1) + task = dist.stream.reduce_scatter(result_tensor, [t1, t2], + sync_op=self._sync_op, + use_calc_stream=self._use_calc_stream) + if not self._sync_op: + task.wait() + if rank == 0: + assert np.allclose(result_tensor, result1, rtol=1e-05, atol=1e-05) + else: + assert np.allclose(result_tensor, result2, rtol=1e-05, atol=1e-05) + + # case 2: pass a pre-sized tensor + result_tensor = paddle.empty_like(t1) + task = dist.stream.reduce_scatter(result_tensor, + tensor, + sync_op=self._sync_op, + use_calc_stream=self._use_calc_stream) + if not self._sync_op: + task.wait() + if rank == 0: + assert np.allclose(result_tensor, result1, rtol=1e-05, atol=1e-05) + else: + assert np.allclose(result_tensor, result2, rtol=1e-05, atol=1e-05) + + # case 3: test the legacy API + result_tensor = paddle.empty_like(t1) + task = dist.stream._reduce_scatter_base( + result_tensor, + tensor, + sync_op=self._sync_op, + use_calc_stream=self._use_calc_stream) + if not self._sync_op: + task.wait() + if rank == 0: + assert np.allclose(result_tensor, result1, rtol=1e-05, atol=1e-05) + else: + assert np.allclose(result_tensor, result2, rtol=1e-05, atol=1e-05) + + +if __name__ == "__main__": + StreamReduceScatterTestCase().run_test_case() diff --git a/python/paddle/fluid/tests/unittests/collective/communication_stream_scatter_api_dygraph.py b/python/paddle/fluid/tests/unittests/collective/communication_stream_scatter_api_dygraph.py new file mode 100644 index 0000000000000000000000000000000000000000..6060e5050ca09bdb0e04b2eae073c4a47f170006 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/communication_stream_scatter_api_dygraph.py @@ -0,0 +1,84 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +import paddle +import paddle.distributed as dist +import test_collective_api_base as test_collective_base + + +class StreamScatterTestCase(): + + def __init__(self): + self._sync_op = eval(os.getenv("sync_op")) + self._use_calc_stream = eval(os.getenv("use_calc_stream")) + self._backend = os.getenv("backend") + self._shape = eval(os.getenv("shape")) + self._dtype = os.getenv("dtype") + self._seeds = eval(os.getenv("seeds")) + if self._backend not in ["nccl", "gloo"]: + raise NotImplementedError( + "Only support nccl and gloo as the backend for now.") + os.environ["PADDLE_DISTRI_BACKEND"] = self._backend + + def run_test_case(self): + dist.init_parallel_env() + + test_data_list = [] + for seed in self._seeds: + test_data_list.append( + test_collective_base.create_test_data(shape=self._shape, + dtype=self._dtype, + seed=seed)) + + src_rank = 1 + src_data = test_data_list[src_rank] + result1 = src_data[0:src_data.shape[0] // 2] + result2 = src_data[src_data.shape[0] // 2:] + + rank = dist.get_rank() + + # case 1: pass a pre-sized tensor list + tensor = paddle.to_tensor(test_data_list[rank]) + t1, t2 = paddle.split(tensor, 2, axis=0) + task = dist.stream.scatter(t1, [t1, t2], + src=src_rank, + sync_op=self._sync_op, + use_calc_stream=self._use_calc_stream) + if not self._sync_op: + task.wait() + if rank == src_rank: + assert np.allclose(t1, result2, rtol=1e-05, atol=1e-05) + else: + assert np.allclose(t1, result1, rtol=1e-05, atol=1e-05) + + # case 2: pass a pre-sized tensor + tensor = paddle.to_tensor(src_data) + t1 = paddle.empty_like(t1) + task = dist.stream.scatter(t1, + tensor, + src=src_rank, + sync_op=self._sync_op, + use_calc_stream=self._use_calc_stream) + if not self._sync_op: + task.wait() + if rank == src_rank: + assert np.allclose(t1, result2, rtol=1e-05, atol=1e-05) + else: + assert np.allclose(t1, result1, rtol=1e-05, atol=1e-05) + + +if __name__ == "__main__": + StreamScatterTestCase().run_test_case() diff --git a/python/paddle/fluid/tests/unittests/collective/communication_stream_sendrecv_api_dygraph.py b/python/paddle/fluid/tests/unittests/collective/communication_stream_sendrecv_api_dygraph.py index a666e6cfe0a3e65880cfce9639a6e46dd2437aca..f9b2806fc1b840350691490f6d0fd6787d0f963f 100644 --- a/python/paddle/fluid/tests/unittests/collective/communication_stream_sendrecv_api_dygraph.py +++ b/python/paddle/fluid/tests/unittests/collective/communication_stream_sendrecv_api_dygraph.py @@ -43,22 +43,25 @@ class StreamSendRecvTestCase(): dtype=self._dtype, seed=seed)) + src_rank = 0 + dst_rank = 1 + rank = dist.get_rank() tensor = paddle.to_tensor(test_data_list[rank]) if rank == 0: task = dist.stream.send(tensor, - dst=1, + dst=dst_rank, sync_op=self._sync_op, use_calc_stream=self._use_calc_stream) else: task = dist.stream.recv(tensor, - src=0, + src=src_rank, sync_op=self._sync_op, use_calc_stream=self._use_calc_stream) if not self._sync_op: task.wait() - result = test_data_list[0] + result = test_data_list[src_rank] assert np.allclose(tensor, result, rtol=1e-05, atol=1e-05) diff --git a/python/paddle/fluid/tests/unittests/collective/test_communication_stream_alltoall_api.py b/python/paddle/fluid/tests/unittests/collective/test_communication_stream_alltoall_api.py new file mode 100644 index 0000000000000000000000000000000000000000..4fa55d86840bc844630418087f1a21486a6592df --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/test_communication_stream_alltoall_api.py @@ -0,0 +1,51 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import itertools +import test_communication_api_base as test_base + + +class TestCommunicationStreamAllToAllAPI(test_base.CommunicationTestDistBase): + + def setUp(self): + super(TestCommunicationStreamAllToAllAPI, self).setUp(num_of_devices=2, + timeout=120) + self._default_envs = { + "backend": "nccl", + "shape": "(100, 200)", + "dtype": "float32", + "seeds": str(self._seeds) + } + self._changeable_envs = { + "sync_op": ["True", "False"], + "use_calc_stream": ["True", "False"] + } + + def test_alltoall_stream(self): + envs_list = test_base.gen_product_envs_list(self._default_envs, + self._changeable_envs) + for envs in envs_list: + if eval(envs["use_calc_stream"]) and not eval(envs["sync_op"]): + continue + self.run_test_case("communication_stream_alltoall_api_dygraph.py", + user_defined_envs=envs) + + def tearDown(self): + super(TestCommunicationStreamAllToAllAPI, self).tearDown() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/test_communication_stream_alltoall_single_api.py b/python/paddle/fluid/tests/unittests/collective/test_communication_stream_alltoall_single_api.py new file mode 100644 index 0000000000000000000000000000000000000000..f1f099b9571f8fc16bf6484e892e39c4a3cb0a06 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/test_communication_stream_alltoall_single_api.py @@ -0,0 +1,53 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import itertools +import test_communication_api_base as test_base + + +class TestCommunicationStreamAllToAllSingleAPI( + test_base.CommunicationTestDistBase): + + def setUp(self): + super(TestCommunicationStreamAllToAllSingleAPI, + self).setUp(num_of_devices=2, timeout=120) + self._default_envs = { + "backend": "nccl", + "shape": "(100, 200)", + "dtype": "float32", + "seeds": str(self._seeds) + } + self._changeable_envs = { + "sync_op": ["True", "False"], + "use_calc_stream": ["True", "False"] + } + + def test_alltoall_single_stream(self): + envs_list = test_base.gen_product_envs_list(self._default_envs, + self._changeable_envs) + for envs in envs_list: + if eval(envs["use_calc_stream"]) and not eval(envs["sync_op"]): + continue + self.run_test_case( + "communication_stream_alltoall_single_api_dygraph.py", + user_defined_envs=envs) + + def tearDown(self): + super(TestCommunicationStreamAllToAllSingleAPI, self).tearDown() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/test_communication_stream_broadcast_api.py b/python/paddle/fluid/tests/unittests/collective/test_communication_stream_broadcast_api.py new file mode 100644 index 0000000000000000000000000000000000000000..07537a480e851ae4be67e9edd4a4af861df33398 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/test_communication_stream_broadcast_api.py @@ -0,0 +1,51 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import itertools +import test_communication_api_base as test_base + + +class TestCommunicationStreamBroadcastAPI(test_base.CommunicationTestDistBase): + + def setUp(self): + super(TestCommunicationStreamBroadcastAPI, self).setUp(num_of_devices=2, + timeout=120) + self._default_envs = { + "backend": "nccl", + "shape": "(100, 200)", + "dtype": "float32", + "seeds": str(self._seeds) + } + self._changeable_envs = { + "sync_op": ["True", "False"], + "use_calc_stream": ["True", "False"] + } + + def test_broadcast_stream(self): + envs_list = test_base.gen_product_envs_list(self._default_envs, + self._changeable_envs) + for envs in envs_list: + if eval(envs["use_calc_stream"]) and not eval(envs["sync_op"]): + continue + self.run_test_case("communication_stream_broadcast_api_dygraph.py", + user_defined_envs=envs) + + def tearDown(self): + super(TestCommunicationStreamBroadcastAPI, self).tearDown() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/test_communication_stream_reduce_api.py b/python/paddle/fluid/tests/unittests/collective/test_communication_stream_reduce_api.py new file mode 100644 index 0000000000000000000000000000000000000000..c8a04c8d893e1603ab8aff949a67d901d6bd12c8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/test_communication_stream_reduce_api.py @@ -0,0 +1,51 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import itertools +import test_communication_api_base as test_base + + +class TestCommunicationStreamReduceAPI(test_base.CommunicationTestDistBase): + + def setUp(self): + super(TestCommunicationStreamReduceAPI, self).setUp(num_of_devices=2, + timeout=120) + self._default_envs = { + "backend": "nccl", + "shape": "(100, 200)", + "dtype": "float32", + "seeds": str(self._seeds) + } + self._changeable_envs = { + "sync_op": ["True", "False"], + "use_calc_stream": ["True", "False"] + } + + def test_reduce_stream(self): + envs_list = test_base.gen_product_envs_list(self._default_envs, + self._changeable_envs) + for envs in envs_list: + if eval(envs["use_calc_stream"]) and not eval(envs["sync_op"]): + continue + self.run_test_case("communication_stream_reduce_api_dygraph.py", + user_defined_envs=envs) + + def tearDown(self): + super(TestCommunicationStreamReduceAPI, self).tearDown() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/test_communication_stream_reduce_scatter_api.py b/python/paddle/fluid/tests/unittests/collective/test_communication_stream_reduce_scatter_api.py new file mode 100644 index 0000000000000000000000000000000000000000..a90e634860d95bdf91e11a8de7e04052f4a8a548 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/test_communication_stream_reduce_scatter_api.py @@ -0,0 +1,53 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import itertools +import test_communication_api_base as test_base + + +class TestCommunicationStreamReduceScatterAPI( + test_base.CommunicationTestDistBase): + + def setUp(self): + super(TestCommunicationStreamReduceScatterAPI, + self).setUp(num_of_devices=2, timeout=120) + self._default_envs = { + "backend": "nccl", + "shape": "(100, 200)", + "dtype": "float32", + "seeds": str(self._seeds) + } + self._changeable_envs = { + "sync_op": ["True", "False"], + "use_calc_stream": ["True", "False"] + } + + def test_reduce_scatter_stream(self): + envs_list = test_base.gen_product_envs_list(self._default_envs, + self._changeable_envs) + for envs in envs_list: + if eval(envs["use_calc_stream"]) and not eval(envs["sync_op"]): + continue + self.run_test_case( + "communication_stream_reduce_scatter_api_dygraph.py", + user_defined_envs=envs) + + def tearDown(self): + super(TestCommunicationStreamReduceScatterAPI, self).tearDown() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/test_communication_stream_scatter_api.py b/python/paddle/fluid/tests/unittests/collective/test_communication_stream_scatter_api.py new file mode 100644 index 0000000000000000000000000000000000000000..d96d931f43fbf474d698f6a2adee53271b8dc07a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective/test_communication_stream_scatter_api.py @@ -0,0 +1,51 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import itertools +import test_communication_api_base as test_base + + +class TestCommunicationStreamScatterAPI(test_base.CommunicationTestDistBase): + + def setUp(self): + super(TestCommunicationStreamScatterAPI, self).setUp(num_of_devices=2, + timeout=120) + self._default_envs = { + "backend": "nccl", + "shape": "(100, 200)", + "dtype": "float32", + "seeds": str(self._seeds) + } + self._changeable_envs = { + "sync_op": ["True", "False"], + "use_calc_stream": ["True", "False"] + } + + def test_reduce_stream(self): + envs_list = test_base.gen_product_envs_list(self._default_envs, + self._changeable_envs) + for envs in envs_list: + if eval(envs["use_calc_stream"]) and not eval(envs["sync_op"]): + continue + self.run_test_case("communication_stream_scatter_api_dygraph.py", + user_defined_envs=envs) + + def tearDown(self): + super(TestCommunicationStreamScatterAPI, self).tearDown() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/collective/testslist.csv b/python/paddle/fluid/tests/unittests/collective/testslist.csv index 883cf7941e3685b22d54d997130a79d60843775c..1f6584f7b9072960329eacc446e8a5e445f613b4 100644 --- a/python/paddle/fluid/tests/unittests/collective/testslist.csv +++ b/python/paddle/fluid/tests/unittests/collective/testslist.csv @@ -34,6 +34,12 @@ test_collective_split_row_linear,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_ test_collective_wait,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_communication_stream_allgather_api,linux,gpu;rocm,120,DIST,,2,,PYTHONPATH=..;http_proxy=;https_proxy=, test_communication_stream_allreduce_api,linux,gpu;rocm,120,DIST,,2,,PYTHONPATH=..;http_proxy=;https_proxy=, +test_communication_stream_alltoall_api,linux,gpu;rocm,120,DIST,,2,,PYTHONPATH=..;http_proxy=;https_proxy=, +test_communication_stream_alltoall_single_api,linux,gpu;rocm,120,DIST,,2,,PYTHONPATH=..;http_proxy=;https_proxy=, +test_communication_stream_broadcast_api,linux,gpu;rocm,120,DIST,,2,,PYTHONPATH=..;http_proxy=;https_proxy=, +test_communication_stream_reduce_api,linux,gpu;rocm,120,DIST,,2,,PYTHONPATH=..;http_proxy=;https_proxy=, +test_communication_stream_reduce_scatter_api,linux,gpu;rocm,120,DIST,,2,,PYTHONPATH=..;http_proxy=;https_proxy=, +test_communication_stream_scatter_api,linux,gpu;rocm,120,DIST,,2,,PYTHONPATH=..;http_proxy=;https_proxy=, test_communication_stream_sendrecv_api,linux,gpu;rocm,120,DIST,,2,,PYTHONPATH=..;http_proxy=;https_proxy=, test_eager_dist_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=.., test_gen_nccl_id_op,,gpu;rocm;ASCEND;ASCEND_CL,,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,