提交 f5cb21ed 编写于 作者: M Megvii Engine Team

fix(mgb/opr): add non finite check

GitOrigin-RevId: a9fcd0a3509681f4c596b6322691d5485353ce2e
上级 bde5cf35
......@@ -1319,11 +1319,11 @@ protected:
};
/*!
* \brief check whether input contains inf value.
* \brief check whether input contains inf or nan value.
*/
class CheckHasInf: public OperatorBase {
class CheckNonFinite: public OperatorBase {
DEF_OPR_PARAM(Empty);
DEF_OPR_IMPL(CheckHasInf, OperatorBase, 1, 1);
DEF_OPR_IMPL(CheckNonFinite, OperatorBase, 1, 1);
public:
virtual size_t get_workspace_in_bytes(const TensorLayout &src,
......
/**
* \file dnn/src/common/check_has_inf.cpp
* \file dnn/src/common/check_non_finite.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
......@@ -14,7 +14,7 @@
namespace megdnn {
void CheckHasInf::check_exec(const TensorLayout& src, const TensorLayout& dst,
void CheckNonFinite::check_exec(const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes) {
megdnn_assert_contiguous(src);
megdnn_assert_contiguous(dst);
......@@ -24,7 +24,7 @@ void CheckHasInf::check_exec(const TensorLayout& src, const TensorLayout& dst,
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
void CheckHasInf::deduce_layout(const TensorLayout&, TensorLayout& dst) {
void CheckNonFinite::deduce_layout(const TensorLayout&, TensorLayout& dst) {
dst.shape[0] = 1;
dst.ndim = 1;
dst.dtype = dtype::Int32();
......
......@@ -216,7 +216,7 @@ private:
cb(FakeQuantBackward) \
cb(TQTForward) \
cb(TQTBackward) \
cb(CheckHasInf) \
cb(CheckNonFinite) \
cb(LSQForward) \
cb(LSQBackward) \
cb(Fill) \
......
......@@ -131,7 +131,7 @@ DEF(PermutationRNG, 1, true, true);
DEF(ShuffleRNGForward, 3, true, true);
DEF(ShuffleRNGBackward, 3, true, false);
DEF(ChecksumForward, 1, true, false);
DEF(CheckHasInf, 2, true, true);
DEF(CheckNonFinite, 2, true, true);
DEF(LSQForward, 5, true, true);
DEF(LSQBackward, 7, true, false);
DEF(Fill, 1, true, false);
......
......@@ -152,7 +152,7 @@ struct MaxOp {
};
template <typename src_ctype, typename dst_ctype, typename wtype_>
struct CheckHasInfOp {
struct CheckNonFiniteOp {
typedef wtype_ wtype;
const wtype INIT;
......@@ -162,9 +162,9 @@ struct CheckHasInfOp {
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) {
#if defined(__CUDA_ARCH__)
return isinf(src[idx]);
return !isfinite(src[idx]);
#else
return std::isinf(src[idx]);
return !std::isfinite(src[idx]);
#endif
}
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) {
......@@ -173,7 +173,7 @@ struct CheckHasInfOp {
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) {
return lhs | rhs;
}
MEGDNN_HOST MEGDNN_DEVICE CheckHasInfOp(src_ctype* src, dst_ctype* dst,
MEGDNN_HOST MEGDNN_DEVICE CheckNonFiniteOp(src_ctype* src, dst_ctype* dst,
size_t B)
: INIT(wtype(0)), src(src), dst(dst), B(B) {}
};
......
/**
* \file dnn/src/cuda/check_has_inf/kern.cu
* \file dnn/src/cuda/check_non_finite/kern.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
......@@ -18,7 +18,7 @@ namespace cuda {
#define COMMA ,
INST_REDUCE(reduce::CheckHasInfOp<dt_float32 COMMA dt_int32 COMMA dt_int32>, false);
INST_REDUCE(reduce::CheckNonFiniteOp<dt_float32 COMMA dt_int32 COMMA dt_int32>, false);
#undef COMMA
} // namespace cuda
......
/**
* \file dnn/src/cuda/check_has_inf/opr_impl.cpp
* \file dnn/src/cuda/check_non_finite/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
......@@ -9,7 +9,7 @@
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cuda/check_has_inf/opr_impl.h"
#include "src/cuda/check_non_finite/opr_impl.h"
#include "src/cuda/reduce_helper.cuh"
#include "src/cuda/handle.h"
......@@ -20,18 +20,18 @@
namespace megdnn {
namespace cuda {
using reduce::CheckHasInfOp;
using reduce::CheckNonFiniteOp;
size_t CheckHasInfImpl::get_workspace_in_bytes(const TensorLayout& src,
size_t CheckNonFiniteImpl::get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) {
typedef CheckHasInfOp<dt_float32, dt_int32, dt_int32> Op;
typedef CheckNonFiniteOp<dt_float32, dt_int32, dt_int32> Op;
return get_reduce_workspace_in_bytes<Op>(1, src.total_nr_elems(), 1);
}
void CheckHasInfImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
void CheckNonFiniteImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size);
typedef CheckHasInfOp<dt_float32, dt_int32, dt_int32> Op;
typedef CheckNonFiniteOp<dt_float32, dt_int32, dt_int32> Op;
auto stream = cuda_stream(this->handle());
auto B = src.layout.total_nr_elems();
return run_reduce<Op, false>(
......
/**
* \file dnn/src/cuda/check_has_inf/opr_impl.h
* \file dnn/src/cuda/check_non_finite/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
......@@ -17,9 +17,9 @@
namespace megdnn {
namespace cuda {
class CheckHasInfImpl final : public CheckHasInf {
class CheckNonFiniteImpl final : public CheckNonFinite {
public:
using CheckHasInf::CheckHasInf;
using CheckNonFinite::CheckNonFinite;
size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) override;
......
......@@ -20,7 +20,7 @@
#include "src/cuda/batch_conv_bias/opr_impl.h"
#include "src/cuda/batch_normalization/opr_impl.h"
#include "src/cuda/batched_matrix_mul/opr_impl.h"
#include "src/cuda/check_has_inf/opr_impl.h"
#include "src/cuda/check_non_finite/opr_impl.h"
#include "src/cuda/checksum/opr_impl.h"
#include "src/cuda/concat/opr_impl.h"
#include "src/cuda/cond_take/opr_impl.h"
......
/**
* \file dnn/src/naive/check_has_inf/opr_impl.cpp
* \file dnn/src/naive/check_non_finite/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
......@@ -9,7 +9,7 @@
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/naive/check_has_inf/opr_impl.h"
#include "src/naive/check_non_finite/opr_impl.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
......@@ -27,7 +27,7 @@ void reduce_fwd(const src_ctype* sptr, wtype* dptr, size_t size) {
size_t mid = l + (r - l) / 2;
return func(l, mid) | func(mid, r);
} else {
return static_cast<wtype>(std::isinf(sptr[l]));
return static_cast<wtype>(!std::isfinite(sptr[l]));
}
};
......@@ -39,12 +39,12 @@ void reduce_fwd(const src_ctype* sptr, wtype* dptr, size_t size) {
namespace megdnn {
namespace naive {
size_t CheckHasInfImpl::get_workspace_in_bytes(const TensorLayout&,
size_t CheckNonFiniteImpl::get_workspace_in_bytes(const TensorLayout&,
const TensorLayout&) {
return 0;
}
void CheckHasInfImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
void CheckNonFiniteImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size);
......
/**
* \file dnn/src/naive/check_has_inf/opr_impl.h
* \file dnn/src/naive/check_non_finite/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
......@@ -16,9 +16,9 @@
namespace megdnn {
namespace naive {
class CheckHasInfImpl final : public CheckHasInf {
class CheckNonFiniteImpl final : public CheckNonFinite {
public:
using CheckHasInf::CheckHasInf;
using CheckNonFinite::CheckNonFinite;
bool is_thread_safe() const override { return true; }
......
......@@ -22,7 +22,7 @@
#include "src/naive/batch_conv_bias/opr_impl.h"
#include "src/naive/batch_normalization/opr_impl.h"
#include "src/naive/batched_matrix_mul/opr_impl.h"
#include "src/naive/check_has_inf/opr_impl.h"
#include "src/naive/check_non_finite/opr_impl.h"
#include "src/naive/checksum/opr_impl.h"
#include "src/naive/concat/opr_impl.h"
#include "src/naive/cond_take/opr_impl.h"
......
/**
* \file dnn/test/cuda/check_has_inf.cpp
* \file dnn/test/cuda/check_non_finite.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
......@@ -15,16 +15,20 @@
namespace megdnn {
namespace test {
TEST_F(CUDA, CHECK_HAS_INF_BASIC) {
Checker<CheckHasInf> checker(handle_cuda());
TEST_F(CUDA, CHECK_NON_FINITE_BASIC) {
Checker<CheckNonFinite> checker(handle_cuda());
checker.set_allow_invalid_check(true);
const auto inf = std::numeric_limits<float>::infinity();
const auto nan = std::numeric_limits<float>::quiet_NaN();
UniformFloatWithValueRNG rng(-1.0f, 1.0f, 0.1f, inf);
checker.set_rng(0, &rng);
checker.execs({{512*16}, {1}});
rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 1.f, inf);
checker.set_rng(0, &rng);
checker.execs({{512*16}, {1}});
rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 1.f, nan);
checker.set_rng(0, &rng);
checker.execs({{512*16}, {1}});
}
} // namespace test
......
/**
* \file test/naive/check_has_inf.cpp
* \file test/naive/check_non_finite.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
......@@ -17,8 +17,8 @@
namespace megdnn {
namespace test {
TEST_F(NAIVE, CHECK_HAS_INF_BASIC) {
Checker<CheckHasInf> checker(handle(), false);
TEST_F(NAIVE, CHECK_NON_FINITE_BASIC) {
Checker<CheckNonFinite> checker(handle(), false);
checker.exect(Testcase{TensorValue({4}, dtype::Float32(),
{1.1, 2.2, 3.3, 4.3}),
{}},
......@@ -29,6 +29,12 @@ TEST_F(NAIVE, CHECK_HAS_INF_BASIC) {
std::numeric_limits<float>::infinity()}),
{}},
Testcase{{}, TensorValue({1}, dtype::Int32(), {1})});
checker.exect(
Testcase{TensorValue({4}, dtype::Float32(),
{1.1f, 2.2f, 3.3f,
std::numeric_limits<float>::quiet_NaN()}),
{}},
Testcase{{}, TensorValue({1}, dtype::Int32(), {1})});
}
} // namespace test
......
......@@ -11,7 +11,7 @@ import numpy as np
from ..autodiff import GradManager
from ..functional import full_like
from ..functional.math import _has_inf
from ..functional.math import _check_non_finite
from ..tensor import Tensor
......@@ -76,7 +76,7 @@ class GradScaler:
self.growth_interval = growth_interval
self._growth_tracker = 0
self._found_inf = False
self._found_non_finite = False
def backward(
self,
......@@ -135,10 +135,10 @@ class GradScaler:
continue
# to support tracing, _check_gradients should be applied to every grad.
if self._check_gradients(tensor.grad):
self._found_inf = True
self._found_non_finite = True
tensor.grad *= inv_scale
if self._found_inf:
if self._found_non_finite:
for tensor in grad_tensors:
if tensor is None or getattr(tensor, "grad", None) is None:
continue
......@@ -148,7 +148,7 @@ class GradScaler:
def _check_gradients(self, grad):
if self.growth_interval == 0:
return False
return _has_inf(grad)
return _check_non_finite(grad)
def update(self, new_scale: float = None):
r"""Update the scale factor according to whether encountered overflow grad.
......@@ -160,7 +160,7 @@ class GradScaler:
if new_scale is not None:
self.scale_factor = float(new_scale)
else:
if self._found_inf:
if self._found_non_finite:
self.scale_factor *= self.backoff_factor
self._growth_tracker = 0
else:
......@@ -168,7 +168,7 @@ class GradScaler:
if self._growth_tracker >= self.growth_interval:
self.scale_factor *= self.growth_factor
self._growth_tracker = 0
self._found_inf = False
self._found_non_finite = False
def state_dict(self):
return {
......
......@@ -1181,8 +1181,8 @@ def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor:
return U, sigma, V
def _has_inf(inp: Tensor) -> Tensor:
r"""Check whether input contains infinite value.
def _check_non_finite(inp: Tensor) -> Tensor:
r"""Check whether input contains infinite or nan value.
Args:
inp: a tensor to be checked.
......@@ -1190,7 +1190,7 @@ def _has_inf(inp: Tensor) -> Tensor:
Returns:
a int32 scalar tensor, 0 for False and 1 for True.
"""
op = builtin.CheckHasInf()
op = builtin.CheckNonFinite()
(oup,) = apply(op, inp.reshape(-1).astype("float32"))
oup._setscalar()
return oup
......@@ -185,14 +185,18 @@ def test_sum_neg_axis():
F.sum(tensor(data), axis=(-1, 1))
def test_has_inf():
def test_non_finite():
shape = (32, 3, 32, 32)
data = np.random.random(shape).astype(np.float32)
rst = F.math._has_inf(tensor(data))
rst = F.math._check_non_finite(tensor(data))
np.testing.assert_equal(rst.numpy(), [0])
data[0][0][0][0] = float("inf")
rst = F.math._has_inf(tensor(data))
rst = F.math._check_non_finite(tensor(data))
np.testing.assert_equal(rst.numpy(), [1])
data[0][0][0][0] = float("nan")
rst = F.math._check_non_finite(tensor(data))
np.testing.assert_equal(rst.numpy(), [1])
......
......@@ -16,17 +16,17 @@
namespace mgb {
namespace imperative {
namespace check_has_inf {
namespace check_non_finite {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<CheckHasInf>();
auto&& op = def.cast_final_safe<CheckNonFinite>();
mgb_assert(inputs.size() == 1);
OperatorNodeConfig config{op.make_name()};
return opr::CheckHasInf::make(inputs[0], {}, config);
return opr::CheckNonFinite::make(inputs[0], {}, config);
}
OP_TRAIT_REG(CheckHasInf, CheckHasInf)
OP_TRAIT_REG(CheckNonFinite, CheckNonFinite)
.apply_on_var_node(apply_on_var_node)
.fallback();
} // namespace check_has_inf
} // namespace check_non_finite
} // namespace imperative
} // namespace mgb
......
......@@ -390,7 +390,7 @@ def CambriconRuntime: MgbHashableOp<"CambriconRuntime"> {
def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>;
def CheckHasInf: MgbHashableOp<"CheckHasInf", [EmptyParam]>;
def CheckNonFinite: MgbHashableOp<"CheckNonFinite", [EmptyParam]>;
def FastpathCopy: MgbHashableOp<"FastpathCopy">;
......
......@@ -491,12 +491,12 @@ MGB_IMPL_OPR_GRAD(TopK) {
}
#endif
/* ================= CheckHasInf ================= */
/* ================= CheckNonFinite ================= */
namespace mgb {
namespace opr {
namespace intl {
template<>
struct MegDNNOprInitPostCtor<CheckHasInf> {
struct MegDNNOprInitPostCtor<CheckNonFinite> {
static void apply(cg::OperatorNodeBase &opr) {
opr.output(0)->dtype(dtype::Int32());
}
......@@ -504,6 +504,6 @@ struct MegDNNOprInitPostCtor<CheckHasInf> {
}
}
}
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CheckHasInf);
MEGDNN_OPR_INIT1(CheckHasInf, "check_has_inf")
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CheckNonFinite);
MEGDNN_OPR_INIT1(CheckNonFinite, "check_non_finite")
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -72,7 +72,7 @@ namespace opr {
#if MGB_CUDA
MGB_SEREG_OPR(NvOf, 1);
#endif
MGB_SEREG_OPR(CheckHasInf, 1);
MGB_SEREG_OPR(CheckNonFinite, 1);
} // namespace opr
} // namespace mgb
......
......@@ -185,7 +185,7 @@ public:
const OperatorNodeConfig& config = {});
};
MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(CheckHasInf);
MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(CheckNonFinite);
} // namespace opr
} // namespace mgb
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册