提交 bca00f2e 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(dnn): midout at where neccessary in megdnn

GitOrigin-RevId: 191334bd96ff361dc10c048e504d49c2e283d9a4
上级 105e4450
......@@ -116,9 +116,9 @@ if(${MGE_ARCH} STREQUAL "AUTO")
endif()
if(${MGE_ARCH} STREQUAL "x86_64" OR ${MGE_ARCH} STREQUAL "i386" OR ${MGE_ARCH} STREQUAL "armv7" OR ${MGE_ARCH} STREQUAL "aarch64")
option(MGB_ENABLE_CPUINFO "Build cpuinfo library for check runtime." ON)
if(MGB_ENABLE_CPUINFO)
message("-- Enable cpuinfo runtime check.")
option(MGE_ENABLE_CPUINFO "Build cpuinfo library for check runtime." ON)
if(MGE_ENABLE_CPUINFO)
message("-- Enable cpuinfo runtime check and little kernel optimize.")
add_definitions(-DMGB_ENABLE_CPUINFO_CHECK)
include(cmake/cpuinfo.cmake)
endif()
......
......@@ -53,7 +53,7 @@ add_library(megdnn EXCLUDE_FROM_ALL OBJECT ${SOURCES})
target_link_libraries(megdnn PUBLIC opr_param_defs)
if(${MGE_ARCH} STREQUAL "x86_64" OR ${MGE_ARCH} STREQUAL "i386" OR ${MGE_ARCH} STREQUAL "armv7" OR ${MGE_ARCH} STREQUAL "aarch64")
if(MGB_ENABLE_CPUINFO)
if(MGE_ENABLE_CPUINFO)
target_link_libraries(megdnn PRIVATE $<BUILD_INTERFACE:cpuinfo>)
endif()
endif()
......
......@@ -49,13 +49,13 @@ size_t ConvBiasImpl::AlgoF16DirectStride2::get_workspace(
return wbundle.total_size_in_bytes();
}
MIDOUT_END();
return false;
return 0;
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 2) {
MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 2) {
return get_kimpls(param);
}
MIDOUT_END();
......
......@@ -58,6 +58,7 @@ size_t MatrixMulImpl::AlgoF32K8x12x1::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern(
......@@ -118,6 +119,7 @@ size_t MatrixMulImpl::AlgoF32MK4_8x12x1::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4_8x12x1::get_kern(
......@@ -177,6 +179,7 @@ size_t MatrixMulImpl::AlgoF32K4x16x1::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K4x16x1::get_kern(
......@@ -237,6 +240,7 @@ size_t MatrixMulImpl::AlgoF32MK4_4x16::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4_4x16::get_kern(
......@@ -313,6 +317,7 @@ size_t MatrixMulImpl::AlgoF16K8x24x1::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K8x24x1::get_kern(
......@@ -352,6 +357,7 @@ size_t MatrixMulImpl::AlgoF16MK8_8x8::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_8x8::get_kern(
......@@ -431,6 +437,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::get_kern(
......@@ -501,6 +508,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::get_kern(
......@@ -573,6 +581,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16::get_kern(
......@@ -635,6 +644,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32K4x4x16::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K4x4x16::get_kern(
......@@ -696,6 +706,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32K8x8x8::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K8x8x8::get_kern(
......@@ -762,6 +773,7 @@ size_t MatrixMulImpl::AlgoInt8x8x16K8x8x8::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K8x8x8::get_kern(
......@@ -828,6 +840,7 @@ size_t MatrixMulImpl::AlgoInt8x8x16K4x4x16::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K4x4x16::get_kern(
......@@ -905,6 +918,7 @@ size_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_kern(
......@@ -981,6 +995,7 @@ size_t MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8::get_kern(
......@@ -1051,6 +1066,7 @@ size_t MatrixMulImpl::AlgoInt16x16x32K12x8x1::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32K12x8x1::get_kern(
......@@ -1092,6 +1108,7 @@ size_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_kern(
......@@ -1172,6 +1189,7 @@ size_t MatrixMulImpl::AlgoQuint8K8x8x4DotProd::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K8x8x4DotProd::get_kern(
......@@ -1277,6 +1295,7 @@ size_t MatrixMulImpl::AlgoQuint8K8x8x8::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K8x8x8::get_kern(
......
......@@ -160,7 +160,12 @@ bool ConvBiasImpl::AlgoF32DirectNCHW44::usable(const NCBKernSizeParam& param,
size_t ConvBiasImpl::AlgoF32DirectNCHW44::get_workspace(
const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp32_nchw44_stride1,
midout_iv("AlgoF32DirectNCHW44::get_workspace"_hash)) {
return get_bundle(param).total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern>
......
......@@ -199,7 +199,12 @@ bool ConvBiasImpl::AlgoF32DirectNCHWNCHW44::usable(
size_t ConvBiasImpl::AlgoF32DirectNCHWNCHW44::get_workspace(
const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_fp32_nchw_nchw44,
midout_iv("AlgoF32DirectNCHWNCHW44::get_workspace"_hash)) {
return get_bundle(param).total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern>
......
......@@ -33,29 +33,39 @@ bool ConvBiasImpl::AlgoS8DirectStride1::usable(const NCBKernSizeParam& param,
return direct_int8_stride1::can_conv_direct_stride1_int8(param);
}
bool ConvBiasImpl::AlgoS8DirectStride1::is_preferred(
const NCBKernSizeParam& param) const {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
auto OC = fm.ocpg;
auto IC = fm.icpg;
bool preferred = ((FH == 2 && (OC <= 10 || IC <= 8)) ||
((FH == 3 || FH == 5 || FH == 7) &&
(OC <= 16 || (IC <= 4 && OC <= 32)))) &&
param.bias_mode != BiasMode::BIAS;
return preferred;
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("AlgoS8DirectStride1::is_preferred"_hash)) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
auto OC = fm.ocpg;
auto IC = fm.icpg;
bool preferred = ((FH == 2 && (OC <= 10 || IC <= 8)) ||
((FH == 3 || FH == 5 || FH == 7) &&
(OC <= 16 || (IC <= 4 && OC <= 32)))) &&
param.bias_mode != BiasMode::BIAS;
return preferred;
}
MIDOUT_END();
}
size_t ConvBiasImpl::AlgoS8DirectStride1::get_workspace(
const NCBKernSizeParam& param) const {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_int8_stride1::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("AlgoS8DirectStride1::get_workspace"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_int8_stride1::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoS8DirectStride1::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 1, 0) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("AlgoS8DirectStride1::dispatch_kerns"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
return direct_int8_stride1::get_kimpls(param, large_group);
}
......@@ -72,15 +82,20 @@ bool ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44::usable(
size_t ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44::get_workspace(
const NCBKernSizeParam& param) const {
auto bundle = channel_wise_nchw44::stride1::get_bundle(param);
return bundle.total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("AlgoS8ChanWiseStride1NCHW44::get_workspace"_hash)) {
auto bundle = channel_wise_nchw44::stride1::get_bundle(param);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("AlgoS8ChanWiseStride1NCHW44"_hash)) {
midout_iv("AlgoS8ChanWiseStride1NCHW44::dispatch_kerns"_hash)) {
return channel_wise_nchw44::stride1::get_kimpls(param);
}
MIDOUT_END();
......@@ -96,15 +111,20 @@ bool ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44::usable(
size_t ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44::get_workspace(
const NCBKernSizeParam& param) const {
auto bundle = channel_wise_nchw44::stride2::get_bundle(param);
return bundle.total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("AlgoS8ChanWiseStride2NCHW44::get_workspace"_hash)) {
auto bundle = channel_wise_nchw44::stride2::get_bundle(param);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("AlgoS8ChanWiseStride2NCHW44"_hash)) {
midout_iv("AlgoS8ChanWiseStride2NCHW44::dispatch_kerns"_hash)) {
return channel_wise_nchw44::stride2::get_kimpls(param);
}
MIDOUT_END();
......@@ -119,15 +139,21 @@ bool ConvBiasImpl::AlgoS8DirectStride2::usable(const NCBKernSizeParam& param,
size_t ConvBiasImpl::AlgoS8DirectStride2::get_workspace(
const NCBKernSizeParam& param) const {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_int8_stride2::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("AlgoS8DirectStride2::get_workspace"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_int8_stride2::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoS8DirectStride2::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 1, 1) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("AlgoS8DirectStride2::dispatch_kerns"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
return direct_int8_stride2::get_kimpls(param, large_group);
}
......@@ -144,15 +170,21 @@ bool ConvBiasImpl::AlgoDotS8DirectStride1::usable(const NCBKernSizeParam& param,
size_t ConvBiasImpl::AlgoDotS8DirectStride1::get_workspace(
const NCBKernSizeParam& param) const {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_dotprod_int8_stride1::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("AlgoDotS8DirectStride1::get_workspace"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_dotprod_int8_stride1::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoDotS8DirectStride1::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 2, 1) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("AlgoDotS8DirectStride1::dispatch_kerns"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
return direct_dotprod_int8_stride1::get_kimpls(param, large_group);
}
......@@ -168,15 +200,21 @@ bool ConvBiasImpl::AlgoDotS8DirectStride2::usable(const NCBKernSizeParam& param,
size_t ConvBiasImpl::AlgoDotS8DirectStride2::get_workspace(
const NCBKernSizeParam& param) const {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_dotprod_int8_stride2::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("AlgoDotS8DirectStride2::get_workspace"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_dotprod_int8_stride2::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoDotS8DirectStride2::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8, 2, 2) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("AlgoDotS8DirectStride2::dispatch_kerns"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
return direct_dotprod_int8_stride2::get_kimpls(param, large_group);
}
......@@ -188,37 +226,45 @@ ConvBiasImpl::AlgoDotS8DirectStride2::dispatch_kerns(
/* ======================= AlgoS8WinogradF23_8x8 ======================== */
bool ConvBiasImpl::AlgoS8WinogradF23_8x8::usable(
const NCBKernSizeParam& param,
const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
if (param.filter_meta.icpg % 8 != 0 || param.filter_meta.ocpg % 8 != 0)
return false;
using Strategy = winograd::winograd_2x3_8x8_s8;
using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode;
Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param =
megdnn::winograd::ConvBias<Strategy, param::MatrixMul::Format::MK8>(
strategy, m_tile_size, param)
.get_matmul_kern_param(param);
return m_matmul_algo->usable(matmul_param) &&
m_matmul_algo->packmode() == PackMode::NO_PACK &&
((param.filter_meta.format == param::ConvBias::Format::NCHW &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8) ||
(param.filter_meta.format ==
param::ConvBias::Format::NCHW_WINOGRAD &&
param.output_block_size == 2 &&
param.winograd_matmul_format == param::MatrixMul::Format::MK8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS16)) &&
!param.filter_meta.should_flip &&
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] &&
param.filter_meta.spatial[0] == 3) &&
(param.filter_meta.stride[0] == param.filter_meta.stride[1] &&
param.filter_meta.stride[0] == 1) &&
(param.filter_meta.dilation[0] == param.filter_meta.dilation[1] &&
param.filter_meta.dilation[0] == 1) &&
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT &&
param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.bias_type.enumv() == DTypeEnum::QuantizedS32 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8;
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("AlgoS8WinogradF23_8x8::usable"_hash)) {
if (param.filter_meta.icpg % 8 != 0 || param.filter_meta.ocpg % 8 != 0)
return false;
using Strategy = winograd::winograd_2x3_8x8_s8;
using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode;
Strategy strategy(param.src_type, param.filter_type, param.dst_type);
auto&& matmul_param =
megdnn::winograd::ConvBias<Strategy,
param::MatrixMul::Format::MK8>(
strategy, m_tile_size, param)
.get_matmul_kern_param(param);
return m_matmul_algo->usable(matmul_param) &&
m_matmul_algo->packmode() == PackMode::NO_PACK &&
((param.filter_meta.format == param::ConvBias::Format::NCHW &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8) ||
(param.filter_meta.format ==
param::ConvBias::Format::NCHW_WINOGRAD &&
param.output_block_size == 2 &&
param.winograd_matmul_format ==
param::MatrixMul::Format::MK8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS16)) &&
!param.filter_meta.should_flip &&
(param.filter_meta.spatial[0] == param.filter_meta.spatial[1] &&
param.filter_meta.spatial[0] == 3) &&
(param.filter_meta.stride[0] == param.filter_meta.stride[1] &&
param.filter_meta.stride[0] == 1) &&
(param.filter_meta.dilation[0] ==
param.filter_meta.dilation[1] &&
param.filter_meta.dilation[0] == 1) &&
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT &&
param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.bias_type.enumv() == DTypeEnum::QuantizedS32 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8;
}
MIDOUT_END();
return false;
}
MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoS8WinogradF23_8x8,
......
......@@ -202,14 +202,19 @@ bool ConvBiasImpl::AlgoDotS8Direct_NCHW44::is_preferred(
size_t ConvBiasImpl::AlgoDotS8Direct_NCHW44::get_workspace(
const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("ALGODOTS8DIRECT_NCHW44::get_workspace"_hash)) {
return get_bundle(param).total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoDotS8Direct_NCHW44::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8,
midout_iv("ALGODOTS8DIRECT_NCHW44"_hash)) {
midout_iv("ALGODOTS8DIRECT_NCHW44::dispatch_kerns"_hash)) {
auto fm = param.filter_meta;
size_t BATCH = param.n;
size_t GROUP = fm.group;
......
......@@ -223,7 +223,12 @@ bool ConvBiasImpl::AlgoS8DirectNCHW44::is_preferred(
size_t ConvBiasImpl::AlgoS8DirectNCHW44::get_workspace(
const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44,
midout_iv("AlgoS8DirectNCHW44::get_workspace"_hash)) {
return get_bundle(param).total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern>
......
......@@ -232,7 +232,12 @@ bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::is_preferred(
size_t ConvBiasImpl::AlgoS8DirectNCHWNCHW44::get_workspace(
const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw_nchw44,
midout_iv("AlgoS8DirectNCHWNCHW44::get_workspace"_hash)) {
return get_bundle(param).total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern>
......
......@@ -183,7 +183,12 @@ bool ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::usable(
size_t ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::get_workspace(
const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8_nchw44_dot,
midout_iv("AlgoDotS8DirectNCHWNCHW44::get_workspace"_hash)) {
return get_bundle(param).total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern>
......
......@@ -83,7 +83,8 @@ void get_rectified_size_str2(size_t IH, size_t IW, size_t OH, size_t OW,
/* ===================== direct algo ===================== */
bool ConvBiasImpl::AlgoI8x8x16Direct::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 1, 0) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoI8x8x16Direct::usable"_hash)) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
return param.bias_mode == BiasMode::NO_BIAS &&
......@@ -122,7 +123,8 @@ WorkspaceBundle ConvBiasImpl::AlgoI8x8x16Direct::get_bundle(
}
size_t ConvBiasImpl::AlgoI8x8x16Direct::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 1, 1) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoI8x8x16Direct::get_workspace"_hash)) {
auto bundle = get_bundle(param);
return bundle.total_size_in_bytes();
}
......@@ -287,7 +289,8 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoI8x8x16Direct::get_kimpls(
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoI8x8x16Direct::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 1, 2) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoI8x8x16Direct::dispatch_kerns"_hash)) {
return get_kimpls(param);
}
MIDOUT_END();
......@@ -297,7 +300,8 @@ ConvBiasImpl::AlgoI8x8x16Direct::dispatch_kerns(
/* ===================== stride-2 algo ===================== */
bool ConvBiasImpl::AlgoI8x8x16Stride2::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 2, 0) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoI8x8x16Stride2::usable"_hash)) {
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
return param.bias_mode == BiasMode::NO_BIAS &&
......@@ -337,7 +341,8 @@ WorkspaceBundle ConvBiasImpl::AlgoI8x8x16Stride2::get_bundle(
}
size_t ConvBiasImpl::AlgoI8x8x16Stride2::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 2, 1) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoI8x8x16Stride2::get_workspace"_hash)) {
auto bundle = get_bundle(param);
return bundle.total_size_in_bytes();
}
......@@ -501,7 +506,8 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoI8x8x16Stride2::get_kimpls(
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoI8x8x16Stride2::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 2, 2) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoI8x8x16Stride2::dispatch_kerns"_hash)) {
return get_kimpls(param);
}
MIDOUT_END();
......@@ -510,7 +516,8 @@ ConvBiasImpl::AlgoI8x8x16Stride2::dispatch_kerns(
bool ConvBiasImpl::AlgoI8x8x16Stride2Filter2::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 3, 0) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoI8x8x16Stride2Filter2::usable"_hash)) {
return param.bias_mode == BiasMode::NO_BIAS &&
param.nonlineMode == NonlineMode::IDENTITY &&
param.nr_threads == 1_z &&
......@@ -522,7 +529,8 @@ bool ConvBiasImpl::AlgoI8x8x16Stride2Filter2::usable(
size_t ConvBiasImpl::AlgoI8x8x16Stride2Filter2::get_workspace(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 3, 1) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoI8x8x16Stride2Filter2::get_workspace"_hash)) {
return conv_bias::get_workspace_in_bytes_conv_int8x8x16_stride2_flt2(
param);
}
......@@ -535,7 +543,8 @@ ConvBiasImpl::AlgoI8x8x16Stride2Filter2::dispatch_kerns(
const NCBKernSizeParam& param) const {
// return {conv_bias::conv_int8x8x16_stride2_flt2,true};
auto kern = [](const NCBKernParam& param, const NCBKernIndex& ncb_index) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl, 3, 2) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv("AlgoI8x8x16Stride2Filter2::dispatch_kerns"_hash)) {
auto ncb_param = param;
ncb_param.src_ptr = param.src<void>(0, ncb_index.ndrange_id[0]);
ncb_param.dst_ptr = param.dst<void>(0, ncb_index.ndrange_id[0]);
......@@ -573,18 +582,25 @@ bool ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44::usable(
size_t ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44::get_workspace(
const NCBKernSizeParam& param) const {
size_t stride_h = param.filter_meta.stride[0];
size_t stride_w = param.filter_meta.stride[1];
megdnn_assert(stride_h == stride_w);
if (stride_h == 1) {
return channel_wise_nchw44_8x8x16::stride1::get_bundle(param)
.total_size_in_bytes();
} else if (stride_h == 2) {
return channel_wise_nchw44_8x8x16::stride2::get_bundle(param)
.total_size_in_bytes();
} else {
return 0;
MIDOUT_BEGIN(
megdnn_arm_common_conv_bias_int8816_kimpl,
midout_iv(
"AlgoS8x8x16ChanWiseStride1Stride2NCHW44::get_workspace"_hash)) {
size_t stride_h = param.filter_meta.stride[0];
size_t stride_w = param.filter_meta.stride[1];
megdnn_assert(stride_h == stride_w);
if (stride_h == 1) {
return channel_wise_nchw44_8x8x16::stride1::get_bundle(param)
.total_size_in_bytes();
} else if (stride_h == 2) {
return channel_wise_nchw44_8x8x16::stride2::get_bundle(param)
.total_size_in_bytes();
} else {
return 0;
}
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern>
......
......@@ -13,11 +13,8 @@
#include "src/common/utils.h"
#include <cstring>
#include "midout.h"
#include "src/arm_common/simd_macro/marm_neon.h"
MIDOUT_DECL(megdnn_arm_common_conv_bias_int8816_filter)
using namespace megdnn;
using namespace arm_common;
using namespace conv_bias;
......
......@@ -12,9 +12,7 @@
#include "src/common/utils.h"
#include <cstring>
#include "midout.h"
#include "src/arm_common/simd_macro/marm_neon.h"
MIDOUT_DECL(megdnn_arm_common_conv_bias_s2_filter)
#pragma GCC diagnostic ignored "-Wunused-parameter"
......
......@@ -229,7 +229,12 @@ bool ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::usable(
size_t ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::get_workspace(
const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_i8i8i16_nchw_nchw44,
midout_iv("AlgoI8x8x16DirectNCHWNCHW44::get_workspace"_hash)) {
return get_bundle(param).total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern>
......
......@@ -32,15 +32,21 @@ bool ConvBiasImpl::AlgoQU8DirectStride1::usable(const NCBKernSizeParam& param,
size_t ConvBiasImpl::AlgoQU8DirectStride1::get_workspace(
const NCBKernSizeParam& param) const {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_quint8_stride1::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8,
midout_iv("AlgoQU8DirectStride1::get_workspace"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_quint8_stride1::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoQU8DirectStride1::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 0, 0) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8,
midout_iv("AlgoQU8DirectStride1::dispatch_kerns"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
return direct_quint8_stride1::get_kimpls(param, large_group);
}
......@@ -57,15 +63,21 @@ bool ConvBiasImpl::AlgoQU8DirectStride2::usable(
size_t ConvBiasImpl::AlgoQU8DirectStride2::get_workspace(
const NCBKernSizeParam& param) const {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_quint8_stride2::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8,
midout_iv("AlgoQU8DirectStride1::get_workspace"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_quint8_stride2::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoQU8DirectStride2::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 0, 1) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8,
midout_iv("AlgoQU8DirectStride1::dispatch_kerns"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
return direct_quint8_stride2::get_kimpls(param, large_group);
}
......@@ -81,15 +93,21 @@ bool ConvBiasImpl::AlgoDotU8DirectStride1::usable(const NCBKernSizeParam& param,
size_t ConvBiasImpl::AlgoDotU8DirectStride1::get_workspace(
const NCBKernSizeParam& param) const {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_dotprod_quint8_stride1::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8,
midout_iv("AlgoQU8DirectStride1::get_workspace"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_dotprod_quint8_stride1::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoDotU8DirectStride1::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 1, 0) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8,
midout_iv("AlgoQU8DirectStride1::dispatch_kerns"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
return direct_dotprod_quint8_stride1::get_kimpls(param, large_group);
}
......@@ -105,15 +123,21 @@ bool ConvBiasImpl::AlgoDotU8DirectStride2::usable(const NCBKernSizeParam& param,
size_t ConvBiasImpl::AlgoDotU8DirectStride2::get_workspace(
const NCBKernSizeParam& param) const {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_dotprod_quint8_stride2::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8,
midout_iv("AlgoQU8DirectStride1::get_workspace"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
auto bundle = direct_dotprod_quint8_stride2::get_bundle(param, large_group);
return bundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoDotU8DirectStride2::dispatch_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8, 1, 1) {
MIDOUT_BEGIN(megdnn_arm_common_conv_bias_quint8,
midout_iv("AlgoQU8DirectStride1::dispatch_kerns"_hash)) {
bool large_group = param.filter_meta.group >= param.nr_threads;
return direct_dotprod_quint8_stride2::get_kimpls(param, large_group);
}
......
......@@ -32,13 +32,23 @@ bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::usable(
size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::get_workspace(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
return deconv::get_workspace_in_bytes_stride1_int8x8x32_dot(param);
MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl,
midout_iv("AlgoSdot8DirectStride1::get_workspace"_hash)) {
return deconv::get_workspace_in_bytes_stride1_int8x8x32_dot(param);
}
MIDOUT_END();
return 0;
}
ConvolutionBackwardDataImpl::ncb_kern_t
ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::dispatch_kern(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const {
return deconv::stride1_int8x8x32_dot;
MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl,
midout_iv("AlgoSdot8DirectStride1::dispatch_kern"_hash)) {
return deconv::stride1_int8x8x32_dot;
}
MIDOUT_END();
return {};
}
/* ===================== direct stride 2 algo ===================== */
......@@ -49,13 +59,23 @@ bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::usable(
size_t ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::get_workspace(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
return deconv::get_workspace_in_bytes_stride2_int8x8x32_dot(param);
MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl,
midout_iv("AlgoSdot8DirectStride2::get_workspace"_hash)) {
return deconv::get_workspace_in_bytes_stride2_int8x8x32_dot(param);
}
MIDOUT_END();
return 0;
}
ConvolutionBackwardDataImpl::ncb_kern_t
ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::dispatch_kern(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const {
return deconv::stride2_int8x8x32_dot;
MIDOUT_BEGIN(megdnn_arm_conv_int8832_kimpl,
midout_iv("AlgoSdot8DirectStride2::dispatch_kern"_hash)) {
return deconv::stride2_int8x8x32_dot;
}
MIDOUT_END();
return {};
}
#endif
......
......@@ -33,13 +33,23 @@ bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::usable(
size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::get_workspace(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
return deconv::get_workspace_in_bytes_stride1_quint8_dot(param);
MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl,
midout_iv("AlgoUdot8DirectStride1::get_workspace"_hash)) {
return deconv::get_workspace_in_bytes_stride1_quint8_dot(param);
}
MIDOUT_END();
return 0;
}
ConvolutionBackwardDataImpl::ncb_kern_t
ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::dispatch_kern(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const {
return deconv::stride1_quint8_dot;
MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl,
midout_iv("AlgoUdot8DirectStride1::dispatch_kern"_hash)) {
return deconv::stride1_quint8_dot;
}
MIDOUT_END();
return {};
}
/* ===================== direct stride 2 algo ===================== */
......@@ -50,13 +60,23 @@ bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::usable(
size_t ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::get_workspace(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
return deconv::get_workspace_in_bytes_stride2_quint8_dot(param);
MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl,
midout_iv("AlgoUdot8DirectStride2::get_workspace"_hash)) {
return deconv::get_workspace_in_bytes_stride2_quint8_dot(param);
}
MIDOUT_END();
return 0;
}
ConvolutionBackwardDataImpl::ncb_kern_t
ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::dispatch_kern(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const {
return deconv::stride2_quint8_dot;
MIDOUT_BEGIN(megdnn_arm_conv_quint8_kimpl,
midout_iv("AlgoUdot8DirectStride2::dispatch_kern"_hash)) {
return deconv::stride2_quint8_dot;
}
MIDOUT_END();
return {};
}
#endif
// vim: syntax=cpp.doxygen
......@@ -18,6 +18,8 @@
MIDOUT_DECL(megdnn_arm_hgemv)
MIDOUT_DECL(megdnn_arm_exec_int8816)
MIDOUT_DECL(megdnn_arm_exec_int8832)
MIDOUT_DECL(megdnn_arm_exec_fp32)
using namespace megdnn;
using namespace arm_common;
......@@ -63,8 +65,13 @@ bool MatrixMulImpl::AlgoInt8x8x16::usable(
size_t MatrixMulImpl::AlgoInt8x8x16::get_workspace(
const KernSizeParam& kern_size_param) const {
auto wbundle = get_workspace_bundle_int_8x8x16(kern_size_param);
return wbundle.total_size_in_bytes();
MIDOUT_BEGIN(megdnn_arm_exec_int8816,
midout_iv("AlgoInt8x8x16::get_workspace"_hash)) {
auto wbundle = get_workspace_bundle_int_8x8x16(kern_size_param);
return wbundle.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16::get_kern(
......@@ -75,11 +82,15 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16::get_kern(
/* ===================== Int8x8x32 Gemv algo ===================== */
namespace {
void int8x8x32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int32>();
gemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
MIDOUT_BEGIN(megdnn_arm_exec_int8832,
midout_iv("int8x8x32_gemv_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int32>();
gemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
}
MIDOUT_END();
}
} // anonymous namespace
......@@ -104,11 +115,15 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Gemv::get_kern(
/* ===================== Int8x8x32 Gemv MK4 algo ===================== */
namespace {
void int8x8x32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int32>();
gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
MIDOUT_BEGIN(megdnn_arm_exec_int8832,
midout_iv("int8x8x32_gemv_mk4_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int32>();
gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
}
MIDOUT_END();
}
} // anonymous namespace
......@@ -147,11 +162,15 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4::get_kern(
/* =================== Int8x8x32 Gemv MK4_DOT algo ==================== */
namespace {
void int8x8x32_gemv_mk4_dot_kern(const MatrixMulImpl::KernParam& kern_param) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int32>();
gemv_like_mk4_dot(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
MIDOUT_BEGIN(megdnn_arm_exec_int8832,
midout_iv("int8x8x32_gemv_mk4_dot_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_int8>(), Bptr = kern_param.B<dt_int8>();
auto Cptr = kern_param.C<dt_int32>();
gemv_like_mk4_dot(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
}
MIDOUT_END();
}
} // anonymous namespace
......@@ -189,12 +208,16 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot::get_kern(
/* ===================== F32 Gemv algo ===================== */
namespace {
void f32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_float32>(),
Bptr = kern_param.B<dt_float32>();
auto Cptr = kern_param.C<dt_float32>();
gemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
MIDOUT_BEGIN(megdnn_arm_exec_fp32,
midout_iv("f32_gemv_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_float32>(),
Bptr = kern_param.B<dt_float32>();
auto Cptr = kern_param.C<dt_float32>();
gemv_like(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
}
MIDOUT_END();
}
} // anonymous namespace
......@@ -225,12 +248,16 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32Gemv::get_kern(
/* ================== F32 Gemv MK4 algo ================== */
namespace {
void f32_gemv_mk4_kern(const MatrixMulImpl::KernParam& kern_param) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_float32>(),
Bptr = kern_param.B<dt_float32>();
auto Cptr = kern_param.C<dt_float32>();
gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
MIDOUT_BEGIN(megdnn_arm_exec_fp32,
midout_iv("f32_gemv_mk4_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDA = kern_param.LDA, LDB = kern_param.LDB, LDC = kern_param.LDC;
const auto Aptr = kern_param.A<dt_float32>(),
Bptr = kern_param.B<dt_float32>();
auto Cptr = kern_param.C<dt_float32>();
gemv_like_mk4(Aptr, Bptr, Cptr, M, N, K, LDA, LDB, LDC);
}
MIDOUT_END();
}
} // anonymous namespace
......@@ -266,11 +293,15 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GemvMK4::get_kern(
namespace {
template <typename stype, typename dtype>
void gevm_like_kern(const MatrixMulImpl::KernParam& kern_param) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDB = kern_param.LDB;
const auto Aptr = kern_param.A<stype>(), Bptr = kern_param.B<stype>();
auto Cptr = kern_param.C<dtype>();
megdnn::arm_common::gemv_like(Bptr, Aptr, Cptr, N, M, K, LDB, 1, 1);
MIDOUT_BEGIN(megdnn_arm_exec_fp32,
midout_iv("gevm_like_kern"_hash)) {
auto M = kern_param.M, N = kern_param.N, K = kern_param.K;
auto LDB = kern_param.LDB;
const auto Aptr = kern_param.A<stype>(), Bptr = kern_param.B<stype>();
auto Cptr = kern_param.C<dtype>();
megdnn::arm_common::gemv_like(Bptr, Aptr, Cptr, N, M, K, LDB, 1, 1);
}
MIDOUT_END();
}
} // anonymous namespace
......
......@@ -75,6 +75,7 @@ size_t MatrixMulImpl::AlgoF32::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32::get_kern(
......@@ -141,6 +142,7 @@ size_t MatrixMulImpl::AlgoF32MK4Pack4x12::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4Pack4x12::get_kern(
......@@ -202,6 +204,7 @@ size_t MatrixMulImpl::AlgoF16K4x16x1::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K4x16x1::get_kern(
......@@ -265,6 +268,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32K4x2x16::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K4x2x16::get_kern(
......@@ -326,6 +330,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32K4x8x8::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K4x8x8::get_kern(
......@@ -386,6 +391,7 @@ size_t MatrixMulImpl::AlgoQuint8K4x8x8::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K4x8x8::get_kern(
......@@ -445,6 +451,7 @@ size_t MatrixMulImpl::AlgoInt8x8x16K4x2x16::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K4x2x16::get_kern(
......@@ -510,6 +517,7 @@ size_t MatrixMulImpl::AlgoInt8x8x16K4x8x8::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16K4x8x8::get_kern(
......@@ -577,6 +585,7 @@ size_t MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4::get_kern(
......@@ -642,6 +651,7 @@ size_t MatrixMulImpl::AlgoInt16x16x32K12x4x1::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32K12x4x1::get_kern(
......@@ -702,6 +712,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32K6x8x4::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32K6x8x4::get_kern(
......@@ -764,6 +775,7 @@ size_t MatrixMulImpl::AlgoQuint8DotK4x8x4::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8DotK4x8x4::get_kern(
......@@ -830,6 +842,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd::get_kern(
......@@ -894,6 +907,7 @@ size_t MatrixMulImpl::AlgoF32MK4_4x8::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK4_4x8::get_kern(
......@@ -929,6 +943,7 @@ size_t MatrixMulImpl::AlgoInt16x16x32MK8_4x8::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32MK8_4x8::get_kern(
......@@ -986,6 +1001,7 @@ size_t MatrixMulImpl::AlgoF16MK8_4x8::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_4x8::get_kern(
......@@ -1066,6 +1082,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16::get_workspace(
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16::get_kern(
......
......@@ -261,6 +261,7 @@ size_t ConvBiasImpl::AlgoConv1x1Gemv::get_workspace(
.total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
SmallVector<ConvBiasImpl::NCBKern>
......
......@@ -175,44 +175,54 @@ bool ConvolutionImpl::AlgoFallback::usable(
}
size_t ConvolutionImpl::AlgoFallback::get_workspace(
const NCBKernSizeParam& param) const {
auto FH = param.filter_meta.spatial[0], FW = param.filter_meta.spatial[1];
size_t nr_threads = param.nr_threads;
if (param.filter_meta.should_flip) {
// need transpose filter
return WorkspaceBundle{nullptr, {FH * FW * sizeof(float)}}
.total_size_in_bytes() *
nr_threads;
} else {
return 0;
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_fallback_conv,
midout_iv("AlgoFallback::get_workspace"_hash)) {
auto FH = param.filter_meta.spatial[0],
FW = param.filter_meta.spatial[1];
size_t nr_threads = param.nr_threads;
if (param.filter_meta.should_flip) {
// need transpose filter
return WorkspaceBundle{nullptr, {FH * FW * sizeof(float)}}
.total_size_in_bytes() *
nr_threads;
} else {
return 0;
}
}
MIDOUT_END();
return 0;
}
SmallVector<ConvolutionImpl::NCBKern>
ConvolutionImpl::AlgoFallback::dispatch_kern(
const NCBKernSizeParam& param) const {
size_t group = param.filter_meta.group;
size_t N = param.n;
size_t nr_threads = param.nr_threads;
size_t workspace_per_thread = get_workspace( param) / nr_threads;
auto kern_fallback = [workspace_per_thread](const NCBKernParam& p,
const NCBKernIndex& ncb_index) {
UNPACK_CONV_F32_NCB_KERN_SIZES(p);
size_t batch_id = ncb_index.ndrange_id[1];
size_t group_id = ncb_index.ndrange_id[0];
MEGDNN_MARK_USED_VAR(N);
auto src = p.src<float>(batch_id, group_id),
filter = p.filter<float>(group_id);
auto dst = p.dst<float>(batch_id, group_id);
size_t thread_id = ncb_index.thread_id;
void* workspace_ptr = reinterpret_cast<void*>(
reinterpret_cast<ptrdiff_t>(p.workspace_ptr) +
workspace_per_thread * thread_id);
convolution::run_conv(src, filter, dst, workspace_ptr, IH, IW, IC, FH,
FW, OH, OW, OC, PH, PW, SH, SW,
!p.filter_meta.should_flip);
};
return {{kern_fallback, {group, N, 1_z}}};
MIDOUT_BEGIN(megdnn_fallback_conv,
midout_iv("AlgoFallback::dispatch_kern"_hash)) {
size_t group = param.filter_meta.group;
size_t N = param.n;
size_t nr_threads = param.nr_threads;
size_t workspace_per_thread = get_workspace( param) / nr_threads;
auto kern_fallback = [workspace_per_thread](const NCBKernParam& p,
const NCBKernIndex& ncb_index) {
UNPACK_CONV_F32_NCB_KERN_SIZES(p);
size_t batch_id = ncb_index.ndrange_id[1];
size_t group_id = ncb_index.ndrange_id[0];
MEGDNN_MARK_USED_VAR(N);
auto src = p.src<float>(batch_id, group_id),
filter = p.filter<float>(group_id);
auto dst = p.dst<float>(batch_id, group_id);
size_t thread_id = ncb_index.thread_id;
void* workspace_ptr = reinterpret_cast<void*>(
reinterpret_cast<ptrdiff_t>(p.workspace_ptr) +
workspace_per_thread * thread_id);
convolution::run_conv(src, filter, dst, workspace_ptr, IH, IW, IC, FH,
FW, OH, OW, OC, PH, PW, SH, SW,
!p.filter_meta.should_flip);
};
return {{kern_fallback, {group, N, 1_z}}};
}
MIDOUT_END();
}
/* ===================== naive algo ===================== */
......@@ -339,22 +349,36 @@ WorkspaceBundle ConvolutionImpl::AlgoDefault::get_bundle(
size_t ConvolutionImpl::AlgoDefault::get_workspace(
const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
MIDOUT_BEGIN(megdnn_fallback_conv,
midout_iv("AlgoDefault::get_workspace"_hash)) {
return get_bundle(param).total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
size_t ConvolutionImpl::AlgoDefault::get_preprocess_workspace(
const NCBKernSizeParam& param) const {
::ConvBiasImpl::NCBKernSizeParam conv_bias_param =
init_conv_bias_param(param);
return m_algorithm->get_preprocess_workspace(conv_bias_param);
MIDOUT_BEGIN(megdnn_fallback_conv,
midout_iv("AlgoDefault::get_preprocess_workspace"_hash)) {
::ConvBiasImpl::NCBKernSizeParam conv_bias_param =
init_conv_bias_param(param);
return m_algorithm->get_preprocess_workspace(conv_bias_param);
}
MIDOUT_END();
}
SmallVector<TensorLayout>
ConvolutionImpl::AlgoDefault::deduce_preprocessed_filter_layout(
const NCBKernSizeParam& param) const {
::ConvBiasImpl::NCBKernSizeParam conv_bias_param =
init_conv_bias_param( param);
return m_algorithm->deduce_preprocessed_filter_layout(conv_bias_param);
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(
megdnn_fallback_conv,
midout_iv("AlgoDefault::deduce_preprocessed_filter_layout"_hash)) {
::ConvBiasImpl::NCBKernSizeParam conv_bias_param =
init_conv_bias_param(param);
return m_algorithm->deduce_preprocessed_filter_layout(conv_bias_param);
}
MIDOUT_END();
}
//! Return the implement preprocess kernel
......@@ -450,19 +474,29 @@ bool ConvolutionBackwardDataImpl::AlgoDirect::usable(
size_t ConvolutionBackwardDataImpl::AlgoDirect::get_workspace(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
auto FH = param.filter_meta.spatial[0], FW = param.filter_meta.spatial[1];
if (param.filter_meta.should_flip) {
// need transpose filter
return FH * FW * sizeof(float);
} else {
return 0;
MIDOUT_BEGIN(megdnn_fallback_conv,
midout_iv("AlgoDirect::get_workspace"_hash)) {
auto FH = param.filter_meta.spatial[0],
FW = param.filter_meta.spatial[1];
if (param.filter_meta.should_flip) {
// need transpose filter
return FH * FW * sizeof(float);
} else {
return 0;
}
}
MIDOUT_END();
return 0;
}
ConvolutionBackwardDataImpl::ncb_kern_t
ConvolutionBackwardDataImpl::AlgoDirect::dispatch_kern(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const {
return kern_direct;
MIDOUT_BEGIN(megdnn_fallback_conv,
midout_iv("AlgoDirect::dispatch_kern"_hash)) {
return kern_direct;
}
MIDOUT_END();
}
/* ===================== Matrix mul algo ===================== */
......@@ -477,35 +511,48 @@ bool ConvolutionBackwardDataImpl::AlgoMatrixMul::usable(
size_t ConvolutionBackwardDataImpl::AlgoMatrixMul::get_workspace(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
return get_bundle(param).total_size_in_bytes();
MIDOUT_BEGIN(megdnn_fallback_conv,
midout_iv("AlgoMatrixMul::get_workspace"_hash)) {
return get_bundle(param).total_size_in_bytes();
}
MIDOUT_END();
return 0;
}
ConvolutionBackwardDataImpl::ncb_kern_t
ConvolutionBackwardDataImpl::AlgoMatrixMul::dispatch_kern(
ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const {
#define cb(dt) \
do { \
if (param.filter_type.enumv() == DTypeTrait<dt>::enumv) { \
using ctype = DTypeTrait<dt>::ctype; \
return kern_matmul<ctype, ctype, ctype>; \
} \
#define cb(dt, midout_tag) \
do { \
if (param.filter_type.enumv() == DTypeTrait<dt>::enumv) { \
MIDOUT_BEGIN(megdnn_fallback_conv, midout_iv(midout_tag)) { \
using ctype = DTypeTrait<dt>::ctype; \
return kern_matmul<ctype, ctype, ctype>; \
} \
MIDOUT_END(); \
} \
} while (0);
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb);
cb(dtype::Float32, "FLOAT"_hash);
MEGDNN_INC_FLOAT16(cb(dtype::Float16, "FLOAT16"_hash));
MEGDNN_INC_FLOAT16(cb(dtype::BFloat16, "BFLOAT16"_hash));
#undef cb
#define cb(dt_src, dt_dst) \
do { \
if (param.diff_type.enumv() == DTypeTrait<dt_src>::enumv && \
param.filter_type.enumv() == DTypeTrait<dt_src>::enumv && \
param.grad_type.enumv() == DTypeTrait<dt_dst>::enumv) { \
return kern_matmul<DTypeTrait<dt_src>::ctype, \
DTypeTrait<dt_src>::ctype, \
DTypeTrait<dt_dst>::ctype>; \
} \
#define cb(dt_src, dt_dst, midout_tag) \
do { \
if (param.diff_type.enumv() == DTypeTrait<dt_src>::enumv && \
param.filter_type.enumv() == DTypeTrait<dt_src>::enumv && \
param.grad_type.enumv() == DTypeTrait<dt_dst>::enumv) { \
MIDOUT_BEGIN(megdnn_fallback_conv, midout_iv(midout_tag)) { \
return kern_matmul<DTypeTrait<dt_src>::ctype, \
DTypeTrait<dt_src>::ctype, \
DTypeTrait<dt_dst>::ctype>; \
} \
MIDOUT_END(); \
} \
} while (0)
cb(dtype::Int8, dtype::Int32);
cb(dtype::QuantizedS8, dtype::QuantizedS32);
cb(dtype::Quantized8Asymm, dtype::QuantizedS32);
cb(dtype::Int8, dtype::Int32, "INT8x8x32"_hash);
cb(dtype::QuantizedS8, dtype::QuantizedS32, "QINT8x8x32"_hash);
cb(dtype::Quantized8Asymm, dtype::QuantizedS32, "QUINT8x8x32"_hash);
megdnn_throw("unsupported data type on matrix mul");
#undef cb
}
......
......@@ -24,7 +24,6 @@
#include <cstring>
MIDOUT_DECL(megdnn_fb_conv_float)
MIDOUT_DECL(megdnn_fb_convbwd_float)
using namespace megdnn;
......
......@@ -53,13 +53,20 @@ bool MatrixMulImpl::AlgoF32K8x12x1::usable(
size_t MatrixMulImpl::AlgoF32K8x12x1::get_workspace(
const KernSizeParam& kern_size_param) const {
auto M = kern_size_param.M, N = kern_size_param.N, K = kern_size_param.K;
matmul::fallback::sgemm_8x12 strategy(M, N, K, kern_size_param.A_type,
kern_size_param.B_type,
kern_size_param.C_type);
return matmul::GemmInterleaved<matmul::fallback::sgemm_8x12>(
M, N, K, kern_size_param.trA, kern_size_param.trB, strategy)
.get_workspace_size();
MIDOUT_BEGIN(megdnn_fb_matmul_f32_kern,
midout_iv("AlgoF32K8x12x1::get_workspace"_hash)) {
auto M = kern_size_param.M, N = kern_size_param.N,
K = kern_size_param.K;
matmul::fallback::sgemm_8x12 strategy(M, N, K, kern_size_param.A_type,
kern_size_param.B_type,
kern_size_param.C_type);
return matmul::GemmInterleaved<matmul::fallback::sgemm_8x12>(
M, N, K, kern_size_param.trA, kern_size_param.trB,
strategy)
.get_workspace_size();
}
MIDOUT_END();
return 0;
}
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern(
......
......@@ -23,10 +23,16 @@ namespace naive {
size_t MatrixMulForwardImpl::get_workspace_in_bytes(const TensorLayout& A,
const TensorLayout& B,
const TensorLayout&) {
if (A.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
return (A.span().dist_elem() + B.span().dist_elem()) * sizeof(uint8_t);
MIDOUT_BEGIN(
megdnn_naive_matmul,
midout_iv("MatrixMulForwardImpl::get_workspace_in_bytes"_hash)) {
if (A.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
return (A.span().dist_elem() + B.span().dist_elem()) *
sizeof(uint8_t);
}
return 0;
}
return 0;
MIDOUT_END();
}
template <bool TA, bool TB>
......@@ -127,7 +133,8 @@ void MatrixMulForwardImpl::exec_internal(_megdnn_tensor_in A,
void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
_megdnn_tensor_out C,
_megdnn_workspace workspace) {
MIDOUT_BEGIN(megdnn_naive_matmul) {
MIDOUT_BEGIN(megdnn_naive_matmul,
midout_iv("MatrixMulForwardImpl::exec"_hash)) {
check_exec(A.layout, B.layout, C.layout, workspace.size);
auto p = param();
MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal(A, B, C, workspace, p));
......
......@@ -17,6 +17,7 @@
#cmakedefine01 MGB_ENABLE_DEBUG_UTIL
#cmakedefine01 MGB_ENABLE_LOGGING
#cmakedefine01 MGB_ENABLE_GRAD
#cmakedefine01 MGB_ENABLE_CPUINFO
#cmakedefine01 MGB_VERBOSE_TYPEINFO_NAME
#cmakedefine01 MGB_BUILD_SLIM_SERVING
#cmakedefine01 MGB_ENABLE_EXCEPTION
......@@ -80,6 +81,16 @@
#define MGB_ENABLE_GRAD 1
#endif
// whether to enable cpuinfo
#ifndef MGB_ENABLE_CPUINFO
#define MGB_ENABLE_CPUINFO 1
#endif
#ifdef IOS
#undef MGB_ENABLE_CPUINFO
#define MGB_ENABLE_CPUINFO 0
#endif
// whether to include actual class name in mgb::Typeinfo object; if this is
// disabled, mgb::serialization::OprRegistry::find_opr_by_name would not work.
#ifndef MGB_VERBOSE_TYPEINFO_NAME
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册