提交 02cbb13b 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(dnn/arm): fix nchw44 fp32 direct algo oh block and unused stride2 algo

GitOrigin-RevId: 8012678faebedb0e21ac7373394b0e035ed16695
上级 d2f5874a
......@@ -107,7 +107,7 @@ static void do_conv_kern(WorkspaceBundle bundle,
constexpr int oc_idx = 0;
int oc_block = oc;
int oh_block = block_helper(kern_param.nr_threads, oh2,
ic * iw * sizeof(float) * 2);
ic * iw * sizeof(float) * stride_h);
const int oh_idx = ncb_index.ndrange_id[2];
const int oh_block_real = std::min(oh - oh_idx * oh_block, oh_block);
const int ih_real = oh_block_real * stride_h + fh - stride_h;
......@@ -297,8 +297,9 @@ ConvBiasImpl::AlgoF32DirectNCHW44::dispatch_kerns(
int oh = param.osz[0];
int ic = param.filter_meta.icpg;
int iw = param.isz[1];
int oh_block =
block_helper(param.nr_threads, oh, ic * iw * sizeof(float) * 2);
int stride_h = param.filter_meta.stride[0];
int oh_block = block_helper(param.nr_threads, oh,
ic * iw * sizeof(float) * stride_h);
CpuNDRange ncb_range = {static_cast<size_t>(batch),
static_cast<size_t>(group),
static_cast<size_t>(div_ceil(oh, oh_block))};
......
......@@ -118,24 +118,30 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) {
conv_bias::ConvBiasAlgoChecker<ConvBias>(
"IM2COLMATMUL:AARCH64_F32K8X12X1:192"));
Benchmarker<ConvBias> benchmarker_int_nchw44(handle);
Benchmarker<ConvBias> benchmarker_nchw44(handle);
if (is_fp32) {
benchmarker_int_nchw44.set_times(RUNS)
benchmarker_nchw44.set_times(RUNS)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_dtype(4, dtype::Float32())
.set_display(false);
} else {
benchmarker_int_nchw44.set_times(RUNS)
benchmarker_nchw44.set_times(RUNS)
.set_dtype(0, dtype::QuantizedS8(2.5))
.set_dtype(1, dtype::QuantizedS8(2.5))
.set_dtype(2, dtype::QuantizedS32(6.25))
.set_dtype(4, dtype::QuantizedS8(60.25))
.set_display(false);
}
benchmarker_int_nchw44.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBias>(".+"));
auto nchw44_algo_regx = ".*(DIRECT|NCHW_NCHW44).*";
#if __ARM_FEATURE_DOTPROD
if (!is_fp32) {
nchw44_algo_regx = ".*DOT.*";
}
#endif
benchmarker_nchw44.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBias>(nchw44_algo_regx));
auto run = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W,
size_t FS, size_t stride, bool input_nchw = false) {
......@@ -171,7 +177,7 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) {
bias = {1, OC / 4, 1, 1, 4};
dst = {N, OC / 4, OH, OW, 4};
auto int_nchw44_used = benchmarker_int_nchw44.set_param(param).exec(
auto int_nchw44_used = benchmarker_nchw44.set_param(param).exec(
{src, filter, bias, {}, dst}) /
RUNS;
float computations = IC * (FS * FS) * dst.total_nr_elems() * 2 * 1e-6;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册