common.h 9.4 KB
Newer Older
1 2 3 4 5 6 7 8
/**
 * \file dnn/src/fallback/conv_bias/common.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15
 */
#pragma once

#include <stdint.h>
#include "megdnn/oprs.h"
16
#include "src/common/postprocess.h"
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
#include "src/common/utils.h"

namespace megdnn {
using NonlineMode = ConvBias::Param::NonlineMode;
using BiasMode = ConvBiasForward::BiasMode;

#define DISPATCH_GEMM_NONLINE(_gemm, _gemm_midout_enum, _bias,      \
                              _bias_midout_enum)                    \
    switch (param.nonlineMode) {                                    \
        case param::ConvBias::NonlineMode::IDENTITY: {              \
            DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
                                   _bias_midout_enum, identity, 0); \
            break;                                                  \
        }                                                           \
        case param::ConvBias::NonlineMode::RELU: {                  \
            DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
                                   _bias_midout_enum, relu, 1);     \
            break;                                                  \
        }                                                           \
        case param::ConvBias::NonlineMode::H_SWISH: {               \
            DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
                                   _bias_midout_enum, hswish, 2);   \
            break;                                                  \
        }                                                           \
        default:                                                    \
            megdnn_assert(0);                                       \
            break;                                                  \
    }

#define DISPATCH_GEMM_BIAS(_gemm, _gemm_midout_enum)                         \
    switch (param.bias_mode) {                                               \
        case BiasMode::NO_BIAS:                                              \
            DISPATCH_GEMM_NONLINE(_gemm, _gemm_midout_enum, nobias, 0)       \
            break;                                                           \
        case BiasMode::BROADCAST_CHANNEL_BIAS:                               \
            DISPATCH_GEMM_NONLINE(_gemm, _gemm_midout_enum, bias_channel, 1) \
            break;                                                           \
        default:                                                             \
            megdnn_assert(0);                                                \
            break;                                                           \
    }

#define DISPATCH_CONV_NONLINE(i, midout_tag, stride, _conv, BIAS_MODE,         \
                              dst_type)                                        \
    switch (param.nonlineMode) {                                               \
        case param::ConvBias::NonlineMode::IDENTITY: {                         \
            DISPATCH_CONV_STRATEGY(i, midout_tag, stride, _conv, BIAS_MODE,    \
                                   TypeCvtOp<dt_qint32 MEGDNN_COMMA dst_type>, \
                                   0);                                         \
            break;                                                             \
        }                                                                      \
        case param::ConvBias::NonlineMode::RELU: {                             \
            DISPATCH_CONV_STRATEGY(i, midout_tag, stride, _conv, BIAS_MODE,    \
                                   ReluOp<dt_qint32 MEGDNN_COMMA dst_type>,    \
                                   1);                                         \
            break;                                                             \
        }                                                                      \
        case param::ConvBias::NonlineMode::H_SWISH: {                          \
            DISPATCH_CONV_STRATEGY(i, midout_tag, stride, _conv, BIAS_MODE,    \
                                   HSwishOp<dt_qint32 MEGDNN_COMMA dst_type>,  \
                                   2);                                         \
            break;                                                             \
        }                                                                      \
        default:                                                               \
            megdnn_assert(0);                                                  \
            break;                                                             \
    }

#define DISPATCH_CONV_BIAS(i, midout_tag, stride, _conv, dst_type)            \
    switch (param.bias_mode) {                                                \
        case BiasMode::NO_BIAS:                                               \
            DISPATCH_CONV_NONLINE(i, midout_tag, stride, _conv,               \
                                  BiasMode::NO_BIAS, dst_type)                \
            break;                                                            \
        case BiasMode::BROADCAST_CHANNEL_BIAS:                                \
            DISPATCH_CONV_NONLINE(i, midout_tag, stride, _conv,               \
                                  BiasMode::BROADCAST_CHANNEL_BIAS, dst_type) \
            break;                                                            \
        default:                                                              \
            megdnn_assert(0);                                                 \
            break;                                                            \
    }

#define DISPATCH_CONV_STRATEGY(i, midout_tag, stride, conv, BIAS_MODE, Op, \
                               _nonline_midout_enum)                       \
    MIDOUT_BEGIN(midout_tag, i, stride, midout_iv(BIAS_MODE),              \
                 _nonline_midout_enum) {                                   \
        return {{conv<i, BIAS_MODE, Op>, {1_z, 1_z, 1_z}}};                \
    }                                                                      \
    MIDOUT_END()

#define DISPATCH_FILTER(filter, kern, arg...) \
    switch (filter) {                         \
        case 2:                               \
            kern(2, ##arg);                   \
            break;                            \
        case 3:                               \
            kern(3, ##arg);                   \
            break;                            \
        case 5:                               \
            kern(5, ##arg);                   \
            break;                            \
        case 7:                               \
            kern(7, ##arg);                   \
            break;                            \
        default:                              \
            megdnn_assert(0);                 \
            break;                            \
    }

127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
#define DISPATCH_FILTER_CHANNEL_WISE(filter, kern, arg...) \
    switch (filter) {                                      \
        case 2:                                            \
            kern(2, ##arg);                                \
            break;                                         \
        case 3:                                            \
            kern(3, ##arg);                                \
            break;                                         \
        case 5:                                            \
            kern(5, ##arg);                                \
            break;                                         \
        default:                                           \
            megdnn_assert(0);                              \
            break;                                         \
    }

143 144
#define MEGDNN_WINOGRAD_ALGO_FUN_DECLARE()                                     \
    bool is_reproducible() const override { return true; }                     \
145
    bool usable(const NCBKernSizeParam& param,                                 \
146
                AlgoSelectionStrategy algo_selection_strategy) const override; \
147 148
    size_t get_workspace(const NCBKernSizeParam& param) const override;        \
    virtual SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam& param) \
149 150
            const override;                                                    \
    SmallVector<TensorLayout> deduce_preprocessed_filter_layout(               \
151 152
            const NCBKernSizeParam& param) const override;                     \
    size_t get_preprocess_workspace(const NCBKernSizeParam& param)             \
153 154
            const override;                                                    \
    virtual SmallVector<NCBKern> dispatch_preprocess_kerns(                    \
155
            const NCBKernSizeParam& param) const override;                     \
156 157 158 159 160 161
                                                                               \
private:                                                                       \
    fallback::MatrixMulImpl::AlgoBase* m_matmul_algo;                          \
    mutable std::string m_name;                                                \
    uint32_t m_tile_size;

162 163 164
}  // namespace megdnn

// vim: syntax=cpp.doxygen