提交 2318ea3f 编写于 作者: M Megvii Engine Team

fix(dnn): fix naive average pooling overflow bug for int8 type

GitOrigin-RevId: b60a7b6cf824d104efbc7c22020220898cdff4fe
上级 f64a2e02
...@@ -47,6 +47,8 @@ struct MaxPooler { ...@@ -47,6 +47,8 @@ struct MaxPooler {
} }
}; };
//! WARNING:for Integer, if sum ctype_ set incorrectly may cause overflow such as
//! (stype_=ctype_ =int8_t)
template <typename stype_, typename ctype_> template <typename stype_, typename ctype_>
struct MeanIncludePoolerBase { struct MeanIncludePoolerBase {
using stype = stype_; using stype = stype_;
...@@ -65,6 +67,7 @@ struct MeanIncludePooler : public MeanIncludePoolerBase<T, T> { ...@@ -65,6 +67,7 @@ struct MeanIncludePooler : public MeanIncludePoolerBase<T, T> {
ctype get_ans() { return this->sum / this->count; } ctype get_ans() { return this->sum / this->count; }
}; };
//! WARNING: the result is truncated
template <> template <>
struct MeanIncludePooler<int8_t> : public MeanIncludePoolerBase<int8_t, int32_t> { struct MeanIncludePooler<int8_t> : public MeanIncludePoolerBase<int8_t, int32_t> {
using MeanIncludePoolerBase::MeanIncludePoolerBase; using MeanIncludePoolerBase::MeanIncludePoolerBase;
...@@ -74,7 +77,9 @@ struct MeanIncludePooler<int8_t> : public MeanIncludePoolerBase<int8_t, int32_t> ...@@ -74,7 +77,9 @@ struct MeanIncludePooler<int8_t> : public MeanIncludePoolerBase<int8_t, int32_t>
std::numeric_limits<int8_t>::max()); std::numeric_limits<int8_t>::max());
} }
}; };
/*!
* average pooling with zero point for quint8
/*/
template <> template <>
struct MeanIncludePooler<dt_quint8> { struct MeanIncludePooler<dt_quint8> {
int32_t sum; int32_t sum;
...@@ -107,7 +112,7 @@ struct MeanIncludePooler<dt_quint8> { ...@@ -107,7 +112,7 @@ struct MeanIncludePooler<dt_quint8> {
/*! /*!
* \brief Average pooling operation within a single window. * \brief Average pooling operation within a single window.
* Works on integers. Rounds toward +INF. * Works on integers. Rounds toward nearest Integer
* \tparam T input data type * \tparam T input data type
* \tparam U convert input data type to U before accumulating * \tparam U convert input data type to U before accumulating
* \tparam ICType data type for intermediate result * \tparam ICType data type for intermediate result
...@@ -228,10 +233,11 @@ struct MeanExcludePooler { ...@@ -228,10 +233,11 @@ struct MeanExcludePooler {
/*! /*!
* \brief Average pooling operation within a single window. * \brief Average pooling operation within a single window.
* Works on integers. Rounds toward +INF. * Works on integers. Rounds toward nearest Integer
* \tparam T input data type * \tparam T input data type
* \tparam U convert input data type to U before accumulating * \tparam U convert input data type to U before accumulating
* \tparam ICType data type for intermediate result * \tparam ICType data type for intermediate result
* WARNING:for Integer, if type U or ICType set incorrectly may cause overflow
*/ */
template <typename T, typename U, typename ICType = U> template <typename T, typename U, typename ICType = U>
struct MeanExcludeRoundedPooler { struct MeanExcludeRoundedPooler {
...@@ -256,6 +262,10 @@ struct MeanExcludeRoundedPooler { ...@@ -256,6 +262,10 @@ struct MeanExcludeRoundedPooler {
} }
}; };
template <>
struct MeanExcludePooler<int8_t> : MeanExcludeRoundedPooler<int8_t, int8_t, int32_t> {
using MeanExcludeRoundedPooler::MeanExcludeRoundedPooler;
};
template <> template <>
struct MeanExcludePooler<dt_quint8> struct MeanExcludePooler<dt_quint8>
: MeanExcludeRoundedPooler<dt_quint8, uint8_t, uint32_t> { : MeanExcludeRoundedPooler<dt_quint8, uint8_t, uint32_t> {
......
...@@ -100,4 +100,35 @@ TEST_F(NAIVE, POOLING_QUANTIZED_Q4) { ...@@ -100,4 +100,35 @@ TEST_F(NAIVE, POOLING_QUANTIZED_Q4) {
TensorValueLowbit4({1, 1, 2, 2}, u4_dt, u8_avg_exclu_dst_vec)}); TensorValueLowbit4({1, 1, 2, 2}, u4_dt, u8_avg_exclu_dst_vec)});
} }
} }
TEST_F(NAIVE, POOLING_INT_AVERAGE) {
using Mode = Pooling::Param::Mode;
Checker<Pooling> checker(handle(), /* check_dispatch */ false);
auto dt = dtype::Int8();
Pooling::Param param = {Mode::AVERAGE, 0, 0, 1, 1, 2, 2};
Testcase input_positive{
TensorValue(
{1, 1, 3, 3}, dt, {127, 127, 127, 127, 127, 127, 127, 127, 127}),
{}};
Testcase input_negative{
TensorValue(
{1, 1, 3, 3}, dt,
{-127, -127, -127, -127, -127, -127, -127, -127, -127}),
{}};
checker.set_param(param).exect(
input_positive,
Testcase{{}, TensorValue({1, 1, 2, 2}, dt, {127, 127, 127, 127})});
checker.set_param(param).exect(
input_negative,
Testcase{{}, TensorValue({1, 1, 2, 2}, dt, {-127, -127, -127, -127})});
param = {Mode::AVERAGE_COUNT_EXCLUDE_PADDING, 0, 0, 1, 1, 2, 2};
checker.set_param(param).exect(
input_positive,
Testcase{{}, TensorValue({1, 1, 2, 2}, dt, {127, 127, 127, 127})});
checker.set_param(param).exect(
input_negative,
Testcase{{}, TensorValue({1, 1, 2, 2}, dt, {-127, -127, -127, -127})});
}
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册