未验证 提交 0f4b2186 编写于 作者: L LielinJiang 提交者: GitHub

Enable bilateral_slice unittest on windows platform (#29896)

* enable bilateral_slice unittest on windows platform

* reduce max threads
上级 95df0e14
......@@ -472,8 +472,8 @@ class BilateralSliceGradOpCUDAKernel : public framework::OpKernel<T> {
grid_sizes.gw = gw;
grid_sizes.input_chans = input_chans;
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), grid_count);
platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D(
ctx.cuda_device_context(), grid_count, 512);
BilateralSliceCudaGridGradKernel<
T><<<config.block_per_grid, config.thread_per_block, 0,
......@@ -481,8 +481,8 @@ class BilateralSliceGradOpCUDAKernel : public framework::OpKernel<T> {
grid_grad_data, output_grad_data, guide_data, input_data, grid_sizes,
has_offset, grid_count, output_chans);
config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), guide_count);
config = platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(),
guide_count, 512);
BilateralSliceCudaGuideGradKernel<
T><<<config.block_per_grid, config.thread_per_block, 0,
......@@ -490,8 +490,8 @@ class BilateralSliceGradOpCUDAKernel : public framework::OpKernel<T> {
guide_grad_data, output_grad_data, grid_data, guide_data, input_data,
grid_sizes, has_offset, guide_count, output_chans);
config =
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), input_count);
config = platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(),
input_count, 512);
BilateralSliceCudaInputGradKernel<
T><<<config.block_per_grid, config.thread_per_block, 0,
......
......@@ -36,7 +36,8 @@ struct GpuLaunchConfig {
};
inline GpuLaunchConfig GetGpuLaunchConfig1D(
const platform::CUDADeviceContext& context, int element_count) {
const platform::CUDADeviceContext& context, int element_count,
int max_threads = 1024) {
PADDLE_ENFORCE_GT(element_count, 0,
platform::errors::InvalidArgument(
"element count should be greater than 0,"
......@@ -53,7 +54,8 @@ inline GpuLaunchConfig GetGpuLaunchConfig1D(
std::min(max_physical_threads, theory_thread_count);
// Need get from device
const int thread_per_block = std::min(1024, context.GetMaxThreadsPerBlock());
const int thread_per_block =
std::min(max_threads, context.GetMaxThreadsPerBlock());
const int block_count =
std::min(DivUp(physical_thread_count, thread_per_block), sm);
......
......@@ -53,7 +53,6 @@ diable_wingpu_test="^test_analysis_predictor$|\
^test_dataloader_unkeep_order$|\
^test_model$|\
^test_add_reader_dependency$|\
^test_bilateral_slice_op$|\
^test_cholesky_op$|\
^test_dataloader_early_reset$|\
^test_decoupled_py_reader$|\
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册