提交 4406f902 编写于 作者: H hjchen2

Fix 1x1 depthwise conv

上级 c02076f0
...@@ -71,34 +71,11 @@ void MatMul<float, float>(const framework::Tensor &matrix_a, bool trans_a, ...@@ -71,34 +71,11 @@ void MatMul<float, float>(const framework::Tensor &matrix_a, bool trans_a,
a[index++] = tmp[i * n + j]; a[index++] = tmp[i * n + j];
} }
} }
if (M == 1) { cblas_sgemm(false, false, M, N, K, alpha, a, K, matrix_b.data<float>(), N,
#ifdef _OPENMP beta, matrix_out->data<float>(), N);
gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data<float>(), N, beta,
matrix_out->data<float>(), N, relu, bias);
#else
gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data<float>(), N, beta,
matrix_out->data<float>(), N, relu, bias);
#endif
} else {
cblas_sgemm(false, false, M, N, K, alpha, a, K, matrix_b.data<float>(), N,
beta, matrix_out->data<float>(), N);
}
} else { } else {
if (M == 1) { cblas_sgemm(false, false, M, N, K, alpha, matrix_a.data<float>(), K,
#ifdef _OPENMP matrix_b.data<float>(), N, beta, matrix_out->data<float>(), N);
gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, beta, matrix_out->data<float>(),
N, relu, bias);
#else
gemm.Sgemm(M, N, K, alpha, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, beta, matrix_out->data<float>(), N,
relu, bias);
#endif
} else {
cblas_sgemm(false, false, M, N, K, alpha, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, beta, matrix_out->data<float>(),
N);
}
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册