提交 2318ea3f 编写于 作者: M Megvii Engine Team

fix(dnn): fix naive average pooling overflow bug for int8 type

GitOrigin-RevId: b60a7b6cf824d104efbc7c22020220898cdff4fe
上级 f64a2e02
......@@ -47,6 +47,8 @@ struct MaxPooler {
}
};
//! WARNING:for Integer, if sum ctype_ set incorrectly may cause overflow such as
//! (stype_=ctype_ =int8_t)
template <typename stype_, typename ctype_>
struct MeanIncludePoolerBase {
using stype = stype_;
......@@ -65,6 +67,7 @@ struct MeanIncludePooler : public MeanIncludePoolerBase<T, T> {
ctype get_ans() { return this->sum / this->count; }
};
//! WARNING: the result is truncated
template <>
struct MeanIncludePooler<int8_t> : public MeanIncludePoolerBase<int8_t, int32_t> {
using MeanIncludePoolerBase::MeanIncludePoolerBase;
......@@ -74,7 +77,9 @@ struct MeanIncludePooler<int8_t> : public MeanIncludePoolerBase<int8_t, int32_t>
std::numeric_limits<int8_t>::max());
}
};
/*!
* average pooling with zero point for quint8
/*/
template <>
struct MeanIncludePooler<dt_quint8> {
int32_t sum;
......@@ -107,7 +112,7 @@ struct MeanIncludePooler<dt_quint8> {
/*!
* \brief Average pooling operation within a single window.
* Works on integers. Rounds toward +INF.
* Works on integers. Rounds toward nearest Integer
* \tparam T input data type
* \tparam U convert input data type to U before accumulating
* \tparam ICType data type for intermediate result
......@@ -228,10 +233,11 @@ struct MeanExcludePooler {
/*!
* \brief Average pooling operation within a single window.
* Works on integers. Rounds toward +INF.
* Works on integers. Rounds toward nearest Integer
* \tparam T input data type
* \tparam U convert input data type to U before accumulating
* \tparam ICType data type for intermediate result
* WARNING:for Integer, if type U or ICType set incorrectly may cause overflow
*/
template <typename T, typename U, typename ICType = U>
struct MeanExcludeRoundedPooler {
......@@ -256,6 +262,10 @@ struct MeanExcludeRoundedPooler {
}
};
template <>
struct MeanExcludePooler<int8_t> : MeanExcludeRoundedPooler<int8_t, int8_t, int32_t> {
using MeanExcludeRoundedPooler::MeanExcludeRoundedPooler;
};
template <>
struct MeanExcludePooler<dt_quint8>
: MeanExcludeRoundedPooler<dt_quint8, uint8_t, uint32_t> {
......
......@@ -100,4 +100,35 @@ TEST_F(NAIVE, POOLING_QUANTIZED_Q4) {
TensorValueLowbit4({1, 1, 2, 2}, u4_dt, u8_avg_exclu_dst_vec)});
}
}
TEST_F(NAIVE, POOLING_INT_AVERAGE) {
using Mode = Pooling::Param::Mode;
Checker<Pooling> checker(handle(), /* check_dispatch */ false);
auto dt = dtype::Int8();
Pooling::Param param = {Mode::AVERAGE, 0, 0, 1, 1, 2, 2};
Testcase input_positive{
TensorValue(
{1, 1, 3, 3}, dt, {127, 127, 127, 127, 127, 127, 127, 127, 127}),
{}};
Testcase input_negative{
TensorValue(
{1, 1, 3, 3}, dt,
{-127, -127, -127, -127, -127, -127, -127, -127, -127}),
{}};
checker.set_param(param).exect(
input_positive,
Testcase{{}, TensorValue({1, 1, 2, 2}, dt, {127, 127, 127, 127})});
checker.set_param(param).exect(
input_negative,
Testcase{{}, TensorValue({1, 1, 2, 2}, dt, {-127, -127, -127, -127})});
param = {Mode::AVERAGE_COUNT_EXCLUDE_PADDING, 0, 0, 1, 1, 2, 2};
checker.set_param(param).exect(
input_positive,
Testcase{{}, TensorValue({1, 1, 2, 2}, dt, {127, 127, 127, 127})});
checker.set_param(param).exect(
input_negative,
Testcase{{}, TensorValue({1, 1, 2, 2}, dt, {-127, -127, -127, -127})});
}
// vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册