direct_dotprod_nchw44.h 3.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
/**
 * \file dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#if __ARM_FEATURE_DOTPROD

#pragma once

#include "src/arm_common/conv_bias/opr_impl.h"

namespace megdnn {
namespace arm_common {
namespace direct_dotprod_nchw44 {

using BiasMode = ConvBiasForward::BiasMode;

/**
 * @brief : do direct conv with no side effect
 *          input buffer's size is  [ih, iw]
 *          output buffer's size is [oh, ow]
 *          filter layout is [OC/4, IC/4, FH, FW, 4, 4]
 *
 * @param : [output ptr] dst
 *          [input]      oh         -> dst rows
 *          [input]      ow         -> dst cols
 *          [input ptr]  src
 *          [input]      ih         -> rows of src used by this this kernel
 *          [input]      iw         -> src step in elements [iw2]
 *          [input ptr]  filter
 *          [input ptr]  bias
 *          [input]      oh_size    -> rows of result generated by this kernel
 *          [input]      oc         -> output channels
 *          [input]      ic         -> intput channels
 *          [input]      op         -> post process operator
 * @return  none
 */

#define KERN(filter_size)                                                     \
    template <typename dst_type, BiasMode bias_mode, typename Op, int stride> \
    void conv_direct_##filter_size##x##filter_size##_int8_nchw44(             \
            dst_type* dst, const int oh, const int ow, const int8_t* src,     \
            const int ih, const int iw, const int8_t* weight,                 \
            const int32_t* bias, const int oh_size, const int oc,             \
            const int ic, const Op& op)

KERN(2);
KERN(3);
KERN(5);
KERN(7);

#undef KERN
/**
 * @brief : copy data from src to dst for direct conv with no side effect
 * @param : [output ptr] dst
 *          [input]      dst_step   -> step of dst in numbers of elements
 *          [input ptr]  src
 *          [input]      src_step   -> step of src in numbers of elements
 *          [input]      ic         -> input channels
 *          [input]      ic_step    -> step of ic in numbers of elements
 *          [input]      ih         -> totle rows to copy
 *          [input]      pad_left   -> cols padding at left
 *          [input]      pad_right  -> cols padding at right
 *          [input]      pad_top    -> rows padding at top
 *          [input]      pad_bottom -> rows padding at bottom
 * @return  none
 */
template <int stride>
void copy_packed_src_int8_nchw44(int8_t* dst, const int dst_step,
                                 const int8_t* src, const int src_step,
                                 const int ic, const int ic_step, const int ih,
                                 const int pad_left, const int pad_right,
                                 const int pad_top, const int pad_bottom);

}  // namespace direct_dotprod_nchw44
}  // namespace arm_common
}  // namespace megdnn

#endif

//vim: syntax=cpp.doxygen