#pragma once

#include "src/arm_common/conv_bias/opr_impl.h"
#if MGB_ENABLE_DOT

namespace megdnn {
namespace arm_common {
namespace direct_dotprod_nchw44 {

using BiasMode = ConvBiasForward::BiasMode;

/**
 * @brief : do direct conv with no side effect
 *          input buffer's size is  [ih, iw]
 *          output buffer's size is [oh, ow]
 *          filter layout is [OC/4, IC/4, FH, FW, 4, 4]
 *
 * @param : [output ptr] dst
 *          [input]      oh         -> dst rows
 *          [input]      ow         -> dst cols
 *          [input ptr]  src
 *          [input]      ih         -> rows of src used by this this kernel
 *          [input]      iw         -> src step in elements [iw2]
 *          [input ptr]  filter
 *          [input ptr]  bias
 *          [input]      oh_size    -> rows of result generated by this kernel
 *          [input]      oc         -> output channels
 *          [input]      ic         -> intput channels
 *          [input]      op         -> post process operator
 * @return  none
 */

template <
        typename dst_type, int stride, BiasMode bias_mode, typename Op, int filter_size>
void conv_direct_sdot_int8_nchw44(
        dst_type* dst, const int oh, const int ow, const int8_t* src, const int ih,
        const int iw, const int8_t* filter, const int32_t* bias, const int oh_size,
        const int oc, const int ic, const Op& op);
/**
 * @brief : copy data from src to dst for direct conv with no side effect
 * @param : [output ptr] dst
 *          [input]      dst_step   -> step of dst in numbers of elements
 *          [input ptr]  src
 *          [input]      src_step   -> step of src in numbers of elements
 *          [input]      ic         -> input channels
 *          [input]      ic_step    -> step of ic in numbers of elements
 *          [input]      ih         -> totle rows to copy
 *          [input]      pad_left   -> cols padding at left
 *          [input]      pad_right  -> cols padding at right
 *          [input]      pad_top    -> rows padding at top
 *          [input]      pad_bottom -> rows padding at bottom
 * @return  none
 */
template <int stride>
void copy_packed_src_int8_nchw44(
        int8_t* dst, const int dst_step, const int8_t* src, const int src_step,
        const int ic, const int ic_step, const int ih, const int pad_left,
        const int pad_right, const int pad_top, const int pad_bottom);

}  // namespace direct_dotprod_nchw44
}  // namespace arm_common
}  // namespace megdnn

#endif

// vim: syntax=cpp.doxygen
