algo.h 4.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
/**
 * \file dnn/src/fallback/pooling/gi/algo.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */
#pragma once
#include "src/common/utils.h"
#include "src/fallback/pooling/opr_impl.h"

#include "pooling_helper.h"

#include "src/naive/handle.h"
#include "src/naive/pooling/opr_impl.h"

namespace megdnn {
namespace fallback {

using AlgoBase = PoolingImpl::AlgoBase;

class PoolingImpl::AlgoGiFilterxModexStride1 final : public AlgoBase {
public:
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; };
    const char* name() const override { return "GI_POOLING_STRIDE1"; }
    bool usable(const PoolingKernSizeParam& param) const override;
    void exec(const PoolingKernParam& param) const override;
    MEGDNN_DECL_ALGO_TYPE(GI_FilterxModexStride1)
};

class PoolingImpl::AlgoGiFilter2ModexStride2 final : public AlgoBase {
public:
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; };
    const char* name() const override { return "GI_POOLING_STRIDE2"; }
    bool usable(const PoolingKernSizeParam& param) const override;
    void exec(const PoolingKernParam& param) const override;
    MEGDNN_DECL_ALGO_TYPE(GI_Filter2ModexStride2)
};
class PoolingImpl::AlgoGiFilter3MaxStride2 final : public AlgoBase {
public:
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; };
    const char* name() const override { return "GI_POOLING_FILTER3_MAX"; }
    bool usable(const PoolingKernSizeParam& param) const override;
    void exec(const PoolingKernParam& param) const override;
    MEGDNN_DECL_ALGO_TYPE(GI_Filter3MaxStride2)
};

class PoolingImpl::AlgoGiFilter3AverageStride2 final : public AlgoBase {
public:
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; };
    const char* name() const override { return "GI_POOLING_FILTER3_AVERAGE"; }
    bool usable(const PoolingKernSizeParam& param) const override;
    void exec(const PoolingKernParam& param) const override;
    MEGDNN_DECL_ALGO_TYPE(GI_Filter3AverageStride2)
};

class PoolingImpl::AlgoGiFilter4MaxStride2 final : public AlgoBase {
public:
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; };
    const char* name() const override { return "GI_POOLING_FILTER4_MAX"; }
    bool usable(const PoolingKernSizeParam& param) const override;
    void exec(const PoolingKernParam& param) const override;
    MEGDNN_DECL_ALGO_TYPE(GI_Filter4MaxStride2)
};

class PoolingImpl::AlgoGiFilter5MaxStride2 final : public AlgoBase {
public:
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; };
    const char* name() const override { return "GI_POOLING_FILTER5_MAX"; }
    bool usable(const PoolingKernSizeParam& param) const override;
    void exec(const PoolingKernParam& param) const override;
    MEGDNN_DECL_ALGO_TYPE(GI_Filter5MaxStride2)
};

class PoolingImpl::AlgoGiFp32ModexStridexNCHW44 final : public AlgoBase {
public:
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; };
    const char* name() const override { return "GI_POOLING_FP32_MODEX_STRIDEX_NCHW44"; }
    bool usable(const PoolingKernSizeParam& param) const override;
    void exec(const PoolingKernParam& param) const override;
    MEGDNN_DECL_ALGO_TYPE(GI_Fp32ModexStridexNCHW44)
};

class PoolingImpl::AlgoFallback final : public AlgoBase {
public:
    AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; };
    const char* name() const override { return "FALLBACK_NOT_GI_POOLING"; }
    bool usable(const PoolingKernSizeParam&) const override { return true; }
    void exec(const PoolingKernParam& /*param*/) const override {
        megdnn_assert(false, "code issue happened!!");
    }
    MEGDNN_DECL_ALGO_TYPE(FallbackNotGI)
};
WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam&);

}  // namespace fallback
}  // namespace megdnn

// vim: syntax=cpp.doxygen