From f5cb21ed3a547ca806dd0daf099a7e38114ff59b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 14 Sep 2021 10:54:25 +0800 Subject: [PATCH] fix(mgb/opr): add non finite check GitOrigin-RevId: a9fcd0a3509681f4c596b6322691d5485353ce2e --- dnn/include/megdnn/oprs/general.h | 6 +++--- .../{check_has_inf.cpp => check_non_finite.cpp} | 6 +++--- dnn/src/common/handle_impl.h | 2 +- dnn/src/common/opr_trait.h | 2 +- dnn/src/common/reduce_helper.h | 8 ++++---- .../{check_has_inf => check_non_finite}/kern.cu | 4 ++-- .../opr_impl.cpp | 14 +++++++------- .../{check_has_inf => check_non_finite}/opr_impl.h | 6 +++--- dnn/src/cuda/handle_create.cpp | 2 +- .../opr_impl.cpp | 10 +++++----- .../{check_has_inf => check_non_finite}/opr_impl.h | 6 +++--- dnn/src/naive/handle.cpp | 2 +- .../{check_has_inf.cpp => check_non_finite.cpp} | 10 +++++++--- .../{check_has_inf.cpp => check_non_finite.cpp} | 12 +++++++++--- imperative/python/megengine/amp/grad_scaler.py | 14 +++++++------- imperative/python/megengine/functional/math.py | 6 +++--- .../python/test/unit/functional/test_math.py | 10 +++++++--- imperative/src/impl/ops/misc.cpp | 10 +++++----- src/core/include/megbrain/ir/ops.td | 2 +- src/opr/impl/misc.cpp | 8 ++++---- src/opr/impl/misc.sereg.h | 2 +- src/opr/include/megbrain/opr/misc.h | 2 +- 22 files changed, 79 insertions(+), 65 deletions(-) rename dnn/src/common/{check_has_inf.cpp => check_non_finite.cpp} (82%) rename dnn/src/cuda/{check_has_inf => check_non_finite}/kern.cu (82%) rename dnn/src/cuda/{check_has_inf => check_non_finite}/opr_impl.cpp (73%) rename dnn/src/cuda/{check_has_inf => check_non_finite}/opr_impl.h (85%) rename dnn/src/naive/{check_has_inf => check_non_finite}/opr_impl.cpp (81%) rename dnn/src/naive/{check_has_inf => check_non_finite}/opr_impl.h (84%) rename dnn/test/cuda/{check_has_inf.cpp => check_non_finite.cpp} (74%) rename dnn/test/naive/{check_has_inf.cpp => check_non_finite.cpp} (72%) diff --git a/dnn/include/megdnn/oprs/general.h b/dnn/include/megdnn/oprs/general.h index 20a9f1bf1..d57bfcbb2 100644 --- a/dnn/include/megdnn/oprs/general.h +++ b/dnn/include/megdnn/oprs/general.h @@ -1319,11 +1319,11 @@ protected: }; /*! - * \brief check whether input contains inf value. + * \brief check whether input contains inf or nan value. */ -class CheckHasInf: public OperatorBase { +class CheckNonFinite: public OperatorBase { DEF_OPR_PARAM(Empty); - DEF_OPR_IMPL(CheckHasInf, OperatorBase, 1, 1); + DEF_OPR_IMPL(CheckNonFinite, OperatorBase, 1, 1); public: virtual size_t get_workspace_in_bytes(const TensorLayout &src, diff --git a/dnn/src/common/check_has_inf.cpp b/dnn/src/common/check_non_finite.cpp similarity index 82% rename from dnn/src/common/check_has_inf.cpp rename to dnn/src/common/check_non_finite.cpp index 66f1a63c9..64e4657d3 100644 --- a/dnn/src/common/check_has_inf.cpp +++ b/dnn/src/common/check_non_finite.cpp @@ -1,5 +1,5 @@ /** - * \file dnn/src/common/check_has_inf.cpp + * \file dnn/src/common/check_non_finite.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -14,7 +14,7 @@ namespace megdnn { -void CheckHasInf::check_exec(const TensorLayout& src, const TensorLayout& dst, +void CheckNonFinite::check_exec(const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { megdnn_assert_contiguous(src); megdnn_assert_contiguous(dst); @@ -24,7 +24,7 @@ void CheckHasInf::check_exec(const TensorLayout& src, const TensorLayout& dst, megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); } -void CheckHasInf::deduce_layout(const TensorLayout&, TensorLayout& dst) { +void CheckNonFinite::deduce_layout(const TensorLayout&, TensorLayout& dst) { dst.shape[0] = 1; dst.ndim = 1; dst.dtype = dtype::Int32(); diff --git a/dnn/src/common/handle_impl.h b/dnn/src/common/handle_impl.h index 816964809..e77fc8653 100644 --- a/dnn/src/common/handle_impl.h +++ b/dnn/src/common/handle_impl.h @@ -216,7 +216,7 @@ private: cb(FakeQuantBackward) \ cb(TQTForward) \ cb(TQTBackward) \ - cb(CheckHasInf) \ + cb(CheckNonFinite) \ cb(LSQForward) \ cb(LSQBackward) \ cb(Fill) \ diff --git a/dnn/src/common/opr_trait.h b/dnn/src/common/opr_trait.h index fbc02c75c..8999b736d 100644 --- a/dnn/src/common/opr_trait.h +++ b/dnn/src/common/opr_trait.h @@ -131,7 +131,7 @@ DEF(PermutationRNG, 1, true, true); DEF(ShuffleRNGForward, 3, true, true); DEF(ShuffleRNGBackward, 3, true, false); DEF(ChecksumForward, 1, true, false); -DEF(CheckHasInf, 2, true, true); +DEF(CheckNonFinite, 2, true, true); DEF(LSQForward, 5, true, true); DEF(LSQBackward, 7, true, false); DEF(Fill, 1, true, false); diff --git a/dnn/src/common/reduce_helper.h b/dnn/src/common/reduce_helper.h index 45da96e3c..a5dbaefc0 100644 --- a/dnn/src/common/reduce_helper.h +++ b/dnn/src/common/reduce_helper.h @@ -152,7 +152,7 @@ struct MaxOp { }; template -struct CheckHasInfOp { +struct CheckNonFiniteOp { typedef wtype_ wtype; const wtype INIT; @@ -162,9 +162,9 @@ struct CheckHasInfOp { MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { #if defined(__CUDA_ARCH__) - return isinf(src[idx]); + return !isfinite(src[idx]); #else - return std::isinf(src[idx]); + return !std::isfinite(src[idx]); #endif } MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { @@ -173,7 +173,7 @@ struct CheckHasInfOp { static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { return lhs | rhs; } - MEGDNN_HOST MEGDNN_DEVICE CheckHasInfOp(src_ctype* src, dst_ctype* dst, + MEGDNN_HOST MEGDNN_DEVICE CheckNonFiniteOp(src_ctype* src, dst_ctype* dst, size_t B) : INIT(wtype(0)), src(src), dst(dst), B(B) {} }; diff --git a/dnn/src/cuda/check_has_inf/kern.cu b/dnn/src/cuda/check_non_finite/kern.cu similarity index 82% rename from dnn/src/cuda/check_has_inf/kern.cu rename to dnn/src/cuda/check_non_finite/kern.cu index cb3d10495..f688d61f2 100644 --- a/dnn/src/cuda/check_has_inf/kern.cu +++ b/dnn/src/cuda/check_non_finite/kern.cu @@ -1,5 +1,5 @@ /** - * \file dnn/src/cuda/check_has_inf/kern.cu + * \file dnn/src/cuda/check_non_finite/kern.cu * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -18,7 +18,7 @@ namespace cuda { #define COMMA , -INST_REDUCE(reduce::CheckHasInfOp, false); +INST_REDUCE(reduce::CheckNonFiniteOp, false); #undef COMMA } // namespace cuda diff --git a/dnn/src/cuda/check_has_inf/opr_impl.cpp b/dnn/src/cuda/check_non_finite/opr_impl.cpp similarity index 73% rename from dnn/src/cuda/check_has_inf/opr_impl.cpp rename to dnn/src/cuda/check_non_finite/opr_impl.cpp index bf44be610..3b548f7dc 100644 --- a/dnn/src/cuda/check_has_inf/opr_impl.cpp +++ b/dnn/src/cuda/check_non_finite/opr_impl.cpp @@ -1,5 +1,5 @@ /** - * \file dnn/src/cuda/check_has_inf/opr_impl.cpp + * \file dnn/src/cuda/check_non_finite/opr_impl.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -9,7 +9,7 @@ * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/cuda/check_has_inf/opr_impl.h" +#include "src/cuda/check_non_finite/opr_impl.h" #include "src/cuda/reduce_helper.cuh" #include "src/cuda/handle.h" @@ -20,18 +20,18 @@ namespace megdnn { namespace cuda { -using reduce::CheckHasInfOp; +using reduce::CheckNonFiniteOp; -size_t CheckHasInfImpl::get_workspace_in_bytes(const TensorLayout& src, +size_t CheckNonFiniteImpl::get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& dst) { - typedef CheckHasInfOp Op; + typedef CheckNonFiniteOp Op; return get_reduce_workspace_in_bytes(1, src.total_nr_elems(), 1); } -void CheckHasInfImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, +void CheckNonFiniteImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { check_exec(src.layout, dst.layout, workspace.size); - typedef CheckHasInfOp Op; + typedef CheckNonFiniteOp Op; auto stream = cuda_stream(this->handle()); auto B = src.layout.total_nr_elems(); return run_reduce( diff --git a/dnn/src/cuda/check_has_inf/opr_impl.h b/dnn/src/cuda/check_non_finite/opr_impl.h similarity index 85% rename from dnn/src/cuda/check_has_inf/opr_impl.h rename to dnn/src/cuda/check_non_finite/opr_impl.h index 32d60f66e..5ab9d6350 100644 --- a/dnn/src/cuda/check_has_inf/opr_impl.h +++ b/dnn/src/cuda/check_non_finite/opr_impl.h @@ -1,5 +1,5 @@ /** - * \file dnn/src/cuda/check_has_inf/opr_impl.h + * \file dnn/src/cuda/check_non_finite/opr_impl.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -17,9 +17,9 @@ namespace megdnn { namespace cuda { -class CheckHasInfImpl final : public CheckHasInf { +class CheckNonFiniteImpl final : public CheckNonFinite { public: - using CheckHasInf::CheckHasInf; + using CheckNonFinite::CheckNonFinite; size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& dst) override; diff --git a/dnn/src/cuda/handle_create.cpp b/dnn/src/cuda/handle_create.cpp index 2e816a98b..2b96c782b 100644 --- a/dnn/src/cuda/handle_create.cpp +++ b/dnn/src/cuda/handle_create.cpp @@ -20,7 +20,7 @@ #include "src/cuda/batch_conv_bias/opr_impl.h" #include "src/cuda/batch_normalization/opr_impl.h" #include "src/cuda/batched_matrix_mul/opr_impl.h" -#include "src/cuda/check_has_inf/opr_impl.h" +#include "src/cuda/check_non_finite/opr_impl.h" #include "src/cuda/checksum/opr_impl.h" #include "src/cuda/concat/opr_impl.h" #include "src/cuda/cond_take/opr_impl.h" diff --git a/dnn/src/naive/check_has_inf/opr_impl.cpp b/dnn/src/naive/check_non_finite/opr_impl.cpp similarity index 81% rename from dnn/src/naive/check_has_inf/opr_impl.cpp rename to dnn/src/naive/check_non_finite/opr_impl.cpp index 1a30910c9..df34dd90f 100644 --- a/dnn/src/naive/check_has_inf/opr_impl.cpp +++ b/dnn/src/naive/check_non_finite/opr_impl.cpp @@ -1,5 +1,5 @@ /** - * \file dnn/src/naive/check_has_inf/opr_impl.cpp + * \file dnn/src/naive/check_non_finite/opr_impl.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -9,7 +9,7 @@ * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "src/naive/check_has_inf/opr_impl.h" +#include "src/naive/check_non_finite/opr_impl.h" #include "src/common/utils.h" #include "src/naive/handle.h" @@ -27,7 +27,7 @@ void reduce_fwd(const src_ctype* sptr, wtype* dptr, size_t size) { size_t mid = l + (r - l) / 2; return func(l, mid) | func(mid, r); } else { - return static_cast(std::isinf(sptr[l])); + return static_cast(!std::isfinite(sptr[l])); } }; @@ -39,12 +39,12 @@ void reduce_fwd(const src_ctype* sptr, wtype* dptr, size_t size) { namespace megdnn { namespace naive { -size_t CheckHasInfImpl::get_workspace_in_bytes(const TensorLayout&, +size_t CheckNonFiniteImpl::get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) { return 0; } -void CheckHasInfImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, +void CheckNonFiniteImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { check_exec(src.layout, dst.layout, workspace.size); diff --git a/dnn/src/naive/check_has_inf/opr_impl.h b/dnn/src/naive/check_non_finite/opr_impl.h similarity index 84% rename from dnn/src/naive/check_has_inf/opr_impl.h rename to dnn/src/naive/check_non_finite/opr_impl.h index 53e9c6351..cdf846e23 100644 --- a/dnn/src/naive/check_has_inf/opr_impl.h +++ b/dnn/src/naive/check_non_finite/opr_impl.h @@ -1,5 +1,5 @@ /** - * \file dnn/src/naive/check_has_inf/opr_impl.h + * \file dnn/src/naive/check_non_finite/opr_impl.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -16,9 +16,9 @@ namespace megdnn { namespace naive { -class CheckHasInfImpl final : public CheckHasInf { +class CheckNonFiniteImpl final : public CheckNonFinite { public: - using CheckHasInf::CheckHasInf; + using CheckNonFinite::CheckNonFinite; bool is_thread_safe() const override { return true; } diff --git a/dnn/src/naive/handle.cpp b/dnn/src/naive/handle.cpp index b492fd934..4c05817e3 100644 --- a/dnn/src/naive/handle.cpp +++ b/dnn/src/naive/handle.cpp @@ -22,7 +22,7 @@ #include "src/naive/batch_conv_bias/opr_impl.h" #include "src/naive/batch_normalization/opr_impl.h" #include "src/naive/batched_matrix_mul/opr_impl.h" -#include "src/naive/check_has_inf/opr_impl.h" +#include "src/naive/check_non_finite/opr_impl.h" #include "src/naive/checksum/opr_impl.h" #include "src/naive/concat/opr_impl.h" #include "src/naive/cond_take/opr_impl.h" diff --git a/dnn/test/cuda/check_has_inf.cpp b/dnn/test/cuda/check_non_finite.cpp similarity index 74% rename from dnn/test/cuda/check_has_inf.cpp rename to dnn/test/cuda/check_non_finite.cpp index 1e4525145..64c678bd9 100644 --- a/dnn/test/cuda/check_has_inf.cpp +++ b/dnn/test/cuda/check_non_finite.cpp @@ -1,5 +1,5 @@ /** - * \file dnn/test/cuda/check_has_inf.cpp + * \file dnn/test/cuda/check_non_finite.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -15,16 +15,20 @@ namespace megdnn { namespace test { -TEST_F(CUDA, CHECK_HAS_INF_BASIC) { - Checker checker(handle_cuda()); +TEST_F(CUDA, CHECK_NON_FINITE_BASIC) { + Checker checker(handle_cuda()); checker.set_allow_invalid_check(true); const auto inf = std::numeric_limits::infinity(); + const auto nan = std::numeric_limits::quiet_NaN(); UniformFloatWithValueRNG rng(-1.0f, 1.0f, 0.1f, inf); checker.set_rng(0, &rng); checker.execs({{512*16}, {1}}); rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 1.f, inf); checker.set_rng(0, &rng); checker.execs({{512*16}, {1}}); + rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 1.f, nan); + checker.set_rng(0, &rng); + checker.execs({{512*16}, {1}}); } } // namespace test diff --git a/dnn/test/naive/check_has_inf.cpp b/dnn/test/naive/check_non_finite.cpp similarity index 72% rename from dnn/test/naive/check_has_inf.cpp rename to dnn/test/naive/check_non_finite.cpp index 1532a7c3c..8a8fc4359 100644 --- a/dnn/test/naive/check_has_inf.cpp +++ b/dnn/test/naive/check_non_finite.cpp @@ -1,5 +1,5 @@ /** - * \file test/naive/check_has_inf.cpp + * \file test/naive/check_non_finite.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. @@ -17,8 +17,8 @@ namespace megdnn { namespace test { -TEST_F(NAIVE, CHECK_HAS_INF_BASIC) { - Checker checker(handle(), false); +TEST_F(NAIVE, CHECK_NON_FINITE_BASIC) { + Checker checker(handle(), false); checker.exect(Testcase{TensorValue({4}, dtype::Float32(), {1.1, 2.2, 3.3, 4.3}), {}}, @@ -29,6 +29,12 @@ TEST_F(NAIVE, CHECK_HAS_INF_BASIC) { std::numeric_limits::infinity()}), {}}, Testcase{{}, TensorValue({1}, dtype::Int32(), {1})}); + checker.exect( + Testcase{TensorValue({4}, dtype::Float32(), + {1.1f, 2.2f, 3.3f, + std::numeric_limits::quiet_NaN()}), + {}}, + Testcase{{}, TensorValue({1}, dtype::Int32(), {1})}); } } // namespace test diff --git a/imperative/python/megengine/amp/grad_scaler.py b/imperative/python/megengine/amp/grad_scaler.py index b23103b0c..f1b64cfe9 100644 --- a/imperative/python/megengine/amp/grad_scaler.py +++ b/imperative/python/megengine/amp/grad_scaler.py @@ -11,7 +11,7 @@ import numpy as np from ..autodiff import GradManager from ..functional import full_like -from ..functional.math import _has_inf +from ..functional.math import _check_non_finite from ..tensor import Tensor @@ -76,7 +76,7 @@ class GradScaler: self.growth_interval = growth_interval self._growth_tracker = 0 - self._found_inf = False + self._found_non_finite = False def backward( self, @@ -135,10 +135,10 @@ class GradScaler: continue # to support tracing, _check_gradients should be applied to every grad. if self._check_gradients(tensor.grad): - self._found_inf = True + self._found_non_finite = True tensor.grad *= inv_scale - if self._found_inf: + if self._found_non_finite: for tensor in grad_tensors: if tensor is None or getattr(tensor, "grad", None) is None: continue @@ -148,7 +148,7 @@ class GradScaler: def _check_gradients(self, grad): if self.growth_interval == 0: return False - return _has_inf(grad) + return _check_non_finite(grad) def update(self, new_scale: float = None): r"""Update the scale factor according to whether encountered overflow grad. @@ -160,7 +160,7 @@ class GradScaler: if new_scale is not None: self.scale_factor = float(new_scale) else: - if self._found_inf: + if self._found_non_finite: self.scale_factor *= self.backoff_factor self._growth_tracker = 0 else: @@ -168,7 +168,7 @@ class GradScaler: if self._growth_tracker >= self.growth_interval: self.scale_factor *= self.growth_factor self._growth_tracker = 0 - self._found_inf = False + self._found_non_finite = False def state_dict(self): return { diff --git a/imperative/python/megengine/functional/math.py b/imperative/python/megengine/functional/math.py index 70e316ea7..16adc50ad 100644 --- a/imperative/python/megengine/functional/math.py +++ b/imperative/python/megengine/functional/math.py @@ -1181,8 +1181,8 @@ def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor: return U, sigma, V -def _has_inf(inp: Tensor) -> Tensor: - r"""Check whether input contains infinite value. +def _check_non_finite(inp: Tensor) -> Tensor: + r"""Check whether input contains infinite or nan value. Args: inp: a tensor to be checked. @@ -1190,7 +1190,7 @@ def _has_inf(inp: Tensor) -> Tensor: Returns: a int32 scalar tensor, 0 for False and 1 for True. """ - op = builtin.CheckHasInf() + op = builtin.CheckNonFinite() (oup,) = apply(op, inp.reshape(-1).astype("float32")) oup._setscalar() return oup diff --git a/imperative/python/test/unit/functional/test_math.py b/imperative/python/test/unit/functional/test_math.py index 428e65fff..e5cc03811 100644 --- a/imperative/python/test/unit/functional/test_math.py +++ b/imperative/python/test/unit/functional/test_math.py @@ -185,14 +185,18 @@ def test_sum_neg_axis(): F.sum(tensor(data), axis=(-1, 1)) -def test_has_inf(): +def test_non_finite(): shape = (32, 3, 32, 32) data = np.random.random(shape).astype(np.float32) - rst = F.math._has_inf(tensor(data)) + rst = F.math._check_non_finite(tensor(data)) np.testing.assert_equal(rst.numpy(), [0]) data[0][0][0][0] = float("inf") - rst = F.math._has_inf(tensor(data)) + rst = F.math._check_non_finite(tensor(data)) + np.testing.assert_equal(rst.numpy(), [1]) + + data[0][0][0][0] = float("nan") + rst = F.math._check_non_finite(tensor(data)) np.testing.assert_equal(rst.numpy(), [1]) diff --git a/imperative/src/impl/ops/misc.cpp b/imperative/src/impl/ops/misc.cpp index d08e09b80..a29f0ef66 100644 --- a/imperative/src/impl/ops/misc.cpp +++ b/imperative/src/impl/ops/misc.cpp @@ -16,17 +16,17 @@ namespace mgb { namespace imperative { -namespace check_has_inf { +namespace check_non_finite { auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { - auto&& op = def.cast_final_safe(); + auto&& op = def.cast_final_safe(); mgb_assert(inputs.size() == 1); OperatorNodeConfig config{op.make_name()}; - return opr::CheckHasInf::make(inputs[0], {}, config); + return opr::CheckNonFinite::make(inputs[0], {}, config); } -OP_TRAIT_REG(CheckHasInf, CheckHasInf) +OP_TRAIT_REG(CheckNonFinite, CheckNonFinite) .apply_on_var_node(apply_on_var_node) .fallback(); -} // namespace check_has_inf +} // namespace check_non_finite } // namespace imperative } // namespace mgb diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 0fcc13ee5..f4d873e47 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -390,7 +390,7 @@ def CambriconRuntime: MgbHashableOp<"CambriconRuntime"> { def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>; -def CheckHasInf: MgbHashableOp<"CheckHasInf", [EmptyParam]>; +def CheckNonFinite: MgbHashableOp<"CheckNonFinite", [EmptyParam]>; def FastpathCopy: MgbHashableOp<"FastpathCopy">; diff --git a/src/opr/impl/misc.cpp b/src/opr/impl/misc.cpp index f4650515b..b0a010878 100644 --- a/src/opr/impl/misc.cpp +++ b/src/opr/impl/misc.cpp @@ -491,12 +491,12 @@ MGB_IMPL_OPR_GRAD(TopK) { } #endif -/* ================= CheckHasInf ================= */ +/* ================= CheckNonFinite ================= */ namespace mgb { namespace opr { namespace intl { template<> -struct MegDNNOprInitPostCtor { +struct MegDNNOprInitPostCtor { static void apply(cg::OperatorNodeBase &opr) { opr.output(0)->dtype(dtype::Int32()); } @@ -504,6 +504,6 @@ struct MegDNNOprInitPostCtor { } } } -MGB_DYN_TYPE_OBJ_FINAL_IMPL(CheckHasInf); -MEGDNN_OPR_INIT1(CheckHasInf, "check_has_inf") +MGB_DYN_TYPE_OBJ_FINAL_IMPL(CheckNonFinite); +MEGDNN_OPR_INIT1(CheckNonFinite, "check_non_finite") // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/impl/misc.sereg.h b/src/opr/impl/misc.sereg.h index f85c231a8..1f2cb2231 100644 --- a/src/opr/impl/misc.sereg.h +++ b/src/opr/impl/misc.sereg.h @@ -72,7 +72,7 @@ namespace opr { #if MGB_CUDA MGB_SEREG_OPR(NvOf, 1); #endif - MGB_SEREG_OPR(CheckHasInf, 1); + MGB_SEREG_OPR(CheckNonFinite, 1); } // namespace opr } // namespace mgb diff --git a/src/opr/include/megbrain/opr/misc.h b/src/opr/include/megbrain/opr/misc.h index 51cade1d9..2a0d7b40c 100644 --- a/src/opr/include/megbrain/opr/misc.h +++ b/src/opr/include/megbrain/opr/misc.h @@ -185,7 +185,7 @@ public: const OperatorNodeConfig& config = {}); }; -MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(CheckHasInf); +MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(CheckNonFinite); } // namespace opr } // namespace mgb -- GitLab