opr_impl.h 2.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
/**
 * \file dnn/src/arm_common/pooling/opr_impl.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */
#pragma once
#include "src/fallback/pooling/opr_impl.h"

namespace megdnn {
namespace arm_common {

class PoolingImpl final : public fallback::PoolingImpl {
public:
    using fallback::PoolingImpl::PoolingImpl;
    void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
              _megdnn_workspace workspace) override;
    size_t get_workspace_in_bytes(const TensorLayout&,
                                  const TensorLayout&) override;

    static size_t constexpr MAX_SPATIAL_DIM = 2;

    struct PoolingKernSizeParam {
        uint32_t n, ic;
        std::array<uint32_t, MAX_SPATIAL_DIM> isz, osz;
        std::array<uint32_t, MAX_SPATIAL_DIM> padding, filter, stride;
        DType src_type, dst_type;
        Handle* handle;
        Param::Format format;
        Mode mode;
    };

    struct PoolingKernParam : public PoolingKernSizeParam {
        void* src_ptr;
        void* dst_ptr;
        void* workspace_ptr;
        size_t workspace_size;

        template <typename T>
        const T* src() const {
            src_type.assert_is_compatible_ctype<T>();
            return static_cast<const T*>(src_ptr);
        }

        template <typename T>
        T* dst() const {
            dst_type.assert_is_compatible_ctype<T>();
            return static_cast<T*>(dst_ptr);
        }

        template <typename T>
        T* workspace() const {
            return static_cast<T*>(workspace_ptr);
        }
    };

    PoolingKernSizeParam make_pooling_kern_szie_param(
            fallback::PoolingImpl* opr, const TensorLayout& src,
            const TensorLayout& dst);

    PoolingKernParam make_pooling_kern_param(fallback::PoolingImpl* opr,
                                             _megdnn_tensor_in src,
                                             _megdnn_tensor_out dst,
                                             _megdnn_workspace workspace);
    class AlgoBase : public detail::Algorithm {
    public:
        virtual ~AlgoBase() = default;
        virtual bool usable(const PoolingKernSizeParam& param) const = 0;
        virtual void exec(const PoolingKernParam& param) const = 0;
    };

private:
    class AlgoFilterxModexStride1;
    class AlgoFilter2ModexStride2;
    class AlgoFilter3MaxStride2;
    class AlgoFilter3AverageStride2;
    class AlgoFilter4MaxStride2;
    class AlgoFilter5MaxStride2;
    class AlgoInt8Filter2MaxStride2;
    class AlgoInt8Filter3MaxStride2;
    class AlgoPack;
};
}  // namespace arm_common
}  // namespace megdnn

// vim: syntax=cpp.doxygen