未验证 提交 17c8e3ad 编写于 作者: A Aurelius84 提交者: GitHub

Polish code in gpu_launch_config.h (#29730)

上级 068d905e
...@@ -37,19 +37,20 @@ struct GpuLaunchConfig { ...@@ -37,19 +37,20 @@ struct GpuLaunchConfig {
inline GpuLaunchConfig GetGpuLaunchConfig1D( inline GpuLaunchConfig GetGpuLaunchConfig1D(
const platform::CUDADeviceContext& context, int element_count) { const platform::CUDADeviceContext& context, int element_count) {
PADDLE_ENFORCE_GT(element_count, 0, platform::errors::InvalidArgument( PADDLE_ENFORCE_GT(element_count, 0,
"element count should greater than 0," platform::errors::InvalidArgument(
" but received value is %d.", "element count should be greater than 0,"
element_count)); " but received value is: %d.",
element_count));
const int theory_thread_count = element_count; const int theory_thread_count = element_count;
// Get Max threads in all SM // Get Max threads in all SM
int max_pyhsical_threads = context.GetMaxPhysicalThreadCount(); int max_physical_threads = context.GetMaxPhysicalThreadCount();
int sm = context.GetSMCount(); int sm = context.GetSMCount();
// Compute pyhsical threads we need, should small than max sm threads // Compute physical threads we need, should small than max sm threads
const int physical_thread_count = const int physical_thread_count =
std::min(max_pyhsical_threads, theory_thread_count); std::min(max_physical_threads, theory_thread_count);
// Need get from device // Need get from device
const int thread_per_block = std::min(1024, context.GetMaxThreadsPerBlock()); const int thread_per_block = std::min(1024, context.GetMaxThreadsPerBlock());
...@@ -64,18 +65,18 @@ inline GpuLaunchConfig GetGpuLaunchConfig1D( ...@@ -64,18 +65,18 @@ inline GpuLaunchConfig GetGpuLaunchConfig1D(
} }
inline GpuLaunchConfig GetGpuLaunchConfig2D( inline GpuLaunchConfig GetGpuLaunchConfig2D(
const platform::CUDADeviceContext& context, int xdim, int ydim) { const platform::CUDADeviceContext& context, int x_dim, int y_dim) {
PADDLE_ENFORCE_GT(xdim, 0, platform::errors::InvalidArgument( PADDLE_ENFORCE_GT(x_dim, 0, platform::errors::InvalidArgument(
"x dim number should greater than 0," "x dim number should greater than 0,"
" but received value is:%d", " but received value is: %d",
xdim)); x_dim));
PADDLE_ENFORCE_GT(ydim, 0, platform::errors::InvalidArgument( PADDLE_ENFORCE_GT(y_dim, 0, platform::errors::InvalidArgument(
"y dim number should greater than 0," "y dim number should greater than 0,"
" but received value is:%d", " but received value is: %d",
ydim)); y_dim));
const int kThreadsPerBlock = 256; const int kThreadsPerBlock = 256;
int block_cols = std::min(xdim, kThreadsPerBlock); int block_cols = std::min(x_dim, kThreadsPerBlock);
int block_rows = std::max(kThreadsPerBlock / block_cols, 1); int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
int max_physical_threads = context.GetMaxPhysicalThreadCount(); int max_physical_threads = context.GetMaxPhysicalThreadCount();
...@@ -83,11 +84,11 @@ inline GpuLaunchConfig GetGpuLaunchConfig2D( ...@@ -83,11 +84,11 @@ inline GpuLaunchConfig GetGpuLaunchConfig2D(
GpuLaunchConfig config; GpuLaunchConfig config;
// Noticed, block size is not align to 32, if needed do it yourself. // Noticed, block size is not align to 32, if needed do it yourself.
config.theory_thread_count = dim3(xdim, ydim, 1); config.theory_thread_count = dim3(x_dim, y_dim, 1);
config.thread_per_block = dim3(block_cols, block_rows, 1); config.thread_per_block = dim3(block_cols, block_rows, 1);
int grid_x = std::min(DivUp(xdim, block_cols), max_blocks); int grid_x = std::min(DivUp(x_dim, block_cols), max_blocks);
int grid_y = std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1)); int grid_y = std::min(max_blocks / grid_x, std::max(y_dim / block_rows, 1));
config.block_per_grid = dim3(grid_x, grid_y, 1); config.block_per_grid = dim3(grid_x, grid_y, 1);
return config; return config;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册