未验证 提交 0f4b2186 编写于 作者: L LielinJiang 提交者: GitHub

Enable bilateral_slice unittest on windows platform (#29896)

* enable bilateral_slice unittest on windows platform

* reduce max threads
上级 95df0e14
...@@ -472,8 +472,8 @@ class BilateralSliceGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -472,8 +472,8 @@ class BilateralSliceGradOpCUDAKernel : public framework::OpKernel<T> {
grid_sizes.gw = gw; grid_sizes.gw = gw;
grid_sizes.input_chans = input_chans; grid_sizes.input_chans = input_chans;
platform::GpuLaunchConfig config = platform::GpuLaunchConfig config = platform::GetGpuLaunchConfig1D(
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), grid_count); ctx.cuda_device_context(), grid_count, 512);
BilateralSliceCudaGridGradKernel< BilateralSliceCudaGridGradKernel<
T><<<config.block_per_grid, config.thread_per_block, 0, T><<<config.block_per_grid, config.thread_per_block, 0,
...@@ -481,8 +481,8 @@ class BilateralSliceGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -481,8 +481,8 @@ class BilateralSliceGradOpCUDAKernel : public framework::OpKernel<T> {
grid_grad_data, output_grad_data, guide_data, input_data, grid_sizes, grid_grad_data, output_grad_data, guide_data, input_data, grid_sizes,
has_offset, grid_count, output_chans); has_offset, grid_count, output_chans);
config = config = platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(),
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), guide_count); guide_count, 512);
BilateralSliceCudaGuideGradKernel< BilateralSliceCudaGuideGradKernel<
T><<<config.block_per_grid, config.thread_per_block, 0, T><<<config.block_per_grid, config.thread_per_block, 0,
...@@ -490,8 +490,8 @@ class BilateralSliceGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -490,8 +490,8 @@ class BilateralSliceGradOpCUDAKernel : public framework::OpKernel<T> {
guide_grad_data, output_grad_data, grid_data, guide_data, input_data, guide_grad_data, output_grad_data, grid_data, guide_data, input_data,
grid_sizes, has_offset, guide_count, output_chans); grid_sizes, has_offset, guide_count, output_chans);
config = config = platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(),
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), input_count); input_count, 512);
BilateralSliceCudaInputGradKernel< BilateralSliceCudaInputGradKernel<
T><<<config.block_per_grid, config.thread_per_block, 0, T><<<config.block_per_grid, config.thread_per_block, 0,
......
...@@ -36,7 +36,8 @@ struct GpuLaunchConfig { ...@@ -36,7 +36,8 @@ struct GpuLaunchConfig {
}; };
inline GpuLaunchConfig GetGpuLaunchConfig1D( inline GpuLaunchConfig GetGpuLaunchConfig1D(
const platform::CUDADeviceContext& context, int element_count) { const platform::CUDADeviceContext& context, int element_count,
int max_threads = 1024) {
PADDLE_ENFORCE_GT(element_count, 0, PADDLE_ENFORCE_GT(element_count, 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"element count should be greater than 0," "element count should be greater than 0,"
...@@ -53,7 +54,8 @@ inline GpuLaunchConfig GetGpuLaunchConfig1D( ...@@ -53,7 +54,8 @@ inline GpuLaunchConfig GetGpuLaunchConfig1D(
std::min(max_physical_threads, theory_thread_count); std::min(max_physical_threads, theory_thread_count);
// Need get from device // Need get from device
const int thread_per_block = std::min(1024, context.GetMaxThreadsPerBlock()); const int thread_per_block =
std::min(max_threads, context.GetMaxThreadsPerBlock());
const int block_count = const int block_count =
std::min(DivUp(physical_thread_count, thread_per_block), sm); std::min(DivUp(physical_thread_count, thread_per_block), sm);
......
...@@ -53,7 +53,6 @@ diable_wingpu_test="^test_analysis_predictor$|\ ...@@ -53,7 +53,6 @@ diable_wingpu_test="^test_analysis_predictor$|\
^test_dataloader_unkeep_order$|\ ^test_dataloader_unkeep_order$|\
^test_model$|\ ^test_model$|\
^test_add_reader_dependency$|\ ^test_add_reader_dependency$|\
^test_bilateral_slice_op$|\
^test_cholesky_op$|\ ^test_cholesky_op$|\
^test_dataloader_early_reset$|\ ^test_dataloader_early_reset$|\
^test_decoupled_py_reader$|\ ^test_decoupled_py_reader$|\
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册