提交 11732057 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(gopt): nchw_nchwxx useable and opt pass use nchw_nchwxx_valid

GitOrigin-RevId: 60942aca5b19af86a1210267f5af27c1558f1a03
上级 eb18eba8
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/internal/opr_header_prologue.h"
......@@ -314,8 +315,10 @@ public:
/**
* \param[in] src (n, ic, ih, iw) or (n, ih, iw, ic)
* \param[in] filter (oc, ic, fh, fw) or (oc, fh, fw, ic) or (oc/4, fh, fw,
* 4*ic) \param[in] bias (1, oc, 1, 1) \param[in] z same as dst \param[out]
* dst (n, oc, oh, ow) or (n, oh, ow, oc)
* 4 * ic)
* \param[in] bias (1, oc, 1, 1)
* \param[in] z same as dst
* \param[out] dst (n, oc, oh, ow) or (n, oh, ow, oc)
*
* \note if the format is NCHW_WINOGRAD, the filter layout is (alphah,
* alphaw, oc, ic)
......@@ -407,6 +410,26 @@ public:
*/
static WinogradParam parse_winograd_name(const std::string& algo_name);
/**
* @brief find if there is nchw_nchwxx conv kernel optimized for argment,
* nchw44 used for arm, nchw88 used for x86
*
* @param src_dtype conv feature map data type
* @param filter_dtype conv filter or weight data type
* @param dst_dtype output data type
* @param fm filter meta param
* @param bias_mode bias mode, no_bias or broadcast or bias
* @param nonline_mode identity or relu or h_swish or sigmoid
* @return true, found a kernel
* @return false, can`t found any kernel
*/
static bool is_nchw_nchwxx_optimized(
const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
const DTypeEnum dst_dtype,
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
const ConvBiasForward::BiasMode bias_mode,
const param::ConvBias::NonlineMode nonline_mode);
protected:
CanonizedFilterMeta check_exec(
const TensorLayout& src, const TensorLayout& filter,
......
......@@ -16,10 +16,10 @@
#include "src/arm_common/conv_bias/fp32/f32_direct_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/fp32/strategy.h"
#include "src/arm_common/elemwise_op.h"
#include "src/common/nchw_nchwxx_valid.h"
#include "src/common/opr_delegate.h"
#include "midout.h"
using namespace megdnn;
using namespace arm_common;
using conv_fun = std::function<void(
......@@ -191,22 +191,10 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
bool ConvBiasImpl::AlgoF32DirectNCHWNCHW44::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
auto&& fm = param.filter_meta;
auto fh = fm.spatial[0];
int oc = fm.ocpg;
bool ok_type = ((param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
(param.dst_type.enumv() == DTypeEnum::Float32))) &&
(fm.format == param::Convolution::Format::NCHW44);
bool ok_src_dst = fm.icpg < 4 && (oc % 4 == 0 && oc >= 4) && fm.group == 1;
bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] &&
(fh == 2 || fh == 3 || fh == 5 || fh == 7);
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 1 || fm.stride[0] == 2);
bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS;
bool avaible = ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv;
return avaible;
return nchw_nchwxx_valid<NchwNchwxxType::NCHW44_FP32>(
param.src_type.enumv(), param.filter_type.enumv(),
param.dst_type.enumv(), param.filter_meta, param.bias_mode,
param.nonlineMode);
}
size_t ConvBiasImpl::AlgoF32DirectNCHWNCHW44::get_workspace(
......
......@@ -15,6 +15,7 @@
#include "src/arm_common/conv_bias/int8/direct_nchw_nchw44_kern.h"
#include "src/arm_common/conv_bias/int8/strategy.h"
#include "src/arm_common/elemwise_op.h"
#include "src/common/nchw_nchwxx_valid.h"
#include "src/common/opr_delegate.h"
#include "midout.h"
......@@ -214,26 +215,12 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
ow, op);
}
bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::usable(
const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const {
MEGDNN_MARK_USED_VAR(algo_selection_strategy);
auto&& fm = param.filter_meta;
auto FH = fm.spatial[0];
auto OC = fm.ocpg;
bool avaible = //! src and filter are qint8, dst is qint8
fm.icpg < 4 && // must be nchw input
((param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
(param.dst_type.enumv() == DTypeEnum::QuantizedS8))) &&
(fm.format == param::Convolution::Format::NCHW44) &&
(OC % 4 == 0 && OC >= 4) && !fm.should_flip && fm.group == 1 &&
fm.spatial_ndim == 2 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 1 || fm.stride[0] == 2) && FH == fm.spatial[1] &&
(FH == 2 || FH == 3 || FH == 5 || FH == 7) && fm.group == 1 &&
param.bias_mode != BiasMode::BIAS;
return avaible;
bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const {
return nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8>(
param.src_type.enumv(), param.filter_type.enumv(),
param.dst_type.enumv(), param.filter_meta, param.bias_mode,
param.nonlineMode);
}
bool ConvBiasImpl::AlgoS8DirectNCHWNCHW44::is_preferred(
......
......@@ -16,6 +16,7 @@
#include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "src/common/nchw_nchwxx_valid.h"
#include "midout.h"
......@@ -174,23 +175,10 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
bool ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
auto&& fm = param.filter_meta;
auto fh = fm.spatial[0];
int oc = fm.ocpg;
int ic = fm.icpg;
bool ok_type = ((param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
(param.dst_type.enumv() == DTypeEnum::QuantizedS8))) &&
(fm.format == param::Convolution::Format::NCHW44_DOT);
bool ok_src_dst = (oc % 4 == 0 && oc >= 4 && ic < 4);
bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] &&
(fh == 2 || fh == 3 || fh == 5 || fh == 7);
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 1 || fm.stride[0] == 2);
bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS;
bool avaible = ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv;
return avaible;
return nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8_DOT>(
param.src_type.enumv(), param.filter_type.enumv(),
param.dst_type.enumv(), param.filter_meta, param.bias_mode,
param.nonlineMode);
}
size_t ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::get_workspace(
......
......@@ -16,6 +16,7 @@
#include "src/arm_common/conv_bias/int8x8x16/algos.h"
#include "src/arm_common/conv_bias/int8x8x16/direct_nchw_nchw44_kern.h"
#include "src/arm_common/elemwise_op.h"
#include "src/common/nchw_nchwxx_valid.h"
#include "src/common/opr_delegate.h"
#include "midout.h"
......@@ -220,23 +221,10 @@ static void do_conv_kern(const WorkspaceBundle& bundle,
bool ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
auto&& fm = param.filter_meta;
auto fh = fm.spatial[0];
int oc = fm.ocpg;
bool ok_type = ((param.src_type.enumv() == DTypeEnum::Int8 &&
param.filter_type.enumv() == DTypeEnum::Int8 &&
(param.dst_type.enumv() == DTypeEnum::Int16))) &&
(fm.format == param::Convolution::Format::NCHW44);
bool ok_src_dst = fm.icpg < 4 && (oc % 4 == 0 && oc >= 4) && fm.group == 1;
bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] &&
(fh == 2 || fh == 3 || fh == 5 || fh == 7);
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 2 || fm.stride[0] == 1);
bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS &&
param.nonlineMode == param::ConvBias::NonlineMode::IDENTITY;
bool avaible = ok_type && ok_src_dst && ok_filter && ok_slide && ok_conv;
return avaible;
return nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8_INT8_INT16>(
param.src_type.enumv(), param.filter_type.enumv(),
param.dst_type.enumv(), param.filter_meta, param.bias_mode,
param.nonlineMode);
}
size_t ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44::get_workspace(
......
/**
* \file dnn/src/common/nchw_nchwxx_valid.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megdnn/oprs/nn.h"
#include "src/common/nchw_nchwxx_valid.h"
using namespace megdnn;
namespace {
using NchwNchwxxFuncInterface = std::function<bool(
const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
const DTypeEnum dst_dtype,
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
const ConvBiasForward::BiasMode bias_mode,
const param::ConvBias::NonlineMode nonline_mode)>;
static SmallVector<NchwNchwxxFuncInterface> g_func_vec{
nchw_nchwxx_valid<NchwNchwxxType::NCHW44_FP32>,
nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8>,
nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8_INT8_INT16>,
nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8_DOT>,
nchw_nchwxx_valid<NchwNchwxxType::NCHW88>,
};
} // namespace
bool ConvBiasForward::is_nchw_nchwxx_optimized(
const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
const DTypeEnum dst_dtype,
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
const ConvBiasForward::BiasMode bias_mode,
const param::ConvBias::NonlineMode nonline_mode) {
for (auto& func : g_func_vec) {
if (func(src_dtype, filter_dtype, dst_dtype, fm, bias_mode,
nonline_mode)) {
return true;
}
}
return false;
}
\ No newline at end of file
/**
* \file dnn/src/common/nchw_nchwxx_valid.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
#include "src/fallback/conv_bias/opr_impl.h"
namespace megdnn {
namespace {
enum NchwNchwxxType {
NCHW44_FP32,
NCHW44_INT8,
NCHW44_INT8_INT8_INT16,
NCHW44_INT8_DOT,
NCHW88,
};
template <NchwNchwxxType T>
static inline bool nchw_nchwxx_valid(
const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
const DTypeEnum dst_dtype,
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
const BiasMode bias_mode,
const param::ConvBias::NonlineMode nonline_mode);
template <>
inline bool nchw_nchwxx_valid<NCHW44_FP32>(
const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
const DTypeEnum dst_dtype,
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
const BiasMode bias_mode,
const param::ConvBias::NonlineMode nonline_mode) {
bool ok_type = ((src_dtype == DTypeEnum::Float32 &&
filter_dtype == DTypeEnum::Float32 &&
(dst_dtype == DTypeEnum::Float32))) &&
(fm.format == param::Convolution::Format::NCHW44);
bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY ||
nonline_mode == param::ConvBias::NonlineMode::RELU ||
nonline_mode == param::ConvBias::NonlineMode::H_SWISH;
bool ok_src_dst =
fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1;
bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] &&
(fm.spatial[0] == 2 || fm.spatial[0] == 3 ||
fm.spatial[0] == 5 || fm.spatial[0] == 7);
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 1 || fm.stride[1] == 2);
bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS;
bool avaible = ok_type && ok_nonline && ok_src_dst && ok_filter &&
ok_slide && ok_conv;
return avaible;
}
template <>
inline bool nchw_nchwxx_valid<NCHW44_INT8>(
const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
const DTypeEnum dst_dtype,
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
const BiasMode bias_mode,
const param::ConvBias::NonlineMode nonline_mode) {
bool ok_type = ((src_dtype == DTypeEnum::QuantizedS8 &&
filter_dtype == DTypeEnum::QuantizedS8 &&
(dst_dtype == DTypeEnum::QuantizedS8))) &&
(fm.format == param::Convolution::Format::NCHW44);
bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY ||
nonline_mode == param::ConvBias::NonlineMode::RELU ||
nonline_mode == param::ConvBias::NonlineMode::H_SWISH;
bool ok_src_dst =
fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1;
bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] &&
(fm.spatial[0] == 2 || fm.spatial[0] == 3 ||
fm.spatial[0] == 5 || fm.spatial[0] == 7);
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 1 || fm.stride[1] == 2);
bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS;
bool avaible = ok_type && ok_nonline && ok_src_dst && ok_filter &&
ok_slide && ok_conv;
return avaible;
}
template <>
inline bool nchw_nchwxx_valid<NCHW44_INT8_INT8_INT16>(
const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
const DTypeEnum dst_dtype,
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
const BiasMode bias_mode,
const param::ConvBias::NonlineMode nonline_mode) {
bool ok_type =
((src_dtype == DTypeEnum::Int8 && filter_dtype == DTypeEnum::Int8 &&
(dst_dtype == DTypeEnum::Int16))) &&
(fm.format == param::Convolution::Format::NCHW44);
bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY;
bool ok_src_dst =
fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1;
bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] &&
(fm.spatial[0] == 2 || fm.spatial[0] == 3 ||
fm.spatial[0] == 5 || fm.spatial[0] == 7);
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 2 || fm.stride[0] == 1);
bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS;
bool avaible = ok_type && ok_nonline && ok_src_dst && ok_filter &&
ok_slide && ok_conv;
return avaible;
}
template <>
inline bool nchw_nchwxx_valid<NCHW44_INT8_DOT>(
const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
const DTypeEnum dst_dtype,
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
const BiasMode bias_mode,
const param::ConvBias::NonlineMode nonline_mode) {
bool ok_type = ((src_dtype == DTypeEnum::QuantizedS8 &&
filter_dtype == DTypeEnum::QuantizedS8 &&
(dst_dtype == DTypeEnum::QuantizedS8))) &&
(fm.format == param::Convolution::Format::NCHW44_DOT);
bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY ||
nonline_mode == param::ConvBias::NonlineMode::RELU ||
nonline_mode == param::ConvBias::NonlineMode::H_SWISH;
bool ok_src_dst =
fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1;
bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] &&
(fm.spatial[0] == 2 || fm.spatial[0] == 3 ||
fm.spatial[0] == 5 || fm.spatial[0] == 7);
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == fm.stride[1] &&
(fm.stride[0] == 1 || fm.stride[1] == 2);
bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS;
bool avaible = ok_type && ok_nonline && ok_src_dst && ok_filter &&
ok_slide && ok_conv;
return avaible;
}
template <>
inline bool nchw_nchwxx_valid<NCHW88>(
const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
const DTypeEnum dst_dtype,
const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
const BiasMode bias_mode,
const param::ConvBias::NonlineMode nonline_mode) {
bool ok_type = ((src_dtype == DTypeEnum::Float32 &&
filter_dtype == DTypeEnum::Float32 &&
(dst_dtype == DTypeEnum::Float32))) &&
(fm.format == param::Convolution::Format::NCHW88);
bool ok_src_dst =
fm.icpg < 8 && (fm.ocpg % 8 == 0 && fm.ocpg >= 8) && fm.group == 1;
bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS;
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1;
bool avaible = ok_type && ok_src_dst && ok_slide && ok_conv;
return avaible;
}
} // namespace
} // namespace megdnn
\ No newline at end of file
......@@ -11,6 +11,7 @@
*/
#pragma once
#include "src/common/nchw_nchwxx_valid.h"
#include "src/x86/conv_bias/opr_impl.h"
using namespace megdnn;
......@@ -29,6 +30,7 @@ class ConvBiasImpl::AlgoDirect final : public AlgoBase {
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids);
public:
bool is_reproducible() const override { return true; }
const char* name() const override {
......@@ -61,6 +63,7 @@ class ConvBiasImpl::AlgoDirectStride2 final : public AlgoBase {
const NCBKernParam& kern_param,
const NCBKernIndex& ncb_index,
const CpuNDRange& workspace_ids);
public:
bool is_reproducible() const override { return true; }
const char* name() const override {
......@@ -163,13 +166,19 @@ public:
AlgoSelectionStrategy) const override {
auto&& fm = param.filter_meta;
bool ok = (fm.format == param::ConvBias::Format::NCHW88) &&
fm.spatial_ndim == 2 &&
param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
param.dst_type.enumv() == DTypeEnum::Float32 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1;
return ok;
bool nchw_nchw88_ok = nchw_nchwxx_valid<NchwNchwxxType::NCHW88>(
param.src_type.enumv(), param.filter_type.enumv(),
param.dst_type.enumv(), param.filter_meta, param.bias_mode,
param.nonlineMode);
bool normal_conv_ok = (fm.format == param::ConvBias::Format::NCHW88) &&
fm.spatial_ndim == 2 &&
param.src_type.enumv() == DTypeEnum::Float32 &&
param.filter_type.enumv() == DTypeEnum::Float32 &&
param.dst_type.enumv() == DTypeEnum::Float32 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1;
return nchw_nchw88_ok || normal_conv_ok;
};
size_t get_workspace(const NCBKernSizeParam&) const override { return 0; }
......
......@@ -1816,155 +1816,67 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var,
}
template <typename OprType>
static inline bool nchw_nchwxx_valid(const OprType& opr,
const VarNodeArray& new_inp,
const size_t pack_size, bool is_dense,
bool is_dot = false);
template <>
inline bool nchw_nchwxx_valid<opr::ConvolutionForward>(
const opr::ConvolutionForward& opr, const VarNodeArray& new_inp,
const size_t pack_size, bool is_dense, bool is_dot) {
auto& filter_shape = new_inp[1]->shape();
auto filter_dtype = new_inp[1]->dtype();
bool is_int8 = filter_dtype.enumv() == DTypeEnum::QuantizedS8 ||
filter_dtype.enumv() == DTypeEnum::Int8;
const size_t oc = filter_shape[0];
const size_t ic = filter_shape[1];
bool is_like_nchw_nchwxx =
is_dense && oc % pack_size == 0 && ic < pack_size;
if (!is_like_nchw_nchwxx) {
static inline bool nchw_nchwxx_valid(
const OprType& opr, const VarNodeArray& new_inp, const size_t pack_size,
megdnn::param::ConvBias::NonlineMode nonline_mode =
megdnn::param::ConvBias::NonlineMode::IDENTITY,
bool is_dot = false) {
auto& src_node = new_inp[0];
auto& filter_node = new_inp[1];
auto dst_node = opr.output(0);
if (filter_node->shape().ndim != 4) {
return false;
}
SmallVector<TensorLayout> layouts;
//! src
layouts.push_back(
{new_inp[0]->shape(), new_inp[0]->dtype(), new_inp[0]->format()});
//! weight
layouts.push_back({{filter_shape[0] / pack_size, filter_shape[2],
filter_shape[3], filter_shape[1], pack_size},
new_inp[1]->dtype(),
new_inp[1]->format()});
auto out0 = opr.output(0);
auto& out_shape = out0->shape();
//! FIXME: return false if oc is invalid
layouts.push_back({{out_shape[0], out_shape[1] / pack_size, out_shape[2],
out_shape[3], pack_size},
out0->dtype(),
out0->format()});
auto megdnn_conv = opr::intl::get_megdnn_handle(opr.comp_node())
->create_operator<megdnn::ConvolutionForward>();
megdnn_conv.get()->param() = opr.param();
//! set by dtype
switch (pack_size) {
case 4:
if (is_dot && is_int8) {
megdnn_conv.get()->param().format =
megdnn::param::Convolution::Format::NCHW44_DOT;
} else {
megdnn_conv.get()->param().format =
megdnn::param::Convolution::Format::NCHW44;
}
break;
case 8:
megdnn_conv.get()->param().format =
megdnn::param::Convolution::Format::NCHW88;
break;
default:
break;
}
bool find_valid_algo = false;
auto algos = megdnn_conv.get()->get_all_algorithms(layouts[0], layouts[1],
layouts[2]);
for (auto i : algos) {
if (i->type() != nullptr) {
find_valid_algo = true;
megdnn::ConvolutionBase<megdnn::param::Convolution>::CanonizedFilterMeta fm;
fm.format = megdnn::param::Convolution::Format::NCHW;
fm.should_flip =
opr.param().mode == megdnn::ConvBiasForward::Mode::CONVOLUTION;
fm.group = 1;
fm.spatial_ndim = 2;
fm.ocpg = filter_node->shape()[0];
fm.icpg = filter_node->shape()[1];
fm.spatial[0] = filter_node->shape()[2];
fm.spatial[1] = filter_node->shape()[3];
fm.stride[0] = opr.param().stride_h;
fm.stride[1] = opr.param().stride_w;
fm.padding[0] = opr.param().pad_h;
fm.padding[1] = opr.param().pad_w;
fm.dilation[0] = opr.param().dilate_h;
fm.dilation[1] = opr.param().dilate_w;
megdnn::ConvBiasForward::BiasMode bias_mode =
megdnn::ConvBiasForward::BiasMode::NO_BIAS;
if (std::is_same<OprType, opr::ConvBiasForward>::value) {
auto& bias_shape = new_inp[2]->shape();
if (bias_shape.ndim == 0) {
bias_mode = megdnn::ConvBiasForward::BiasMode::NO_BIAS;
} else if (bias_shape.eq_shape(dst_node->shape())) {
bias_mode = megdnn::ConvBiasForward::BiasMode::BIAS;
} else {
//! just check the ndim, the detail shape check is in check_exec
mgb_assert(bias_shape.ndim == dst_node->shape().ndim);
bias_mode =
megdnn::ConvBiasForward::BiasMode::BROADCAST_CHANNEL_BIAS;
}
}
return find_valid_algo;
}
template <>
inline bool nchw_nchwxx_valid<opr::ConvBiasForward>(
const opr::ConvBiasForward& opr, const VarNodeArray& new_inp,
const size_t pack_size, bool is_dense, bool is_dot) {
auto& filter_shape = new_inp[1]->shape();
auto filter_dtype = new_inp[1]->dtype();
bool is_int8 = filter_dtype.enumv() == DTypeEnum::QuantizedS8 ||
filter_dtype.enumv() == DTypeEnum::Int8;
const size_t oc = filter_shape[0];
const size_t ic = filter_shape[1];
bool is_like_nchw_nchwxx =
is_dense && oc % pack_size == 0 && ic < pack_size;
if (!is_like_nchw_nchwxx) {
return false;
}
SmallVector<TensorLayout> layouts;
//! src
layouts.push_back(
{new_inp[0]->shape(), new_inp[0]->dtype(), new_inp[0]->format()});
//! weight
layouts.push_back({{filter_shape[0] / pack_size, filter_shape[2],
filter_shape[3], filter_shape[1], pack_size},
new_inp[1]->dtype(),
new_inp[1]->format()});
auto& bias_shape = new_inp[2]->shape();
layouts.push_back({{bias_shape[0], bias_shape[1] / pack_size, bias_shape[2],
bias_shape[3], pack_size},
new_inp[2]->dtype(),
new_inp[2]->format()});
auto out0 = opr.output(0);
auto& out_shape = out0->shape();
//! FIXME: return false if oc is invalid
layouts.push_back({{out_shape[0], out_shape[1] / pack_size, out_shape[2],
out_shape[3], pack_size},
out0->dtype(),
out0->format()});
// megdnn::ConvolutionForward
auto megdnn_conv = opr::intl::get_megdnn_handle(opr.comp_node())
->create_operator<megdnn::ConvBiasForward>();
megdnn_conv.get()->param() = opr.param();
//! FIXME: set by dtype
switch (pack_size) {
case 4:
if (is_dot && is_int8) {
megdnn_conv.get()->param().format =
megdnn::param::Convolution::Format::NCHW44_DOT;
} else {
megdnn_conv.get()->param().format =
megdnn::param::Convolution::Format::NCHW44;
}
break;
case 8:
megdnn_conv.get()->param().format =
megdnn::param::Convolution::Format::NCHW88;
break;
default:
break;
}
bool find_valid_algo = false;
auto algos = megdnn_conv.get()->get_all_algorithms(
layouts[0], layouts[1], layouts[2], {}, layouts[3]);
for (auto i : algos) {
if (i->type() != nullptr) {
find_valid_algo = true;
if (pack_size == 4) {
if (is_dot && filter_node->dtype().enumv() == DTypeEnum::QuantizedS8) {
fm.format = megdnn::param::Convolution::Format::NCHW44_DOT;
} else {
fm.format = megdnn::param::Convolution::Format::NCHW44;
}
} else if (pack_size == 8) {
fm.format = megdnn::param::Convolution::Format::NCHW88;
} else {
mgb_assert(0, "only support nchw44 nchw88");
}
return find_valid_algo;
return megdnn::ConvBiasForward::is_nchw_nchwxx_optimized(
src_node->dtype().enumv(), filter_node->dtype().enumv(),
dst_node->dtype().enumv(), fm, bias_mode, nonline_mode);
}
void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
using RelayoutMode = RelayoutPlaceholder::LayoutType;
using TestFilterResult = std::pair<TransType, RelayoutMode>;
......@@ -1984,19 +1896,6 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
megdnn::param::Pooling::Format pooling_format =
megdnn::param::Pooling::Format::NCHW88;
std::string convter_pass_name = "conv_format_nchw88";
#if MEGDNN_AARCH64 || MEGDNN_ARMv7
if (pack_c_size == 8) {
mgb_log_error(
"runtime backend is ARM, but nchw88 only support X86, you may "
"have performance loss\n");
}
#elif MEGDNN_X86
if (pack_c_size == 4) {
mgb_log_error(
"runtime backend is X86, but nchw44 only support arm, you may "
"have performance loss\n");
}
#endif
if (pack_c_size == 4) {
weight_to_nchwxx_mode_dense = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DENSE;
......@@ -2053,10 +1952,8 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
mgb_assert(conv_opr.param().format ==
megdnn::param::Convolution::Format::NCHW,
"ConvertFormat Pass only support converting NCHW to NCHWXX");
bool is_dense = conv_opr.param().sparse ==
megdnn::param::Convolution::Sparse::DENSE;
bool valid_nchw_nchw44 =
nchw_nchwxx_valid(conv_opr, new_inp, pack_c_size, is_dense);
nchw_nchwxx_valid(conv_opr, new_inp, pack_c_size);
auto is_trans = test_trans_nchwxx(
conv_opr.param().sparse, new_inp[1], conv_opr.param().stride_h,
conv_opr.param().stride_w, valid_nchw_nchw44);
......@@ -2133,10 +2030,9 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
mgb_assert(conv_bias_opr.param().format ==
megdnn::param::ConvBias::Format::NCHW,
"ConvertFormat Pass only support converting NCHW to NCHWXX");
bool is_dense = conv_bias_opr.param().sparse ==
megdnn::param::Convolution::Sparse::DENSE;
bool valid_nchw_nchw44 = nchw_nchwxx_valid(conv_bias_opr, new_inp,
pack_c_size, is_dense);
bool valid_nchw_nchw44 =
nchw_nchwxx_valid(conv_bias_opr, new_inp, pack_c_size,
conv_bias_opr.param().nonlineMode);
auto is_trans = test_trans_nchwxx(
conv_bias_opr.param().sparse, new_inp[1],
conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w,
......@@ -2371,13 +2267,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
MIDOUT_B("EnableNchw44DotPass::make")
auto ret = std::make_unique<EnableNchw44DotPass>();
ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
//! First is whether the conv can trans to nchwxx, second is the filter
//! trans mode
#if MEGDNN_X86
mgb_log_error(
"backend is X86, but nchw44_dot only support arm, you may have "
"performance loss\n");
#endif
//! First is whether the conv can trans to nchwxx, second is the filter
//! trans mode
using RelayoutMode = RelayoutPlaceholder::LayoutType;
struct TestTransResult {
......@@ -2453,14 +2344,12 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
megdnn::param::Convolution::Format::NCHW,
"ConvertFormat Pass only support converting NCHW to "
"NCHW44_DOT");
bool is_dense = conv_opr.param().sparse ==
megdnn::param::Convolution::Sparse::DENSE;
bool valid_nchw_nchw44 =
nchw_nchwxx_valid(conv_opr, new_inp, pack_c_size, is_dense);
bool valid_nchw_nchw44 = nchw_nchwxx_valid(
conv_opr, new_inp, pack_c_size,
megdnn::param::ConvBias::NonlineMode::IDENTITY, true);
auto is_trans = test_trans_nchw44_dot(
conv_opr.param().sparse, new_inp[1], conv_opr.param().stride_h,
conv_opr.param().stride_w, valid_nchw_nchw44);
//! can not trans to nchwxx
if (is_trans.trans_type == TransType::TRANS_NONE) {
mgb_assert(new_inp[1]->shape().ndim == 4 ||
......@@ -2533,10 +2422,9 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
mgb_assert(conv_bias_opr.param().format ==
megdnn::param::ConvBias::Format::NCHW,
"ConvertFormat Pass only support converting NCHW to NCHWXX");
bool is_dense = conv_bias_opr.param().sparse ==
megdnn::param::Convolution::Sparse::DENSE;
bool valid_nchw_nchw44 = nchw_nchwxx_valid(conv_bias_opr, new_inp,
pack_c_size, is_dense);
bool valid_nchw_nchw44 =
nchw_nchwxx_valid(conv_bias_opr, new_inp, pack_c_size,
conv_bias_opr.param().nonlineMode, true);
auto is_trans = test_trans_nchw44_dot(
conv_bias_opr.param().sparse, new_inp[1],
conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w,
......
......@@ -2913,7 +2913,8 @@ TEST(TestGoptInference, ConvertFormatNCHW88) {
opr::Convolution::Param param_conv;
param_conv.pad_h = param_conv.pad_w = 1;
auto w1 = mkcvar("w1", {8, 3, 3, 3}),
conv1 = opr::Convolution::make(x, w1, param_conv);
conv1 = opr::Convolution::make(x, w1, param_conv, {},
OperatorNodeConfig("conv1"));
//! channel wise
opr::ConvBias::Param param_conv_bias;
param_conv_bias.pad_h = param_conv_bias.pad_w = 1;
......@@ -2954,7 +2955,8 @@ TEST(TestGoptInference, ConvertFormatNCHW88) {
options.enable_nchw88();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
}
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW88,
find_opr<opr::Convolution>(y_opt, "conv1").param().format);
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW88,
find_opr<opr::ConvBias>(y_opt).param().format);
......@@ -3084,13 +3086,8 @@ TEST(TestGoptInference, ConvertFormatNCHW44) {
options.enable_nchw44();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44,
find_opr<opr::Convolution>(y_opt, "conv1").param().format);
#else
ASSERT_EQ(opr::Convolution::Param::Format::NCHW,
find_opr<opr::Convolution>(y_opt, "conv1").param().format);
#endif
ASSERT_EQ(opr::Convolution::Param::Format::NCHW,
find_opr<opr::ConvBias>(y_opt, "conv1_f1").param().format);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44,
......@@ -3325,17 +3322,10 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) {
options.enable_nchw44_dot();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44,
find_opr<opr::Convolution>(y_opt, "conv1").param().format);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44_DOT,
find_opr<opr::Convolution>(y_opt, "conv1_3_q").param().format);
#else
ASSERT_EQ(opr::Convolution::Param::Format::NCHW,
find_opr<opr::Convolution>(y_opt, "conv1").param().format);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW,
find_opr<opr::Convolution>(y_opt, "conv1_3_q").param().format);
#endif
ASSERT_EQ(opr::Convolution::Param::Format::NCHW,
find_opr<opr::ConvBias>(y_opt, "conv1_f1").param().format);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44_DOT,
......
......@@ -611,11 +611,11 @@ public:
"%s: input shapes (%s %s, %s %s) -> (%s %s): algo=%s "
"workspace=%.2fMiB reproducible=%d",
mgb_opr->dyn_typeinfo()->name,
layouts[0].TensorShape::to_string().c_str(),
layouts[0].to_string().c_str(),
layouts[0].dtype.name(),
layouts[1].TensorShape::to_string().c_str(),
layouts[1].to_string().c_str(),
layouts[1].dtype.name(),
layouts[layouts.size() - 1].TensorShape::to_string().c_str(),
layouts[layouts.size() - 1].to_string().c_str(),
layouts[layouts.size() - 1].dtype.name(),
algo->name(),
workspace / (1024 * 1024.0), algo->is_reproducible());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册