提交 d345c862 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(dnn/fallback): add im2col none dot int8 nchw44 support

GitOrigin-RevId: d326035202c9a2a81e51f475273aef23f7489c09
上级 b8cbd451
......@@ -17,18 +17,14 @@
#include "src/fallback/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/winograd/strategy.h"
#include "src/naive/convolution/helper.h"
#if MEGDNN_X86
#include "src/x86/conv_bias/postprocess_helper.h"
#endif
#include "midout.h"
MIDOUT_DECL(megdnn_fallback_im2col)
using namespace megdnn;
using namespace fallback;
using namespace im2col;
#if MEGDNN_X86
using namespace x86;
#endif
/*======================== AlgoIm2col=======================*/
/*!
......@@ -47,8 +43,8 @@ using Pack_Mode=fallback::MatrixMulImpl::AlgoBase::PackMode;
static void copy_padding_kern(WorkspaceBundle bundle,
const ConvBiasImpl::NCBKernParam& param,
const ConvBiasImpl::NCBKernIndex& ncb_index,
StrategyBase* im2colstrategy) {
im2colstrategy->copy_padding_kern(bundle, param, ncb_index);
StrategyBase* im2colstrategy, size_t pack_oc_size) {
im2colstrategy->copy_padding_kern(bundle, param, ncb_index, pack_oc_size);
}
//! packA_kern
......@@ -57,9 +53,9 @@ static void packA_kern(WorkspaceBundle bundle,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
StrategyBase* im2colstrategy) {
StrategyBase* im2colstrategy, size_t pack_oc_size) {
im2colstrategy->packA_kern(bundle, param, matmulparam, matmul_algo,
ncb_index);
ncb_index, pack_oc_size);
}
/*!
......@@ -129,14 +125,17 @@ public:
size_t oc_tile_size) {
size_t IC = param.filter_meta.icpg, FH = param.filter_meta.spatial[0],
FW = param.filter_meta.spatial[1];
size_t pack_oc_size = 1;
size_t im2col = 0, packb = 0, bias_temp = 0;
bool default_pack = matmul_algo->packmode() == Pack_Mode::DEFAULT;
megdnn_assert(default_pack, "only support default packa");
if (param.filter_meta.format == param::ConvBias::Format::NCHW44) {
pack_oc_size = 4;
}
size_t im2col_dst_size =
IC * FH * FW * ohw_tile_size * sizeof(param.src_type);
size_t matmul_dst_size =
oc_tile_size * ohw_tile_size * sizeof(param.bias_type);
size_t matmul_dst_size = pack_oc_size * oc_tile_size * ohw_tile_size *
sizeof(param.bias_type);
//! matmul_dst and im2col_dst use the same memory
WorkspaceBundle wb = matmul_algo->get_bundle(im2col_kern_param);
packb = wb.get_size(1);
......@@ -318,17 +317,18 @@ public:
}
};
#undef FILL_IM2COL_STRATEGY_PARAM
fallback::MatrixMulImpl::KernSizeParam
ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param,
size_t ohw_tile_size,
size_t oc_tile_size) const {
bool is_nchw44 =
param.filter_meta.format == param::ConvBias::Format::NCHW44;
size_t M = oc_tile_size;
size_t N = ohw_tile_size;
size_t K = param.filter_meta.icpg * param.filter_meta.spatial[0] *
param.filter_meta.spatial[1];
size_t LDA = K, LDB = N, LDC = N;
size_t pack_oc_size = is_nchw44 ? 4 : 1;
size_t LDA = pack_oc_size * K, LDB = pack_oc_size * N, LDC = N;
bool is_dst_8bit = (param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
......@@ -345,7 +345,8 @@ ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param,
false,
false,
param::MatrixMul::ComputeMode::DEFAULT,
param::MatrixMul::Format::DEFAULT};
is_nchw44 ? param::MatrixMul::Format::MK4
: param::MatrixMul::Format::DEFAULT};
}
void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block(
......@@ -405,6 +406,7 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle(
size_t GROUP = param.filter_meta.group;
bool need_pack = m_matmul_algo->packmode() == Pack_Mode::DEFAULT;
bool only_packA = m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA;
if (need_pack || only_packA) {
auto inner_block = m_matmul_algo->get_inner_block_size();
choice_ohw_oc_block(param, inner_block.m, inner_block.n, need_pack);
......@@ -421,16 +423,19 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle(
need_pack);
packa_group_size = 0;
}
if (no_need_pading) {
padding = 0; //! not need padding
} else {
padding = (GROUP * N * IC * IH2 * IW2) *
sizeof(param.src_type); //! for padding
}
packa_size = GROUP * packa_group_size; //! for packA size = GROUP * a_size
WorkspaceBundle ws = {nullptr, {}};
auto im2col_kern_param =
get_matmul_kern_param(param, m_ohw_tile_size, m_oc_tile_size);
if (m_matmul_algo->packmode() == Pack_Mode::DEFAULT) {
Im2colKerns<Pack_Mode::DEFAULT> defaultkern;
ws = defaultkern.get_thread_bundle(param, im2col_kern_param,
......@@ -447,6 +452,7 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle(
m_matmul_algo, m_ohw_tile_size,
m_oc_tile_size);
}
return {nullptr,
{padding, packa_size, ws.total_size_in_bytes() * nr_threads}};
}
......@@ -461,7 +467,7 @@ size_t ConvBiasImpl::AlgoIm2col::get_workspace(
}
SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
ConvBiasImpl* opr, const NCBKernSizeParam& param) const {
ConvBiasImpl*, const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_fallback_im2col, 0, 1) {
UNPACK_CONV_F32_NCB_KERN_SIZES(param);
MEGDNN_MARK_USED_VAR(SH);
......@@ -473,7 +479,6 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
size_t ohw = OH * OW;
size_t ohw_parallel_times = div_ceil(ohw, m_ohw_tile_size);
size_t GROUP = param.filter_meta.group;
WorkspaceBundle bundle = get_bundle(param);
WorkspaceBundle bundle_thread = {nullptr, {}};
size_t oc_parallel_times = div_ceil<size_t>(OC, m_oc_tile_size);
......@@ -483,11 +488,14 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
bool no_pack = packmode == Pack_Mode::NO_PACK;
bool only_packA = packmode == Pack_Mode::ONLY_PACKA;
size_t packa_parallel_times = 0;
size_t pack_oc_size =
(param.filter_meta.format == param::ConvBias::Format::NCHW ? 1
: 4);
if (only_packA) {
packa_parallel_times = div_ceil<size_t>(OC, m_oc_tile_size);
} else if (default_pack) {
packa_parallel_times = div_ceil<size_t>(
OC, m_matmul_algo->get_inner_block_size().m);
OC, m_matmul_algo->get_inner_block_size().m * pack_oc_size);
}
auto matmul_param = get_matmul_kern_param(
......@@ -520,25 +528,29 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
strategyparam.skip_copy_dst =
strategyparam.is_ohw_size_bigger && !strategyparam.is_dst_8bit;
strategyparam.oc_tile_size = m_oc_tile_size;
strategyparam.pack_oc_size = pack_oc_size;
SmallVector<ConvBiasImpl::NCBKern> ret_kern;
MIDOUT_BEGIN(
megdnn_fallback_im2col,
midout_iv("ConvBiasImpl::AlgoIm2col::dispatch_kerns"_hash)) {
StrategyBase* im2colstrategy = Factory::get_im2col_strategy(
param, m_matmul_algo, opr->param().format);
auto kern_padding = [bundle, im2colstrategy](
StrategyBase* im2colstrategy =
Factory::get_im2col_strategy(param, m_matmul_algo);
auto kern_padding = [bundle, im2colstrategy,
pack_oc_size = pack_oc_size](
const NCBKernParam& param,
const NCBKernIndex& ncb_index) {
copy_padding_kern(bundle, param, ncb_index, im2colstrategy);
copy_padding_kern(bundle, param, ncb_index, im2colstrategy,
pack_oc_size);
};
auto kern_packA = [bundle, matmul_algo = m_matmul_algo,
matmul_param,
im2colstrategy](const NCBKernParam& param,
const NCBKernIndex& ncb_index) {
matmul_param, im2colstrategy,
pack_oc_size = pack_oc_size](
const NCBKernParam& param,
const NCBKernIndex& ncb_index) {
packA_kern(bundle, param, matmul_param, matmul_algo, ncb_index,
im2colstrategy);
im2colstrategy, pack_oc_size);
};
if (default_pack) {
auto kern_compute_default =
......@@ -556,7 +568,8 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
ret_kern.push_back({kern_packA, {GROUP, packa_parallel_times}});
if (need_padding) {
ret_kern.push_back({kern_padding, {param.n, GROUP, IC}});
ret_kern.push_back({kern_padding,
{param.n, GROUP, IC / pack_oc_size}});
}
ret_kern.push_back(
{kern_compute_default,
......@@ -629,19 +642,25 @@ bool ConvBiasImpl::AlgoIm2col::usable(
param.nonlineMode != megdnn::NonlineMode::IDENTITY) {
return false;
}
//! current now im2col only support int8 quantized s8 nchw44
if (opr->param().format == param::ConvBias::Format::NCHW44 &&
(param.src_type.enumv() == param.filter_type.enumv() &&
(param.src_type.enumv() != DTypeEnum::Int8) &&
(param.src_type.enumv() != DTypeEnum::QuantizedS8))) {
return false;
}
fallback::MatrixMulImpl::KernSizeParam matmul_param =
get_matmul_kern_param(param, m_ohw_tile_size, m_oc_tile_size);
bool matmulusable = m_matmul_algo->usable(matmul_param);
return matmulusable &&
(opr->param().format == param::ConvBias::Format::NCHW) &&
((param.filter_meta.spatial[0] == param.filter_meta.spatial[1] &&
(param.filter_meta.spatial[0] <= 7) &&
(param.filter_meta.spatial[0] >= 2)) ||
(param.filter_meta.spatial[0] != param.filter_meta.spatial[1] &&
(param.filter_meta.spatial[0] <= 7) &&
(param.filter_meta.spatial[0] >= 1) &&
(param.filter_meta.spatial[1] <= 7) &&
(param.filter_meta.spatial[1] >= 1))) &&
(opr->param().format == param::ConvBias::Format::NCHW ||
opr->param().format == param::ConvBias::Format::NCHW44) &&
(!(param.filter_meta.spatial[0] ==
param.filter_meta.spatial[1] &&
(param.filter_meta.spatial[0] == 1) &&
param.filter_meta.stride[0] == param.filter_meta.stride[1] &&
param.filter_meta.stride[0] == 1)) &&
(param.filter_meta.dilation[0] ==
param.filter_meta.dilation[1] &&
param.filter_meta.dilation[0] == 1) &&
......
......@@ -36,7 +36,6 @@ class ConvBiasImpl::AlgoIm2col final : public AlgoBase {
const NCBKernSizeParam& param, size_t ohw_tile_size,
size_t oc_tile_size) const;
WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const;
WorkspaceBundle get_thread_bundle(const NCBKernSizeParam& param) const;
void choice_ohw_oc_block(const NCBKernSizeParam& param, size_t block_m,
size_t block_n, bool pack_default) const;
......
......@@ -14,6 +14,7 @@
namespace megdnn {
using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode;
using FormatMode = param::ConvBias::Format;
struct StrategyParam {
size_t batch_id;
......@@ -28,6 +29,7 @@ struct StrategyParam {
size_t block_m;
size_t block_n;
size_t block_k;
size_t pack_oc_size;
bool skip_copy_dst;
bool is_dst_8bit;
bool is_ohw_size_bigger;
......@@ -40,13 +42,15 @@ public:
virtual void copy_padding_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) = 0;
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_size) = 0;
virtual void packA_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) = 0;
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_size) = 0;
virtual void exec_im2col(
WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
......@@ -70,14 +74,16 @@ public:
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode, PackMode packmode>
megdnn::PostprocessMode postprocess_mode, PackMode packmode,
FormatMode format>
class Strategy;
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT> : public StrategyBase {
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>
: public StrategyBase {
public:
constexpr static size_t BUNDLE_PADDING_INDEX = 0;
constexpr static size_t BUNDLE_PACKA_INDEX = 1;
......@@ -85,24 +91,26 @@ public:
constexpr static size_t THREAD_BUNDLE_IM2COL_INDEX = 1;
constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 2;
Strategy();
Strategy() = default;
void copy_padding_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override;
void packA_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_size) override;
void packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_size) override;
virtual void exec_im2col(
WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override;
void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
void exec_matmul(
const fallback::ConvBiasImpl::NCBKernParam& param,
......@@ -132,7 +140,32 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::NO_PACK> : public StrategyBase {
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW44>
: public Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT,
FormatMode::NCHW> {
public:
const size_t BUNDLE_PADDING_INDEX = 0;
const size_t BUNDLE_PACKA_INDEX = 1;
const size_t THREAD_BUNDLE_PACKB_INDEX = 0;
const size_t THREAD_BUNDLE_IM2COL_INDEX = 1;
const size_t THREAD_BUNDLE_BIAS_INDEX = 2;
Strategy() = default;
void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
};
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>
: public StrategyBase {
public:
constexpr static size_t BUNDLE_PADDING_INDEX = 0;
constexpr static size_t BUNDLE_PACKA_INDEX = 1;
......@@ -141,19 +174,20 @@ public:
constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 2;
constexpr static size_t THREAD_BUNDLE_MATCOMP_INDEX = 3;
Strategy();
Strategy() = default;
void copy_padding_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override;
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_size) override;
void packA_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override;
void packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_size) override;
void exec_matmul(
const fallback::ConvBiasImpl::NCBKernParam& param,
......@@ -197,7 +231,8 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::ONLY_PACKA> : public StrategyBase {
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>
: public StrategyBase {
public:
constexpr static size_t BUNDLE_PADDING_INDEX = 0;
constexpr static size_t BUNDLE_PACKA_INDEX = 1;
......@@ -206,19 +241,20 @@ public:
constexpr static size_t THREAD_BUNDLE_MATMULDST_INDEX = 2;
constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 3;
Strategy();
Strategy() = default;
void copy_padding_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override;
void packA_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override;
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_size) override;
void packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_size) override;
void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam,
......
......@@ -8,8 +8,6 @@
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megdnn/opr_param_defs.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/im2col/strategy_base.h"
#include "src/fallback/convolution/img2col_helper.h"
#if MEGDNN_X86
......@@ -22,22 +20,15 @@ using namespace x86;
#endif
namespace megdnn {
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::DEFAULT>::Strategy()
: StrategyBase() {}
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::DEFAULT>::
copy_padding_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) {
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>::
copy_padding_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_oc_size) {
UNPACK_CONV_F32_NCB_KERN_SIZES(param);
MEGDNN_MARK_USED_VAR(N);
MEGDNN_MARK_USED_VAR(OC);
......@@ -53,9 +44,13 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
size_t batch_id = ncb_index.ndrange_id[0];
size_t group_id = ncb_index.ndrange_id[1];
size_t channel_id = ncb_index.ndrange_id[2];
size_t PH_SIZE = PH * IW2 * pack_oc_size;
PW = PW * pack_oc_size;
IW = IW * pack_oc_size;
size_t padding_group_size = IH2 * IW2 * IC;
size_t workspace_channel_offset = IH2 * IW2 * channel_id;
size_t workspace_channel_offset = pack_oc_size * IH2 * IW2 * channel_id;
size_t workspace_group_offset = group_id * padding_group_size;
size_t workspace_batch_offset =
param.filter_meta.group * batch_id * padding_group_size;
......@@ -65,8 +60,8 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point;
}
src_ctype* src = const_cast<src_ctype*>(
param.src<src_ctype>(batch_id, group_id, channel_id));
src_ctype* src = const_cast<src_ctype*>(param.src<src_ctype>(
batch_id, group_id, channel_id, 1, pack_oc_size));
src_ctype* src2;
src2 = static_cast<src_ctype*>(bundle.get(BUNDLE_PADDING_INDEX)) +
workspace_group_offset + workspace_batch_offset +
......@@ -74,8 +69,8 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
src_ctype* src2_ptr = src2;
const src_ctype* src_ptr = src;
if (PH != 0) {
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2);
src2_ptr += PH * IW2;
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE);
src2_ptr += PH_SIZE;
}
rep(ih, IH) {
if (PW != 0)
......@@ -87,8 +82,8 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
rep(pw, PW) * (src2_ptr++) = src_zp;
}
if (PH != 0) {
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2);
src2_ptr += PH * IW2;
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE);
src2_ptr += PH_SIZE;
}
}
......@@ -96,12 +91,13 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::DEFAULT>::
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>::
packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) {
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_oc_size) {
bundle.set(param.workspace_ptr);
fallback::MatrixMulImpl::KernParam matmul_param;
size_t group_id = ncb_index.ndrange_id[0];
......@@ -114,38 +110,38 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
matmul_algo->get_packA_type_size();
size_t a_panel_offset = ncb_index.ndrange_id[1] * packed_per_oc_block_size;
int8_t* a_panel = static_cast<int8_t*>(bundle.get(BUNDLE_PACKA_INDEX)) +
group_id * packA_group_size + a_panel_offset;
group_id * packA_group_size +
(pack_oc_size == 4 ? 0 : a_panel_offset);
matmul_param.A_ptr =
const_cast<src_ctype*>(param.filter<src_ctype>(group_id));
matmul_algo->pack_A(matmul_param, a_panel, ncb_index.ndrange_id[1],
matmul_algo->get_inner_block_size().m);
matmul_algo->get_inner_block_size().m * pack_oc_size);
}
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::DEFAULT>::
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>::
exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo
) {
size_t m_sh = param.filter_meta.stride[0];
size_t m_sw = param.filter_meta.stride[1];
size_t m_oc = param.filter_meta.ocpg;
size_t m_oh = param.osz[0];
size_t m_ow = param.osz[1];
size_t m_ic = param.filter_meta.icpg;
size_t m_ih = param.isz[0] + param.filter_meta.padding[0] * 2;
size_t m_iw = param.isz[1] + param.filter_meta.padding[1] * 2;
size_t m_fh = param.filter_meta.spatial[0];
size_t m_fw = param.filter_meta.spatial[1];
size_t m_is_xcorr = !param.filter_meta.should_flip;
fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
size_t sh = param.filter_meta.stride[0];
size_t sw = param.filter_meta.stride[1];
size_t oc = param.filter_meta.ocpg;
size_t oh = param.osz[0];
size_t ow = param.osz[1];
size_t ic = param.filter_meta.icpg;
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2;
size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2;
size_t fh = param.filter_meta.spatial[0];
size_t fw = param.filter_meta.spatial[1];
size_t is_xcorr = !param.filter_meta.should_flip;
size_t input_offset =
m_ih * m_iw * m_ic *
ih * iw * ic *
(sparam.group_id + param.filter_meta.group * sparam.batch_id) *
sizeof(src_ctype);
......@@ -160,26 +156,22 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
}
src_ctype* im2col_dst = static_cast<src_ctype*>(
bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX));
if (m_sh == 1 && m_sw == 1) {
if (m_is_xcorr) {
img2col<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw,
m_fh, m_fw, sparam.ohw_cur_index,
sparam.output_block_size);
if (sh == 1 && sw == 1) {
if (is_xcorr) {
img2col<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw,
sparam.ohw_cur_index, sparam.output_block_size);
} else {
img2col<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw,
m_fh, m_fw, sparam.ohw_cur_index,
sparam.output_block_size);
img2col<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw,
sparam.ohw_cur_index, sparam.output_block_size);
}
} else {
if (m_is_xcorr) {
img2col_stride<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih,
m_iw, m_fh, m_fw, m_sh, m_sw,
sparam.ohw_cur_index,
if (is_xcorr) {
img2col_stride<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh,
fw, sh, sw, sparam.ohw_cur_index,
sparam.output_block_size);
} else {
img2col_stride<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic,
m_ih, m_iw, m_fh, m_fw, m_sh, m_sw,
sparam.ohw_cur_index,
img2col_stride<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh,
fw, sh, sw, sparam.ohw_cur_index,
sparam.output_block_size);
}
}
......@@ -199,7 +191,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::DEFAULT>::
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>::
get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread,
const StrategyParam& sparam) {
......@@ -218,7 +210,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::DEFAULT>::
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>::
exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, WorkspaceBundle bundle,
WorkspaceBundle bundle_thread,
......@@ -240,11 +232,11 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
src_ctype* b_panel =
reinterpret_cast<src_ctype*>(reinterpret_cast<uintptr_t>(
bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX)));
size_t pack_oc_size = sparam.pack_oc_size;
matmul_param.M = sparam.output_block_oc_size;
matmul_param.N = sparam.output_block_size;
matmul_param.LDB = sparam.output_block_size;
matmul_param.LDC = sparam.output_block_size;
matmul_param.LDB = pack_oc_size * sparam.output_block_size;
matmul_param.LDC = pack_oc_size * sparam.output_block_size;
matmul_param.C_ptr = matmul_dst;
auto matmul_kern_naked = matmul_algo->get_kern_naked(matmul_param);
......@@ -255,7 +247,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::DEFAULT>::
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>::
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam,
WorkspaceBundle bundle_thread) {
......@@ -274,7 +266,8 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
PostProcess<op_ctype, op_dtype, postprocess_mode>::run(
matmul_dst, bias_preprocess_ptr, matmul_dst, param.bias_mode,
param.nonlineMode, param.bias_type, param.dst_type, 1_z,
sparam.output_block_oc_size, 1_z, sparam.output_block_size);
sparam.output_block_oc_size, 1_z, sparam.output_block_size,
sparam.pack_oc_size);
copy_dst(param, matmul_dst, sparam);
}
......@@ -282,20 +275,24 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::DEFAULT>::
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>::
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param,
const void* matmul_dst, const StrategyParam& sparam) {
if (!sparam.skip_copy_dst) {
size_t pack_oc_size = sparam.pack_oc_size;
dst_ctype* dst_tmp_ptr =
reinterpret_cast<dst_ctype*>(const_cast<void*>(matmul_dst));
dst_ctype* dst =
param.dst<dst_ctype>(sparam.batch_id, sparam.group_id) +
sparam.oc_cur_index * sparam.ohw + sparam.ohw_cur_index;
for (size_t oc = 0; oc < sparam.output_block_oc_size; oc++) {
sparam.oc_cur_index * sparam.ohw +
sparam.ohw_cur_index * pack_oc_size;
size_t oc_loop = sparam.output_block_oc_size / pack_oc_size;
for (size_t oc = 0; oc < oc_loop; oc++) {
std::memcpy(dst, dst_tmp_ptr,
sizeof(dst_ctype) * sparam.output_block_size);
dst_tmp_ptr += sparam.output_block_size;
dst += sparam.ohw;
sizeof(dst_ctype) * sparam.output_block_size *
pack_oc_size);
dst_tmp_ptr += sparam.output_block_size * pack_oc_size;
dst += sparam.ohw * pack_oc_size;
}
}
}
......@@ -304,7 +301,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::DEFAULT>::
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>::
get_bias_temp_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread) {
bias_ctype* bias_tmp_ptr =
......@@ -319,7 +316,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::DEFAULT>::
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>::
copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param,
WorkspaceBundle bundle_thread, const StrategyParam& sparam) {
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>(
......@@ -340,31 +337,20 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
}
}
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode) \
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode, PackMode::DEFAULT>;
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode) \
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode, PackMode::DEFAULT, \
FormatMode::NCHW>;
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32,
megdnn::PostprocessMode::FLOAT)
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16,
megdnn::PostprocessMode::FLOAT)
#else
#if !MEGDNN_DISABLE_FLOAT16
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16,
megdnn::PostprocessMode::NO_PROCESS)
#endif
#endif
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8,
megdnn::PostprocessMode::QUANTIZED)
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32,
megdnn::PostprocessMode::NO_PROCESS)
#endif
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8,
megdnn::PostprocessMode::QUANTIZED)
......
/**
* \file dnn/src/fallback/conv_bias/im2col/strategy_default.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/fallback/conv_bias/im2col/strategy_base.h"
#include "src/fallback/convolution/img2col_helper.h"
#if MEGDNN_X86
#include "src/x86/conv_bias/postprocess_helper.h"
#endif
using namespace megdnn;
#if MEGDNN_X86
using namespace x86;
#endif
namespace megdnn {
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW44>::
exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
size_t sh = param.filter_meta.stride[0];
size_t sw = param.filter_meta.stride[1];
size_t oc = param.filter_meta.ocpg;
size_t oh = param.osz[0];
size_t ow = param.osz[1];
size_t ic = param.filter_meta.icpg;
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2;
size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2;
size_t fh = param.filter_meta.spatial[0];
size_t fw = param.filter_meta.spatial[1];
size_t is_xcorr = !param.filter_meta.should_flip;
constexpr static size_t pack_size = 4;
size_t input_offset =
ih * iw * ic *
(sparam.group_id + param.filter_meta.group * sparam.batch_id) *
sizeof(src_ctype);
src_ctype* src2 = reinterpret_cast<src_ctype*>(
reinterpret_cast<uintptr_t>(bundle.get(BUNDLE_PADDING_INDEX)) +
input_offset);
bool is_phpwzero = param.filter_meta.padding[0] == 0 &&
param.filter_meta.padding[1] == 0;
if (is_phpwzero) {
src2 = const_cast<src_ctype*>(
param.src<src_ctype>(sparam.batch_id, sparam.group_id));
}
src_ctype* im2col_dst = static_cast<src_ctype*>(
bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX));
if (is_xcorr) {
if (sh == sw && sh == 1) {
img2col_nchw4<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh,
fw, sh, sw, sparam.ohw_cur_index,
sparam.output_block_size);
} else {
img2col_stride_nchw4<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw,
fh, fw, sh, sw, sparam.ohw_cur_index,
sparam.output_block_size);
}
} else {
if (sh == sw && sh == 1) {
img2col_nchw4<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh,
fw, sh, sw, sparam.ohw_cur_index,
sparam.output_block_size);
} else {
img2col_stride_nchw4<false>(
src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw, sh, sw,
sparam.ohw_cur_index, sparam.output_block_size);
}
}
matmul_param.M = sparam.output_block_oc_size;
matmul_param.N = sparam.output_block_size;
matmul_param.LDB = pack_size * sparam.output_block_size;
matmul_param.LDC = pack_size * sparam.output_block_size;
matmul_param.B_ptr = im2col_dst;
src_ctype* b_panel =
reinterpret_cast<src_ctype*>(reinterpret_cast<uintptr_t>(
bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX)));
matmul_algo->pack_B(matmul_param, b_panel, 0, matmul_param.N);
}
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode) \
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode, PackMode::DEFAULT, \
FormatMode::NCHW44>;
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32,
megdnn::PostprocessMode::FLOAT)
#if !MEGDNN_DISABLE_FLOAT16
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16,
megdnn::PostprocessMode::NO_PROCESS)
#endif
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8,
megdnn::PostprocessMode::QUANTIZED)
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32,
megdnn::PostprocessMode::NO_PROCESS)
INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16,
megdnn::PostprocessMode::NO_PROCESS)
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32,
megdnn::PostprocessMode::NO_PROCESS)
#undef INSTANTIAL_CLASS
} // namespace megdnn
......@@ -9,8 +9,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megdnn/opr_param_defs.h"
#include "src/common/utils.h"
#include "src/fallback/conv_bias/im2col/strategy_base.h"
#include "src/fallback/convolution/img2col_helper.h"
#if MEGDNN_X86
......@@ -22,22 +20,16 @@ using namespace megdnn;
using namespace x86;
#endif
namespace megdnn {
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::NO_PACK>::Strategy()
: StrategyBase() {}
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::NO_PACK>::
copy_padding_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) {
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>::
copy_padding_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t) {
UNPACK_CONV_F32_NCB_KERN_SIZES(param);
MEGDNN_MARK_USED_VAR(N);
MEGDNN_MARK_USED_VAR(OC);
......@@ -96,12 +88,13 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::NO_PACK>::
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>::
packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) {
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t) {
MEGDNN_MARK_USED_VAR(bundle);
MEGDNN_MARK_USED_VAR(param);
MEGDNN_MARK_USED_VAR(matmulparam);
......@@ -115,7 +108,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::NO_PACK>::
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>::
get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread,
const StrategyParam& sparam) {
......@@ -134,7 +127,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::NO_PACK>::
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>::
exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, WorkspaceBundle bundle,
WorkspaceBundle bundle_thread,
......@@ -167,29 +160,28 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::NO_PACK>::
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>::
exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo
) {
fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
MEGDNN_MARK_USED_VAR(matmul_param);
MEGDNN_MARK_USED_VAR(matmul_algo);
size_t m_sh = param.filter_meta.stride[0];
size_t m_sw = param.filter_meta.stride[1];
size_t m_oc = param.filter_meta.ocpg;
size_t m_oh = param.osz[0];
size_t m_ow = param.osz[1];
size_t m_ic = param.filter_meta.icpg;
size_t m_ih = param.isz[0] + param.filter_meta.padding[0] * 2;
size_t m_iw = param.isz[1] + param.filter_meta.padding[1] * 2;
size_t m_fh = param.filter_meta.spatial[0];
size_t m_fw = param.filter_meta.spatial[1];
size_t m_is_xcorr = !param.filter_meta.should_flip;
size_t sh = param.filter_meta.stride[0];
size_t sw = param.filter_meta.stride[1];
size_t oc = param.filter_meta.ocpg;
size_t oh = param.osz[0];
size_t ow = param.osz[1];
size_t ic = param.filter_meta.icpg;
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2;
size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2;
size_t fh = param.filter_meta.spatial[0];
size_t fw = param.filter_meta.spatial[1];
size_t is_xcorr = !param.filter_meta.should_flip;
size_t input_offset =
m_ih * m_iw * m_ic *
ih * iw * ic *
(sparam.group_id + param.filter_meta.group * sparam.batch_id) *
sizeof(src_ctype);
......@@ -205,26 +197,22 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
}
src_ctype* im2col_dst = static_cast<src_ctype*>(
bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX));
if (m_sh == 1 && m_sw == 1) {
if (m_is_xcorr) {
img2col<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw,
m_fh, m_fw, sparam.ohw_cur_index,
sparam.output_block_size);
if (sh == 1 && sw == 1) {
if (is_xcorr) {
img2col<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw,
sparam.ohw_cur_index, sparam.output_block_size);
} else {
img2col<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw,
m_fh, m_fw, sparam.ohw_cur_index,
sparam.output_block_size);
img2col<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw,
sparam.ohw_cur_index, sparam.output_block_size);
}
} else {
if (m_is_xcorr) {
img2col_stride<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih,
m_iw, m_fh, m_fw, m_sh, m_sw,
sparam.ohw_cur_index,
if (is_xcorr) {
img2col_stride<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh,
fw, sh, sw, sparam.ohw_cur_index,
sparam.output_block_size);
} else {
img2col_stride<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic,
m_ih, m_iw, m_fh, m_fw, m_sh, m_sw,
sparam.ohw_cur_index,
img2col_stride<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh,
fw, sh, sw, sparam.ohw_cur_index,
sparam.output_block_size);
}
}
......@@ -234,7 +222,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::NO_PACK>::
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>::
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam,
WorkspaceBundle bundle_thread) {
......@@ -262,7 +250,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::NO_PACK>::
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>::
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param,
const void* matmul_dst, const StrategyParam& sparam) {
if (!sparam.skip_copy_dst) {
......@@ -284,7 +272,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::NO_PACK>::
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>::
copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param,
WorkspaceBundle bundle_thread, const StrategyParam& sparam) {
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>(
......@@ -305,31 +293,20 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
}
}
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode) \
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode, PackMode::NO_PACK>;
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode) \
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode, PackMode::NO_PACK, \
FormatMode::NCHW>;
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32,
megdnn::PostprocessMode::FLOAT)
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16,
megdnn::PostprocessMode::FLOAT)
#else
#if !MEGDNN_DISABLE_FLOAT16
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16,
megdnn::PostprocessMode::NO_PROCESS)
#endif
#endif
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8,
megdnn::PostprocessMode::QUANTIZED)
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32,
megdnn::PostprocessMode::NO_PROCESS)
#endif
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8,
megdnn::PostprocessMode::QUANTIZED)
......
......@@ -9,7 +9,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megdnn/opr_param_defs.h"
#include "src/fallback/conv_bias/im2col/strategy_base.h"
#include "src/fallback/convolution/img2col_helper.h"
#if MEGDNN_X86
......@@ -21,22 +20,16 @@ using namespace megdnn;
using namespace x86;
#endif
namespace megdnn {
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::ONLY_PACKA>::Strategy()
: StrategyBase() {}
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::ONLY_PACKA>::
copy_padding_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) {
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>::
copy_padding_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t) {
UNPACK_CONV_F32_NCB_KERN_SIZES(param);
MEGDNN_MARK_USED_VAR(N);
MEGDNN_MARK_USED_VAR(OC);
......@@ -95,12 +88,13 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::ONLY_PACKA>::
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>::
packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) {
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t) {
bundle.set(param.workspace_ptr);
fallback::MatrixMulImpl::KernParam matmul_param;
static_cast<fallback::MatrixMulImpl::KernSizeParam&>(matmul_param) =
......@@ -128,7 +122,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::ONLY_PACKA>::
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>::
get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread,
const StrategyParam& sparam) {
......@@ -147,7 +141,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::ONLY_PACKA>::
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>::
exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, WorkspaceBundle bundle,
WorkspaceBundle bundle_thread,
......@@ -185,29 +179,28 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::ONLY_PACKA>::
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>::
exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param,
fallback::MatrixMulImpl::AlgoBase* matmul_algo
) {
fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
MEGDNN_MARK_USED_VAR(matmul_param);
MEGDNN_MARK_USED_VAR(matmul_algo);
size_t m_sh = param.filter_meta.stride[0];
size_t m_sw = param.filter_meta.stride[1];
size_t m_oc = param.filter_meta.ocpg;
size_t m_oh = param.osz[0];
size_t m_ow = param.osz[1];
size_t m_ic = param.filter_meta.icpg;
size_t m_ih = param.isz[0] + param.filter_meta.padding[0] * 2;
size_t m_iw = param.isz[1] + param.filter_meta.padding[1] * 2;
size_t m_fh = param.filter_meta.spatial[0];
size_t m_fw = param.filter_meta.spatial[1];
size_t m_is_xcorr = !param.filter_meta.should_flip;
size_t sh = param.filter_meta.stride[0];
size_t sw = param.filter_meta.stride[1];
size_t oc = param.filter_meta.ocpg;
size_t oh = param.osz[0];
size_t ow = param.osz[1];
size_t ic = param.filter_meta.icpg;
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2;
size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2;
size_t fh = param.filter_meta.spatial[0];
size_t fw = param.filter_meta.spatial[1];
size_t is_xcorr = !param.filter_meta.should_flip;
size_t input_offset =
m_ih * m_iw * m_ic *
ih * iw * ic *
(sparam.group_id + param.filter_meta.group * sparam.batch_id) *
sizeof(src_ctype);
......@@ -222,26 +215,22 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
}
src_ctype* im2col_dst = static_cast<src_ctype*>(
bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX));
if (m_sh == 1 && m_sw == 1) {
if (m_is_xcorr) {
img2col<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw,
m_fh, m_fw, sparam.ohw_cur_index,
sparam.output_block_size);
if (sh == 1 && sw == 1) {
if (is_xcorr) {
img2col<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw,
sparam.ohw_cur_index, sparam.output_block_size);
} else {
img2col<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih, m_iw,
m_fh, m_fw, sparam.ohw_cur_index,
sparam.output_block_size);
img2col<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh, fw,
sparam.ohw_cur_index, sparam.output_block_size);
}
} else {
if (m_is_xcorr) {
img2col_stride<true>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic, m_ih,
m_iw, m_fh, m_fw, m_sh, m_sw,
sparam.ohw_cur_index,
if (is_xcorr) {
img2col_stride<true>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh,
fw, sh, sw, sparam.ohw_cur_index,
sparam.output_block_size);
} else {
img2col_stride<false>(src2, im2col_dst, m_oc, m_oh, m_ow, m_ic,
m_ih, m_iw, m_fh, m_fw, m_sh, m_sw,
sparam.ohw_cur_index,
img2col_stride<false>(src2, im2col_dst, oc, oh, ow, ic, ih, iw, fh,
fw, sh, sw, sparam.ohw_cur_index,
sparam.output_block_size);
}
}
......@@ -251,7 +240,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::ONLY_PACKA>::
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>::
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam,
WorkspaceBundle bundle_thread) {
......@@ -292,7 +281,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode,PackMode::ONLY_PACKA>::
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>::
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param,
const void* matmul_dst, const StrategyParam& sparam) {
if (!sparam.skip_copy_dst) {
......@@ -310,31 +299,20 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
}
}
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode) \
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, \
_op_ctype, _op_dtype, _postprocess_mode,PackMode::ONLY_PACKA>;
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode) \
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode, \
PackMode::ONLY_PACKA, FormatMode::NCHW>;
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32,
megdnn::PostprocessMode::FLOAT)
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16,
megdnn::PostprocessMode::FLOAT)
#else
#if !MEGDNN_DISABLE_FLOAT16
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16,
megdnn::PostprocessMode::NO_PROCESS)
#endif
#endif
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8,
megdnn::PostprocessMode::QUANTIZED)
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32,
megdnn::PostprocessMode::NO_PROCESS)
#endif
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8,
megdnn::PostprocessMode::QUANTIZED)
......
......@@ -9,7 +9,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/common/utils.h"
namespace {
template <bool is_xcorr, typename dtype>
......@@ -41,7 +40,326 @@ void img2col_stride(const dtype* __restrict src, dtype* __restrict dst,
}
}
//!add for im2col matmul multithread
//
template <bool is_xcorr, typename dtype>
void img2col_stride_nchw4(const dtype* __restrict src, dtype* __restrict dst,
const int OC, const int OH, const int OW, const int IC,
const int IH, const int IW, const int FH, const int FW,
const int SH, const int SW, const int cur_index,
const int block_size) {
MEGDNN_MARK_USED_VAR(OC);
MEGDNN_MARK_USED_VAR(OH);
int start_h = cur_index / OW;
int cur_remain_w = cur_index % OW;
int end_h = (cur_index + block_size) / OW;
int end_remain_w = (cur_index + block_size) % OW;
bool same_line = false;
if (start_h == end_h) {
same_line = true;
}
size_t newIC = IC / 4;
size_t i = 0;
if (sizeof(dtype) != 1) {
if (same_line) {
rep(ic, newIC) {
rep(fh, FH) {
rep(fw, FW) {
int fh2, fw2;
if (is_xcorr) {
fh2 = fh;
fw2 = fw;
} else {
fh2 = FH - fh - 1;
fw2 = FW - fw - 1;
}
for (int w = cur_remain_w; w < end_remain_w; w++) {
size_t index = 4 * (ic * IH * IW +
(start_h * SH + fh2) * IW +
(w * SW + fw2));
dst[i++] = src[index];
dst[i++] = src[index + 1];
dst[i++] = src[index + 2];
dst[i++] = src[index + 3];
}
}
}
}
} else {
rep(ic, newIC) {
rep(fh, FH) {
rep(fw, FW) {
int fh2, fw2;
if (is_xcorr) {
fh2 = fh;
fw2 = fw;
} else {
fh2 = FH - fh - 1;
fw2 = FW - fw - 1;
}
for (int w = cur_remain_w; w < OW; w++) {
size_t index =4 * (ic * IH * IW +
(start_h * SH + fh2) * IW +
(w * SW + fw2));
dst[i++] = src[index + 0];
dst[i++] = src[index + 1];
dst[i++] = src[index + 2];
dst[i++] = src[index + 3];
}
for (int h = start_h + 1; h < end_h; h++) {
rep(ow, OW) {
size_t index = 4 * (ic * IH * IW +
(h * SH + fh2) * IW +
(ow * SW + fw2));
dst[i++] = src[index + 0];
dst[i++] = src[index + 1];
dst[i++] = src[index + 2];
dst[i++] = src[index + 3];
}
}
for (int w = 0; w < end_remain_w; w++) {
size_t index = 4 * (ic * IH * IW +
(end_h * SH + fh2) * IW +
(w * SW + fw2));
dst[i++] = src[index + 0];
dst[i++] = src[index + 1];
dst[i++] = src[index + 2];
dst[i++] = src[index + 3];
}
}
}
}
}
} else {
uint32_t* output = nullptr;
const uint32_t* uint32_src =
static_cast<const uint32_t*>(static_cast<const void*>(src));
output = static_cast<uint32_t*>(static_cast<void*>(dst));
if (same_line) {
rep(ic, newIC) {
rep(fh, FH) {
rep(fw, FW) {
int fh2, fw2;
if (is_xcorr) {
fh2 = fh;
fw2 = fw;
} else {
fh2 = FH - fh - 1;
fw2 = FW - fw - 1;
}
size_t index =
(ic * IH * IW + (start_h * SH + fh2) * IW +
(cur_remain_w * SW + fw2));
for (int w = cur_remain_w; w < end_remain_w; w++) {
output[i++] = uint32_src[index];
index += SW;
}
}
}
}
} else {
rep(ic, newIC) {
rep(fh, FH) {
rep(fw, FW) {
int fh2, fw2;
if (is_xcorr) {
fh2 = fh;
fw2 = fw;
} else {
fh2 = FH - fh - 1;
fw2 = FW - fw - 1;
}
size_t index = ic * IH * IW +
(start_h * SH + fh2) * IW +
cur_remain_w * SW + fw2;
for (int w = cur_remain_w; w < OW; w++) {
output[i++] = uint32_src[index];
index += SW;
}
for (int h = start_h + 1; h < end_h; h++) {
index = ic * IH * IW + (h * SH + fh2) * IW + fw2;
rep(ow, OW) {
output[i++] = uint32_src[index];
index += SW;
}
}
index = ic * IH * IW + (end_h * SH + fh2) * IW + fw2;
for (int w = 0; w < end_remain_w; w++) {
output[i++] = uint32_src[index];
index += SW;
}
}
}
}
}
}
}
template <bool is_xcorr, typename dtype>
void img2col_nchw4(const dtype* __restrict src, dtype* __restrict dst,
const int OC, const int OH, const int OW, const int IC,
const int IH, const int IW, const int FH, const int FW,
const int SH, const int SW, const int cur_index,
const int block_size) {
MEGDNN_MARK_USED_VAR(OC);
MEGDNN_MARK_USED_VAR(OH);
MEGDNN_MARK_USED_VAR(SH);
MEGDNN_MARK_USED_VAR(SW);
int start_h = cur_index / OW;
int cur_remain_w = cur_index % OW;
int end_h = (cur_index + block_size) / OW;
int end_remain_w = (cur_index + block_size) % OW;
bool same_line = false;
if (start_h == end_h) {
same_line = true;
}
size_t newIC = IC / 4;
size_t i = 0;
if (sizeof(dtype) != 1) {
if (same_line) {
rep(ic, newIC) {
rep(fh, FH) {
rep(fw, FW) {
int fh2, fw2;
if (is_xcorr) {
fh2 = fh;
fw2 = fw;
} else {
fh2 = FH - fh - 1;
fw2 = FW - fw - 1;
}
for (int w = cur_remain_w; w < end_remain_w; w++) {
size_t index =
4 * (ic * IH * IW + (start_h + fh2) * IW +
(w + fw2));
dst[i++] = src[index];
dst[i++] = src[index + 1];
dst[i++] = src[index + 2];
dst[i++] = src[index + 3];
}
}
}
}
} else {
rep(ic, newIC) {
rep(fh, FH) {
rep(fw, FW) {
int fh2, fw2;
if (is_xcorr) {
fh2 = fh;
fw2 = fw;
} else {
fh2 = FH - fh - 1;
fw2 = FW - fw - 1;
}
for (int w = cur_remain_w; w < OW; w++) {
size_t index = ic * IH * IW + (start_h + fh2) * IW +
(w + fw2);
dst[i++] = src[4 * index];
dst[i++] = src[4 * index + 1];
dst[i++] = src[4 * index + 2];
dst[i++] = src[4 * index + 3];
}
for (int h = start_h + 1; h < end_h; h++) {
rep(ow, OW) {
size_t index =
4 * (ic * IH * IW + (h + fh2) * IW +
(ow + fw2));
dst[i++] = src[index + 0];
dst[i++] = src[index + 1];
dst[i++] = src[index + 2];
dst[i++] = src[index + 3];
}
}
for (int w = 0; w < end_remain_w; w++) {
size_t index = 4 * (ic * IH * IW +
(end_h + fh2) * IW + (w + fw2));
dst[i++] = src[index + 0];
dst[i++] = src[index + 1];
dst[i++] = src[index + 2];
dst[i++] = src[index + 3];
}
}
}
}
}
} else {
uint32_t* output = nullptr;
const uint32_t* uint32_src =
static_cast<const uint32_t*>(static_cast<const void*>(src));
output = static_cast<uint32_t*>(static_cast<void*>(dst));
if (same_line) {
rep(ic, newIC) {
rep(fh, FH) {
rep(fw, FW) {
int fh2, fw2;
if (is_xcorr) {
fh2 = fh;
fw2 = fw;
} else {
fh2 = FH - fh - 1;
fw2 = FW - fw - 1;
}
for (int w = cur_remain_w; w < end_remain_w; w++) {
size_t index = (ic * IH * IW +
(start_h + fh2) * IW + (w + fw2));
output[i++] = uint32_src[index];
}
}
}
}
} else {
rep(ic, newIC) {
rep(fh, FH) {
rep(fw, FW) {
int fh2, fw2;
if (is_xcorr) {
fh2 = fh;
fw2 = fw;
} else {
fh2 = FH - fh - 1;
fw2 = FW - fw - 1;
}
for (int w = cur_remain_w; w < OW; w++) {
size_t index = ic * IH * IW + (start_h + fh2) * IW +
(w + fw2);
output[i++] = uint32_src[index];
}
for (int h = start_h + 1; h < end_h; h++) {
rep(ow, OW) {
size_t index = (ic * IH * IW + (h + fh2) * IW +
(ow + fw2));
output[i++] = uint32_src[index];
}
}
for (int w = 0; w < end_remain_w; w++) {
size_t index = (ic * IH * IW + (end_h + fh2) * IW +
(w + fw2));
output[i++] = uint32_src[index];
}
}
}
}
}
}
}
template <bool is_xcorr, typename dtype>
void img2col_stride(const dtype* __restrict src, dtype* __restrict dst,
......
......@@ -124,7 +124,8 @@ struct PostProcess {
megdnn::ConvBiasForward::BiasMode bias_mode,
megdnn::param::ConvBias::NonlineMode nonlineMode,
DType bias_type, DType dst_type, size_t N, size_t OC,
size_t OH, size_t OW) {
size_t OH, size_t OW, size_t pack_oc_size = 1) {
MEGDNN_MARK_USED_VAR(pack_oc_size);
megdnn::param::Elemwise::Mode elem_mode =
megdnn::param::Elemwise::Mode::ADD;
if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) {
......@@ -154,7 +155,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::FLOAT> {
megdnn::ConvBiasForward::BiasMode bias_mode,
megdnn::param::ConvBias::NonlineMode nonlineMode,
DType bias_type, DType dst_type, size_t N, size_t OC,
size_t OH, size_t OW) {
size_t OH, size_t OW, size_t pack_oc_size=1) {
MEGDNN_MARK_USED_VAR(pack_oc_size);
megdnn::param::Elemwise::Mode elem_mode =
megdnn::param::Elemwise::Mode::ADD;
if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) {
......@@ -185,7 +187,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
megdnn::ConvBiasForward::BiasMode bias_mode,
megdnn::param::ConvBias::NonlineMode nonlineMode,
DType bias_type, DType dst_type, size_t N, size_t OC,
size_t OH, size_t OW) {
size_t OH, size_t OW,size_t pack_oc_size = 1) {
MEGDNN_MARK_USED_VAR(pack_oc_size);
MEGDNN_MARK_USED_VAR(conv_dst_ptr);
MEGDNN_MARK_USED_VAR(bias_ptr);
MEGDNN_MARK_USED_VAR(dst_ptr);
......@@ -292,7 +295,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::QUANTIZED> {
megdnn::ConvBiasForward::BiasMode bias_mode,
megdnn::param::ConvBiasV0::NonlineMode nonlineMode,
DType bias_type, DType dst_type, size_t N, size_t OC,
size_t OH, size_t OW) {
size_t OH, size_t OW, size_t pack_oc_size = 1) {
MEGDNN_MARK_USED_VAR(pack_oc_size);
megdnn::param::Elemwise::Mode elem_mode =
megdnn::param::Elemwise::Mode::ADD;
if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册