提交 adf75a29 编写于 作者: M Megvii Engine Team

perf(dnn/cuda): add sass int4 128x128

GitOrigin-RevId: 1bc54821023814c3c80b5fd22c24ab0007dd6203
上级 8da2f698
...@@ -260,7 +260,7 @@ void benchmark_target_algo_with_cudnn_tsc( ...@@ -260,7 +260,7 @@ void benchmark_target_algo_with_cudnn_tsc(
megdnn_assert(src_dtype.enumv() == filter_dtype.enumv()); megdnn_assert(src_dtype.enumv() == filter_dtype.enumv());
CUBenchmarker<ConvBiasForward> benchmarker(handle); CUBenchmarker<ConvBiasForward> benchmarker(handle);
CUBenchmarker<ConvBiasForward> benchmarker_cudnn(handle); CUBenchmarker<ConvBiasForward> benchmarker_cudnn(handle);
size_t RUNS = 1000; size_t RUNS = 200;
benchmarker.set_display(false).set_times(RUNS); benchmarker.set_display(false).set_times(RUNS);
benchmarker.set_dtype(0, src_dtype) benchmarker.set_dtype(0, src_dtype)
.set_dtype(1, filter_dtype) .set_dtype(1, filter_dtype)
...@@ -282,9 +282,6 @@ void benchmark_target_algo_with_cudnn_tsc( ...@@ -282,9 +282,6 @@ void benchmark_target_algo_with_cudnn_tsc(
.set_dtype(2, change_cudnn_bias_dtype) .set_dtype(2, change_cudnn_bias_dtype)
.set_dtype(3, change_cudnn_dst_dtype) .set_dtype(3, change_cudnn_dst_dtype)
.set_dtype(4, change_cudnn_dst_dtype); .set_dtype(4, change_cudnn_dst_dtype);
benchmarker_cudnn.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
change_cudnn_algo));
} else { } else {
benchmarker_cudnn.set_dtype(0, src_dtype) benchmarker_cudnn.set_dtype(0, src_dtype)
.set_dtype(1, filter_dtype) .set_dtype(1, filter_dtype)
...@@ -391,13 +388,28 @@ void benchmark_target_algo_with_cudnn_tsc( ...@@ -391,13 +388,28 @@ void benchmark_target_algo_with_cudnn_tsc(
} }
float time_in_ms_cudnn = 0; float time_in_ms_cudnn = 0;
if (with_cudnn) { if (with_cudnn) {
time_in_ms_cudnn = benchmarker_cudnn.execs( if (change_cudnn_algo) {
{get_tensor_shape(src, format_cudnn), time_in_ms_cudnn =
get_tensor_shape(filter, format_cudnn), algo_benchmark<ConvBiasForward,
get_tensor_shape(bias, format_cudnn), OprProxy<ConvBiasForward>, CUTimer>(
{}, benchmarker_cudnn,
{}}) / {get_tensor_shape(src, format_cudnn),
RUNS; get_tensor_shape(filter, format_cudnn),
get_tensor_shape(bias, format_cudnn),
{},
{}},
change_cudnn_algo) /
RUNS;
} else {
time_in_ms_cudnn =
benchmarker_cudnn.execs(
{get_tensor_shape(src, format_cudnn),
get_tensor_shape(filter, format_cudnn),
get_tensor_shape(bias, format_cudnn),
{},
{}}) /
RUNS;
}
} }
float flo = 2.0 * arg.n * arg.co * ho * wo * arg.ci * arg.f * arg.f / float flo = 2.0 * arg.n * arg.co * ho * wo * arg.ci * arg.f * arg.f /
...@@ -432,13 +444,28 @@ void benchmark_target_algo_with_cudnn_tsc( ...@@ -432,13 +444,28 @@ void benchmark_target_algo_with_cudnn_tsc(
} }
time_in_ms_cudnn = 0; time_in_ms_cudnn = 0;
if (with_cudnn) { if (with_cudnn) {
time_in_ms_cudnn = benchmarker_cudnn.execs( if (change_cudnn_algo) {
{get_tensor_shape(src, format_cudnn), time_in_ms_cudnn =
get_tensor_shape(filter, format_cudnn), algo_benchmark<ConvBiasForward,
get_tensor_shape(bias, format_cudnn), OprProxy<ConvBiasForward>, CUTimer>(
get_tensor_shape(z, format_cudnn), benchmarker_cudnn,
{}}) / {get_tensor_shape(src, format_cudnn),
RUNS; get_tensor_shape(filter, format_cudnn),
get_tensor_shape(bias, format_cudnn),
get_tensor_shape(z, format_cudnn),
{}},
change_cudnn_algo) /
RUNS;
} else {
time_in_ms_cudnn =
benchmarker_cudnn.execs(
{get_tensor_shape(src, format_cudnn),
get_tensor_shape(filter, format_cudnn),
get_tensor_shape(bias, format_cudnn),
get_tensor_shape(z, format_cudnn),
{}}) /
RUNS;
}
} }
printf("src=%s, filter=%s, dst=%s, time(algo=%s)=%.2f %.2fTops, " printf("src=%s, filter=%s, dst=%s, time(algo=%s)=%.2f %.2fTops, "
"time(cudnn)=%.2f %.2fTops, " "time(cudnn)=%.2f %.2fTops, "
......
...@@ -151,7 +151,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, ...@@ -151,7 +151,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
"format should be NCHW4/NCHW4_NCHW/NCHW4_NCHW32"); "format should be NCHW4/NCHW4_NCHW/NCHW4_NCHW32");
packed_size = 4; packed_size = 4;
} }
return dst_shape.total_nr_elems() * fh * fw * src_shape[1] * 4 / group * return dst_shape.total_nr_elems() * fh * fw * src_shape[1] * packed_size / group *
2; 2;
}; };
auto eval_conv_computation_chwn4 = [&param, &src_shape, &filter_shape, auto eval_conv_computation_chwn4 = [&param, &src_shape, &filter_shape,
...@@ -178,7 +178,8 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, ...@@ -178,7 +178,8 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
param.format == Param::Format::NCHW44 || param.format == Param::Format::NCHW44 ||
param.format == Param::Format::NCHW44_DOT || param.format == Param::Format::NCHW44_DOT ||
param.format == Param::Format::NCHW32 || param.format == Param::Format::NCHW32 ||
param.format == Param::Format::NCHW32_NCHW4) { param.format == Param::Format::NCHW32_NCHW4 ||
param.format == Param::Format::NCHW64) {
return eval_conv_computation_nchwx(); return eval_conv_computation_nchwx();
} }
if (param.format == Param::Format::CHWN4) { if (param.format == Param::Format::CHWN4) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册