未验证 提交 a66bb67a 编写于 作者: J JZ-LIANG 提交者: GitHub

Bugfix for Collective default calc stream (#48308)

* get default calc stream from execution ctx instead of global dev ctx pool.
上级 fc882c7b
...@@ -46,8 +46,8 @@ class AllToAllOpCUDAKernel : public framework::OpKernel<T> { ...@@ -46,8 +46,8 @@ class AllToAllOpCUDAKernel : public framework::OpKernel<T> {
gpuStream_t stream = nullptr; gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); // should ExecutionContext for calc stream.
stream = static_cast<phi::GPUContext*>(dev_ctx)->stream(); stream = ctx.cuda_device_context().stream();
} else { } else {
stream = comm->stream(); stream = comm->stream();
} }
......
...@@ -39,8 +39,8 @@ class BarrierOpCUDAKernel : public framework::OpKernel<T> { ...@@ -39,8 +39,8 @@ class BarrierOpCUDAKernel : public framework::OpKernel<T> {
int rid = ctx.Attr<int>("ring_id"); int rid = ctx.Attr<int>("ring_id");
auto comm = platform::NCCLCommContext::Instance().Get(rid, place); auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); // should ExecutionContext for calc stream.
auto stream = static_cast<phi::GPUContext*>(dev_ctx)->stream(); auto stream = ctx.cuda_device_context().stream();
ncclRedOp_t nccl_red_type = ncclSum; ncclRedOp_t nccl_red_type = ncclSum;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(), stream)); sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(), stream));
......
...@@ -67,8 +67,8 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel<T> { ...@@ -67,8 +67,8 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel<T> {
gpuStream_t stream = nullptr; gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); // should ExecutionContext for calc stream.
stream = static_cast<phi::GPUContext*>(dev_ctx)->stream(); stream = ctx.cuda_device_context().stream();
} else { } else {
stream = comm->stream(); stream = comm->stream();
} }
......
...@@ -482,8 +482,10 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> { ...@@ -482,8 +482,10 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
gpuStream_t stream = nullptr; gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); // should not use global ctx for calc stream.
stream = static_cast<phi::GPUContext*>(dev_ctx)->stream(); // auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
// stream = static_cast<phi::GPUContext*>(dev_ctx)->stream();
stream = ctx.cuda_device_context().stream();
} else { } else {
stream = comm->stream(); stream = comm->stream();
} }
......
...@@ -53,8 +53,8 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel<T> { ...@@ -53,8 +53,8 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel<T> {
auto comm = platform::NCCLCommContext::Instance().Get(rid, place); auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
gpuStream_t stream = nullptr; gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); // should ExecutionContext for calc stream.
stream = static_cast<phi::GPUContext*>(dev_ctx)->stream(); stream = ctx.cuda_device_context().stream();
} else { } else {
stream = comm->stream(); stream = comm->stream();
} }
......
...@@ -89,8 +89,8 @@ class CConcatOpCUDAKernel : public framework::OpKernel<T> { ...@@ -89,8 +89,8 @@ class CConcatOpCUDAKernel : public framework::OpKernel<T> {
const T* send_buff = x->data<T>(); const T* send_buff = x->data<T>();
T* recv_buff = temp_out.data<T>(); T* recv_buff = temp_out.data<T>();
gpuStream_t stream = nullptr; gpuStream_t stream = nullptr;
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); // should ExecutionContext for calc stream.
stream = static_cast<phi::GPUContext*>(dev_ctx)->stream(); stream = ctx.cuda_device_context().stream();
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclAllGather(send_buff, platform::dynload::ncclAllGather(send_buff,
......
...@@ -311,8 +311,8 @@ class CReduceOpCUDAKernel : public framework::OpKernel<T> { ...@@ -311,8 +311,8 @@ class CReduceOpCUDAKernel : public framework::OpKernel<T> {
gpuStream_t stream = nullptr; gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); // should ExecutionContext for calc stream.
stream = static_cast<phi::GPUContext*>(dev_ctx)->stream(); stream = ctx.cuda_device_context().stream();
} else { } else {
stream = comm->stream(); stream = comm->stream();
} }
......
...@@ -54,8 +54,8 @@ class CReduceScatterOpCUDAKernel : public framework::OpKernel<T> { ...@@ -54,8 +54,8 @@ class CReduceScatterOpCUDAKernel : public framework::OpKernel<T> {
gpuStream_t stream = nullptr; gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); // should ExecutionContext for calc stream.
stream = static_cast<phi::GPUContext*>(dev_ctx)->stream(); stream = ctx.cuda_device_context().stream();
} else { } else {
stream = comm->stream(); stream = comm->stream();
} }
......
...@@ -60,8 +60,8 @@ class CScatterOpCUDAKernel : public framework::OpKernel<T> { ...@@ -60,8 +60,8 @@ class CScatterOpCUDAKernel : public framework::OpKernel<T> {
gpuStream_t stream = nullptr; gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); // should ExecutionContext for calc stream.
stream = static_cast<phi::GPUContext*>(dev_ctx)->stream(); stream = ctx.cuda_device_context().stream();
} else { } else {
stream = comm->stream(); stream = comm->stream();
} }
......
...@@ -82,8 +82,8 @@ struct GlobalGatherFunctor<phi::GPUContext, T> { ...@@ -82,8 +82,8 @@ struct GlobalGatherFunctor<phi::GPUContext, T> {
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
gpuStream_t stream = nullptr; gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); // should ExecutionContext for calc stream.
stream = static_cast<phi::GPUContext*>(dev_ctx)->stream(); stream = ctx.cuda_device_context().stream();
} else { } else {
stream = comm->stream(); stream = comm->stream();
} }
......
...@@ -81,8 +81,8 @@ struct GlobalScatterFunctor<phi::GPUContext, T> { ...@@ -81,8 +81,8 @@ struct GlobalScatterFunctor<phi::GPUContext, T> {
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place); auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
gpuStream_t stream = nullptr; gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); // should ExecutionContext for calc stream.
stream = static_cast<phi::GPUContext*>(dev_ctx)->stream(); stream = ctx.cuda_device_context().stream();
} else { } else {
stream = comm->stream(); stream = comm->stream();
} }
......
...@@ -75,8 +75,8 @@ class PartialAllGatherOpCUDAKernel : public framework::OpKernel<T> { ...@@ -75,8 +75,8 @@ class PartialAllGatherOpCUDAKernel : public framework::OpKernel<T> {
gpuStream_t stream = nullptr; gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); // should ExecutionContext for calc stream.
stream = static_cast<phi::GPUContext*>(dev_ctx)->stream(); stream = ctx.cuda_device_context().stream();
} else { } else {
stream = comm->stream(); stream = comm->stream();
} }
......
...@@ -81,8 +81,8 @@ class PartialRecvOpCUDAKernel : public framework::OpKernel<T> { ...@@ -81,8 +81,8 @@ class PartialRecvOpCUDAKernel : public framework::OpKernel<T> {
gpuStream_t stream = nullptr; gpuStream_t stream = nullptr;
auto comm = platform::NCCLCommContext::Instance().Get(rid, place); auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); // should ExecutionContext for calc stream.
stream = static_cast<phi::GPUContext *>(dev_ctx)->stream(); stream = ctx.cuda_device_context().stream();
} else { } else {
stream = comm->stream(); stream = comm->stream();
} }
......
...@@ -77,8 +77,8 @@ class PartialSendCUDAKernel : public framework::OpKernel<T> { ...@@ -77,8 +77,8 @@ class PartialSendCUDAKernel : public framework::OpKernel<T> {
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place); auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); // should ExecutionContext for calc stream.
stream = static_cast<phi::GPUContext*>(dev_ctx)->stream(); stream = ctx.cuda_device_context().stream();
} else { } else {
stream = comm->stream(); stream = comm->stream();
} }
......
...@@ -157,8 +157,8 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> { ...@@ -157,8 +157,8 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
} }
auto comm = platform::NCCLCommContext::Instance().Get(rid, place); auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); // should ExecutionContext for calc stream.
stream = static_cast<phi::GPUContext *>(dev_ctx)->stream(); stream = ctx.cuda_device_context().stream();
} else { } else {
stream = comm->stream(); stream = comm->stream();
} }
......
...@@ -151,8 +151,8 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> { ...@@ -151,8 +151,8 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> {
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place); auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); // should ExecutionContext for calc stream.
stream = static_cast<phi::GPUContext*>(dev_ctx)->stream(); stream = ctx.cuda_device_context().stream();
} else { } else {
stream = comm->stream(); stream = comm->stream();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册