提交 11732057 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(gopt): nchw_nchwxx useable and opt pass use nchw_nchwxx_valid

GitOrigin-RevId: 60942aca5b19af86a1210267f5af27c1558f1a03
上级 eb18eba8
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#pragma once #pragma once
#include "megdnn/internal/opr_header_prologue.h" #include "megdnn/internal/opr_header_prologue.h"
...@@ -314,8 +315,10 @@ public: ...@@ -314,8 +315,10 @@ public:
/** /**
* \param[in] src (n, ic, ih, iw) or (n, ih, iw, ic) * \param[in] src (n, ic, ih, iw) or (n, ih, iw, ic)
* \param[in] filter (oc, ic, fh, fw) or (oc, fh, fw, ic) or (oc/4, fh, fw, * \param[in] filter (oc, ic, fh, fw) or (oc, fh, fw, ic) or (oc/4, fh, fw,
* 4*ic) \param[in] bias (1, oc, 1, 1) \param[in] z same as dst \param[out] * 4 * ic)
* dst (n, oc, oh, ow) or (n, oh, ow, oc) * \param[in] bias (1, oc, 1, 1)
* \param[in] z same as dst
* \param[out] dst (n, oc, oh, ow) or (n, oh, ow, oc)
* *
* \note if the format is NCHW_WINOGRAD, the filter layout is (alphah, * \note if the format is NCHW_WINOGRAD, the filter layout is (alphah,
* alphaw, oc, ic) * alphaw, oc, ic)
...@@ -407,6 +410,26 @@ public: ...@@ -407,6 +410,26 @@ public:
*/ */
static WinogradParam parse_winograd_name(const std::string& algo_name); static WinogradParam parse_winograd_name(const std::string& algo_name);
/**
* @brief find if there is nchw_nchwxx conv kernel optimized for argment,
* nchw44 used for arm, nchw88 used for x86
*
* @param src_dtype conv feature map data type
* @param filter_dtype conv filter or weight data type
* @param dst_dtype output data type
* @param fm filter meta param
* @param bias_mode bias mode, no_bias or broadcast or bias
* @param nonline_mode identity or relu or h_swish or sigmoid
* @return true, found a kernel
* @return false, can`t found any kernel
*/
static bool is_nchw_nchwxx_optimized(
const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
const DTypeEnum dst_dtype,
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
const ConvBiasForward::BiasMode bias_mode,
const param::ConvBias::NonlineMode nonline_mode);
protected: protected:
CanonizedFilterMeta check_exec( CanonizedFilterMeta check_exec(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
......
...@@ -16,10 +16,10 @@ ...@@ -16,10 +16,10 @@
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h" #include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/strategy.h" #include "src/arm_common/conv_bias/fp32/strategy.h"
#include "src/arm_common/elemwise_op.h" #include "src/arm_common/elemwise_op.h"
#include "src/common/nchw_nchwxx_valid.h"
#include "src/common/opr_delegate.h" #include "src/common/opr_delegate.h"
#include "midout.h" #include "midout.h"
using namespace megdnn; using namespace megdnn;
using namespace arm_common; using namespace arm_common;
using conv_fun = std::function<void( using conv_fun = std::function<void(
...@@ -191,22 +191,10 @@ static void do_conv_kern(const WorkspaceBundle& bundle, ...@@ -191,22 +191,10 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
bool ConvBiasImpl::AlgoF32DirectNCHWNCHW44::usable( bool ConvBiasImpl::AlgoF32DirectNCHWNCHW44::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const { const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
auto&& fm = param.filter_meta; return nchw_nchwxx_valid<NchwNchwxxType::NCHW44_FP32>(
auto fh = fm.spatial[0]; param.src_type.enumv(), param.filter_type.enumv(),
int oc = fm.ocpg; param.dst_type.enumv(), param.filter_meta, param.bias_mode,
bool ok_type = ((param.src_type.enumv() == DTypeEnum::Float32 && param.nonlineMode);
param.filter_type.enumv() == DTypeEnum::Float32 &&
(param.dst_type.enumv() == DTypeEnum::Float32))) &&
(fm.format == param::Convolution::Format::NCHW44);
bool ok_src_dst = fm.icpg < 4 && (oc % 4 == 0 && oc >= 4) && fm.group == 1;
bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] &&
(fh == 2 || fh == 3 || fh == 5 || fh == 7);
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 1 || fm.stride[0] == 2);
bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS;
bool avaible = ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv;
return avaible;
} }
size_t ConvBiasImpl::AlgoF32DirectNCHWNCHW44::get_workspace( size_t ConvBiasImpl::AlgoF32DirectNCHWNCHW44::get_workspace(
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h" #include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/int8/strategy.h" #include "src/arm_common/conv_bias/int8/strategy.h"
#include "src/arm_common/elemwise_op.h" #include "src/arm_common/elemwise_op.h"
#include "src/common/nchw_nchwxx_valid.h"
#include "src/common/opr_delegate.h" #include "src/common/opr_delegate.h"
#include "midout.h" #include "midout.h"
...@@ -214,26 +215,12 @@ static void do_conv_kern(const WorkspaceBundle& bundle, ...@@ -214,26 +215,12 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
ow, op); ow, op);
} }
bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::usable( bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::usable(const NCBKernSizeParam& param,
const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
AlgoSelectionStrategy algo_selection_strategy) const { return nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8>(
MEGDNN_MARK_USED_VAR(algo_selection_strategy); param.src_type.enumv(), param.filter_type.enumv(),
auto&& fm = param.filter_meta; param.dst_type.enumv(), param.filter_meta, param.bias_mode,
auto FH = fm.spatial[0]; param.nonlineMode);
auto OC = fm.ocpg;
bool avaible = //! src and filter are qint8, dst is qint8
fm.icpg < 4 && // must be nchw input
((param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
(param.dst_type.enumv() == DTypeEnum::QuantizedS8))) &&
(fm.format == param::Convolution::Format::NCHW44) &&
(OC % 4 == 0 && OC >= 4) && !fm.should_flip && fm.group == 1 &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 1 || fm.stride[0] == 2) && FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7) && fm.group == 1 &&
param.bias_mode != BiasMode::BIAS;
return avaible;
} }
bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::is_preferred( bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::is_preferred(
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "src/arm_common/conv_bias/int8/algos.h" #include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" #include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h" #include "src/arm_common/elemwise_op.h"
#include "src/common/nchw_nchwxx_valid.h"
#include "midout.h" #include "midout.h"
...@@ -174,23 +175,10 @@ static void do_conv_kern(const WorkspaceBundle& bundle, ...@@ -174,23 +175,10 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
bool ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::usable( bool ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const { const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
auto&& fm = param.filter_meta; return nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8_DOT>(
auto fh = fm.spatial[0]; param.src_type.enumv(), param.filter_type.enumv(),
int oc = fm.ocpg; param.dst_type.enumv(), param.filter_meta, param.bias_mode,
int ic = fm.icpg; param.nonlineMode);
bool ok_type = ((param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
(param.dst_type.enumv() == DTypeEnum::QuantizedS8))) &&
(fm.format == param::Convolution::Format::NCHW44_DOT);
bool ok_src_dst = (oc % 4 == 0 && oc >= 4 && ic < 4);
bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] &&
(fh == 2 || fh == 3 || fh == 5 || fh == 7);
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 1 || fm.stride[0] == 2);
bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS;
bool avaible = ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv;
return avaible;
} }
size_t ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::get_workspace( size_t ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::get_workspace(
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "src/arm_common/conv_bias/int8x8x16/algos.h" #include "src/arm_common/conv_bias/int8x8x16/algos.h"
#include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h" #include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h" #include "src/arm_common/elemwise_op.h"
#include "src/common/nchw_nchwxx_valid.h"
#include "src/common/opr_delegate.h" #include "src/common/opr_delegate.h"
#include "midout.h" #include "midout.h"
...@@ -220,23 +221,10 @@ static void do_conv_kern(const WorkspaceBundle& bundle, ...@@ -220,23 +221,10 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
bool ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::usable( bool ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const { const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
auto&& fm = param.filter_meta; return nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8_INT8_INT16>(
auto fh = fm.spatial[0]; param.src_type.enumv(), param.filter_type.enumv(),
int oc = fm.ocpg; param.dst_type.enumv(), param.filter_meta, param.bias_mode,
bool ok_type = ((param.src_type.enumv() == DTypeEnum::Int8 && param.nonlineMode);
param.filter_type.enumv() == DTypeEnum::Int8 &&
(param.dst_type.enumv() == DTypeEnum::Int16))) &&
(fm.format == param::Convolution::Format::NCHW44);
bool ok_src_dst = fm.icpg < 4 && (oc % 4 == 0 && oc >= 4) && fm.group == 1;
bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] &&
(fh == 2 || fh == 3 || fh == 5 || fh == 7);
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 2 || fm.stride[0] == 1);
bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS &&
param.nonlineMode == param::ConvBias::NonlineMode::IDENTITY;
bool avaible = ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv;
return avaible;
} }
size_t ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::get_workspace( size_t ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::get_workspace(
......
/**
* \file dnn/src/common/nchw_nchwxx_valid.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megdnn/oprs/nn.h"
#include "src/common/nchw_nchwxx_valid.h"
using namespace megdnn;
namespace {
using NchwNchwxxFuncInterface = std::function<bool(
const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
const DTypeEnum dst_dtype,
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
const ConvBiasForward::BiasMode bias_mode,
const param::ConvBias::NonlineMode nonline_mode)>;
static SmallVector<NchwNchwxxFuncInterface> g_func_vec{
nchw_nchwxx_valid<NchwNchwxxType::NCHW44_FP32>,
nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8>,
nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8_INT8_INT16>,
nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8_DOT>,
nchw_nchwxx_valid<NchwNchwxxType::NCHW88>,
};
} // namespace
bool ConvBiasForward::is_nchw_nchwxx_optimized(
const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
const DTypeEnum dst_dtype,
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
const ConvBiasForward::BiasMode bias_mode,
const param::ConvBias::NonlineMode nonline_mode) {
for (auto& func : g_func_vec) {
if (func(src_dtype, filter_dtype, dst_dtype, fm, bias_mode,
nonline_mode)) {
return true;
}
}
return false;
}
\ No newline at end of file
/**
* \file dnn/src/common/nchw_nchwxx_valid.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
#include "src/fallback/conv_bias/opr_impl.h"
namespace megdnn {
namespace {
enum NchwNchwxxType {
NCHW44_FP32,
NCHW44_INT8,
NCHW44_INT8_INT8_INT16,
NCHW44_INT8_DOT,
NCHW88,
};
template <NchwNchwxxType T>
static inline bool nchw_nchwxx_valid(
const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
const DTypeEnum dst_dtype,
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
const BiasMode bias_mode,
const param::ConvBias::NonlineMode nonline_mode);
template <>
inline bool nchw_nchwxx_valid<NCHW44_FP32>(
const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
const DTypeEnum dst_dtype,
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
const BiasMode bias_mode,
const param::ConvBias::NonlineMode nonline_mode) {
bool ok_type = ((src_dtype == DTypeEnum::Float32 &&
filter_dtype == DTypeEnum::Float32 &&
(dst_dtype == DTypeEnum::Float32))) &&
(fm.format == param::Convolution::Format::NCHW44);
bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY ||
nonline_mode == param::ConvBias::NonlineMode::RELU ||
nonline_mode == param::ConvBias::NonlineMode::H_SWISH;
bool ok_src_dst =
fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1;
bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] &&
(fm.spatial[0] == 2 || fm.spatial[0] == 3 ||
fm.spatial[0] == 5 || fm.spatial[0] == 7);
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 1 || fm.stride[1] == 2);
bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS;
bool avaible = ok_type && ok_nonline && ok_src_dst && ok_filter &&
ok_slide && ok_conv;
return avaible;
}
template <>
inline bool nchw_nchwxx_valid<NCHW44_INT8>(
const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
const DTypeEnum dst_dtype,
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
const BiasMode bias_mode,
const param::ConvBias::NonlineMode nonline_mode) {
bool ok_type = ((src_dtype == DTypeEnum::QuantizedS8 &&
filter_dtype == DTypeEnum::QuantizedS8 &&
(dst_dtype == DTypeEnum::QuantizedS8))) &&
(fm.format == param::Convolution::Format::NCHW44);
bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY ||
nonline_mode == param::ConvBias::NonlineMode::RELU ||
nonline_mode == param::ConvBias::NonlineMode::H_SWISH;
bool ok_src_dst =
fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1;
bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] &&
(fm.spatial[0] == 2 || fm.spatial[0] == 3 ||
fm.spatial[0] == 5 || fm.spatial[0] == 7);
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 1 || fm.stride[1] == 2);
bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS;
bool avaible = ok_type && ok_nonline && ok_src_dst && ok_filter &&
ok_slide && ok_conv;
return avaible;
}
template <>
inline bool nchw_nchwxx_valid<NCHW44_INT8_INT8_INT16>(
const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
const DTypeEnum dst_dtype,
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
const BiasMode bias_mode,
const param::ConvBias::NonlineMode nonline_mode) {
bool ok_type =
((src_dtype == DTypeEnum::Int8 && filter_dtype == DTypeEnum::Int8 &&
(dst_dtype == DTypeEnum::Int16))) &&
(fm.format == param::Convolution::Format::NCHW44);
bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY;
bool ok_src_dst =
fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1;
bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] &&
(fm.spatial[0] == 2 || fm.spatial[0] == 3 ||
fm.spatial[0] == 5 || fm.spatial[0] == 7);
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 2 || fm.stride[0] == 1);
bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS;
bool avaible = ok_type && ok_nonline && ok_src_dst && ok_filter &&
ok_slide && ok_conv;
return avaible;
}
template <>
inline bool nchw_nchwxx_valid<NCHW44_INT8_DOT>(
const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
const DTypeEnum dst_dtype,
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
const BiasMode bias_mode,
const param::ConvBias::NonlineMode nonline_mode) {
bool ok_type = ((src_dtype == DTypeEnum::QuantizedS8 &&
filter_dtype == DTypeEnum::QuantizedS8 &&
(dst_dtype == DTypeEnum::QuantizedS8))) &&
(fm.format == param::Convolution::Format::NCHW44_DOT);
bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY ||
nonline_mode == param::ConvBias::NonlineMode::RELU ||
nonline_mode == param::ConvBias::NonlineMode::H_SWISH;
bool ok_src_dst =
fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1;
bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] &&
(fm.spatial[0] == 2 || fm.spatial[0] == 3 ||
fm.spatial[0] == 5 || fm.spatial[0] == 7);
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 1 || fm.stride[1] == 2);
bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS;
bool avaible = ok_type && ok_nonline && ok_src_dst && ok_filter &&
ok_slide && ok_conv;
return avaible;
}
template <>
inline bool nchw_nchwxx_valid<NCHW88>(
const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
const DTypeEnum dst_dtype,
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
const BiasMode bias_mode,
const param::ConvBias::NonlineMode nonline_mode) {
bool ok_type = ((src_dtype == DTypeEnum::Float32 &&
filter_dtype == DTypeEnum::Float32 &&
(dst_dtype == DTypeEnum::Float32))) &&
(fm.format == param::Convolution::Format::NCHW88);
bool ok_src_dst =
fm.icpg < 8 && (fm.ocpg % 8 == 0 && fm.ocpg >= 8) && fm.group == 1;
bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS;
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1;
bool avaible = ok_type && ok_src_dst && ok_slide && ok_conv;
return avaible;
}
} // namespace
} // namespace megdnn
\ No newline at end of file
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
*/ */
#pragma once #pragma once
#include "src/common/nchw_nchwxx_valid.h"
#include "src/x86/conv_bias/opr_impl.h" #include "src/x86/conv_bias/opr_impl.h"
using namespace megdnn; using namespace megdnn;
...@@ -29,6 +30,7 @@ class ConvBiasImpl::AlgoDirect final : public AlgoBase { ...@@ -29,6 +30,7 @@ class ConvBiasImpl::AlgoDirect final : public AlgoBase {
const NCBKernParam& kern_param, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index, const NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids); const CpuNDRange& workspace_ids);
public: public:
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
const char* name() const override { const char* name() const override {
...@@ -61,6 +63,7 @@ class ConvBiasImpl::AlgoDirectStride2 final : public AlgoBase { ...@@ -61,6 +63,7 @@ class ConvBiasImpl::AlgoDirectStride2 final : public AlgoBase {
const NCBKernParam& kern_param, const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index, const NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids); const CpuNDRange& workspace_ids);
public: public:
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
const char* name() const override { const char* name() const override {
...@@ -163,13 +166,19 @@ public: ...@@ -163,13 +166,19 @@ public:
AlgoSelectionStrategy) const override { AlgoSelectionStrategy) const override {
auto&& fm = param.filter_meta; auto&& fm = param.filter_meta;
bool ok = (fm.format == param::ConvBias::Format::NCHW88) && bool nchw_nchw88_ok = nchw_nchwxx_valid<NchwNchwxxType::NCHW88>(
fm.spatial_ndim == 2 && param.src_type.enumv(), param.filter_type.enumv(),
param.src_type.enumv() == DTypeEnum::Float32 && param.dst_type.enumv(), param.filter_meta, param.bias_mode,
param.filter_type.enumv() == DTypeEnum::Float32 && param.nonlineMode);
param.dst_type.enumv() == DTypeEnum::Float32 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1; bool normal_conv_ok = (fm.format == param::ConvBias::Format::NCHW88) &&
return ok; fm.spatial_ndim == 2 &&
param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
param.dst_type.enumv() == DTypeEnum::Float32 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1;
return nchw_nchw88_ok || normal_conv_ok;
}; };
size_t get_workspace(const NCBKernSizeParam&) const override { return 0; } size_t get_workspace(const NCBKernSizeParam&) const override { return 0; }
......
...@@ -1816,155 +1816,67 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var, ...@@ -1816,155 +1816,67 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var,
} }
template <typename OprType> template <typename OprType>
static inline bool nchw_nchwxx_valid(const OprType& opr, static inline bool nchw_nchwxx_valid(
const VarNodeArray& new_inp, const OprType& opr, const VarNodeArray& new_inp, const size_t pack_size,
const size_t pack_size, bool is_dense, megdnn::param::ConvBias::NonlineMode nonline_mode =
bool is_dot = false); megdnn::param::ConvBias::NonlineMode::IDENTITY,
template <> bool is_dot = false) {
inline bool nchw_nchwxx_valid<opr::ConvolutionForward>( auto& src_node = new_inp[0];
const opr::ConvolutionForward& opr, const VarNodeArray& new_inp, auto& filter_node = new_inp[1];
const size_t pack_size, bool is_dense, bool is_dot) { auto dst_node = opr.output(0);
auto& filter_shape = new_inp[1]->shape(); if (filter_node->shape().ndim != 4) {
auto filter_dtype = new_inp[1]->dtype();
bool is_int8 = filter_dtype.enumv() == DTypeEnum::QuantizedS8 ||
filter_dtype.enumv() == DTypeEnum::Int8;
const size_t oc = filter_shape[0];
const size_t ic = filter_shape[1];
bool is_like_nchw_nchwxx =
is_dense && oc % pack_size == 0 && ic < pack_size;
if (!is_like_nchw_nchwxx) {
return false; return false;
} }
SmallVector<TensorLayout> layouts; megdnn::ConvolutionBase<megdnn::param::Convolution>::CanonizedFilterMeta fm;
fm.format = megdnn::param::Convolution::Format::NCHW;
//! src fm.should_flip =
layouts.push_back( opr.param().mode == megdnn::ConvBiasForward::Mode::CONVOLUTION;
{new_inp[0]->shape(), new_inp[0]->dtype(), new_inp[0]->format()}); fm.group = 1;
fm.spatial_ndim = 2;
//! weight fm.ocpg = filter_node->shape()[0];
layouts.push_back({{filter_shape[0] / pack_size, filter_shape[2], fm.icpg = filter_node->shape()[1];
filter_shape[3], filter_shape[1], pack_size}, fm.spatial[0] = filter_node->shape()[2];
new_inp[1]->dtype(), fm.spatial[1] = filter_node->shape()[3];
new_inp[1]->format()}); fm.stride[0] = opr.param().stride_h;
fm.stride[1] = opr.param().stride_w;
auto out0 = opr.output(0); fm.padding[0] = opr.param().pad_h;
auto& out_shape = out0->shape(); fm.padding[1] = opr.param().pad_w;
//! FIXME: return false if oc is invalid fm.dilation[0] = opr.param().dilate_h;
layouts.push_back({{out_shape[0], out_shape[1] / pack_size, out_shape[2], fm.dilation[1] = opr.param().dilate_w;
out_shape[3], pack_size},
out0->dtype(), megdnn::ConvBiasForward::BiasMode bias_mode =
out0->format()}); megdnn::ConvBiasForward::BiasMode::NO_BIAS;
if (std::is_same<OprType, opr::ConvBiasForward>::value) {
auto megdnn_conv = opr::intl::get_megdnn_handle(opr.comp_node()) auto& bias_shape = new_inp[2]->shape();
->create_operator<megdnn::ConvolutionForward>(); if (bias_shape.ndim == 0) {
megdnn_conv.get()->param() = opr.param(); bias_mode = megdnn::ConvBiasForward::BiasMode::NO_BIAS;
//! set by dtype } else if (bias_shape.eq_shape(dst_node->shape())) {
switch (pack_size) { bias_mode = megdnn::ConvBiasForward::BiasMode::BIAS;
case 4: } else {
if (is_dot && is_int8) { //! just check the ndim, the detail shape check is in check_exec
megdnn_conv.get()->param().format = mgb_assert(bias_shape.ndim == dst_node->shape().ndim);
megdnn::param::Convolution::Format::NCHW44_DOT; bias_mode =
} else { megdnn::ConvBiasForward::BiasMode::BROADCAST_CHANNEL_BIAS;
megdnn_conv.get()->param().format =
megdnn::param::Convolution::Format::NCHW44;
}
break;
case 8:
megdnn_conv.get()->param().format =
megdnn::param::Convolution::Format::NCHW88;
break;
default:
break;
}
bool find_valid_algo = false;
auto algos = megdnn_conv.get()->get_all_algorithms(layouts[0], layouts[1],
layouts[2]);
for (auto i : algos) {
if (i->type() != nullptr) {
find_valid_algo = true;
} }
} }
return find_valid_algo; if (pack_size == 4) {
} if (is_dot && filter_node->dtype().enumv() == DTypeEnum::QuantizedS8) {
template <> fm.format = megdnn::param::Convolution::Format::NCHW44_DOT;
inline bool nchw_nchwxx_valid<opr::ConvBiasForward>( } else {
const opr::ConvBiasForward& opr, const VarNodeArray& new_inp, fm.format = megdnn::param::Convolution::Format::NCHW44;
const size_t pack_size, bool is_dense, bool is_dot) {
auto& filter_shape = new_inp[1]->shape();
auto filter_dtype = new_inp[1]->dtype();
bool is_int8 = filter_dtype.enumv() == DTypeEnum::QuantizedS8 ||
filter_dtype.enumv() == DTypeEnum::Int8;
const size_t oc = filter_shape[0];
const size_t ic = filter_shape[1];
bool is_like_nchw_nchwxx =
is_dense && oc % pack_size == 0 && ic < pack_size;
if (!is_like_nchw_nchwxx) {
return false;
}
SmallVector<TensorLayout> layouts;
//! src
layouts.push_back(
{new_inp[0]->shape(), new_inp[0]->dtype(), new_inp[0]->format()});
//! weight
layouts.push_back({{filter_shape[0] / pack_size, filter_shape[2],
filter_shape[3], filter_shape[1], pack_size},
new_inp[1]->dtype(),
new_inp[1]->format()});
auto& bias_shape = new_inp[2]->shape();
layouts.push_back({{bias_shape[0], bias_shape[1] / pack_size, bias_shape[2],
bias_shape[3], pack_size},
new_inp[2]->dtype(),
new_inp[2]->format()});
auto out0 = opr.output(0);
auto& out_shape = out0->shape();
//! FIXME: return false if oc is invalid
layouts.push_back({{out_shape[0], out_shape[1] / pack_size, out_shape[2],
out_shape[3], pack_size},
out0->dtype(),
out0->format()});
// megdnn::ConvolutionForward
auto megdnn_conv = opr::intl::get_megdnn_handle(opr.comp_node())
->create_operator<megdnn::ConvBiasForward>();
megdnn_conv.get()->param() = opr.param();
//! FIXME: set by dtype
switch (pack_size) {
case 4:
if (is_dot && is_int8) {
megdnn_conv.get()->param().format =
megdnn::param::Convolution::Format::NCHW44_DOT;
} else {
megdnn_conv.get()->param().format =
megdnn::param::Convolution::Format::NCHW44;
}
break;
case 8:
megdnn_conv.get()->param().format =
megdnn::param::Convolution::Format::NCHW88;
break;
default:
break;
}
bool find_valid_algo = false;
auto algos = megdnn_conv.get()->get_all_algorithms(
layouts[0], layouts[1], layouts[2], {}, layouts[3]);
for (auto i : algos) {
if (i->type() != nullptr) {
find_valid_algo = true;
} }
} else if (pack_size == 8) {
fm.format = megdnn::param::Convolution::Format::NCHW88;
} else {
mgb_assert(0, "only support nchw44 nchw88");
} }
return find_valid_algo; return megdnn::ConvBiasForward::is_nchw_nchwxx_optimized(
src_node->dtype().enumv(), filter_node->dtype().enumv(),
dst_node->dtype().enumv(), fm, bias_mode, nonline_mode);
} }
void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
using RelayoutMode = RelayoutPlaceholder::LayoutType; using RelayoutMode = RelayoutPlaceholder::LayoutType;
using TestFilterResult = std::pair<TransType, RelayoutMode>; using TestFilterResult = std::pair<TransType, RelayoutMode>;
...@@ -1984,19 +1896,6 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { ...@@ -1984,19 +1896,6 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
megdnn::param::Pooling::Format pooling_format = megdnn::param::Pooling::Format pooling_format =
megdnn::param::Pooling::Format::NCHW88; megdnn::param::Pooling::Format::NCHW88;
std::string convter_pass_name = "conv_format_nchw88"; std::string convter_pass_name = "conv_format_nchw88";
#if MEGDNN_AARCH64 || MEGDNN_ARMv7
if (pack_c_size == 8) {
mgb_log_error(
"runtime backend is ARM, but nchw88 only support X86, you may "
"have performance loss\n");
}
#elif MEGDNN_X86
if (pack_c_size == 4) {
mgb_log_error(
"runtime backend is X86, but nchw44 only support arm, you may "
"have performance loss\n");
}
#endif
if (pack_c_size == 4) { if (pack_c_size == 4) {
weight_to_nchwxx_mode_dense = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DENSE; weight_to_nchwxx_mode_dense = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DENSE;
...@@ -2053,10 +1952,8 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { ...@@ -2053,10 +1952,8 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
mgb_assert(conv_opr.param().format == mgb_assert(conv_opr.param().format ==
megdnn::param::Convolution::Format::NCHW, megdnn::param::Convolution::Format::NCHW,
"ConvertFormat Pass only support converting NCHW to NCHWXX"); "ConvertFormat Pass only support converting NCHW to NCHWXX");
bool is_dense = conv_opr.param().sparse ==
megdnn::param::Convolution::Sparse::DENSE;
bool valid_nchw_nchw44 = bool valid_nchw_nchw44 =
nchw_nchwxx_valid(conv_opr, new_inp, pack_c_size, is_dense); nchw_nchwxx_valid(conv_opr, new_inp, pack_c_size);
auto is_trans = test_trans_nchwxx( auto is_trans = test_trans_nchwxx(
conv_opr.param().sparse, new_inp[1], conv_opr.param().stride_h, conv_opr.param().sparse, new_inp[1], conv_opr.param().stride_h,
conv_opr.param().stride_w, valid_nchw_nchw44); conv_opr.param().stride_w, valid_nchw_nchw44);
...@@ -2133,10 +2030,9 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { ...@@ -2133,10 +2030,9 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
mgb_assert(conv_bias_opr.param().format == mgb_assert(conv_bias_opr.param().format ==
megdnn::param::ConvBias::Format::NCHW, megdnn::param::ConvBias::Format::NCHW,
"ConvertFormat Pass only support converting NCHW to NCHWXX"); "ConvertFormat Pass only support converting NCHW to NCHWXX");
bool is_dense = conv_bias_opr.param().sparse == bool valid_nchw_nchw44 =
megdnn::param::Convolution::Sparse::DENSE; nchw_nchwxx_valid(conv_bias_opr, new_inp, pack_c_size,
bool valid_nchw_nchw44 = nchw_nchwxx_valid(conv_bias_opr, new_inp, conv_bias_opr.param().nonlineMode);
pack_c_size, is_dense);
auto is_trans = test_trans_nchwxx( auto is_trans = test_trans_nchwxx(
conv_bias_opr.param().sparse, new_inp[1], conv_bias_opr.param().sparse, new_inp[1],
conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w, conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w,
...@@ -2371,13 +2267,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2371,13 +2267,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
MIDOUT_B("EnableNchw44DotPass::make") MIDOUT_B("EnableNchw44DotPass::make")
auto ret = std::make_unique<EnableNchw44DotPass>(); auto ret = std::make_unique<EnableNchw44DotPass>();
ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
//! First is whether the conv can trans to nchwxx, second is the filter //! First is whether the conv can trans to nchwxx, second is the filter
//! trans mode //! trans mode
#if MEGDNN_X86
mgb_log_error(
"backend is X86, but nchw44_dot only support arm, you may have "
"performance loss\n");
#endif
using RelayoutMode = RelayoutPlaceholder::LayoutType; using RelayoutMode = RelayoutPlaceholder::LayoutType;
struct TestTransResult { struct TestTransResult {
...@@ -2453,14 +2344,12 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2453,14 +2344,12 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
megdnn::param::Convolution::Format::NCHW, megdnn::param::Convolution::Format::NCHW,
"ConvertFormat Pass only support converting NCHW to " "ConvertFormat Pass only support converting NCHW to "
"NCHW44_DOT"); "NCHW44_DOT");
bool is_dense = conv_opr.param().sparse == bool valid_nchw_nchw44 = nchw_nchwxx_valid(
megdnn::param::Convolution::Sparse::DENSE; conv_opr, new_inp, pack_c_size,
bool valid_nchw_nchw44 = megdnn::param::ConvBias::NonlineMode::IDENTITY, true);
nchw_nchwxx_valid(conv_opr, new_inp, pack_c_size, is_dense);
auto is_trans = test_trans_nchw44_dot( auto is_trans = test_trans_nchw44_dot(
conv_opr.param().sparse, new_inp[1], conv_opr.param().stride_h, conv_opr.param().sparse, new_inp[1], conv_opr.param().stride_h,
conv_opr.param().stride_w, valid_nchw_nchw44); conv_opr.param().stride_w, valid_nchw_nchw44);
//! can not trans to nchwxx //! can not trans to nchwxx
if (is_trans.trans_type == TransType::TRANS_NONE) { if (is_trans.trans_type == TransType::TRANS_NONE) {
mgb_assert(new_inp[1]->shape().ndim == 4 || mgb_assert(new_inp[1]->shape().ndim == 4 ||
...@@ -2533,10 +2422,9 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2533,10 +2422,9 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
mgb_assert(conv_bias_opr.param().format == mgb_assert(conv_bias_opr.param().format ==
megdnn::param::ConvBias::Format::NCHW, megdnn::param::ConvBias::Format::NCHW,
"ConvertFormat Pass only support converting NCHW to NCHWXX"); "ConvertFormat Pass only support converting NCHW to NCHWXX");
bool is_dense = conv_bias_opr.param().sparse == bool valid_nchw_nchw44 =
megdnn::param::Convolution::Sparse::DENSE; nchw_nchwxx_valid(conv_bias_opr, new_inp, pack_c_size,
bool valid_nchw_nchw44 = nchw_nchwxx_valid(conv_bias_opr, new_inp, conv_bias_opr.param().nonlineMode, true);
pack_c_size, is_dense);
auto is_trans = test_trans_nchw44_dot( auto is_trans = test_trans_nchw44_dot(
conv_bias_opr.param().sparse, new_inp[1], conv_bias_opr.param().sparse, new_inp[1],
conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w, conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w,
......
...@@ -2913,7 +2913,8 @@ TEST(TestGoptInference, ConvertFormatNCHW88) { ...@@ -2913,7 +2913,8 @@ TEST(TestGoptInference, ConvertFormatNCHW88) {
opr::Convolution::Param param_conv; opr::Convolution::Param param_conv;
param_conv.pad_h = param_conv.pad_w = 1; param_conv.pad_h = param_conv.pad_w = 1;
auto w1 = mkcvar("w1", {8, 3, 3, 3}), auto w1 = mkcvar("w1", {8, 3, 3, 3}),
conv1 = opr::Convolution::make(x, w1, param_conv); conv1 = opr::Convolution::make(x, w1, param_conv, {},
OperatorNodeConfig("conv1"));
//! channel wise //! channel wise
opr::ConvBias::Param param_conv_bias; opr::ConvBias::Param param_conv_bias;
param_conv_bias.pad_h = param_conv_bias.pad_w = 1; param_conv_bias.pad_h = param_conv_bias.pad_w = 1;
...@@ -2954,7 +2955,8 @@ TEST(TestGoptInference, ConvertFormatNCHW88) { ...@@ -2954,7 +2955,8 @@ TEST(TestGoptInference, ConvertFormatNCHW88) {
options.enable_nchw88(); options.enable_nchw88();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
} }
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW88,
find_opr<opr::Convolution>(y_opt, "conv1").param().format);
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW88, ASSERT_EQ(opr::ConvBias::Param::Format::NCHW88,
find_opr<opr::ConvBias>(y_opt).param().format); find_opr<opr::ConvBias>(y_opt).param().format);
...@@ -3084,13 +3086,8 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { ...@@ -3084,13 +3086,8 @@ TEST(TestGoptInference, ConvertFormatNCHW44) {
options.enable_nchw44(); options.enable_nchw44();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, ASSERT_EQ(opr::Convolution::Param::Format::NCHW44,
find_opr<opr::Convolution>(y_opt, "conv1").param().format); find_opr<opr::Convolution>(y_opt, "conv1").param().format);
#else
ASSERT_EQ(opr::Convolution::Param::Format::NCHW,
find_opr<opr::Convolution>(y_opt, "conv1").param().format);
#endif
ASSERT_EQ(opr::Convolution::Param::Format::NCHW, ASSERT_EQ(opr::Convolution::Param::Format::NCHW,
find_opr<opr::ConvBias>(y_opt, "conv1_f1").param().format); find_opr<opr::ConvBias>(y_opt, "conv1_f1").param().format);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, ASSERT_EQ(opr::Convolution::Param::Format::NCHW44,
...@@ -3325,17 +3322,10 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { ...@@ -3325,17 +3322,10 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) {
options.enable_nchw44_dot(); options.enable_nchw44_dot();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, ASSERT_EQ(opr::Convolution::Param::Format::NCHW44,
find_opr<opr::Convolution>(y_opt, "conv1").param().format); find_opr<opr::Convolution>(y_opt, "conv1").param().format);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44_DOT, ASSERT_EQ(opr::Convolution::Param::Format::NCHW44_DOT,
find_opr<opr::Convolution>(y_opt, "conv1_3_q").param().format); find_opr<opr::Convolution>(y_opt, "conv1_3_q").param().format);
#else
ASSERT_EQ(opr::Convolution::Param::Format::NCHW,
find_opr<opr::Convolution>(y_opt, "conv1").param().format);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW,
find_opr<opr::Convolution>(y_opt, "conv1_3_q").param().format);
#endif
ASSERT_EQ(opr::Convolution::Param::Format::NCHW, ASSERT_EQ(opr::Convolution::Param::Format::NCHW,
find_opr<opr::ConvBias>(y_opt, "conv1_f1").param().format); find_opr<opr::ConvBias>(y_opt, "conv1_f1").param().format);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44_DOT, ASSERT_EQ(opr::Convolution::Param::Format::NCHW44_DOT,
......
...@@ -611,11 +611,11 @@ public: ...@@ -611,11 +611,11 @@ public:
"%s: input shapes (%s %s, %s %s) -> (%s %s): algo=%s " "%s: input shapes (%s %s, %s %s) -> (%s %s): algo=%s "
"workspace=%.2fMiB reproducible=%d", "workspace=%.2fMiB reproducible=%d",
mgb_opr->dyn_typeinfo()->name, mgb_opr->dyn_typeinfo()->name,
layouts[0].TensorShape::to_string().c_str(), layouts[0].to_string().c_str(),
layouts[0].dtype.name(), layouts[0].dtype.name(),
layouts[1].TensorShape::to_string().c_str(), layouts[1].to_string().c_str(),
layouts[1].dtype.name(), layouts[1].dtype.name(),
layouts[layouts.size() - 1].TensorShape::to_string().c_str(), layouts[layouts.size() - 1].to_string().c_str(),
layouts[layouts.size() - 1].dtype.name(), layouts[layouts.size() - 1].dtype.name(),
algo->name(), algo->name(),
workspace / (1024 * 1024.0), algo->is_reproducible()); workspace / (1024 * 1024.0), algo->is_reproducible());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册