diff --git a/paddle/fluid/operators/collective/alltoall_op.cu.cc b/paddle/fluid/operators/collective/alltoall_op.cu.cc index e50d14e5ef6ae18406cd057dae9798f074dfd709..fd67342b3affa3e7b97f46c03734058c1a2b74cb 100644 --- a/paddle/fluid/operators/collective/alltoall_op.cu.cc +++ b/paddle/fluid/operators/collective/alltoall_op.cu.cc @@ -46,8 +46,8 @@ class AllToAllOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/barrier_op.cu.cc b/paddle/fluid/operators/collective/barrier_op.cu.cc index 622b25f2a49bb364dcca33a3c57e1511505c6528..648b8fdc83b878be13a2b0b885e721bafe5ea2f6 100644 --- a/paddle/fluid/operators/collective/barrier_op.cu.cc +++ b/paddle/fluid/operators/collective/barrier_op.cu.cc @@ -39,8 +39,8 @@ class BarrierOpCUDAKernel : public framework::OpKernel { int rid = ctx.Attr("ring_id"); auto comm = platform::NCCLCommContext::Instance().Get(rid, place); - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - auto stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + auto stream = ctx.cuda_device_context().stream(); ncclRedOp_t nccl_red_type = ncclSum; PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(), stream)); diff --git a/paddle/fluid/operators/collective/c_allgather_op.cu.cc b/paddle/fluid/operators/collective/c_allgather_op.cu.cc index ddef85d73e084123d74af2e0d6c4f6dbb43c8848..947475ece482ab2cc4d51f2fd9fce8881672909f 100644 --- a/paddle/fluid/operators/collective/c_allgather_op.cu.cc +++ b/paddle/fluid/operators/collective/c_allgather_op.cu.cc @@ -67,8 +67,8 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/c_allreduce_op.h b/paddle/fluid/operators/collective/c_allreduce_op.h index 4d90442afbc5ab06d9b490acf73f42ebcb4904c9..8d3af26f0c2542ee412112525a3bf92e46f0eaa9 100644 --- a/paddle/fluid/operators/collective/c_allreduce_op.h +++ b/paddle/fluid/operators/collective/c_allreduce_op.h @@ -482,8 +482,10 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should not use global ctx for calc stream. + // auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + // stream = static_cast(dev_ctx)->stream(); + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/c_broadcast_op.cu.cc b/paddle/fluid/operators/collective/c_broadcast_op.cu.cc index 78fb50ce31c62d63919417fbc1962968c91ae3ec..47e5bfd825d650b2958f1119d1df78f68797bd24 100644 --- a/paddle/fluid/operators/collective/c_broadcast_op.cu.cc +++ b/paddle/fluid/operators/collective/c_broadcast_op.cu.cc @@ -53,8 +53,8 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel { auto comm = platform::NCCLCommContext::Instance().Get(rid, place); gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/c_concat_op.cu.cc b/paddle/fluid/operators/collective/c_concat_op.cu.cc index e2ee9cefdbfb28d1804b3e9f39ddda23d75c0f8e..2d7eaf26ea420ad5ffb915e720a02b8f13b88081 100644 --- a/paddle/fluid/operators/collective/c_concat_op.cu.cc +++ b/paddle/fluid/operators/collective/c_concat_op.cu.cc @@ -89,8 +89,8 @@ class CConcatOpCUDAKernel : public framework::OpKernel { const T* send_buff = x->data(); T* recv_buff = temp_out.data(); gpuStream_t stream = nullptr; - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::ncclAllGather(send_buff, diff --git a/paddle/fluid/operators/collective/c_reduce_op.h b/paddle/fluid/operators/collective/c_reduce_op.h index f9288dea063f057a4beaa9124565c42b01e77482..3e752011f152e2762580881bf87b41dcb76ac311 100644 --- a/paddle/fluid/operators/collective/c_reduce_op.h +++ b/paddle/fluid/operators/collective/c_reduce_op.h @@ -311,8 +311,8 @@ class CReduceOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc b/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc index b4eba9d124243c3d22603dfbf415ea9d24d127d2..e0b0800f77769dd6b4d2141c656930815df7aa91 100644 --- a/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc +++ b/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc @@ -54,8 +54,8 @@ class CReduceScatterOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/c_scatter_op.cu.cc b/paddle/fluid/operators/collective/c_scatter_op.cu.cc index 903d3d568861a8513ef6618115aa7246f8787355..72493e51505cd03a9ce891ef9887cea70a837cb5 100644 --- a/paddle/fluid/operators/collective/c_scatter_op.cu.cc +++ b/paddle/fluid/operators/collective/c_scatter_op.cu.cc @@ -60,8 +60,8 @@ class CScatterOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/global_gather_op.cu.cc b/paddle/fluid/operators/collective/global_gather_op.cu.cc index 439630a7f1dd7c305e985fe0410e8af91bbaddd0..83e1a4d4ca778c8f5e7b93d29ebbb1ad4d627011 100644 --- a/paddle/fluid/operators/collective/global_gather_op.cu.cc +++ b/paddle/fluid/operators/collective/global_gather_op.cu.cc @@ -82,8 +82,8 @@ struct GlobalGatherFunctor { auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/global_scatter_op.cu.cc b/paddle/fluid/operators/collective/global_scatter_op.cu.cc index 4ccf9dee2631f294b2a0ae23763b828cc2fe0d8d..017398413b372b0c694378aefd049d86e69037af 100644 --- a/paddle/fluid/operators/collective/global_scatter_op.cu.cc +++ b/paddle/fluid/operators/collective/global_scatter_op.cu.cc @@ -81,8 +81,8 @@ struct GlobalScatterFunctor { auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/partial_allgather_op.cu.cc b/paddle/fluid/operators/collective/partial_allgather_op.cu.cc index cd1e12d7e1bab20c192e4953b8be2f82721ac17a..c4565a94500639b776cff9475c7bc0e669493f91 100644 --- a/paddle/fluid/operators/collective/partial_allgather_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_allgather_op.cu.cc @@ -75,8 +75,8 @@ class PartialAllGatherOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/partial_recv_op.cu.cc b/paddle/fluid/operators/collective/partial_recv_op.cu.cc index c8a49f51d5c4681f6d7d9cc681924d4441f5fa40..c95d1fe4bc6195e2c2c52bcd5df7956170aa2a1f 100644 --- a/paddle/fluid/operators/collective/partial_recv_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_recv_op.cu.cc @@ -81,8 +81,8 @@ class PartialRecvOpCUDAKernel : public framework::OpKernel { gpuStream_t stream = nullptr; auto comm = platform::NCCLCommContext::Instance().Get(rid, place); if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/partial_send_op.cu.cc b/paddle/fluid/operators/collective/partial_send_op.cu.cc index 7d4125be8d32e740a29376f54a74cefb8d1812c3..7b9c154bd44997bc218ab34f8d5d5e43fa353436 100644 --- a/paddle/fluid/operators/collective/partial_send_op.cu.cc +++ b/paddle/fluid/operators/collective/partial_send_op.cu.cc @@ -77,8 +77,8 @@ class PartialSendCUDAKernel : public framework::OpKernel { auto place = ctx.GetPlace(); auto comm = platform::NCCLCommContext::Instance().Get(rid, place); if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/recv_v2_op.cu.cc b/paddle/fluid/operators/collective/recv_v2_op.cu.cc index 06e06a79c6b623b4301b60ace6b10168f954001d..a32376f3e842da4185349c0e75da016416ca6b3d 100644 --- a/paddle/fluid/operators/collective/recv_v2_op.cu.cc +++ b/paddle/fluid/operators/collective/recv_v2_op.cu.cc @@ -157,8 +157,8 @@ class RecvOpV2CUDAKernel : public framework::OpKernel { } auto comm = platform::NCCLCommContext::Instance().Get(rid, place); if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); } diff --git a/paddle/fluid/operators/collective/send_v2_op.cu.cc b/paddle/fluid/operators/collective/send_v2_op.cu.cc index c7ab3c749b9b73aada30baa07abf0b63d323878d..631595ccd08695fab158bda7f469fd3a26a2a794 100644 --- a/paddle/fluid/operators/collective/send_v2_op.cu.cc +++ b/paddle/fluid/operators/collective/send_v2_op.cu.cc @@ -151,8 +151,8 @@ class SendOpV2CUDAKernel : public framework::OpKernel { auto place = ctx.GetPlace(); auto comm = platform::NCCLCommContext::Instance().Get(rid, place); if (ctx.Attr("use_calc_stream")) { - auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); - stream = static_cast(dev_ctx)->stream(); + // should ExecutionContext for calc stream. + stream = ctx.cuda_device_context().stream(); } else { stream = comm->stream(); }