algos.h 7.6 KB
Newer Older
1 2 3 4 5 6 7 8
/**
 * \file dnn/src/x86/conv_bias/f32/algos.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12 13 14 15 16 17 18 19 20 21 22 23
 */

#pragma once
#include "src/x86/conv_bias/opr_impl.h"

using namespace megdnn;
using namespace x86;

/* ===================== direct algo ===================== */
class ConvBiasImpl::AlgoDirect final : public AlgoBase {
    SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
    WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const;

24
    static void copy_padding_kern(const WorkspaceBundle& bundle,
25
                                  const NCBKernParam& kern_param,
26 27
                                  const NCBKernIndex& ncb_index,
                                  const CpuNDRange& workspace_ids);
28
    static void do_conv_kern(const WorkspaceBundle& bundle,
29
                             const NCBKernParam& kern_param,
30 31
                             const NCBKernIndex& ncb_index,
                             const CpuNDRange& workspace_ids);
32 33 34 35 36 37 38 39 40
    bool m_large_group;

public:
    AlgoDirect(bool large_group) : m_large_group(large_group) {}
    bool is_reproducible() const override { return true; }
    const char* name() const override {
        return m_large_group ? "X86_CONV_BIAS_DIRECT_STRIDE1_LARGE_GROUP"
                             : "X86_CONV_BIAS_DIRECT_STRIDE1_SMALL_GROUP";
    }
41
    bool usable(const NCBKernSizeParam& param,
42 43
                AlgoSelectionStrategy algo_selection_strategy) const override;

44
    size_t get_workspace(const NCBKernSizeParam& param) const override;
45 46

    virtual SmallVector<NCBKern> dispatch_kerns(
47

48 49 50 51 52 53 54 55 56 57 58 59
            const NCBKernSizeParam& param) const override {
        return get_kimpls(param);
    }

    void* type() const override;
};

/* ===================== direct-stride2 algo ===================== */
class ConvBiasImpl::AlgoDirectStride2 final : public AlgoBase {
    SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
    WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const;

60
    static void copy_padding_kern(const WorkspaceBundle& bundle,
61
                                  const NCBKernParam& kern_param,
62
                                  const NCBKernIndex& ncb_index,
63 64
                                  const CpuNDRange& workspace_ids);
    static void do_conv_kern(const WorkspaceBundle& bundle,
65
                             const NCBKernParam& kern_param,
66 67
                             const NCBKernIndex& ncb_index,
                             const CpuNDRange& workspace_ids);
68 69 70 71 72 73 74 75 76
    bool m_large_group;

public:
    AlgoDirectStride2(bool large_group) : m_large_group(large_group) {}
    bool is_reproducible() const override { return true; }
    const char* name() const override {
        return m_large_group ? "X86_CONV_BIAS_DIRECT_STRIDE2_LARGE_GROUP"
                             : "X86_CONV_BIAS_DIRECT_STRIDE2_SMALL_GROUP";
    }
77
    bool usable(const NCBKernSizeParam& param,
78 79
                AlgoSelectionStrategy algo_selection_strategy) const override;

80
    size_t get_workspace(const NCBKernSizeParam& param) const override;
81 82

    virtual SmallVector<NCBKern> dispatch_kerns(
83

84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
            const NCBKernSizeParam& param) const override {
        return get_kimpls(param);
    }

    void* type() const override;
};
/* =========================== winograd ======================== */
class ConvBiasImpl::AlgoFP32WinogradF63_8x8 final : public AlgoBase {
public:
    AlgoFP32WinogradF63_8x8(fallback::MatrixMulImpl::AlgoBase* matmul_algo,
                            uint32_t tile_size)
            : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {}
    const char* name() const override {
        if (m_name.empty()) {
            m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
                    m_matmul_algo->name(), {8, 6, m_tile_size});
        }
        return m_name.c_str();
    }
    void* type() const override;
104
    MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
};

class ConvBiasImpl::AlgoFP32WinogradF23_8x8 final : public AlgoBase {
public:
    AlgoFP32WinogradF23_8x8(fallback::MatrixMulImpl::AlgoBase* matmul_algo,
                            uint32_t tile_size)
            : m_matmul_algo{matmul_algo}, m_tile_size{tile_size} {}
    const char* name() const override {
        if (m_name.empty()) {
            m_name = ConvBiasImpl::algo_name<ConvBias::WinogradParam>(
                    m_matmul_algo->name(), {8, 2, m_tile_size});
        }
        return m_name.c_str();
    }
    void* type() const override;
120
    MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
121 122 123 124 125 126 127 128 129 130 131 132
};

/* ===================== matmul algo ===================== */
class ConvBiasImpl::AlgoMatrixMul final : public AlgoBase {
    static MatrixMul* get_matmul_opr();
    static WorkspaceBundle get_bundle(const NCBKernSizeParam& param);
    static void kimpl(const NCBKernParam& param, const NCBKernIndex&);

public:
    bool is_reproducible() const override { return true; }
    const char* name() const override { return "X86_CONV_BIAS_MATMUL"; }

133
    bool usable(const NCBKernSizeParam& param,
134 135 136 137 138 139 140 141 142 143 144 145 146
                AlgoSelectionStrategy) const override {
        auto&& fm = param.filter_meta;
        return fm.format == Param::Format::NCHW && fm.spatial_ndim == 2 &&
               param.src_type.enumv() == DTypeEnum::Float32 &&
               param.filter_type.enumv() == DTypeEnum::Float32 &&
               param.dst_type.enumv() == DTypeEnum::Float32 &&
               fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
               //! The matmul opr is only used in single thread
               //! TODO:support the no pack matmul algo in fallback im2col +
               //! matmul
               param.nr_threads == 1_z;
    }

147
    bool is_preferred(const NCBKernSizeParam&) const override;
148

149
    size_t get_workspace(const NCBKernSizeParam& param) const override {
150 151 152 153 154 155 156 157 158 159 160
        return get_bundle(param).total_size_in_bytes();
    }
    SmallVector<NCBKern> dispatch_kerns(
            const NCBKernSizeParam& param) const override {
        size_t group = param.filter_meta.group;
        return {{kimpl, {group, 1_z, 1_z}}};
    }

    void* type() const override;
};

161
#if MEGDNN_X86_WITH_MKL_DNN
162 163 164 165 166 167 168 169
class ConvBiasImpl::AlgoMkldnnConv final : public AlgoBase {
    static void kern_mkldnn_fp32(const NCBKernParam& param,
                                 const NCBKernIndex&);

public:
    AlgoMkldnnConv() {}
    bool is_reproducible() const override { return true; }
    const char* name() const override { return "MKLDNN_CONV_FP32"; }
170
    bool usable(const NCBKernSizeParam& param,
171 172 173 174 175 176 177 178 179 180 181 182
                AlgoSelectionStrategy) const override {
        auto&& fm = param.filter_meta;

        bool ok = (fm.format == param::ConvBias::Format::NCHW88) &&
                  fm.spatial_ndim == 2 &&
                  param.src_type.enumv() == DTypeEnum::Float32 &&
                  param.filter_type.enumv() == DTypeEnum::Float32 &&
                  param.dst_type.enumv() == DTypeEnum::Float32 &&
                  fm.dilation[0] == 1 && fm.dilation[1] == 1;
        return ok;
    };

183
    size_t get_workspace(const NCBKernSizeParam&) const override { return 0; }
184 185 186 187 188 189 190 191 192 193 194 195 196

    SmallVector<NCBKern> dispatch_kerns(
            const NCBKernSizeParam& /*param*/) const override {
        auto kern = [](const NCBKernParam& param,
                       const NCBKernIndex& ncb_index) {
            kern_mkldnn_fp32(param, ncb_index);
        };
        return {{kern, {1_z, 1_z, 1_z}}};
    }
    void* type() const override;
};
#endif
// vim: syntax=cpp.doxygen