diff --git a/paddle/fluid/operators/math/im2col.cu b/paddle/fluid/operators/math/im2col.cu index f2a2148ba6954f50cf59ae30f4f4be6aa070739f..3eadaa2677ab4f6cb69f4163079c2d891717eec1 100644 --- a/paddle/fluid/operators/math/im2col.cu +++ b/paddle/fluid/operators/math/im2col.cu @@ -16,6 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/math/im2col.h" #include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_launch_config.h" namespace paddle { namespace operators { @@ -104,10 +105,14 @@ class Im2ColFunctordims()[4]; int num_outputs = im_channels * col_height * col_width; - int blocks = (num_outputs + 1024 - 1) / 1024; + int num_thread = 1024; +#ifdef WITH_NV_JETSON + platform::ChangeThreadNum(context, &num_thread); +#endif + int blocks = (num_outputs + num_thread - 1) / num_thread; int block_x = 512; int block_y = (blocks + 512 - 1) / 512; - dim3 threads(1024, 1); + dim3 threads(num_thread, 1); dim3 grid(block_x, block_y); im2col<<>>( im.data(), num_outputs, im_height, im_width, dilation[0], @@ -228,10 +233,14 @@ class Col2ImFunctor::operator()( const int padding_width = paddings[1]; int nthreads = batch_size * output_channels * output_height * output_width; - int blocks = (nthreads + 1024 - 1) / 1024; - dim3 threads(1024, 1); + int thread_num = 1024; +#ifdef WITH_NV_JETSON + // platform::ChangeThreadNum(context, &thread_num); + thread_num = 512; +#endif + int blocks = (nthreads + thread_num - 1) / thread_num; + dim3 threads(thread_num, 1); dim3 grid(blocks, 1); KernelPool2D<<>>( @@ -298,10 +304,13 @@ class Pool2dFunctor { T* output_data = output->mutable_data(context.GetPlace()); int nthreads = batch_size * output_channels * output_height * output_width; - int blocks = (nthreads + 1024 - 1) / 1024; - dim3 threads(1024, 1); + int thread_num = 1024; +#ifdef WITH_NV_JETSON + platform::ChangeThreadNum(context, &thread_num); +#endif + int blocks = (nthreads + thread_num - 1) / thread_num; + dim3 threads(thread_num, 1); dim3 grid(blocks, 1); - KernelPool2D<<>>( nthreads, input_data, input_channels, input_height, input_width, output_height, output_width, ksize_height, ksize_width, stride_height, @@ -341,10 +350,13 @@ class Pool2dFunctor { T* output_data = output->mutable_data(context.GetPlace()); int nthreads = batch_size * output_channels * output_height * output_width; - int blocks = (nthreads + 1024 - 1) / 1024; - dim3 threads(1024, 1); + int thread_num = 1024; +#ifdef WITH_NV_JETSON + platform::ChangeThreadNum(context, &thread_num); +#endif + int blocks = (nthreads + thread_num - 1) / thread_num; + dim3 threads(thread_num, 1); dim3 grid(blocks, 1); - KernelPool2D<<>>( nthreads, input_data, input_channels, input_height, input_width, output_height, output_width, ksize_height, ksize_width, stride_height, @@ -911,8 +923,12 @@ class Pool3dFunctor { int nthreads = batch_size * output_channels * output_depth * output_height * output_width; - int blocks = (nthreads + 1024 - 1) / 1024; - dim3 threads(1024, 1); + int thread_num = 1024; +#ifdef WITH_NV_JETSON + platform::ChangeThreadNum(context, &thread_num); +#endif + int blocks = (nthreads + thread_num - 1) / thread_num; + dim3 threads(thread_num, 1); dim3 grid(blocks, 1); KernelPool3D<<>>( @@ -962,8 +978,12 @@ class Pool3dFunctor { int nthreads = batch_size * output_channels * output_depth * output_height * output_width; - int blocks = (nthreads + 1024 - 1) / 1024; - dim3 threads(1024, 1); + int thread_num = 1024; +#ifdef WITH_NV_JETSON + platform::ChangeThreadNum(context, &thread_num); +#endif + int blocks = (nthreads + thread_num - 1) / thread_num; + dim3 threads(thread_num, 1); dim3 grid(blocks, 1); KernelPool3D<<>>( @@ -1377,10 +1397,14 @@ class MaxPool2dWithIndexFunctor { T2* mask_data = mask->mutable_data(context.GetPlace()); int nthreads = batch_size * output_channels * output_height * output_width; - int blocks = (nthreads + 1024 - 1) / 1024; - dim3 threads(1024, 1); - dim3 grid(blocks, 1); + int thread_num = 1024; +#ifdef WITH_NV_JETSON + platform::ChangeThreadNum(context, &thread_num); +#endif + int blocks = (nthreads + thread_num - 1) / thread_num; + dim3 threads(thread_num, 1); + dim3 grid(blocks, 1); KernelMaxPool2dWithIdx<<>>( nthreads, input_data, input_channels, input_height, input_width, output_height, output_width, ksize_height, ksize_width, stride_height, @@ -1613,8 +1637,13 @@ class MaxPool3dWithIndexFunctor { int nthreads = batch_size * output_channels * output_depth * output_height * output_width; - int blocks = (nthreads + 1024 - 1) / 1024; - dim3 threads(1024, 1); + int thread_num = 1024; +#ifdef WITH_NV_JETSON + platform::ChangeThreadNum(context, &thread_num); +#endif + + int blocks = (nthreads + thread_num - 1) / thread_num; + dim3 threads(thread_num, 1); dim3 grid(blocks, 1); KernelMaxPool3DWithIdx<<>>( diff --git a/paddle/fluid/operators/math/vol2col.cu b/paddle/fluid/operators/math/vol2col.cu index eca39e919737210267d7af1856903d3e1fc697d1..d83b5b0fe3afb390009b1153c8e8c6175abe415f 100644 --- a/paddle/fluid/operators/math/vol2col.cu +++ b/paddle/fluid/operators/math/vol2col.cu @@ -16,6 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/math/vol2col.h" #include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_launch_config.h" namespace paddle { namespace operators { @@ -152,8 +153,14 @@ class Vol2ColFunctor { int num_outputs = input_channels * output_depth * output_height * output_width; - const int threads = 1024; - const int blocks = (num_outputs + 1024 - 1) / 1024; + int max_threads = 1024; +#ifdef WITH_NV_JETSON + platform::ChangeThreadNum(context, &max_threads); +#endif + + const int threads = max_threads; + const int blocks = (num_outputs + max_threads - 1) / max_threads; + vol2col<<>>( num_outputs, vol.data(), input_depth, input_height, input_width, dilations[0], dilations[1], dilations[2], filter_depth, filter_height, @@ -313,8 +320,13 @@ class Col2VolFunctor { int num_kernels = input_channels * input_depth * input_height * input_width; - const int threads = 1024; - const int blocks = (num_kernels + 1024 - 1) / 1024; + int max_threads = 1024; +#ifdef WITH_NV_JETSON + platform::ChangeThreadNum(context, &max_threads); +#endif + + const int threads = max_threads; + const int blocks = (num_kernels + max_threads - 1) / max_threads; col2vol<<>>( num_kernels, col.data(), input_depth, input_height, input_width, diff --git a/paddle/fluid/operators/roi_align_op.cu b/paddle/fluid/operators/roi_align_op.cu index 934802f6a9e0e9eec1e6492595c336a5ce3bd927..111828005222bb52990032be7a98b8bb6fb7367a 100644 --- a/paddle/fluid/operators/roi_align_op.cu +++ b/paddle/fluid/operators/roi_align_op.cu @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/fluid/memory/memory.h" #include "paddle/fluid/operators/roi_align_op.h" #include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_launch_config.h" namespace paddle { namespace operators { @@ -261,7 +262,9 @@ class GPUROIAlignOpKernel : public framework::OpKernel { int output_size = out->numel(); int blocks = NumBlocks(output_size); int threads = kNumCUDAThreads; - +#ifdef WITH_NV_JETSON + platform::ChangeThreadNum(ctx.cuda_device_context(), &threads, 256); +#endif Tensor roi_batch_id_list; roi_batch_id_list.Resize({rois_num}); auto cplace = platform::CPUPlace(); diff --git a/paddle/fluid/platform/for_range.h b/paddle/fluid/platform/for_range.h index 2cd6e44dd7a1a9b62499c4c8367d83979c0ba52d..6e5c7f4e9166093a3263ef4b1ef7e5374343a889 100644 --- a/paddle/fluid/platform/for_range.h +++ b/paddle/fluid/platform/for_range.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/gpu_launch_config.h" namespace paddle { namespace platform { @@ -65,6 +66,11 @@ struct ForRange { #ifdef __HIPCC__ // HIP will throw core dump when threads > 256 constexpr int num_threads = 256; +#elif WITH_NV_JETSON + // JETSON_NANO will throw core dump when threads > 128 + int num_thread = 256; + platform::ChangeThreadNum(dev_ctx_, &num_thread, 128); + const int num_threads = num_thread; #else constexpr int num_threads = 1024; #endif diff --git a/paddle/fluid/platform/gpu_launch_config.h b/paddle/fluid/platform/gpu_launch_config.h index a82262419066fab0c1a58d3b6781bc765fa1a4c6..399f1dbaa03e1f1325cec670e9dabfcfedeab6d4 100644 --- a/paddle/fluid/platform/gpu_launch_config.h +++ b/paddle/fluid/platform/gpu_launch_config.h @@ -23,6 +23,7 @@ #else #include #endif + #include #include #include @@ -33,6 +34,18 @@ namespace platform { inline int DivUp(int a, int b) { return (a + b - 1) / b; } +#ifdef WITH_NV_JETSON +// The number of threads cannot be assigned 1024 in some cases when the device +// is nano or tx2 . +inline void ChangeThreadNum(const platform::CUDADeviceContext& context, + int* num_thread, int alternative_num_thread = 512) { + if (context.GetComputeCapability() == 53 || + context.GetComputeCapability() == 62) { + *num_thread = alternative_num_thread; + } +} +#endif + struct GpuLaunchConfig { dim3 theory_thread_count = dim3(1, 1, 1); dim3 thread_per_block = dim3(1, 1, 1); @@ -61,15 +74,22 @@ inline GpuLaunchConfig GetGpuLaunchConfig1D( // Compute physical threads we need, should small than max sm threads const int physical_thread_count = - std::min(max_physical_threads, theory_thread_count); + (std::min)(max_physical_threads, theory_thread_count); + + // Get compute_capability + const int capability = context.GetComputeCapability(); + +#ifdef WITH_NV_JETSON + if (capability == 53 || capability == 62) { + max_threads = 512; + } +#endif // Need get from device const int thread_per_block = - std::min(max_threads, context.GetMaxThreadsPerBlock()); + (std::min)(max_threads, context.GetMaxThreadsPerBlock()); const int block_count = - std::min(DivUp(physical_thread_count, thread_per_block), sm); - // Get compute_capability - const int capability = context.GetComputeCapability(); + (std::min)(DivUp(physical_thread_count, thread_per_block), sm); GpuLaunchConfig config; config.theory_thread_count.x = theory_thread_count; @@ -91,19 +111,20 @@ inline GpuLaunchConfig GetGpuLaunchConfig2D( y_dim)); const int kThreadsPerBlock = 256; - int block_cols = std::min(x_dim, kThreadsPerBlock); - int block_rows = std::max(kThreadsPerBlock / block_cols, 1); + int block_cols = (std::min)(x_dim, kThreadsPerBlock); + int block_rows = (std::max)(kThreadsPerBlock / block_cols, 1); int max_physical_threads = context.GetMaxPhysicalThreadCount(); - const int max_blocks = std::max(max_physical_threads / kThreadsPerBlock, 1); + const int max_blocks = (std::max)(max_physical_threads / kThreadsPerBlock, 1); GpuLaunchConfig config; // Noticed, block size is not align to 32, if needed do it yourself. config.theory_thread_count = dim3(x_dim, y_dim, 1); config.thread_per_block = dim3(block_cols, block_rows, 1); - int grid_x = std::min(DivUp(x_dim, block_cols), max_blocks); - int grid_y = std::min(max_blocks / grid_x, std::max(y_dim / block_rows, 1)); + int grid_x = (std::min)(DivUp(x_dim, block_cols), max_blocks); + int grid_y = + (std::min)(max_blocks / grid_x, (std::max)(y_dim / block_rows, 1)); config.block_per_grid = dim3(grid_x, grid_y, 1); return config; diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 2c001614d1bacb997c1c1b082b50b6cfac3b88f7..00f2d2aa0b2fa264e2c0143e3690585bfe5b1cc9 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -753,7 +753,11 @@ endif() if (NOT WIN32) set_tests_properties(test_multiprocess_reader_exception PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE") set_tests_properties(test_layers PROPERTIES TIMEOUT 120) + if (WITH_NV_JETSON) + set_tests_properties(test_ir_memory_optimize_transformer PROPERTIES TIMEOUT 1200) + else () set_tests_properties(test_ir_memory_optimize_transformer PROPERTIES TIMEOUT 120) + endif () endif() if (WITH_DISTRIBUTE AND NOT WIN32) @@ -799,7 +803,11 @@ set_tests_properties(test_elementwise_div_op PROPERTIES TIMEOUT 120) set_tests_properties(test_regularizer_api PROPERTIES TIMEOUT 150) set_tests_properties(test_multiclass_nms_op PROPERTIES TIMEOUT 120) if(NOT WIN32) + if (WITH_NV_JETSON) + set_tests_properties(test_ir_memory_optimize_nlp PROPERTIES TIMEOUT 1200) + else () set_tests_properties(test_ir_memory_optimize_nlp PROPERTIES TIMEOUT 120) + endif () endif() set_tests_properties(test_add_reader_dependency PROPERTIES TIMEOUT 120) set_tests_properties(test_bilateral_slice_op PROPERTIES TIMEOUT 120) @@ -822,12 +830,28 @@ else() set_tests_properties(test_static_save_load_large PROPERTIES TIMEOUT 600) set_tests_properties(test_paddle_save_load PROPERTIES TIMEOUT 250) endif() +if (WITH_NV_JETSON) + set_tests_properties(test_concat_op PROPERTIES TIMEOUT 1200) + set_tests_properties(test_conv3d_transpose_part2_op PROPERTIES TIMEOUT 1200) + set_tests_properties(test_conv3d_transpose_op PROPERTIES TIMEOUT 1200) + set_tests_properties(test_conv3d_op PROPERTIES TIMEOUT 1200) + set_tests_properties(test_norm_op PROPERTIES TIMEOUT 1200) + set_tests_properties(test_layer_norm_op PROPERTIES TIMEOUT 1500) + set_tests_properties(test_pool3d_op PROPERTIES TIMEOUT 1500) +else() + set_tests_properties(test_concat_op PROPERTIES TIMEOUT 120) + set_tests_properties(test_conv3d_transpose_part2_op PROPERTIES TIMEOUT 120) + set_tests_properties(test_conv3d_transpose_op PROPERTIES TIMEOUT 120) + set_tests_properties(test_conv3d_op PROPERTIES TIMEOUT 120) + set_tests_properties(test_norm_op PROPERTIES TIMEOUT 120) + set_tests_properties(test_layer_norm_op PROPERTIES TIMEOUT 150) + set_tests_properties(test_pool3d_op PROPERTIES TIMEOUT 150) +endif() set_tests_properties(test_imperative_selected_rows_to_lod_tensor PROPERTIES TIMEOUT 120) set_tests_properties(test_index_select_op PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_ssa_graph_inference_feed_partial_data PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_executor_crf PROPERTIES TIMEOUT 120) set_tests_properties(test_imperative_save_load PROPERTIES TIMEOUT 120) -set_tests_properties(test_concat_op PROPERTIES TIMEOUT 120) set_tests_properties(test_partial_eager_deletion_transformer PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_executor_seresnext_with_reduce_gpu PROPERTIES TIMEOUT 120) set_tests_properties(test_dropout_op PROPERTIES TIMEOUT 120) @@ -851,8 +875,6 @@ set_tests_properties(test_parallel_executor_mnist PROPERTIES TIMEOUT 120) set_tests_properties(test_imperative_ptb_rnn PROPERTIES TIMEOUT 120) set_tests_properties(test_imperative_save_load_v2 PROPERTIES TIMEOUT 120) set_tests_properties(test_conv2d_transpose_op PROPERTIES TIMEOUT 120) -set_tests_properties(test_conv3d_transpose_part2_op PROPERTIES TIMEOUT 120) -set_tests_properties(test_conv3d_transpose_op PROPERTIES TIMEOUT 120) set_tests_properties(test_prroi_pool_op PROPERTIES TIMEOUT 120) set_tests_properties(test_multiprocess_dataloader_iterable_dataset_static PROPERTIES TIMEOUT 120) set_tests_properties(test_lstm_cudnn_op PROPERTIES TIMEOUT 120) @@ -882,7 +904,6 @@ set_tests_properties(test_adam_optimizer_fp32_fp64 PROPERTIES TIMEOUT 120) set_tests_properties(test_elementwise_nn_grad PROPERTIES TIMEOUT 120) set_tests_properties(test_buffer_shared_memory_reuse_pass_and_fuse_optimization_op_pass PROPERTIES TIMEOUT 120) set_tests_properties(test_conv_nn_grad PROPERTIES TIMEOUT 120) -set_tests_properties(test_conv3d_op PROPERTIES TIMEOUT 120) set_tests_properties(test_program_prune_backward PROPERTIES TIMEOUT 120) set_tests_properties(test_group_norm_op PROPERTIES TIMEOUT 120) set_tests_properties(test_imperative_optimizer PROPERTIES TIMEOUT 120) @@ -902,13 +923,10 @@ set_tests_properties(test_elementwise_mul_op PROPERTIES TIMEOUT 120) set_tests_properties(test_cyclic_cifar_dataset PROPERTIES TIMEOUT 120) set_tests_properties(test_fuse_all_reduce_pass PROPERTIES TIMEOUT 120) set_tests_properties(test_dygraph_multi_forward PROPERTIES TIMEOUT 120) -set_tests_properties(test_norm_op PROPERTIES TIMEOUT 120) set_tests_properties(test_imperative_ocr_attention_model PROPERTIES TIMEOUT 120) set_tests_properties(test_imperative_mnist PROPERTIES TIMEOUT 120) set_tests_properties(test_fused_elemwise_activation_op PROPERTIES TIMEOUT 270) set_tests_properties(test_gru_op PROPERTIES TIMEOUT 200) -set_tests_properties(test_layer_norm_op PROPERTIES TIMEOUT 150) -set_tests_properties(test_pool3d_op PROPERTIES TIMEOUT 150) set_tests_properties(test_regularizer PROPERTIES TIMEOUT 150) set_tests_properties(test_imperative_resnet PROPERTIES TIMEOUT 200) set_tests_properties(test_imperative_resnet_sorted_gradient PROPERTIES TIMEOUT 200) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index f73327f8248d8a7c9d9cc9357b1812526efc437a..8b4a0ee98fa54efe95902160b91df27dfcd740e8 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -49,7 +49,11 @@ set_tests_properties(test_trt_activation_pass PROPERTIES TIMEOUT 120) set_tests_properties(test_trt_conv_pass PROPERTIES TIMEOUT 120) #set_tests_properties(test_trt_multiclass_nms_op PROPERTIES TIMEOUT 200) set_tests_properties(test_trt_dynamic_shape PROPERTIES TIMEOUT 120) -set_tests_properties(test_trt_pool_op PROPERTIES ENVIRONMENT FLAGS_fraction_of_gpu_memory_to_use=0.1 TIMEOUT 45) +if(WITH_NV_JETSON) + set_tests_properties(test_trt_pool_op PROPERTIES ENVIRONMENT FLAGS_fraction_of_gpu_memory_to_use=0.1 TIMEOUT 450) +else() + set_tests_properties(test_trt_pool_op PROPERTIES ENVIRONMENT FLAGS_fraction_of_gpu_memory_to_use=0.1 TIMEOUT 45) +endif() set_tests_properties(test_trt_reduce_mean_op PROPERTIES TIMEOUT 60) set_tests_properties(test_trt_tile_op PROPERTIES TIMEOUT 60) set_tests_properties(test_trt_fc_fuse_quant_dequant_pass PROPERTIES TIMEOUT 100) diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index cd0d624eb40bc2af0ca3be6e8e87ad30fc144d53..018a979bc5eaa4510e5da17b8e5f6407119706ea 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -1349,7 +1349,8 @@ class OpTest(unittest.TestCase): places = self._get_places() for place in places: res = self.check_output_with_place(place, atol, no_check_set, - equal_nan, check_dygraph) + equal_nan, check_dygraph, + inplace_atol) if check_dygraph: outs, dygraph_outs, fetch_list = res else: