From 22d3c560d87a0d9635ab358bbf5d19b34d8dc468 Mon Sep 17 00:00:00 2001 From: FlyingQianMM <245467267@qq.com> Date: Thu, 28 Apr 2022 08:59:38 +0800 Subject: [PATCH] set device id of Place() to get GPUContext needed by LimitGridDim in ElemwiseGradBroadcast (#42320) * set device id of Place() to get GPUContext needed by LimitGridDim in ElemwiseGradBroadcast * fix code style --- paddle/phi/kernels/funcs/elementwise_grad_base.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/funcs/elementwise_grad_base.h b/paddle/phi/kernels/funcs/elementwise_grad_base.h index 1021b510b2..7508d8ee8c 100644 --- a/paddle/phi/kernels/funcs/elementwise_grad_base.h +++ b/paddle/phi/kernels/funcs/elementwise_grad_base.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/phi/backends/all_context.h" +#include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/funcs/elementwise_utils.h" @@ -978,7 +979,7 @@ static void ElemwiseGradBroadcast1CUDA(gpuStream_t stream, // suppose perfoemance improves with h increased. dim3 block_size = dim3(BLOCK_X, BLOCK_Y); dim3 grid_size = dim3((w + BLOCK_X - 1) / BLOCK_X); - auto gplace = phi::GPUPlace(); + auto gplace = phi::GPUPlace(phi::backends::gpu::GetCurrentDeviceId()); auto *ctx = static_cast( paddle::platform::DeviceContextPool::Instance().Get(gplace)); paddle::platform::LimitGridDim(*ctx, &grid_size); @@ -1003,7 +1004,7 @@ static void ElemwiseGradBroadcast2CUDA(gpuStream_t stream, T *dy) { int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post); dim3 grid_size = dim3(n); - auto gplace = phi::GPUPlace(); + auto gplace = phi::GPUPlace(phi::backends::gpu::GetCurrentDeviceId()); auto *ctx = static_cast( paddle::platform::DeviceContextPool::Instance().Get(gplace)); paddle::platform::LimitGridDim(*ctx, &grid_size); -- GitLab