未验证 提交 a66bb67a 编写于 作者: J JZ-LIANG 提交者: GitHub

Bugfix for Collective default calc stream (#48308)

* get default calc stream from execution ctx instead of global dev ctx pool.
上级 fc882c7b
......@@ -46,8 +46,8 @@ class AllToAllOpCUDAKernel : public framework::OpKernel<T> {
gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<phi::GPUContext*>(dev_ctx)->stream();
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
}
......
......@@ -39,8 +39,8 @@ class BarrierOpCUDAKernel : public framework::OpKernel<T> {
int rid = ctx.Attr<int>("ring_id");
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
auto stream = static_cast<phi::GPUContext*>(dev_ctx)->stream();
// should ExecutionContext for calc stream.
auto stream = ctx.cuda_device_context().stream();
ncclRedOp_t nccl_red_type = ncclSum;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(), stream));
......
......@@ -67,8 +67,8 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel<T> {
gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<phi::GPUContext*>(dev_ctx)->stream();
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
}
......
......@@ -482,8 +482,10 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<phi::GPUContext*>(dev_ctx)->stream();
// should not use global ctx for calc stream.
// auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
// stream = static_cast<phi::GPUContext*>(dev_ctx)->stream();
stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
}
......
......@@ -53,8 +53,8 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel<T> {
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<phi::GPUContext*>(dev_ctx)->stream();
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
}
......
......@@ -89,8 +89,8 @@ class CConcatOpCUDAKernel : public framework::OpKernel<T> {
const T* send_buff = x->data<T>();
T* recv_buff = temp_out.data<T>();
gpuStream_t stream = nullptr;
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<phi::GPUContext*>(dev_ctx)->stream();
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::ncclAllGather(send_buff,
......
......@@ -311,8 +311,8 @@ class CReduceOpCUDAKernel : public framework::OpKernel<T> {
gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<phi::GPUContext*>(dev_ctx)->stream();
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
}
......
......@@ -54,8 +54,8 @@ class CReduceScatterOpCUDAKernel : public framework::OpKernel<T> {
gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<phi::GPUContext*>(dev_ctx)->stream();
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
}
......
......@@ -60,8 +60,8 @@ class CScatterOpCUDAKernel : public framework::OpKernel<T> {
gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<phi::GPUContext*>(dev_ctx)->stream();
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
}
......
......@@ -82,8 +82,8 @@ struct GlobalGatherFunctor<phi::GPUContext, T> {
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<phi::GPUContext*>(dev_ctx)->stream();
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
}
......
......@@ -81,8 +81,8 @@ struct GlobalScatterFunctor<phi::GPUContext, T> {
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<phi::GPUContext*>(dev_ctx)->stream();
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
}
......
......@@ -75,8 +75,8 @@ class PartialAllGatherOpCUDAKernel : public framework::OpKernel<T> {
gpuStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<phi::GPUContext*>(dev_ctx)->stream();
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
}
......
......@@ -81,8 +81,8 @@ class PartialRecvOpCUDAKernel : public framework::OpKernel<T> {
gpuStream_t stream = nullptr;
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<phi::GPUContext *>(dev_ctx)->stream();
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
}
......
......@@ -77,8 +77,8 @@ class PartialSendCUDAKernel : public framework::OpKernel<T> {
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<phi::GPUContext*>(dev_ctx)->stream();
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
}
......
......@@ -157,8 +157,8 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
}
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<phi::GPUContext *>(dev_ctx)->stream();
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
}
......
......@@ -151,8 +151,8 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> {
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<phi::GPUContext*>(dev_ctx)->stream();
// should ExecutionContext for calc stream.
stream = ctx.cuda_device_context().stream();
} else {
stream = comm->stream();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册