diff --git a/paddle/fluid/distributed/collective/nccl_tools.cc b/paddle/fluid/distributed/collective/nccl_tools.cc index 940c8d47ccb88261c85888bec0daf5eb1b302970..b490ea78514b13c0dd3772f816d1cce47e0ffd7b 100644 --- a/paddle/fluid/distributed/collective/nccl_tools.cc +++ b/paddle/fluid/distributed/collective/nccl_tools.cc @@ -47,5 +47,43 @@ std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID) { return oss.str(); } +std::string NCCLDTypeToString(ncclDataType_t dtype) { +#define PD_NCCL_DTYPE_TO_STR(__nccl_dtype, __str_dtype) \ + if (dtype == __nccl_dtype) return __str_dtype; + PD_NCCL_DTYPE_TO_STR(ncclFloat, "float32"); + PD_NCCL_DTYPE_TO_STR(ncclFloat32, "float32"); + PD_NCCL_DTYPE_TO_STR(ncclHalf, "float16"); + PD_NCCL_DTYPE_TO_STR(ncclFloat16, "float16"); +#if NCCL_VERSION_CODE >= 21000 + PD_NCCL_DTYPE_TO_STR(ncclBfloat16, "bfloat16"); +#endif + PD_NCCL_DTYPE_TO_STR(ncclDouble, "float64"); + PD_NCCL_DTYPE_TO_STR(ncclFloat64, "float64"); + + PD_NCCL_DTYPE_TO_STR(ncclInt8, "int8"); + PD_NCCL_DTYPE_TO_STR(ncclChar, "int8"); + PD_NCCL_DTYPE_TO_STR(ncclUint8, "uint8"); + PD_NCCL_DTYPE_TO_STR(ncclInt32, "int32"); + PD_NCCL_DTYPE_TO_STR(ncclInt, "int32"); + PD_NCCL_DTYPE_TO_STR(ncclUint32, "uint32"); + PD_NCCL_DTYPE_TO_STR(ncclInt64, "int64"); + PD_NCCL_DTYPE_TO_STR(ncclUint64, "uint64"); + +#undef PD_NCCL_DTYPE_TO_STR + PADDLE_THROW(phi::errors::InvalidArgument( + "This datatype %d in nccl is not supported.", static_cast(dtype))); +} + +std::string NCCLRedTypeToString(ncclRedOp_t op) { + if (op == ncclSum) return "SUM"; + if (op == ncclProd) return "PROD"; + if (op == ncclMin) return "MIN"; + if (op == ncclMax) return "MAX"; +#if NCCL_VERSION_CODE >= 21000 + if (op == ncclAvg) return "AVG"; +#endif + return "UDF_" + std::to_string(op); +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/nccl_tools.h b/paddle/fluid/distributed/collective/nccl_tools.h index 135aadd2a241451db7f4e64a358ff468259984ec..ba29bb1d13c4cd9ce7e9ad0cf5243872fb61e6b4 100644 --- a/paddle/fluid/distributed/collective/nccl_tools.h +++ b/paddle/fluid/distributed/collective/nccl_tools.h @@ -29,21 +29,25 @@ namespace paddle { namespace distributed { -#define NCCL_CHECK(cmd) \ - do { \ - ncclResult_t r = cmd; \ - if (r != ncclSuccess) { \ - printf("Failed, NCCL error %s:%d '%s'\n", \ - __FILE__, \ - __LINE__, \ - phi::dynload::ncclGetErrorString(r)); \ - exit(EXIT_FAILURE); \ - } \ +#define NCCL_CHECK(cmd) \ + do { \ + ncclResult_t r = cmd; \ + if (r != ncclSuccess) { \ + PADDLE_THROW( \ + phi::errors::External("Failed, NCCL error %s:%d '%s'\n", \ + __FILE__, \ + __LINE__, \ + phi::dynload::ncclGetErrorString(r))); \ + } \ } while (0) ncclRedOp_t ToNCCLRedType(ReduceOp reduction); std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID); +std::string NCCLDTypeToString(ncclDataType_t dtype); + +std::string NCCLRedTypeToString(ncclRedOp_t op); + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/process_group_nccl.cc b/paddle/fluid/distributed/collective/process_group_nccl.cc index a67e78da1f14ccb52bd6957476293701ff63cdc7..effd61a3b5009b77b650088978586973fc833ad8 100644 --- a/paddle/fluid/distributed/collective/process_group_nccl.cc +++ b/paddle/fluid/distributed/collective/process_group_nccl.cc @@ -24,6 +24,7 @@ #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/utils/data_type.h" +DECLARE_bool(benchmark); DECLARE_bool(nccl_blocking_wait); DECLARE_bool(use_stream_safe_cuda_allocator); @@ -58,7 +59,7 @@ void ProcessGroupNCCL::NCCLTask::UpdateWaitChain( bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) { // Warning here when use calc stream but also invoke waiting explicitly. if (UseCalcStream()) { - VLOG(3) << "Warning: The communication is on calc stream, wait here is " + VLOG(5) << "Warning: The communication is on calc stream, wait here is " "useless."; return true; } @@ -103,6 +104,11 @@ void ProcessGroupNCCL::GroupStart() { void ProcessGroupNCCL::GroupEnd() { NCCL_CHECK(phi::dynload::ncclGroupEnd()); --s_group_call_counter; + // NOTE: This is to sync the calc stream and comm stream for debug using + // batch_isend_irecv + if (FLAGS_benchmark) { + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); + } } phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext( @@ -163,6 +169,19 @@ std::shared_ptr ProcessGroupNCCL::AllGather( rank_, comm); } + + VLOG(3) << "[ncclAllGather] " + << "sendbuff: " << in_tensor_maybe_partial.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor_maybe_partial.numel() + << ", datatype: " + << NCCLDTypeToString( + phi::ToNCCLDataType(in_tensor_maybe_partial.dtype())) + << ", ncclcomm: " << comm << ", stream: " << stream + << ", rank_in_group: " << rank_ << ", nranks: " << size_ + << ", offset: " << offset << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream; + NCCL_CHECK(phi::dynload::ncclAllGather( in_tensor_maybe_partial.data(), out_tensor->data(), @@ -196,6 +215,19 @@ std::shared_ptr ProcessGroupNCCL::AllReduce( rank_, comm); } + + VLOG(3) << "[ncclAllReduce] " + << "sendbuff: " << in_tensor.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor.numel() << ", datatype: " + << NCCLDTypeToString(phi::ToNCCLDataType(in_tensor.dtype())) + << ", redop: " + << NCCLRedTypeToString(ToNCCLRedType(opts.reduce_op)) + << ", ncclcomm: " << comm << ", stream: " << stream + << ", rank_in_group: " << rank_ << ", nranks: " << size_ + << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream; + NCCL_CHECK( phi::dynload::ncclAllReduce(in_tensor.data(), out_tensor->data(), @@ -264,6 +296,20 @@ std::shared_ptr ProcessGroupNCCL::AllToAll( int64_t in_offset = 0, in_numel = 0, out_offset = 0, out_numel = 0; phi::DenseTensor input_partial, output_partial; + VLOG(3) << "[AllToAll] " + << "sendbuff: " << in_tensor.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor.numel() << ", datatype: " + << NCCLDTypeToString(phi::ToNCCLDataType(in_tensor.dtype())) + << ", ncclcomm: " << comm << ", stream: " << stream + << ", rank_in_group: " << rank_ << ", nranks: " << size_ + << ", out_size_each_rank: " + << string::join_strings(out_size_each_rank, ',') + << ", in_size_each_rank: " + << string::join_strings(in_size_each_rank, ',') + << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream; + GroupStart(); for (auto i = 0; i < size_; i++) { in_numel = in_size_each_rank[i] * in_row_size; @@ -308,6 +354,9 @@ std::shared_ptr ProcessGroupNCCL::Barrier( phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim{1}); phi::DenseTensor barrier_tensor{allocator.get(), meta}; + VLOG(3) << "[Barrier] " + << "barrier opt: " << opts.device_id; + auto task = AllReduce(&barrier_tensor, barrier_tensor, {}, @@ -336,6 +385,17 @@ std::shared_ptr ProcessGroupNCCL::Broadcast( phi::distributed::NCCLDynamicCheck::CheckShape( *out_tensor, root, rank_, comm); } + + VLOG(3) << "[ncclBroadcast] " + << "sendbuff: " << in_tensor.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor.numel() << ", datatype: " + << NCCLDTypeToString(phi::ToNCCLDataType(in_tensor.dtype())) + << ", root: " << root << ", ncclcomm: " << comm + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream; + NCCL_CHECK( phi::dynload::ncclBroadcast(in_tensor.data(), out_tensor->data(), @@ -371,6 +431,19 @@ std::shared_ptr ProcessGroupNCCL::Reduce( rank_, comm); } + + VLOG(3) << "[ncclReduce] " + << "sendbuff: " << in_tensor.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor.numel() << ", datatype: " + << NCCLDTypeToString(phi::ToNCCLDataType(in_tensor.dtype())) + << ", redop: " + << NCCLRedTypeToString(ToNCCLRedType(opts.reduce_op)) + << ", root: " << opts.root_rank << ", ncclcomm: " << comm + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream; + NCCL_CHECK( phi::dynload::ncclReduce(in_tensor.data(), out_tensor->data(), @@ -406,6 +479,19 @@ std::shared_ptr ProcessGroupNCCL::ReduceScatter( rank_, comm); } + + VLOG(3) << "[ncclReduceScatter] " + << "sendbuff: " << in_tensor.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor.numel() << ", datatype: " + << NCCLDTypeToString(phi::ToNCCLDataType(in_tensor.dtype())) + << ", redop: " + << NCCLRedTypeToString(ToNCCLRedType(opts.reduce_op)) + << ", ncclcomm: " << comm << ", stream: " << stream + << ", rank_in_group: " << rank_ << ", nranks: " << size_ + << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream; + NCCL_CHECK(phi::dynload::ncclReduceScatter( in_tensor.data(), out_tensor->data(), @@ -442,6 +528,17 @@ std::shared_ptr ProcessGroupNCCL::Scatter( rank_, comm); } + + VLOG(3) << "[Scatter] " + << "sendbuff: " << in_tensor.data() + << ", recvbuff: " << out_tensor->data() + << ", count: " << in_tensor.numel() << ", datatype: " + << NCCLDTypeToString(phi::ToNCCLDataType(in_tensor.dtype())) + << ", root: " << opts.root_rank << ", ncclcomm: " << comm + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream; + int64_t numel = in_tensor.numel() / size_; if (rank_ == opts.root_rank) { int64_t offset = 0; @@ -520,6 +617,16 @@ std::shared_ptr ProcessGroupNCCL::Gather( phi::distributed::NCCLDynamicCheck::CheckGatherShape( in_tensor, gather_tensors, opts.root_rank, rank_, size_, comm); } + + VLOG(3) << "[Gather] " + << "sendbuff: " << in_tensor.data() + << ", count: " << in_tensor.numel() << ", datatype: " + << NCCLDTypeToString(phi::ToNCCLDataType(in_tensor.dtype())) + << ", root: " << opts.root_rank << ", ncclcomm: " << comm + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream; + GroupStart(); // root receive from all devices if (rank_ == opts.root_rank) { @@ -570,6 +677,17 @@ std::shared_ptr ProcessGroupNCCL::Recv( rank_, comm); } + + VLOG(3) << "[ncclRecv] " + << "recvbuff: " << tensor->data() + << ", count: " << tensor->numel() << ", datatype: " + << NCCLDTypeToString(phi::ToNCCLDataType(tensor->dtype())) + << ", src_in_group: " << src_rank << ", ncclcomm: " << comm + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", offset: " << offset + << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream; + NCCL_CHECK(phi::dynload::ncclRecv(tensor->data(), tensor->numel(), phi::ToNCCLDataType(tensor->dtype()), @@ -605,6 +723,18 @@ std::shared_ptr ProcessGroupNCCL::Send( rank_, comm); } + + VLOG(3) << "[ncclSend] " + << "sendbuff: " << tensor_maybe_partial.data() + << ", count: " << tensor_maybe_partial.numel() << ", datatype: " + << NCCLDTypeToString( + phi::ToNCCLDataType(tensor_maybe_partial.dtype())) + << ", dst_in_group: " << dst_rank << ", ncclcomm: " << comm + << ", stream: " << stream << ", rank_in_group: " << rank_ + << ", nranks: " << size_ << ", offset: " << offset + << ", sync_op: " << sync_op + << ", use_calc_stream: " << use_calc_stream; + NCCL_CHECK(phi::dynload::ncclSend( tensor_maybe_partial.data(), tensor_maybe_partial.numel(), @@ -669,7 +799,7 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place, BroadcastUniqueNCCLID(&nccl_id, is_p2p_op, place_key, p2p_rank); - VLOG(3) << "init nccl rank: " << rank_ << ", nranks: " << size_ + VLOG(3) << "init nccl rank_in_group: " << rank_ << ", nranks: " << size_ << ", place key: " << place_key << ", nccl uniqueid: " << SerializeNCCLUniqueId(nccl_id); @@ -687,7 +817,7 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place, NCCL_CHECK(phi::dynload::ncclGroupEnd()); VLOG(3) << "Get nccl comm: " << nccl_comm << " for place_key: " << place_key - << " on rank: " << rank << " nranks: " << num_ranks; + << " on rank_in_group: " << rank << " nranks: " << num_ranks; auto comm_ctx = std::make_unique(place); comm_ctx->set_nccl_comm(nccl_comm); @@ -754,6 +884,11 @@ std::shared_ptr ProcessGroupNCCL::Collective( if (sync_op) { task->Wait(); } + + if (FLAGS_benchmark) { + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); + } + return task; } @@ -816,6 +951,10 @@ std::shared_ptr ProcessGroupNCCL::Point2Point( task->Wait(); } + if (!is_batch_p2p && FLAGS_benchmark) { + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); + } + return task; } diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index e3e3e33bb49100613ee1e52228a27cbaf6724d8f..cfabb974f95af05c197dc0d7b67822e071656e7d 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -88,7 +88,11 @@ class FakeMicroDataset: self._acc_steps, len(data), ) - output.append(data[micro_step].detach()) + output.append( + data[micro_step].detach() + if data[micro_step] is not None + else None + ) elif data is not None: self._check_data_vaild(data) output.append(data[begin:end, :].detach())