提交 89186edc 编写于 作者: M Megvii Engine Team

fix(dnn): correct reduce/argmxx/fakequant calculation with nan

GitOrigin-RevId: 7e78bdae9106186c5d1a1b8ee2ab337ed69db21b
上级 68cdabd2
...@@ -78,6 +78,72 @@ struct ArgmxxOp { ...@@ -78,6 +78,72 @@ struct ArgmxxOp {
const wtype INIT; const wtype INIT;
}; };
template <bool is_max>
struct ArgmxxOp<dt_float32, is_max> {
using stype_ = dt_float32;
struct wtype {
stype_ key;
dt_int32 val;
MEGDNN_HOST MEGDNN_DEVICE wtype() {}
MEGDNN_HOST MEGDNN_DEVICE wtype(stype_ key, dt_int32 val)
: key(key), val(val) {}
MEGDNN_HOST MEGDNN_DEVICE wtype(wtype& rhs) : key(rhs.key), val(rhs.val) {}
MEGDNN_HOST MEGDNN_DEVICE wtype(volatile wtype& rhs)
: key(rhs.key), val(rhs.val) {}
MEGDNN_HOST MEGDNN_DEVICE wtype(const wtype& rhs)
: key(rhs.key), val(rhs.val) {}
MEGDNN_HOST MEGDNN_DEVICE wtype(const volatile wtype& rhs)
: key(rhs.key), val(rhs.val) {}
MEGDNN_HOST MEGDNN_DEVICE volatile wtype& operator=(const wtype& rhs) volatile {
this->key = rhs.key;
this->val = rhs.val;
return *this;
}
};
MEGDNN_HOST MEGDNN_DEVICE
ArgmxxOp(stype_* src, dt_int32* dst, uint32_t A, uint32_t B, uint32_t C)
: src(src),
dst(dst),
A(A),
B(B),
C(C),
INIT(wtype(
is_max ? DTypeTrait<stype_>::min() : DTypeTrait<stype_>::max(),
0)) {}
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) {
wtype res;
res.key = src[idx];
res.val = idx / C % B;
return res;
}
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) {
dst[idx] = val.val;
}
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) {
#if defined(__CUDA_ARCH__)
if (isnan(lhs.key))
#else
if (std::isnan(lhs.key))
#endif
return lhs;
if (is_max) {
if (lhs.key > rhs.key)
return lhs;
else
return rhs;
} else {
if (lhs.key < rhs.key)
return lhs;
else
return rhs;
}
}
stype_* src;
dt_int32* dst;
uint32_t A, B, C;
const wtype INIT;
};
} // namespace argmxx } // namespace argmxx
} // namespace megdnn } // namespace megdnn
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -119,6 +119,28 @@ struct MinOp { ...@@ -119,6 +119,28 @@ struct MinOp {
: INIT(wtype(DTypeTrait<wtype>::max())), src(src), dst(dst), B(B) {} : INIT(wtype(DTypeTrait<wtype>::max())), src(src), dst(dst), B(B) {}
}; };
template <typename src_ctype, typename dst_ctype>
struct MinOp<src_ctype, dst_ctype, dt_float32> {
typedef dt_float32 wtype;
const wtype INIT;
src_ctype* src;
dst_ctype* dst;
const size_t B;
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; }
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; }
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) {
#if defined(__CUDA_ARCH__)
return (isnan(lhs) || lhs < rhs) ? lhs : rhs;
#else
return (std::isnan(lhs) || lhs < rhs) ? lhs : rhs;
#endif
}
MEGDNN_HOST MEGDNN_DEVICE MinOp(src_ctype* src, dst_ctype* dst, size_t B)
: INIT(wtype(DTypeTrait<wtype>::max())), src(src), dst(dst), B(B) {}
};
template <typename src_ctype, typename dst_ctype, typename wtype_> template <typename src_ctype, typename dst_ctype, typename wtype_>
struct MaxOp { struct MaxOp {
typedef wtype_ wtype; typedef wtype_ wtype;
...@@ -141,6 +163,28 @@ struct MaxOp { ...@@ -141,6 +163,28 @@ struct MaxOp {
: INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {} : INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {}
}; };
template <typename src_ctype, typename dst_ctype>
struct MaxOp<src_ctype, dst_ctype, dt_float32> {
typedef dt_float32 wtype;
const wtype INIT;
src_ctype* src;
dst_ctype* dst;
const size_t B;
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; }
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; }
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) {
#if defined(__CUDA_ARCH__)
return (isnan(lhs) || lhs > rhs) ? lhs : rhs;
#else
return (std::isnan(lhs) || lhs > rhs) ? lhs : rhs;
#endif
}
MEGDNN_HOST MEGDNN_DEVICE MaxOp(src_ctype* src, dst_ctype* dst, size_t B)
: INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {}
};
template <typename src_ctype, typename dst_ctype, typename wtype_> template <typename src_ctype, typename dst_ctype, typename wtype_>
struct CheckNonFiniteOp { struct CheckNonFiniteOp {
typedef wtype_ wtype; typedef wtype_ wtype;
......
...@@ -30,6 +30,10 @@ struct FakeQuantKernOp { ...@@ -30,6 +30,10 @@ struct FakeQuantKernOp {
__device__ void operator()(uint32_t idx, ctype scale, ctype zero_point) { __device__ void operator()(uint32_t idx, ctype scale, ctype zero_point) {
ctype x = round(input[idx] / scale) + zero_point; ctype x = round(input[idx] / scale) + zero_point;
if (isnan(x)) {
output[idx] = NAN;
return;
}
x = fmaxf(fminf(x, qmax), qmin); x = fmaxf(fminf(x, qmax), qmin);
output[idx] = (x - zero_point) * scale; output[idx] = (x - zero_point) * scale;
} }
...@@ -54,7 +58,7 @@ struct FakeQuantBwdKernOp { ...@@ -54,7 +58,7 @@ struct FakeQuantBwdKernOp {
__device__ void operator()(uint32_t idx, ctype scale, ctype zero_point) { __device__ void operator()(uint32_t idx, ctype scale, ctype zero_point) {
ctype x = round(input[idx] / scale) + zero_point; ctype x = round(input[idx] / scale) + zero_point;
grad[idx] = x <= qmax && x >= qmin ? diff[idx] : 0.0; grad[idx] = isnan(x) ? NAN : x <= qmax && x >= qmin ? diff[idx] : 0.0;
} }
#if MEGDNN_CC_HOST #if MEGDNN_CC_HOST
...@@ -77,6 +81,10 @@ struct FakeQuantKernOpNonContig { ...@@ -77,6 +81,10 @@ struct FakeQuantKernOpNonContig {
__device__ void operator()( __device__ void operator()(
uint32_t, ctype& output, ctype input, ctype scale, ctype zero_point) { uint32_t, ctype& output, ctype input, ctype scale, ctype zero_point) {
ctype x = round(input / scale) + zero_point; ctype x = round(input / scale) + zero_point;
if (isnan(x)) {
output = NAN;
return;
}
x = fmaxf(fminf(x, qmax), qmin); x = fmaxf(fminf(x, qmax), qmin);
output = (x - zero_point) * scale; output = (x - zero_point) * scale;
} }
...@@ -96,7 +104,7 @@ struct FakeQuantBwdKernOpNonContig { ...@@ -96,7 +104,7 @@ struct FakeQuantBwdKernOpNonContig {
uint32_t, ctype& grad, ctype diff, ctype input, ctype scale, uint32_t, ctype& grad, ctype diff, ctype input, ctype scale,
ctype zero_point) { ctype zero_point) {
ctype x = round(input / scale) + zero_point; ctype x = round(input / scale) + zero_point;
grad = x <= qmax && x >= qmin ? diff : 0.0; grad = isnan(x) ? NAN : x <= qmax && x >= qmin ? diff : 0.0;
} }
#if MEGDNN_CC_HOST #if MEGDNN_CC_HOST
......
...@@ -26,14 +26,18 @@ struct traits; ...@@ -26,14 +26,18 @@ struct traits;
template <> template <>
struct traits<true> { struct traits<true> {
static const float init; static const float init;
static bool better_than(float lhs, float rhs) { return lhs > rhs; } static bool better_than(float lhs, float rhs) {
return std::isnan(lhs) ? true : lhs > rhs;
}
}; };
const float traits<true>::init = std::numeric_limits<float>::lowest(); const float traits<true>::init = std::numeric_limits<float>::lowest();
template <> template <>
struct traits<false> { struct traits<false> {
static const float init; static const float init;
static float better_than(float lhs, float rhs) { return lhs < rhs; } static float better_than(float lhs, float rhs) {
return std::isnan(lhs) ? true : lhs < rhs;
}
}; };
const float traits<false>::init = std::numeric_limits<float>::max(); const float traits<false>::init = std::numeric_limits<float>::max();
......
...@@ -73,25 +73,35 @@ const ctype Trait<Mode::PRODUCT, ctype>::INIT = ctype(1); ...@@ -73,25 +73,35 @@ const ctype Trait<Mode::PRODUCT, ctype>::INIT = ctype(1);
template <typename ctype> template <typename ctype>
struct Trait<Mode::MIN, ctype> { struct Trait<Mode::MIN, ctype> {
static const ctype INIT;
static ctype apply(ctype x, ctype y) { return x < y ? x : y; } static ctype apply(ctype x, ctype y) { return x < y ? x : y; }
static ctype visit(ctype x) { return x; } static ctype visit(ctype x) { return x; }
static ctype write(ctype x, size_t) { return x; } static ctype write(ctype x, size_t) { return x; }
}; };
template <typename ctype>
const ctype Trait<Mode::MIN, ctype>::INIT = DTypeTrait<ctype>::max(); template <>
struct Trait<Mode::MIN, dt_float32> {
using ctype = dt_float32;
static ctype apply(ctype x, ctype y) { return (std::isnan(x) || x < y) ? x : y; }
static ctype visit(ctype x) { return x; }
static ctype write(ctype x, size_t) { return x; }
};
template <typename ctype> template <typename ctype>
struct Trait<Mode::MAX, ctype> { struct Trait<Mode::MAX, ctype> {
static const ctype INIT;
static ctype apply(ctype x, ctype y) { return x > y ? x : y; } static ctype apply(ctype x, ctype y) { return x > y ? x : y; }
static ctype visit(ctype x) { return x; } static ctype visit(ctype x) { return x; }
static ctype write(ctype x, size_t) { return x; } static ctype write(ctype x, size_t) { return x; }
}; };
template <typename ctype>
const ctype Trait<Mode::MAX, ctype>::INIT = DTypeTrait<ctype>::min(); template <>
struct Trait<Mode::MAX, dt_float32> {
using ctype = dt_float32;
static ctype apply(ctype x, ctype y) { return (std::isnan(x) || x > y) ? x : y; }
static ctype visit(ctype x) { return x; }
static ctype write(ctype x, size_t) { return x; }
};
template <Mode mode, typename ctype> template <Mode mode, typename ctype>
void reduce_fwd( void reduce_fwd(
......
...@@ -21,7 +21,9 @@ using namespace fake_quant; ...@@ -21,7 +21,9 @@ using namespace fake_quant;
TEST_F(CUDA, FAKE_QUANT) { TEST_F(CUDA, FAKE_QUANT) {
std::vector<TestArg> args = get_args(); std::vector<TestArg> args = get_args();
auto dtype = dtype::Float32(); auto dtype = dtype::Float32();
std::unique_ptr<RNG> rng; UniformFloatRNG rng(-1.0f, 1.0f);
const auto nan = std::numeric_limits<float>::quiet_NaN();
UniformFloatWithValueRNG rng1 = UniformFloatWithValueRNG(-1.0f, 1.0f, 0.5f, nan);
for (auto&& arg : args) { for (auto&& arg : args) {
auto param = arg.param; auto param = arg.param;
...@@ -35,6 +37,17 @@ TEST_F(CUDA, FAKE_QUANT) { ...@@ -35,6 +37,17 @@ TEST_F(CUDA, FAKE_QUANT) {
.set_dtype(2, dtype) .set_dtype(2, dtype)
.set_dtype(3, dtype) .set_dtype(3, dtype)
.execs(TensorShapeArray{ishape, scale_shape, zeropoint_shape, ishape}); .execs(TensorShapeArray{ishape, scale_shape, zeropoint_shape, ishape});
checker.set_allow_invalid_check(true);
checker.set_rng(0, &rng1);
checker.set_param(param)
.set_dtype(0, dtype)
.set_dtype(1, dtype)
.set_dtype(2, dtype)
.set_dtype(3, dtype)
.execs(TensorShapeArray{ishape, scale_shape, zeropoint_shape, ishape});
checker.set_rng(0, &rng);
checker.set_allow_invalid_check(false);
} }
// test noncontiguous layout // test noncontiguous layout
for (auto&& arg : args) { for (auto&& arg : args) {
...@@ -53,12 +66,25 @@ TEST_F(CUDA, FAKE_QUANT) { ...@@ -53,12 +66,25 @@ TEST_F(CUDA, FAKE_QUANT) {
{scale_shape, dtype::Float32()}, {scale_shape, dtype::Float32()},
{zeropoint_shape, dtype::Float32()}, {zeropoint_shape, dtype::Float32()},
ilayout}); ilayout});
checker.set_allow_invalid_check(true);
checker.set_rng(0, &rng1);
checker.set_param(param).execl(
{ilayout,
{scale_shape, dtype::Float32()},
{zeropoint_shape, dtype::Float32()},
ilayout});
checker.set_rng(0, &rng);
checker.set_allow_invalid_check(false);
} }
} }
TEST_F(CUDA, FAKE_QUANT_BACKWARD) { TEST_F(CUDA, FAKE_QUANT_BACKWARD) {
std::vector<TestArg> args = get_args(); std::vector<TestArg> args = get_args();
auto dtype = dtype::Float32(); auto dtype = dtype::Float32();
UniformFloatRNG rng(-1.0f, 1.0f);
const auto nan = std::numeric_limits<float>::quiet_NaN();
UniformFloatWithValueRNG rng1 = UniformFloatWithValueRNG(-1.0f, 1.0f, 0.5f, nan);
for (auto&& arg : args) { for (auto&& arg : args) {
auto param = arg.param; auto param = arg.param;
...@@ -74,6 +100,19 @@ TEST_F(CUDA, FAKE_QUANT_BACKWARD) { ...@@ -74,6 +100,19 @@ TEST_F(CUDA, FAKE_QUANT_BACKWARD) {
.set_dtype(4, dtype) .set_dtype(4, dtype)
.execs(TensorShapeArray{ .execs(TensorShapeArray{
ishape, ishape, scale_shape, zeropoint_shape, ishape}); ishape, ishape, scale_shape, zeropoint_shape, ishape});
checker.set_allow_invalid_check(true);
checker.set_rng(0, &rng1);
checker.set_param(param)
.set_dtype(0, dtype)
.set_dtype(1, dtype)
.set_dtype(2, dtype)
.set_dtype(3, dtype)
.set_dtype(4, dtype)
.execs(TensorShapeArray{
ishape, ishape, scale_shape, zeropoint_shape, ishape});
checker.set_rng(0, &rng);
checker.set_allow_invalid_check(false);
} }
// test noncontiguous layout // test noncontiguous layout
for (auto&& arg : args) { for (auto&& arg : args) {
...@@ -93,6 +132,17 @@ TEST_F(CUDA, FAKE_QUANT_BACKWARD) { ...@@ -93,6 +132,17 @@ TEST_F(CUDA, FAKE_QUANT_BACKWARD) {
{scale_shape, dtype::Float32()}, {scale_shape, dtype::Float32()},
{zeropoint_shape, dtype::Float32()}, {zeropoint_shape, dtype::Float32()},
ilayout}); ilayout});
checker.set_allow_invalid_check(true);
checker.set_rng(0, &rng1);
checker.set_param(param).execl(
{ilayout,
ilayout,
{scale_shape, dtype::Float32()},
{zeropoint_shape, dtype::Float32()},
ilayout});
checker.set_rng(0, &rng);
checker.set_allow_invalid_check(false);
} }
} }
......
...@@ -54,6 +54,20 @@ TEST_F(CUDA, REDUCE) { ...@@ -54,6 +54,20 @@ TEST_F(CUDA, REDUCE) {
// very large reduce // very large reduce
checker.execs({{1, 4194304, 1}, {}}); checker.execs({{1, 4194304, 1}, {}});
// inputs have nan
{
const auto nan = std::numeric_limits<float>::quiet_NaN();
UniformFloatWithValueRNG rng1 =
UniformFloatWithValueRNG(-1.0f, 1.0f, 0.5f, nan);
checker.set_allow_invalid_check(true).set_rng(0, &rng1);
for (auto mode : {Mode::MIN, Mode::MAX}) {
checker.set_param({mode, 1});
checker.execs({{2, 64, 32}, {}});
}
checker.set_allow_invalid_check(false);
}
checker.set_rng(0, &rng);
auto check = [&](Reduce::Mode mode, DType src_dtype, DType dst_dtype, auto check = [&](Reduce::Mode mode, DType src_dtype, DType dst_dtype,
Reduce::DataType data_type) { Reduce::DataType data_type) {
for (int32_t axis : {0, 1, 2, 3}) { for (int32_t axis : {0, 1, 2, 3}) {
......
...@@ -21,7 +21,11 @@ def common_test_reduce(opr, ref_opr): ...@@ -21,7 +21,11 @@ def common_test_reduce(opr, ref_opr):
data2_shape = (2, 9, 12) data2_shape = (2, 9, 12)
data1 = np.random.random(data1_shape).astype(np.float32) data1 = np.random.random(data1_shape).astype(np.float32)
data2 = np.random.random(data2_shape).astype(np.float32) data2 = np.random.random(data2_shape).astype(np.float32)
cases = [{"input": data1}, {"input": data2}] cases = [
{"input": data1},
{"input": data2},
{"input": np.array([[[1, 2, np.nan, 4], [8, 6, 5, 2], [2, 3, 4, 5]]])},
]
if opr not in (F.argmin, F.argmax): if opr not in (F.argmin, F.argmax):
# test default axis # test default axis
......
...@@ -143,6 +143,11 @@ def test_fakequant(): ...@@ -143,6 +143,11 @@ def test_fakequant():
assert np.allclose(x.grad.numpy(), x1.grad.numpy()) assert np.allclose(x.grad.numpy(), x1.grad.numpy())
assert make_shape_tuple(x.grad.shape) == make_shape_tuple(x1.grad.shape) assert make_shape_tuple(x.grad.shape) == make_shape_tuple(x1.grad.shape)
# test nan
x = F.full((1, 32, 3, 3), np.nan)
y = fake_quant_tensor(x, qparams).numpy()
assert np.isnan(y).all()
zero_point = tensor([1.0], dtype=np.float32) zero_point = tensor([1.0], dtype=np.float32)
scale = tensor([4.0], dtype=np.float32) scale = tensor([4.0], dtype=np.float32)
run(zero_point, scale) run(zero_point, scale)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册