opr_impl.cpp 5.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
/**
 * \file dnn/src/arm_common/pooling/opr_impl.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */
#include "src/arm_common/pooling/opr_impl.h"
#include "src/arm_common/pooling/algo.h"
#include "src/common/metahelper.h"

using namespace megdnn;
using namespace arm_common;

class PoolingImpl::AlgoPack : NonCopyableObj {
    AlgoFilterxModexStride1 algo_filterx_modex_stride1;
    AlgoFilter2ModexStride2 algo_filter2_modex_stride2;
    AlgoFilter3MaxStride2 algo_filter3_max_stride2;
    AlgoFilter3AverageStride2 algo_filter3_average_stride2;
    AlgoFilter4MaxStride2 algo_filter4_max_stride2;
    AlgoFilter5MaxStride2 algo_filter5_max_stride2;
    AlgoInt8Filter2MaxStride2 algo_int8_filter2_max_stride2;
    AlgoInt8Filter3MaxStride2 algo_int8_filter3_max_stride2;
28 29 30 31
    AlgoFilter2ModexStridexNCHW44 algo_filter2_modex_stridex_nchw4;
    AlgoFilter3ModexStridexNCHW44 algo_filter3_modex_stridex_nchw4;
    AlgoFilter4ModexStridexNCHW44 algo_filter4_modex_stridex_nchw4;
    AlgoFilter5ModexStridexNCHW44 algo_filter5_modex_stridex_nchw4;
32 33 34 35 36 37 38 39 40 41 42

public:
    AlgoPack() {
        all_algos.emplace_back(&algo_filterx_modex_stride1);
        all_algos.emplace_back(&algo_filter2_modex_stride2);
        all_algos.emplace_back(&algo_filter3_max_stride2);
        all_algos.emplace_back(&algo_filter3_average_stride2);
        all_algos.emplace_back(&algo_filter4_max_stride2);
        all_algos.emplace_back(&algo_filter5_max_stride2);
        all_algos.emplace_back(&algo_int8_filter2_max_stride2);
        all_algos.emplace_back(&algo_int8_filter3_max_stride2);
43 44 45 46
        all_algos.emplace_back(&algo_filter3_modex_stridex_nchw4);
        all_algos.emplace_back(&algo_filter2_modex_stridex_nchw4);
        all_algos.emplace_back(&algo_filter4_modex_stridex_nchw4);
        all_algos.emplace_back(&algo_filter5_modex_stridex_nchw4);
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
    }
    SmallVector<AlgoBase*> all_algos;
};

PoolingImpl::PoolingKernSizeParam PoolingImpl::make_pooling_kern_szie_param(
        fallback::PoolingImpl* opr, const TensorLayout& src,
        const TensorLayout& dst) {
    auto safe_u32 = [](size_t v) -> uint32_t {
        megdnn_assert(v <= std::numeric_limits<uint32_t>::max(),
                      "value too large: %zu", v);
        return v;
    };
    return {safe_u32(src.shape[0]),
            safe_u32(src.shape[1]),
            {{safe_u32(src.shape[2]), safe_u32(src.shape[3])}},
            {{safe_u32(dst.shape[2]), safe_u32(dst.shape[3])}},
            {{safe_u32(opr->param().pad_h), safe_u32(opr->param().pad_w)}},
            {{safe_u32(opr->param().window_h),
              safe_u32(opr->param().window_w)}},
            {{safe_u32(opr->param().stride_h),
              safe_u32(opr->param().stride_w)}},
            src.dtype,
            dst.dtype,
            opr->handle(),
            opr->param().format,
            opr->param().mode};
};

PoolingImpl::PoolingKernParam PoolingImpl::make_pooling_kern_param(
        fallback::PoolingImpl* opr, _megdnn_tensor_in src,
        _megdnn_tensor_out dst, _megdnn_workspace workspace) {
    PoolingKernParam ret;
    static_cast<PoolingKernSizeParam&>(ret) =
            make_pooling_kern_szie_param(opr, src.layout, dst.layout);
    ret.src_ptr = src.raw_ptr;
    ret.dst_ptr = dst.raw_ptr;
    ret.workspace_ptr = workspace.raw_ptr;
    ret.workspace_size = workspace.size;
    return ret;
};

size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src,
                                           const TensorLayout& dst) {
    bool find_algo = false;
    static AlgoPack m_algo_pack;
    auto param = make_pooling_kern_szie_param(this, src, dst);
    for (auto& m_algo : m_algo_pack.all_algos) {
        if (m_algo->usable(param)) {
            find_algo = true;
            break;
        }
    }
    size_t arm_common_workspace = 0;

    //! When multi-thread, every thread has its own workspace
    size_t nr_threads = static_cast<naive::HandleImpl*>(handle())
                                ->megcore_dispatcher()
                                ->nr_threads();
    if ((param.src_type.category() == DTypeCategory::FLOAT ||
         param.src_type == dtype::Int8{} ||
         param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
         param.src_type.enumv() == DTypeEnum::Quantized8Asymm) &&
        param.filter[0] == param.filter[1] &&
        (param.filter[0] == 3 || param.filter[0] == 5) &&
        param.format == Param::Format::NCHW &&
        (param.mode == Mode::MAX ||
         (param.mode == Mode::AVERAGE && param.filter[0] == 3)) &&
        param.stride[0] == 2 && param.stride[1] == 2 && param.isz[0] >= 2 &&
        param.isz[1] >= 2) {
        WorkspaceBundle ws = get_bundle(param);
        arm_common_workspace = ws.total_size_in_bytes() * nr_threads;
    }

120 121 122 123 124 125
    if ((param.src_type.enumv() == DTypeEnum::QuantizedS8) &&
        (param.format == param::Pooling::Format::NCHW44)) {
        WorkspaceBundle ws = get_bundle_nchw44(param);
        arm_common_workspace = ws.total_size_in_bytes() * nr_threads;
    }

126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
    if (find_algo) {
        return arm_common_workspace;
    } else {
        auto fallback_worksapce =
                fallback::PoolingImpl::get_workspace_in_bytes(src, dst);
        return fallback_worksapce;
    }
}

void PoolingImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
                       _megdnn_workspace workspace) {
    check_exec(src.layout, dst.layout, workspace.size);
    auto param = make_pooling_kern_param(this, src, dst, workspace);
    static AlgoPack m_algo_pack;
    for (auto& m_algo : m_algo_pack.all_algos) {
        if (m_algo->usable(param)) {
            m_algo->exec(param);
            return;
        }
    }
    fallback::PoolingImpl::exec(src, dst, workspace);
}

// vim: syntax=cpp.doxygen