From a66bb67afbb292889e688056f2752fec9cf2011c Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Tue, 29 Nov 2022 15:16:39 +0800 Subject: [PATCH] Bugfix for Collective default calc stream (#48308) * get default calc stream from execution ctx instead of global dev ctx pool. --- paddle/fluid/operators/collective/alltoall_op.cu.cc | 4 ++-- paddle/fluid/operators/collective/barrier_op.cu.cc | 4 ++-- paddle/fluid/operators/collective/c_allgather_op.cu.cc | 4 ++-- paddle/fluid/operators/collective/c_allreduce_op.h | 6 ++++-- paddle/fluid/operators/collective/c_broadcast_op.cu.cc | 4 ++-- paddle/fluid/operators/collective/c_concat_op.cu.cc | 4 ++-- paddle/fluid/operators/collective/c_reduce_op.h | 4 ++-- paddle/fluid/operators/collective/c_reducescatter_op.cu.cc | 4 ++-- paddle/fluid/operators/collective/c_scatter_op.cu.cc | 4 ++-- paddle/fluid/operators/collective/global_gather_op.cu.cc | 4 ++-- paddle/fluid/operators/collective/global_scatter_op.cu.cc | 4 ++-- .../fluid/operators/collective/partial_allgather_op.cu.cc | 4 ++-- paddle/fluid/operators/collective/partial_recv_op.cu.cc | 4 ++-- paddle/fluid/operators/collective/partial_send_op.cu.cc | 4 ++-- paddle/fluid/operators/collective/recv_v2_op.cu.cc | 4 ++-- paddle/fluid/operators/collective/send_v2_op.cu.cc | 4 ++-- 16 files changed, 34 insertions(+), 32 deletions(-) diff --git a/paddle/fluid/operators/collective/alltoall_op.cu.cc b/paddle/fluid/operators/collective/alltoall_op.cu.cc index e50d14e5ef6..fd67342b3af 100644 --- a/paddle/fluid/operators/collective/alltoall_op.cu.cc +++ b/paddle/fluid/operators/collective/alltoall_op.cu.cc @@ -46,8 +46,8 @@ class AllToAllOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/barrier_op.cu.cc b/paddle/fluid/operators/collective/barrier_op.cu.cc index 622b25f2a49..648b8fdc83b 100644 --- a/paddle/fluid/operators/collective/barrier_op.cu.cc +++ b/paddle/fluid/operators/collective/barrier_op.cu.cc @@ -39,8 +39,8 @@ class BarrierOpCUDAKernel : public framework::OpKernel { int rid = ctx.Attr("ring_id"); auto comm = platform::NCCLCommContext::Instance().Get(rid, place); - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - auto stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + auto stream = ctx.cuda_device_context().stream(); ncclRedOp_t nccl_red_type = ncclSum; PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(), stream)); diff --git a/paddle/fluid/operators/collective/c_allgather_op.cu.cc b/paddle/fluid/operators/collective/c_allgather_op.cu.cc index ddef85d73e0..947475ece48 100644 --- a/paddle/fluid/operators/collective/c_allgather_op.cu.cc +++ b/paddle/fluid/operators/collective/c_allgather_op.cu.cc @@ -67,8 +67,8 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/c_allreduce_op.h b/paddle/fluid/operators/collective/c_allreduce_op.h index 4d90442afbc..8d3af26f0c2 100644 --- a/paddle/fluid/operators/collective/c_allreduce_op.h +++ b/paddle/fluid/operators/collective/c_allreduce_op.h @@ -482,8 +482,10 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should not use global ctx for calc stream. + // auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + // stream = static_cast(dev_ctx)->stream(); + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/c_broadcast_op.cu.cc b/paddle/fluid/operators/collective/c_broadcast_op.cu.cc index 78fb50ce31c..47e5bfd825d 100644 --- a/paddle/fluid/operators/collective/c_broadcast_op.cu.cc +++ b/paddle/fluid/operators/collective/c_broadcast_op.cu.cc @@ -53,8 +53,8 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel { auto comm = platform::NCCLCommContext::Instance().Get(rid, place); gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/c_concat_op.cu.cc b/paddle/fluid/operators/collective/c_concat_op.cu.cc index e2ee9cefdbf..2d7eaf26ea4 100644 --- a/paddle/fluid/operators/collective/c_concat_op.cu.cc +++ b/paddle/fluid/operators/collective/c_concat_op.cu.cc @@ -89,8 +89,8 @@ class CConcatOpCUDAKernel : public framework::OpKernel { const T* send_buff = x->data(); T* recv_buff = temp_out.data(); gpuStream_t stream = nullptr; - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::ncclAllGather(send_buff, diff --git a/paddle/fluid/operators/collective/c_reduce_op.h b/paddle/fluid/operators/collective/c_reduce_op.h index f9288dea063..3e752011f15 100644 --- a/paddle/fluid/operators/collective/c_reduce_op.h +++ b/paddle/fluid/operators/collective/c_reduce_op.h @@ -311,8 +311,8 @@ class CReduceOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc b/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc index b4eba9d1242..e0b0800f777 100644 --- a/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc +++ b/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc @@ -54,8 +54,8 @@ class CReduceScatterOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/c_scatter_op.cu.cc b/paddle/fluid/operators/collective/c_scatter_op.cu.cc index 903d3d56886..72493e51505 100644 --- a/paddle/fluid/operators/collective/c_scatter_op.cu.cc +++ b/paddle/fluid/operators/collective/c_scatter_op.cu.cc @@ -60,8 +60,8 @@ class CScatterOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/global_gather_op.cu.cc b/paddle/fluid/operators/collective/global_gather_op.cu.cc index 439630a7f1d..83e1a4d4ca7 100644 --- a/paddle/fluid/operators/collective/global_gather_op.cu.cc +++ b/paddle/fluid/operators/collective/global_gather_op.cu.cc @@ -82,8 +82,8 @@ struct GlobalGatherFunctor { auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/global_scatter_op.cu.cc b/paddle/fluid/operators/collective/global_scatter_op.cu.cc index 4ccf9dee263..017398413b3 100644 --- a/paddle/fluid/operators/collective/global_scatter_op.cu.cc +++ b/paddle/fluid/operators/collective/global_scatter_op.cu.cc @@ -81,8 +81,8 @@ struct GlobalScatterFunctor { auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/partial_allgather_op.cu.cc b/paddle/fluid/operators/collective/partial_allgather_op.cu.cc index cd1e12d7e1b..c4565a94500 100644 --- a/paddle/fluid/operators/collective/partial_allgather_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_allgather_op.cu.cc @@ -75,8 +75,8 @@ class PartialAllGatherOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/partial_recv_op.cu.cc b/paddle/fluid/operators/collective/partial_recv_op.cu.cc index c8a49f51d5c..c95d1fe4bc6 100644 --- a/paddle/fluid/operators/collective/partial_recv_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_recv_op.cu.cc @@ -81,8 +81,8 @@ class PartialRecvOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; auto comm = platform::NCCLCommContext::Instance().Get(rid, place); if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/partial_send_op.cu.cc b/paddle/fluid/operators/collective/partial_send_op.cu.cc index 7d4125be8d3..7b9c154bd44 100644 --- a/paddle/fluid/operators/collective/partial_send_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_send_op.cu.cc @@ -77,8 +77,8 @@ class PartialSendCUDAKernel : public framework::OpKernel { auto place = ctx.GetPlace(); auto comm = platform::NCCLCommContext::Instance().Get(rid, place); if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/recv_v2_op.cu.cc b/paddle/fluid/operators/collective/recv_v2_op.cu.cc index 06e06a79c6b..a32376f3e84 100644 --- a/paddle/fluid/operators/collective/recv_v2_op.cu.cc +++ b/paddle/fluid/operators/collective/recv_v2_op.cu.cc @@ -157,8 +157,8 @@ class RecvOpV2CUDAKernel : public framework::OpKernel { } auto comm = platform::NCCLCommContext::Instance().Get(rid, place); if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/send_v2_op.cu.cc b/paddle/fluid/operators/collective/send_v2_op.cu.cc index c7ab3c749b9..631595ccd08 100644 --- a/paddle/fluid/operators/collective/send_v2_op.cu.cc +++ b/paddle/fluid/operators/collective/send_v2_op.cu.cc @@ -151,8 +151,8 @@ class SendOpV2CUDAKernel : public framework::OpKernel { auto place = ctx.GetPlace(); auto comm = platform::NCCLCommContext::Instance().Get(rid, place); if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } -- GitLab