From adf75a291df0bc4f0e36e6a4a9981ee49cc5dc6b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 21 Apr 2021 18:13:21 +0800 Subject: [PATCH] perf(dnn/cuda): add sass int4 128x128 GitOrigin-RevId: 1bc54821023814c3c80b5fd22c24ab0007dd6203 --- dnn/test/cuda/conv_test_utils.cpp | 63 ++++++++++++++++++++++--------- src/plugin/impl/opr_footprint.cpp | 5 ++- 2 files changed, 48 insertions(+), 20 deletions(-) diff --git a/dnn/test/cuda/conv_test_utils.cpp b/dnn/test/cuda/conv_test_utils.cpp index 324a5979c..ff9ae0dfe 100644 --- a/dnn/test/cuda/conv_test_utils.cpp +++ b/dnn/test/cuda/conv_test_utils.cpp @@ -260,7 +260,7 @@ void benchmark_target_algo_with_cudnn_tsc( megdnn_assert(src_dtype.enumv() == filter_dtype.enumv()); CUBenchmarker benchmarker(handle); CUBenchmarker benchmarker_cudnn(handle); - size_t RUNS = 1000; + size_t RUNS = 200; benchmarker.set_display(false).set_times(RUNS); benchmarker.set_dtype(0, src_dtype) .set_dtype(1, filter_dtype) @@ -282,9 +282,6 @@ void benchmark_target_algo_with_cudnn_tsc( .set_dtype(2, change_cudnn_bias_dtype) .set_dtype(3, change_cudnn_dst_dtype) .set_dtype(4, change_cudnn_dst_dtype); - benchmarker_cudnn.set_before_exec_callback( - conv_bias::ConvBiasAlgoChecker( - change_cudnn_algo)); } else { benchmarker_cudnn.set_dtype(0, src_dtype) .set_dtype(1, filter_dtype) @@ -391,13 +388,28 @@ void benchmark_target_algo_with_cudnn_tsc( } float time_in_ms_cudnn = 0; if (with_cudnn) { - time_in_ms_cudnn = benchmarker_cudnn.execs( - {get_tensor_shape(src, format_cudnn), - get_tensor_shape(filter, format_cudnn), - get_tensor_shape(bias, format_cudnn), - {}, - {}}) / - RUNS; + if (change_cudnn_algo) { + time_in_ms_cudnn = + algo_benchmark, CUTimer>( + benchmarker_cudnn, + {get_tensor_shape(src, format_cudnn), + get_tensor_shape(filter, format_cudnn), + get_tensor_shape(bias, format_cudnn), + {}, + {}}, + change_cudnn_algo) / + RUNS; + } else { + time_in_ms_cudnn = + benchmarker_cudnn.execs( + {get_tensor_shape(src, format_cudnn), + get_tensor_shape(filter, format_cudnn), + get_tensor_shape(bias, format_cudnn), + {}, + {}}) / + RUNS; + } } float flo = 2.0 * arg.n * arg.co * ho * wo * arg.ci * arg.f * arg.f / @@ -432,13 +444,28 @@ void benchmark_target_algo_with_cudnn_tsc( } time_in_ms_cudnn = 0; if (with_cudnn) { - time_in_ms_cudnn = benchmarker_cudnn.execs( - {get_tensor_shape(src, format_cudnn), - get_tensor_shape(filter, format_cudnn), - get_tensor_shape(bias, format_cudnn), - get_tensor_shape(z, format_cudnn), - {}}) / - RUNS; + if (change_cudnn_algo) { + time_in_ms_cudnn = + algo_benchmark, CUTimer>( + benchmarker_cudnn, + {get_tensor_shape(src, format_cudnn), + get_tensor_shape(filter, format_cudnn), + get_tensor_shape(bias, format_cudnn), + get_tensor_shape(z, format_cudnn), + {}}, + change_cudnn_algo) / + RUNS; + } else { + time_in_ms_cudnn = + benchmarker_cudnn.execs( + {get_tensor_shape(src, format_cudnn), + get_tensor_shape(filter, format_cudnn), + get_tensor_shape(bias, format_cudnn), + get_tensor_shape(z, format_cudnn), + {}}) / + RUNS; + } } printf("src=%s, filter=%s, dst=%s, time(algo=%s)=%.2f %.2fTops, " "time(cudnn)=%.2f %.2fTops, " diff --git a/src/plugin/impl/opr_footprint.cpp b/src/plugin/impl/opr_footprint.cpp index e5f71e502..40d981343 100644 --- a/src/plugin/impl/opr_footprint.cpp +++ b/src/plugin/impl/opr_footprint.cpp @@ -151,7 +151,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, "format should be NCHW4/NCHW4_NCHW/NCHW4_NCHW32"); packed_size = 4; } - return dst_shape.total_nr_elems() * fh * fw * src_shape[1] * 4 / group * + return dst_shape.total_nr_elems() * fh * fw * src_shape[1] * packed_size / group * 2; }; auto eval_conv_computation_chwn4 = [¶m, &src_shape, &filter_shape, @@ -178,7 +178,8 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, param.format == Param::Format::NCHW44 || param.format == Param::Format::NCHW44_DOT || param.format == Param::Format::NCHW32 || - param.format == Param::Format::NCHW32_NCHW4) { + param.format == Param::Format::NCHW32_NCHW4 || + param.format == Param::Format::NCHW64) { return eval_conv_computation_nchwx(); } if (param.format == Param::Format::CHWN4) { -- GitLab