未验证 提交 0fe0aea9 编写于 作者: F FlyingQianMM 提交者: GitHub

set device id of Place() to get GPUContext needed by LimitGridDim in...

set device id of Place() to get GPUContext needed by LimitGridDim in ElemwiseGradBroadcast (PaddlePaddle#42320) (#42332)
上级 2ea56c90
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/phi/backends/all_context.h" #include "paddle/phi/backends/all_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/elementwise_utils.h" #include "paddle/phi/kernels/funcs/elementwise_utils.h"
...@@ -978,7 +979,7 @@ static void ElemwiseGradBroadcast1CUDA(gpuStream_t stream, ...@@ -978,7 +979,7 @@ static void ElemwiseGradBroadcast1CUDA(gpuStream_t stream,
// suppose perfoemance improves with h increased. // suppose perfoemance improves with h increased.
dim3 block_size = dim3(BLOCK_X, BLOCK_Y); dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
dim3 grid_size = dim3((w + BLOCK_X - 1) / BLOCK_X); dim3 grid_size = dim3((w + BLOCK_X - 1) / BLOCK_X);
auto gplace = phi::GPUPlace(); auto gplace = phi::GPUPlace(phi::backends::gpu::GetCurrentDeviceId());
auto *ctx = static_cast<GPUContext *>( auto *ctx = static_cast<GPUContext *>(
paddle::platform::DeviceContextPool::Instance().Get(gplace)); paddle::platform::DeviceContextPool::Instance().Get(gplace));
paddle::platform::LimitGridDim(*ctx, &grid_size); paddle::platform::LimitGridDim(*ctx, &grid_size);
...@@ -1003,7 +1004,7 @@ static void ElemwiseGradBroadcast2CUDA(gpuStream_t stream, ...@@ -1003,7 +1004,7 @@ static void ElemwiseGradBroadcast2CUDA(gpuStream_t stream,
T *dy) { T *dy) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post); int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
dim3 grid_size = dim3(n); dim3 grid_size = dim3(n);
auto gplace = phi::GPUPlace(); auto gplace = phi::GPUPlace(phi::backends::gpu::GetCurrentDeviceId());
auto *ctx = static_cast<GPUContext *>( auto *ctx = static_cast<GPUContext *>(
paddle::platform::DeviceContextPool::Instance().Get(gplace)); paddle::platform::DeviceContextPool::Instance().Get(gplace));
paddle::platform::LimitGridDim(*ctx, &grid_size); paddle::platform::LimitGridDim(*ctx, &grid_size);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册