未验证 提交 17c8e3ad 编写于 作者: A Aurelius84 提交者: GitHub

Polish code in gpu_launch_config.h (#29730)

上级 068d905e
......@@ -37,19 +37,20 @@ struct GpuLaunchConfig {
inline GpuLaunchConfig GetGpuLaunchConfig1D(
const platform::CUDADeviceContext& context, int element_count) {
PADDLE_ENFORCE_GT(element_count, 0, platform::errors::InvalidArgument(
"element count should greater than 0,"
" but received value is %d.",
element_count));
PADDLE_ENFORCE_GT(element_count, 0,
platform::errors::InvalidArgument(
"element count should be greater than 0,"
" but received value is: %d.",
element_count));
const int theory_thread_count = element_count;
// Get Max threads in all SM
int max_pyhsical_threads = context.GetMaxPhysicalThreadCount();
int max_physical_threads = context.GetMaxPhysicalThreadCount();
int sm = context.GetSMCount();
// Compute pyhsical threads we need, should small than max sm threads
// Compute physical threads we need, should small than max sm threads
const int physical_thread_count =
std::min(max_pyhsical_threads, theory_thread_count);
std::min(max_physical_threads, theory_thread_count);
// Need get from device
const int thread_per_block = std::min(1024, context.GetMaxThreadsPerBlock());
......@@ -64,18 +65,18 @@ inline GpuLaunchConfig GetGpuLaunchConfig1D(
}
inline GpuLaunchConfig GetGpuLaunchConfig2D(
const platform::CUDADeviceContext& context, int xdim, int ydim) {
PADDLE_ENFORCE_GT(xdim, 0, platform::errors::InvalidArgument(
"x dim number should greater than 0,"
" but received value is:%d",
xdim));
PADDLE_ENFORCE_GT(ydim, 0, platform::errors::InvalidArgument(
"y dim number should greater than 0,"
" but received value is:%d",
ydim));
const platform::CUDADeviceContext& context, int x_dim, int y_dim) {
PADDLE_ENFORCE_GT(x_dim, 0, platform::errors::InvalidArgument(
"x dim number should greater than 0,"
" but received value is: %d",
x_dim));
PADDLE_ENFORCE_GT(y_dim, 0, platform::errors::InvalidArgument(
"y dim number should greater than 0,"
" but received value is: %d",
y_dim));
const int kThreadsPerBlock = 256;
int block_cols = std::min(xdim, kThreadsPerBlock);
int block_cols = std::min(x_dim, kThreadsPerBlock);
int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
int max_physical_threads = context.GetMaxPhysicalThreadCount();
......@@ -83,11 +84,11 @@ inline GpuLaunchConfig GetGpuLaunchConfig2D(
GpuLaunchConfig config;
// Noticed, block size is not align to 32, if needed do it yourself.
config.theory_thread_count = dim3(xdim, ydim, 1);
config.theory_thread_count = dim3(x_dim, y_dim, 1);
config.thread_per_block = dim3(block_cols, block_rows, 1);
int grid_x = std::min(DivUp(xdim, block_cols), max_blocks);
int grid_y = std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1));
int grid_x = std::min(DivUp(x_dim, block_cols), max_blocks);
int grid_y = std::min(max_blocks / grid_x, std::max(y_dim / block_rows, 1));
config.block_per_grid = dim3(grid_x, grid_y, 1);
return config;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册