opr_impl.cpp 5.9 KB
Newer Older
1 2
#include "src/arm_common/pooling/opr_impl.h"
#include "src/arm_common/pooling/algo.h"
3
#include "src/common/algo_chooser.h"
M
Megvii Engine Team 已提交
4
#include "src/common/metahelper.h"
5 6 7 8 9

using namespace megdnn;
using namespace arm_common;

class PoolingImpl::AlgoPack : NonCopyableObj {
10 11
private:
    AlgoBase::Mapper m_all_algos_map;
12 13 14 15 16 17 18 19
    AlgoFilterxModexStride1 algo_filterx_modex_stride1;
    AlgoFilter2ModexStride2 algo_filter2_modex_stride2;
    AlgoFilter3MaxStride2 algo_filter3_max_stride2;
    AlgoFilter3AverageStride2 algo_filter3_average_stride2;
    AlgoFilter4MaxStride2 algo_filter4_max_stride2;
    AlgoFilter5MaxStride2 algo_filter5_max_stride2;
    AlgoInt8Filter2MaxStride2 algo_int8_filter2_max_stride2;
    AlgoInt8Filter3MaxStride2 algo_int8_filter3_max_stride2;
20 21 22 23
    AlgoFilter2ModexStridexNCHW44 algo_filter2_modex_stridex_nchw4;
    AlgoFilter3ModexStridexNCHW44 algo_filter3_modex_stridex_nchw4;
    AlgoFilter4ModexStridexNCHW44 algo_filter4_modex_stridex_nchw4;
    AlgoFilter5ModexStridexNCHW44 algo_filter5_modex_stridex_nchw4;
24
    AlgoFallback algo_fallback;
25 26 27 28 29 30 31 32 33 34 35

public:
    AlgoPack() {
        all_algos.emplace_back(&algo_filterx_modex_stride1);
        all_algos.emplace_back(&algo_filter2_modex_stride2);
        all_algos.emplace_back(&algo_filter3_max_stride2);
        all_algos.emplace_back(&algo_filter3_average_stride2);
        all_algos.emplace_back(&algo_filter4_max_stride2);
        all_algos.emplace_back(&algo_filter5_max_stride2);
        all_algos.emplace_back(&algo_int8_filter2_max_stride2);
        all_algos.emplace_back(&algo_int8_filter3_max_stride2);
36 37 38 39
        all_algos.emplace_back(&algo_filter3_modex_stridex_nchw4);
        all_algos.emplace_back(&algo_filter2_modex_stridex_nchw4);
        all_algos.emplace_back(&algo_filter4_modex_stridex_nchw4);
        all_algos.emplace_back(&algo_filter5_modex_stridex_nchw4);
40 41 42 43 44
        all_algos.emplace_back(&algo_fallback);

        for (auto&& algo : all_algos) {
            m_all_algos_map.emplace(algo->info().desc, algo);
        }
45 46
    }
    SmallVector<AlgoBase*> all_algos;
47
    const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
48 49
};

50 51
PoolingImpl::AlgoPack PoolingImpl::sm_algo_pack;

M
Megvii Engine Team 已提交
52 53
size_t PoolingImpl::get_workspace_in_bytes(
        const TensorLayout& src, const TensorLayout& dst) {
54
    auto param = make_pooling_kern_szie_param(this, src, dst);
55 56 57
    auto algo = get_algorithm(this, src, dst);
    if (!is_fallback_algo(algo)) {
        size_t arm_common_workspace = 0;
58

59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
        //! When multi-thread, every thread has its own workspace
        size_t nr_threads = static_cast<naive::HandleImpl*>(handle())
                                    ->megcore_dispatcher()
                                    ->nr_threads();
        if ((param.src_type.category() == DTypeCategory::FLOAT ||
             param.src_type == dtype::Int8{} ||
             param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
             param.src_type.enumv() == DTypeEnum::Quantized8Asymm) &&
            param.filter[0] == param.filter[1] &&
            (param.filter[0] == 3 || param.filter[0] == 5) &&
            param.format == Param::Format::NCHW &&
            (param.mode == Mode::MAX ||
             (param.mode == Mode::AVERAGE && param.filter[0] == 3)) &&
            param.stride[0] == 2 && param.stride[1] == 2 && param.isz[0] >= 2 &&
            param.isz[1] >= 2) {
            WorkspaceBundle ws = get_bundle(param);
            arm_common_workspace = ws.total_size_in_bytes() * nr_threads;
        }
77

78 79 80 81 82 83
        if ((param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
             param.src_type.enumv() == DTypeEnum::Int8) &&
            (param.format == param::Pooling::Format::NCHW44)) {
            WorkspaceBundle ws = get_bundle_nchw44(param);
            arm_common_workspace = ws.total_size_in_bytes() * nr_threads;
        }
84 85 86 87 88 89 90 91
        return arm_common_workspace;
    } else {
        auto fallback_worksapce =
                fallback::PoolingImpl::get_workspace_in_bytes(src, dst);
        return fallback_worksapce;
    }
}

M
Megvii Engine Team 已提交
92 93
void PoolingImpl::exec(
        _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
94 95
    check_exec(src.layout, dst.layout, workspace.size);
    auto param = make_pooling_kern_param(this, src, dst, workspace);
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
    auto algo = get_algorithm(this, src.layout, dst.layout);
    if (!is_fallback_algo(algo)) {
        algo->exec(param);
    } else {
        fallback::PoolingImpl::exec(src, dst, workspace);
    }
}

MEGDNN_DEF_GET_ALGO_FROM_DESC(PoolingImpl);

std::vector<Algorithm*> PoolingImpl::get_all_algorithms(
        const TensorLayout& src, const TensorLayout& dst) {
    auto param = make_pooling_kern_szie_param(this, src, dst);
    std::vector<Algorithm*> ret;
    ret.reserve(algo_pack().all_algos.size());
    for (auto i : algo_pack().all_algos) {
        if (i->usable(param)) {
            ret.push_back(i);
        }
    }
    return ret;
}
118 119
std::vector<Algorithm*> PoolingImpl::get_all_algorithms_safe(
        const TensorLayout& src, const TensorLayout& dst) {
M
Megvii Engine Team 已提交
120
    auto ret_safe = get_all_algorithms(src, dst);
121 122 123
    megdnn_assert(!ret_safe.empty(), "no usable pooling fwd algorithm");
    return ret_safe;
}
124 125 126 127 128 129 130 131 132 133 134

Algorithm* PoolingImpl::get_algorithm_heuristic(
        const TensorLayout& src, const TensorLayout& dst,
        size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
        const AlgoAttribute& negative_attr) {
    MEGDNN_MARK_USED_VAR(workspace_limit_in_bytes);

    auto param = make_pooling_kern_szie_param(this, src, dst);
    for (auto&& iter : sm_algo_pack.all_algos) {
        if (iter->is_available_attribute(param, positive_attr, negative_attr)) {
            return iter;
135 136
        }
    }
M
Megvii Engine Team 已提交
137 138 139 140 141
    megdnn_throw(ssprintf(
            "require algorithm with attribute(%s) and without "
            "attribute(%s), but can't get suitable algo.\n",
            Algorithm::attribute_str(positive_attr).c_str(),
            Algorithm::attribute_str(negative_attr).c_str()));
142
    return nullptr;
143 144 145
}

// vim: syntax=cpp.doxygen