提交 ea70d99b 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(mge/convbias): make fallback convbias support nhwcd4 layout

GitOrigin-RevId: 1c306f867dc636b8e8349d4d42a17b914f160552
上级 497ef6c3
......@@ -382,19 +382,11 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
if (dst.type() == TensorFormat::Type::IMAGE2D_PACK4 &&
(
handle()->type() != Handle::HandleType::NAIVE)) {
#if MEGDNN_ENABLE_MANGLING
megdnn_throw(
"Only naive and opencl handle support "
"Image2DPack4TensorFormat, try build with debug for get more "
"info");
#else
megdnn_throw(
"Only naive and opencl handle support "
"Image2DPack4TensorFormat, try to export MGB_USE_MEGDNN_DBG=2 "
"and also export CUDA_VISIBLE_DEVICES=\'\' at CUDA env"
"to enable naive handle");
#endif
handle()->type() != Handle::HandleType::NAIVE &&
handle()->type() != Handle::HandleType::X86)) {
megdnn_throw(
"Dump with Image2DPack4TensorFormat is not available on CUDA compnode, "
"try export CUDA_VISIBLE_DEVICES=\'\'");
}
#undef CHECK_SRC
}
......
......@@ -297,6 +297,9 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic_with_ncb(
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
const AlgoAttribute& positive_attr,
const AlgoAttribute& negative_attr) {
if (ConvBiasImpl::param().format == Param::Format::NHWCD4) {
return nullptr;
}
auto algo_data_type = param.deduce_algo_data_type();
auto suggest_category_order = suggest_algo_category_order(param);
for (auto category : suggest_category_order) {
......@@ -346,7 +349,7 @@ ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param(
param().format == Param::Format::NCHW32 ||
param().format == Param::Format::NCHW64) {
spatial_pos = 2;
} else if (param().format == Param::Format::NHWC) {
} else if (param().format == Param::Format::NHWC || param().format == Param::Format::NHWCD4) {
spatial_pos = 1;
} else {
megdnn_assert(0, "invalid conv format %d",
......@@ -497,6 +500,9 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_from_desc(
ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm(
const NCBKernSizeParam& param, size_t workspace_size) {
if (ConvBiasImpl::param().format == Param::Format::NHWCD4) {
return nullptr;
}
if (auto algo = get_algorithm_from_desc(execution_policy().algo)) {
return algo;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册