nchw_nchwxx_valid.h 8.9 KB
Newer Older
1 2
#pragma once
#include "megdnn/oprs.h"
3 4

#if !MGE_BUILD_CL_ONLY
5
#include "src/fallback/conv_bias/opr_impl.h"
6 7
#endif

8
namespace megdnn {
9 10 11 12
#if MGE_BUILD_CL_ONLY
using NonlineMode = ConvBias::Param::NonlineMode;
using BiasMode = ConvBiasForward::BiasMode;
#endif
13 14 15 16 17 18 19
namespace {
enum NchwNchwxxType {
    NCHW44_FP32,
    NCHW44_INT8,
    NCHW44_INT8_INT8_INT16,
    NCHW44_INT8_DOT,
    NCHW88,
20 21 22
#if !MEGDNN_DISABLE_FLOAT16
    NCHW88_FP16,
#endif
23 24 25 26 27 28
};
template <NchwNchwxxType T>
static inline bool nchw_nchwxx_valid(
        const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
        const DTypeEnum dst_dtype,
        const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
M
Megvii Engine Team 已提交
29
        const BiasMode bias_mode, const param::ConvBias::NonlineMode nonline_mode);
30 31 32 33 34 35

template <>
inline bool nchw_nchwxx_valid<NCHW44_FP32>(
        const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
        const DTypeEnum dst_dtype,
        const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
M
Megvii Engine Team 已提交
36 37 38 39 40
        const BiasMode bias_mode, const param::ConvBias::NonlineMode nonline_mode) {
    bool ok_type =
            ((src_dtype == DTypeEnum::Float32 && filter_dtype == DTypeEnum::Float32 &&
              (dst_dtype == DTypeEnum::Float32))) &&
            (fm.format == param::Convolution::Format::NCHW44);
41 42 43 44 45 46 47
    bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY ||
                      nonline_mode == param::ConvBias::NonlineMode::RELU ||
                      nonline_mode == param::ConvBias::NonlineMode::H_SWISH;
    bool ok_src_dst =
            fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1;

    bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] &&
M
Megvii Engine Team 已提交
48 49
                     (fm.spatial[0] == 2 || fm.spatial[0] == 3 || fm.spatial[0] == 5 ||
                      fm.spatial[0] == 7);
50 51 52 53
    bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
                    fm.stride[0] == fm.stride[1] &&
                    (fm.stride[0] == 1 || fm.stride[1] == 2);
    bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS;
M
Megvii Engine Team 已提交
54 55
    bool avaible =
            ok_type && ok_nonline && ok_src_dst && ok_filter && ok_slide && ok_conv;
56 57 58 59 60 61 62
    return avaible;
}
template <>
inline bool nchw_nchwxx_valid<NCHW44_INT8>(
        const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
        const DTypeEnum dst_dtype,
        const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
M
Megvii Engine Team 已提交
63
        const BiasMode bias_mode, const param::ConvBias::NonlineMode nonline_mode) {
64 65 66 67 68 69 70 71 72
    bool ok_type = ((src_dtype == DTypeEnum::QuantizedS8 &&
                     filter_dtype == DTypeEnum::QuantizedS8 &&
                     (dst_dtype == DTypeEnum::QuantizedS8))) &&
                   (fm.format == param::Convolution::Format::NCHW44);
    bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY ||
                      nonline_mode == param::ConvBias::NonlineMode::RELU ||
                      nonline_mode == param::ConvBias::NonlineMode::H_SWISH;
    bool ok_src_dst =
            fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1;
M
Megvii Engine Team 已提交
73 74 75 76
    bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] &&
                     (fm.spatial[0] == 2 || fm.spatial[0] == 3 || fm.spatial[0] == 5 ||
                      fm.spatial[0] == 7 ||
                      (fm.spatial[0] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1));
77 78 79 80
    bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
                    fm.stride[0] == fm.stride[1] &&
                    (fm.stride[0] == 1 || fm.stride[1] == 2);
    bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS;
M
Megvii Engine Team 已提交
81 82
    bool avaible =
            ok_type && ok_nonline && ok_src_dst && ok_filter && ok_slide && ok_conv;
83 84 85 86 87 88 89
    return avaible;
}
template <>
inline bool nchw_nchwxx_valid<NCHW44_INT8_INT8_INT16>(
        const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
        const DTypeEnum dst_dtype,
        const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
M
Megvii Engine Team 已提交
90 91 92 93
        const BiasMode bias_mode, const param::ConvBias::NonlineMode nonline_mode) {
    bool ok_type = ((src_dtype == DTypeEnum::Int8 && filter_dtype == DTypeEnum::Int8 &&
                     (dst_dtype == DTypeEnum::Int16))) &&
                   (fm.format == param::Convolution::Format::NCHW44);
94 95 96 97
    bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY;
    bool ok_src_dst =
            fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1;
    bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] &&
M
Megvii Engine Team 已提交
98 99
                     (fm.spatial[0] == 2 || fm.spatial[0] == 3 || fm.spatial[0] == 5 ||
                      fm.spatial[0] == 7);
100 101 102 103
    bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
                    fm.stride[0] == fm.stride[1] &&
                    (fm.stride[0] == 2 || fm.stride[0] == 1);
    bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS;
M
Megvii Engine Team 已提交
104 105
    bool avaible =
            ok_type && ok_nonline && ok_src_dst && ok_filter && ok_slide && ok_conv;
106 107 108 109 110 111 112
    return avaible;
}
template <>
inline bool nchw_nchwxx_valid<NCHW44_INT8_DOT>(
        const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
        const DTypeEnum dst_dtype,
        const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
M
Megvii Engine Team 已提交
113
        const BiasMode bias_mode, const param::ConvBias::NonlineMode nonline_mode) {
114 115 116 117 118 119 120 121 122
    bool ok_type = ((src_dtype == DTypeEnum::QuantizedS8 &&
                     filter_dtype == DTypeEnum::QuantizedS8 &&
                     (dst_dtype == DTypeEnum::QuantizedS8))) &&
                   (fm.format == param::Convolution::Format::NCHW44_DOT);
    bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY ||
                      nonline_mode == param::ConvBias::NonlineMode::RELU ||
                      nonline_mode == param::ConvBias::NonlineMode::H_SWISH;
    bool ok_src_dst =
            fm.icpg < 4 && (fm.ocpg % 4 == 0 && fm.ocpg >= 4) && fm.group == 1;
M
Megvii Engine Team 已提交
123 124 125 126
    bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] &&
                     (fm.spatial[0] == 2 || fm.spatial[0] == 3 || fm.spatial[0] == 5 ||
                      fm.spatial[0] == 7 ||
                      (fm.spatial[0] == 1 && fm.stride[0] == 1 && fm.stride[1] == 1));
127 128 129 130
    bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
                    fm.stride[0] == fm.stride[1] &&
                    (fm.stride[0] == 1 || fm.stride[1] == 2);
    bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS;
M
Megvii Engine Team 已提交
131 132
    bool avaible =
            ok_type && ok_nonline && ok_src_dst && ok_filter && ok_slide && ok_conv;
133 134 135 136 137 138 139 140
    return avaible;
}

template <>
inline bool nchw_nchwxx_valid<NCHW88>(
        const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
        const DTypeEnum dst_dtype,
        const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
M
Megvii Engine Team 已提交
141 142 143 144 145
        const BiasMode bias_mode, const param::ConvBias::NonlineMode) {
    bool ok_type =
            ((src_dtype == DTypeEnum::Float32 && filter_dtype == DTypeEnum::Float32 &&
              (dst_dtype == DTypeEnum::Float32))) &&
            (fm.format == param::Convolution::Format::NCHW88);
146 147 148 149 150 151 152
    bool ok_src_dst =
            fm.icpg < 8 && (fm.ocpg % 8 == 0 && fm.ocpg >= 8) && fm.group == 1;
    bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS;
    bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1;
    bool avaible = ok_type && ok_src_dst && ok_slide && ok_conv;
    return avaible;
}
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
#if !MEGDNN_DISABLE_FLOAT16
template <>
inline bool nchw_nchwxx_valid<NCHW88_FP16>(
        const DTypeEnum src_dtype, const DTypeEnum filter_dtype,
        const DTypeEnum dst_dtype,
        const ConvolutionBase<param::Convolution>::CanonizedFilterMeta& fm,
        const BiasMode bias_mode, const param::ConvBias::NonlineMode nonline_mode) {
    bool ok_type =
            ((src_dtype == DTypeEnum::Float16 && filter_dtype == DTypeEnum::Float16 &&
              (dst_dtype == DTypeEnum::Float16))) &&
            (fm.format == param::Convolution::Format::NCHW88);
    bool ok_nonline = nonline_mode == param::ConvBias::NonlineMode::IDENTITY ||
                      nonline_mode == param::ConvBias::NonlineMode::RELU ||
                      nonline_mode == param::ConvBias::NonlineMode::H_SWISH;
    bool ok_src_dst =
            fm.icpg < 8 && (fm.ocpg % 8 == 0 && fm.ocpg >= 8) && fm.group == 1;

    bool ok_filter = fm.spatial_ndim == 2 && fm.spatial[0] == fm.spatial[1] &&
                     (fm.spatial[0] == 2 || fm.spatial[0] == 3 || fm.spatial[0] == 5 ||
                      fm.spatial[0] == 7);
    bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
                    fm.stride[0] == fm.stride[1] &&
                    (fm.stride[0] == 1 || fm.stride[1] == 2);
    bool ok_conv = !fm.should_flip && bias_mode != BiasMode::BIAS;
    bool avaible =
            ok_type && ok_nonline && ok_src_dst && ok_filter && ok_slide && ok_conv;
    return avaible;
}
#endif
182 183

}  // namespace
184
}  // namespace megdnn