提交 c357db01 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(mgb/arm_common): add 8x8x16 nchw44 max pooling

GitOrigin-RevId: ed460adb7a47930546f8c9e13b729d9e1306c55f
上级 7f5f375f
......@@ -58,7 +58,8 @@ WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param) {
WorkspaceBundle get_bundle_nchw44(
const PoolingImpl::PoolingKernSizeParam& param) {
megdnn_assert((param.src_type.enumv() == DTypeEnum::QuantizedS8) &&
megdnn_assert((param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Int8) &&
(param.format == param::Pooling::Format::NCHW44));
auto IH = param.isz[0];
auto IW = param.isz[1];
......@@ -605,10 +606,15 @@ bool PoolingImpl::AlgoFilter3ModexStridexNCHW44::usable(
auto FH = param.filter[0];
auto FW = param.filter[1];
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Int8) &&
param.format == Param::Format::NCHW44 &&
(param.mode == Mode::MAX || param.mode == Mode::AVERAGE) &&
FH == 3 && FW == 3 && SW == SH && (SH == 1 || SW == 2);
//! Int8 not support average, because its round mode is different form
//! quint8
avaible &= !(param.src_type.enumv() == DTypeEnum::Int8 &&
param.mode == Mode::AVERAGE);
return avaible;
}
......@@ -693,10 +699,15 @@ bool PoolingImpl::AlgoFilter2ModexStridexNCHW44::usable(
auto FH = param.filter[0];
auto FW = param.filter[1];
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Int8) &&
param.format == Param::Format::NCHW44 &&
(param.mode == Mode::MAX || param.mode == Mode::AVERAGE) &&
FH == 2 && FW == 2 && SH == SW && (SW == 1 || SW == 2);
//! Int8 not support average, because its round mode is different form
//! quint8
avaible &= !(param.src_type.enumv() == DTypeEnum::Int8 &&
param.mode == Mode::AVERAGE);
return avaible;
}
......@@ -781,10 +792,16 @@ bool PoolingImpl::AlgoFilter4ModexStridexNCHW44::usable(
auto FH = param.filter[0];
auto FW = param.filter[1];
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Int8) &&
param.format == Param::Format::NCHW44 &&
(param.mode == Mode::MAX || param.mode == Mode::AVERAGE) &&
FH == 4 && FW == 4 && SH == SW && (SW == 1 || SW == 2);
//! Int8 not support average, because its round mode is different form
//! quint8
avaible &= !(param.src_type.enumv() == DTypeEnum::Int8 &&
param.mode == Mode::AVERAGE);
return avaible;
}
......@@ -869,10 +886,15 @@ bool PoolingImpl::AlgoFilter5ModexStridexNCHW44::usable(
auto FH = param.filter[0];
auto FW = param.filter[1];
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
bool avaible = (param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Int8) &&
param.format == Param::Format::NCHW44 &&
(param.mode == Mode::MAX || param.mode == Mode::AVERAGE) &&
FH == 5 && FW == 5 && SH == SW && (SW == 1 || SW == 2);
//! Int8 not support average, because its round mode is different form
//! quint8
avaible &= !(param.src_type.enumv() == DTypeEnum::Int8 &&
param.mode == Mode::AVERAGE);
return avaible;
}
......
......@@ -119,7 +119,8 @@ size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src,
arm_common_workspace = ws.total_size_in_bytes() * nr_threads;
}
if ((param.src_type.enumv() == DTypeEnum::QuantizedS8) &&
if ((param.src_type.enumv() == DTypeEnum::QuantizedS8 ||
param.src_type.enumv() == DTypeEnum::Int8) &&
(param.format == param::Pooling::Format::NCHW44)) {
WorkspaceBundle ws = get_bundle_nchw44(param);
arm_common_workspace = ws.total_size_in_bytes() * nr_threads;
......
......@@ -118,18 +118,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_NCHW44_FP32) {
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W3x3_NCHW44)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_rng(0, &rng);
// clang-format off
for (size_t ih: {3, 5, 10})
for (size_t iw: {3, 5, 7, 9, 15, 20})
for (size_t ph: {0, 1, 2})
for (size_t pw: {0, 1, 2})
for(auto mode: {param::Pooling::Mode::MAX, param::Pooling::Mode::AVERAGE})
for(auto data_type: SmallVector<DType>{dtype::QuantizedS8(1.1f), dtype::Int8()})
if (ih+2*ph >= 3 && iw+2*pw >= 3)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_dtype(0, dtype::QuantizedS8(1.1f));
checker.set_rng(0,&rng);
checker.set_dtype(0, data_type);
param::Pooling param;
param.mode = mode;
......@@ -149,18 +150,20 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W3x3_NCHW44)
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W2x2_NCHW44)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_rng(0, &rng);
// clang-format off
for (size_t ih: {2, 5, 10, 17})
for (size_t iw: {2, 6, 8, 16, 26})
for (size_t ph: {0, 1})
for (size_t pw: {0, 1})
for(auto mode: {param::Pooling::Mode::MAX,param::Pooling::Mode::AVERAGE})
for(auto mode: {param::Pooling::Mode::MAX, param::Pooling::Mode::AVERAGE})
for(auto data_type: SmallVector<DType>{dtype::QuantizedS8(1.1f), dtype::Int8()})
if (ih+2*ph >= 2 && iw+2*pw >= 2)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_dtype(0, dtype::QuantizedS8(1.1f));
checker.set_rng(0,&rng);
checker.set_dtype(0, data_type);
checker.set_dtype(1, data_type);
param::Pooling param;
param.mode = mode;
......@@ -179,18 +182,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W2x2_NCHW44)
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W4x4_NCHW44)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_rng(0, &rng);
// clang-format off
for (size_t ih: {4, 10, 18, 25, 30})
for (size_t iw: {4, 12, 17, 20, 25})
for (size_t ph: {0, 1, 2})
for (size_t pw: {0, 1, 2})
for(auto data_type: SmallVector<DType>{dtype::QuantizedS8(1.1f), dtype::Int8()})
for(auto mode: {param::Pooling::Mode::MAX,param::Pooling::Mode::AVERAGE})
if (ih+2*ph >= 4 && iw+2*pw >= 4)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_dtype(0, dtype::QuantizedS8(1.1f));
checker.set_rng(0,&rng);
checker.set_dtype(0, data_type);
param::Pooling param;
param.mode = mode;
......@@ -208,18 +212,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W4x4_NCHW44)
}
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W5x5_NCHW44)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_rng(0, &rng);
// clang-format off
for (size_t ih: {5, 9, 19, 20, 39})
for (size_t iw: {5, 12, 23, 27, 39})
for (size_t ph: {0, 1, 2})
for (size_t pw: {0, 1, 2})
for(auto data_type: SmallVector<DType>{dtype::QuantizedS8(1.1f), dtype::Int8()})
for(auto mode: {param::Pooling::Mode::MAX,param::Pooling::Mode::AVERAGE})
if (ih+2*ph >= 5 && iw+2*pw >= 5)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_dtype(0, dtype::QuantizedS8(1.1f));
checker.set_rng(0,&rng);
checker.set_dtype(0, data_type);
param::Pooling param;
param.mode = mode;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册