提交 ca855d8d 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(dnn/fallback): fuse im2col and packb 4x4x16

GitOrigin-RevId: 123920899dd396a67a88dedcf8a2791f214a19c3
上级 4f77509e
......@@ -41,48 +41,56 @@ enum class StrategyType : uint32_t {
};
struct StrategyHashParam {
fallback::ConvBiasImpl::NCBKernSizeParam param;
param::ConvBias::Format format;
fallback::MatrixMulImpl::AlgoBase::PackMode packmode;
bool is_xcorr;
bool is_square; //! kernel_h == kernel_w, stride_h = stride_w
size_t block_m;
size_t block_n;
size_t block_k;
size_t kernel;
size_t stride;
fallback::ConvBiasImpl::NCBKernSizeParam param;
param::ConvBias::Format format;
fallback::MatrixMulImpl::AlgoBase::PackMode packmode;
};
struct StrategyHashParamHash {
std::size_t operator()(const StrategyHashParam& sparam) const {
constexpr size_t base = 1; //! avoid hashkey is zero
std::size_t result =
static_cast<std::size_t>(sparam.param.src_type.enumv()) + base;
uint64_t operator()(const StrategyHashParam& sparam) const {
constexpr uint64_t base = 1; //! avoid hashkey is zero
uint64_t result =
static_cast<uint64_t>(sparam.param.src_type.enumv()) + base;
result = result ^
((static_cast<std::size_t>(sparam.param.dst_type.enumv()) +
base)
((static_cast<uint64_t>(sparam.param.dst_type.enumv()) + base)
<< 3);
result = result ^
((static_cast<std::size_t>(sparam.param.filter_type.enumv()) +
((static_cast<uint64_t>(sparam.param.filter_type.enumv()) +
base)
<< 6);
result = result ^
((static_cast<std::size_t>(sparam.param.bias_type.enumv()) +
base)
((static_cast<uint64_t>(sparam.param.bias_type.enumv()) + base)
<< 9);
result = result ^ ((static_cast<uint64_t>(sparam.format) + base) << 12);
result = result ^
((static_cast<std::size_t>(sparam.format) + base) << 12);
result = result ^
((static_cast<std::size_t>(sparam.packmode) + base) << 15);
result = result ^
((static_cast<std::size_t>(sparam.block_m) + base) << 18);
((static_cast<uint64_t>(sparam.packmode) + base) << 15);
result =
result ^ ((static_cast<uint64_t>(sparam.block_m) + base) << 18);
result =
result ^ ((static_cast<uint64_t>(sparam.block_n) + base) << 22);
result =
result ^ ((static_cast<uint64_t>(sparam.block_k) + base) << 26);
result = result ^ ((static_cast<uint64_t>(sparam.kernel) + base) << 30);
result = result ^ ((static_cast<uint64_t>(sparam.stride) + base) << 34);
result = result ^
((static_cast<std::size_t>(sparam.block_n) + base) << 22);
((static_cast<uint64_t>(sparam.is_square) + base) << 35);
result = result ^
((static_cast<std::size_t>(sparam.block_k) + base) << 26);
((static_cast<uint64_t>(sparam.is_xcorr) + base) << 36);
return result;
};
};
struct StrategyHashParamEqual {
std::size_t operator()(const StrategyHashParam& param1,
const StrategyHashParam& param2) const {
bool operator()(const StrategyHashParam& param1,
const StrategyHashParam& param2) const {
bool flags = true;
flags = param1.param.src_type == param2.param.src_type && flags;
flags = param1.param.filter_type == param2.param.filter_type && flags;
......@@ -93,6 +101,10 @@ struct StrategyHashParamEqual {
flags = param1.block_m == param2.block_m && flags;
flags = param1.block_n == param2.block_n && flags;
flags = param1.block_k == param2.block_k && flags;
flags = param1.kernel == param2.kernel && flags;
flags = param1.stride == param2.stride && flags;
flags = param1.is_square == param2.is_square && flags;
flags = param1.is_xcorr == param2.is_xcorr && flags;
return flags;
};
};
......@@ -484,10 +496,15 @@ Strategy* StrategyDelegationStorage::get(
sparam.block_m = block_m;
sparam.block_n = block_n;
sparam.block_k = block_k;
sparam.kernel = param.filter_meta.spatial[0];
sparam.stride = param.filter_meta.stride[0];
sparam.is_square =
param.filter_meta.spatial[0] == param.filter_meta.spatial[0];
sparam.is_xcorr = param.filter_meta.should_flip;
MEGDNN_LOCK_GUARD(m_mtx);
if (map_strategys.find(sparam) == map_strategys.end()) {
MEGDNN_LOCK_GUARD(m_mtx);
auto strategy = Factory::make_strategy(matmul_algo, packmode,
param, stype);
auto strategy =
Factory::make_strategy(matmul_algo, packmode, param, stype);
map_strategys[sparam] = std::move(strategy);
}
return static_cast<Strategy*>(map_strategys[sparam].get());
......
......@@ -293,3 +293,5 @@ public:
WorkspaceBundle bundle_thread, const StrategyParam& sparam);
};
} // namespace megdnn
// vim: syntax=cpp.doxygen
文件模式从 100755 更改为 100644
/**
* \file dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/fallback/conv_bias/im2col/strategy_base.h"
// vim: syntax=cpp.doxygen
......@@ -1209,6 +1209,22 @@ TEST_F(ARM_COMMON_MULTI_THREADS,
#undef cb
}
#if MEGDNN_AARCH64
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_FUSE) {
UniformIntRNG rng{-50, 50};
#define cb(name) \
checker_conv_bias(get_nchw44_conv_bias_args({3}, 1), handle(), &rng, \
epsilon, dtype::QuantizedS8(2.5f), \
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
dtype::QuantizedS8(60.25f), name);
float epsilon = 0.001;
cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_4X4X16:96");
#undef cb
}
#endif
#endif
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册