提交 f5cb21ed 编写于 作者: M Megvii Engine Team

fix(mgb/opr): add non finite check

GitOrigin-RevId: a9fcd0a3509681f4c596b6322691d5485353ce2e
上级 bde5cf35
...@@ -1319,11 +1319,11 @@ protected: ...@@ -1319,11 +1319,11 @@ protected:
}; };
/*! /*!
* \brief check whether input contains inf value. * \brief check whether input contains inf or nan value.
*/ */
class CheckHasInf: public OperatorBase { class CheckNonFinite: public OperatorBase {
DEF_OPR_PARAM(Empty); DEF_OPR_PARAM(Empty);
DEF_OPR_IMPL(CheckHasInf, OperatorBase, 1, 1); DEF_OPR_IMPL(CheckNonFinite, OperatorBase, 1, 1);
public: public:
virtual size_t get_workspace_in_bytes(const TensorLayout &src, virtual size_t get_workspace_in_bytes(const TensorLayout &src,
......
/** /**
* \file dnn/src/common/check_has_inf.cpp * \file dnn/src/common/check_non_finite.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
namespace megdnn { namespace megdnn {
void CheckHasInf::check_exec(const TensorLayout& src, const TensorLayout& dst, void CheckNonFinite::check_exec(const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes) { size_t workspace_in_bytes) {
megdnn_assert_contiguous(src); megdnn_assert_contiguous(src);
megdnn_assert_contiguous(dst); megdnn_assert_contiguous(dst);
...@@ -24,7 +24,7 @@ void CheckHasInf::check_exec(const TensorLayout& src, const TensorLayout& dst, ...@@ -24,7 +24,7 @@ void CheckHasInf::check_exec(const TensorLayout& src, const TensorLayout& dst,
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
} }
void CheckHasInf::deduce_layout(const TensorLayout&, TensorLayout& dst) { void CheckNonFinite::deduce_layout(const TensorLayout&, TensorLayout& dst) {
dst.shape[0] = 1; dst.shape[0] = 1;
dst.ndim = 1; dst.ndim = 1;
dst.dtype = dtype::Int32(); dst.dtype = dtype::Int32();
......
...@@ -216,7 +216,7 @@ private: ...@@ -216,7 +216,7 @@ private:
cb(FakeQuantBackward) \ cb(FakeQuantBackward) \
cb(TQTForward) \ cb(TQTForward) \
cb(TQTBackward) \ cb(TQTBackward) \
cb(CheckHasInf) \ cb(CheckNonFinite) \
cb(LSQForward) \ cb(LSQForward) \
cb(LSQBackward) \ cb(LSQBackward) \
cb(Fill) \ cb(Fill) \
......
...@@ -131,7 +131,7 @@ DEF(PermutationRNG, 1, true, true); ...@@ -131,7 +131,7 @@ DEF(PermutationRNG, 1, true, true);
DEF(ShuffleRNGForward, 3, true, true); DEF(ShuffleRNGForward, 3, true, true);
DEF(ShuffleRNGBackward, 3, true, false); DEF(ShuffleRNGBackward, 3, true, false);
DEF(ChecksumForward, 1, true, false); DEF(ChecksumForward, 1, true, false);
DEF(CheckHasInf, 2, true, true); DEF(CheckNonFinite, 2, true, true);
DEF(LSQForward, 5, true, true); DEF(LSQForward, 5, true, true);
DEF(LSQBackward, 7, true, false); DEF(LSQBackward, 7, true, false);
DEF(Fill, 1, true, false); DEF(Fill, 1, true, false);
......
...@@ -152,7 +152,7 @@ struct MaxOp { ...@@ -152,7 +152,7 @@ struct MaxOp {
}; };
template <typename src_ctype, typename dst_ctype, typename wtype_> template <typename src_ctype, typename dst_ctype, typename wtype_>
struct CheckHasInfOp { struct CheckNonFiniteOp {
typedef wtype_ wtype; typedef wtype_ wtype;
const wtype INIT; const wtype INIT;
...@@ -162,9 +162,9 @@ struct CheckHasInfOp { ...@@ -162,9 +162,9 @@ struct CheckHasInfOp {
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) {
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__)
return isinf(src[idx]); return !isfinite(src[idx]);
#else #else
return std::isinf(src[idx]); return !std::isfinite(src[idx]);
#endif #endif
} }
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) {
...@@ -173,7 +173,7 @@ struct CheckHasInfOp { ...@@ -173,7 +173,7 @@ struct CheckHasInfOp {
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) {
return lhs | rhs; return lhs | rhs;
} }
MEGDNN_HOST MEGDNN_DEVICE CheckHasInfOp(src_ctype* src, dst_ctype* dst, MEGDNN_HOST MEGDNN_DEVICE CheckNonFiniteOp(src_ctype* src, dst_ctype* dst,
size_t B) size_t B)
: INIT(wtype(0)), src(src), dst(dst), B(B) {} : INIT(wtype(0)), src(src), dst(dst), B(B) {}
}; };
......
/** /**
* \file dnn/src/cuda/check_has_inf/kern.cu * \file dnn/src/cuda/check_non_finite/kern.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
...@@ -18,7 +18,7 @@ namespace cuda { ...@@ -18,7 +18,7 @@ namespace cuda {
#define COMMA , #define COMMA ,
INST_REDUCE(reduce::CheckHasInfOp<dt_float32 COMMA dt_int32 COMMA dt_int32>, false); INST_REDUCE(reduce::CheckNonFiniteOp<dt_float32 COMMA dt_int32 COMMA dt_int32>, false);
#undef COMMA #undef COMMA
} // namespace cuda } // namespace cuda
......
/** /**
* \file dnn/src/cuda/check_has_inf/opr_impl.cpp * \file dnn/src/cuda/check_non_finite/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include "src/cuda/check_has_inf/opr_impl.h" #include "src/cuda/check_non_finite/opr_impl.h"
#include "src/cuda/reduce_helper.cuh" #include "src/cuda/reduce_helper.cuh"
#include "src/cuda/handle.h" #include "src/cuda/handle.h"
...@@ -20,18 +20,18 @@ ...@@ -20,18 +20,18 @@
namespace megdnn { namespace megdnn {
namespace cuda { namespace cuda {
using reduce::CheckHasInfOp; using reduce::CheckNonFiniteOp;
size_t CheckHasInfImpl::get_workspace_in_bytes(const TensorLayout& src, size_t CheckNonFiniteImpl::get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) { const TensorLayout& dst) {
typedef CheckHasInfOp<dt_float32, dt_int32, dt_int32> Op; typedef CheckNonFiniteOp<dt_float32, dt_int32, dt_int32> Op;
return get_reduce_workspace_in_bytes<Op>(1, src.total_nr_elems(), 1); return get_reduce_workspace_in_bytes<Op>(1, src.total_nr_elems(), 1);
} }
void CheckHasInfImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, void CheckNonFiniteImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) { _megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size); check_exec(src.layout, dst.layout, workspace.size);
typedef CheckHasInfOp<dt_float32, dt_int32, dt_int32> Op; typedef CheckNonFiniteOp<dt_float32, dt_int32, dt_int32> Op;
auto stream = cuda_stream(this->handle()); auto stream = cuda_stream(this->handle());
auto B = src.layout.total_nr_elems(); auto B = src.layout.total_nr_elems();
return run_reduce<Op, false>( return run_reduce<Op, false>(
......
/** /**
* \file dnn/src/cuda/check_has_inf/opr_impl.h * \file dnn/src/cuda/check_non_finite/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
...@@ -17,9 +17,9 @@ ...@@ -17,9 +17,9 @@
namespace megdnn { namespace megdnn {
namespace cuda { namespace cuda {
class CheckHasInfImpl final : public CheckHasInf { class CheckNonFiniteImpl final : public CheckNonFinite {
public: public:
using CheckHasInf::CheckHasInf; using CheckNonFinite::CheckNonFinite;
size_t get_workspace_in_bytes(const TensorLayout& src, size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) override; const TensorLayout& dst) override;
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include "src/cuda/batch_conv_bias/opr_impl.h" #include "src/cuda/batch_conv_bias/opr_impl.h"
#include "src/cuda/batch_normalization/opr_impl.h" #include "src/cuda/batch_normalization/opr_impl.h"
#include "src/cuda/batched_matrix_mul/opr_impl.h" #include "src/cuda/batched_matrix_mul/opr_impl.h"
#include "src/cuda/check_has_inf/opr_impl.h" #include "src/cuda/check_non_finite/opr_impl.h"
#include "src/cuda/checksum/opr_impl.h" #include "src/cuda/checksum/opr_impl.h"
#include "src/cuda/concat/opr_impl.h" #include "src/cuda/concat/opr_impl.h"
#include "src/cuda/cond_take/opr_impl.h" #include "src/cuda/cond_take/opr_impl.h"
......
/** /**
* \file dnn/src/naive/check_has_inf/opr_impl.cpp * \file dnn/src/naive/check_non_finite/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include "src/naive/check_has_inf/opr_impl.h" #include "src/naive/check_non_finite/opr_impl.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/naive/handle.h" #include "src/naive/handle.h"
...@@ -27,7 +27,7 @@ void reduce_fwd(const src_ctype* sptr, wtype* dptr, size_t size) { ...@@ -27,7 +27,7 @@ void reduce_fwd(const src_ctype* sptr, wtype* dptr, size_t size) {
size_t mid = l + (r - l) / 2; size_t mid = l + (r - l) / 2;
return func(l, mid) | func(mid, r); return func(l, mid) | func(mid, r);
} else { } else {
return static_cast<wtype>(std::isinf(sptr[l])); return static_cast<wtype>(!std::isfinite(sptr[l]));
} }
}; };
...@@ -39,12 +39,12 @@ void reduce_fwd(const src_ctype* sptr, wtype* dptr, size_t size) { ...@@ -39,12 +39,12 @@ void reduce_fwd(const src_ctype* sptr, wtype* dptr, size_t size) {
namespace megdnn { namespace megdnn {
namespace naive { namespace naive {
size_t CheckHasInfImpl::get_workspace_in_bytes(const TensorLayout&, size_t CheckNonFiniteImpl::get_workspace_in_bytes(const TensorLayout&,
const TensorLayout&) { const TensorLayout&) {
return 0; return 0;
} }
void CheckHasInfImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, void CheckNonFiniteImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) { _megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size); check_exec(src.layout, dst.layout, workspace.size);
......
/** /**
* \file dnn/src/naive/check_has_inf/opr_impl.h * \file dnn/src/naive/check_non_finite/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
...@@ -16,9 +16,9 @@ ...@@ -16,9 +16,9 @@
namespace megdnn { namespace megdnn {
namespace naive { namespace naive {
class CheckHasInfImpl final : public CheckHasInf { class CheckNonFiniteImpl final : public CheckNonFinite {
public: public:
using CheckHasInf::CheckHasInf; using CheckNonFinite::CheckNonFinite;
bool is_thread_safe() const override { return true; } bool is_thread_safe() const override { return true; }
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#include "src/naive/batch_conv_bias/opr_impl.h" #include "src/naive/batch_conv_bias/opr_impl.h"
#include "src/naive/batch_normalization/opr_impl.h" #include "src/naive/batch_normalization/opr_impl.h"
#include "src/naive/batched_matrix_mul/opr_impl.h" #include "src/naive/batched_matrix_mul/opr_impl.h"
#include "src/naive/check_has_inf/opr_impl.h" #include "src/naive/check_non_finite/opr_impl.h"
#include "src/naive/checksum/opr_impl.h" #include "src/naive/checksum/opr_impl.h"
#include "src/naive/concat/opr_impl.h" #include "src/naive/concat/opr_impl.h"
#include "src/naive/cond_take/opr_impl.h" #include "src/naive/cond_take/opr_impl.h"
......
/** /**
* \file dnn/test/cuda/check_has_inf.cpp * \file dnn/test/cuda/check_non_finite.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
...@@ -15,16 +15,20 @@ ...@@ -15,16 +15,20 @@
namespace megdnn { namespace megdnn {
namespace test { namespace test {
TEST_F(CUDA, CHECK_HAS_INF_BASIC) { TEST_F(CUDA, CHECK_NON_FINITE_BASIC) {
Checker<CheckHasInf> checker(handle_cuda()); Checker<CheckNonFinite> checker(handle_cuda());
checker.set_allow_invalid_check(true); checker.set_allow_invalid_check(true);
const auto inf = std::numeric_limits<float>::infinity(); const auto inf = std::numeric_limits<float>::infinity();
const auto nan = std::numeric_limits<float>::quiet_NaN();
UniformFloatWithValueRNG rng(-1.0f, 1.0f, 0.1f, inf); UniformFloatWithValueRNG rng(-1.0f, 1.0f, 0.1f, inf);
checker.set_rng(0, &rng); checker.set_rng(0, &rng);
checker.execs({{512*16}, {1}}); checker.execs({{512*16}, {1}});
rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 1.f, inf); rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 1.f, inf);
checker.set_rng(0, &rng); checker.set_rng(0, &rng);
checker.execs({{512*16}, {1}}); checker.execs({{512*16}, {1}});
rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 1.f, nan);
checker.set_rng(0, &rng);
checker.execs({{512*16}, {1}});
} }
} // namespace test } // namespace test
......
/** /**
* \file test/naive/check_has_inf.cpp * \file test/naive/check_non_finite.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
namespace megdnn { namespace megdnn {
namespace test { namespace test {
TEST_F(NAIVE, CHECK_HAS_INF_BASIC) { TEST_F(NAIVE, CHECK_NON_FINITE_BASIC) {
Checker<CheckHasInf> checker(handle(), false); Checker<CheckNonFinite> checker(handle(), false);
checker.exect(Testcase{TensorValue({4}, dtype::Float32(), checker.exect(Testcase{TensorValue({4}, dtype::Float32(),
{1.1, 2.2, 3.3, 4.3}), {1.1, 2.2, 3.3, 4.3}),
{}}, {}},
...@@ -29,6 +29,12 @@ TEST_F(NAIVE, CHECK_HAS_INF_BASIC) { ...@@ -29,6 +29,12 @@ TEST_F(NAIVE, CHECK_HAS_INF_BASIC) {
std::numeric_limits<float>::infinity()}), std::numeric_limits<float>::infinity()}),
{}}, {}},
Testcase{{}, TensorValue({1}, dtype::Int32(), {1})}); Testcase{{}, TensorValue({1}, dtype::Int32(), {1})});
checker.exect(
Testcase{TensorValue({4}, dtype::Float32(),
{1.1f, 2.2f, 3.3f,
std::numeric_limits<float>::quiet_NaN()}),
{}},
Testcase{{}, TensorValue({1}, dtype::Int32(), {1})});
} }
} // namespace test } // namespace test
......
...@@ -11,7 +11,7 @@ import numpy as np ...@@ -11,7 +11,7 @@ import numpy as np
from ..autodiff import GradManager from ..autodiff import GradManager
from ..functional import full_like from ..functional import full_like
from ..functional.math import _has_inf from ..functional.math import _check_non_finite
from ..tensor import Tensor from ..tensor import Tensor
...@@ -76,7 +76,7 @@ class GradScaler: ...@@ -76,7 +76,7 @@ class GradScaler:
self.growth_interval = growth_interval self.growth_interval = growth_interval
self._growth_tracker = 0 self._growth_tracker = 0
self._found_inf = False self._found_non_finite = False
def backward( def backward(
self, self,
...@@ -135,10 +135,10 @@ class GradScaler: ...@@ -135,10 +135,10 @@ class GradScaler:
continue continue
# to support tracing, _check_gradients should be applied to every grad. # to support tracing, _check_gradients should be applied to every grad.
if self._check_gradients(tensor.grad): if self._check_gradients(tensor.grad):
self._found_inf = True self._found_non_finite = True
tensor.grad *= inv_scale tensor.grad *= inv_scale
if self._found_inf: if self._found_non_finite:
for tensor in grad_tensors: for tensor in grad_tensors:
if tensor is None or getattr(tensor, "grad", None) is None: if tensor is None or getattr(tensor, "grad", None) is None:
continue continue
...@@ -148,7 +148,7 @@ class GradScaler: ...@@ -148,7 +148,7 @@ class GradScaler:
def _check_gradients(self, grad): def _check_gradients(self, grad):
if self.growth_interval == 0: if self.growth_interval == 0:
return False return False
return _has_inf(grad) return _check_non_finite(grad)
def update(self, new_scale: float = None): def update(self, new_scale: float = None):
r"""Update the scale factor according to whether encountered overflow grad. r"""Update the scale factor according to whether encountered overflow grad.
...@@ -160,7 +160,7 @@ class GradScaler: ...@@ -160,7 +160,7 @@ class GradScaler:
if new_scale is not None: if new_scale is not None:
self.scale_factor = float(new_scale) self.scale_factor = float(new_scale)
else: else:
if self._found_inf: if self._found_non_finite:
self.scale_factor *= self.backoff_factor self.scale_factor *= self.backoff_factor
self._growth_tracker = 0 self._growth_tracker = 0
else: else:
...@@ -168,7 +168,7 @@ class GradScaler: ...@@ -168,7 +168,7 @@ class GradScaler:
if self._growth_tracker >= self.growth_interval: if self._growth_tracker >= self.growth_interval:
self.scale_factor *= self.growth_factor self.scale_factor *= self.growth_factor
self._growth_tracker = 0 self._growth_tracker = 0
self._found_inf = False self._found_non_finite = False
def state_dict(self): def state_dict(self):
return { return {
......
...@@ -1181,8 +1181,8 @@ def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor: ...@@ -1181,8 +1181,8 @@ def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor:
return U, sigma, V return U, sigma, V
def _has_inf(inp: Tensor) -> Tensor: def _check_non_finite(inp: Tensor) -> Tensor:
r"""Check whether input contains infinite value. r"""Check whether input contains infinite or nan value.
Args: Args:
inp: a tensor to be checked. inp: a tensor to be checked.
...@@ -1190,7 +1190,7 @@ def _has_inf(inp: Tensor) -> Tensor: ...@@ -1190,7 +1190,7 @@ def _has_inf(inp: Tensor) -> Tensor:
Returns: Returns:
a int32 scalar tensor, 0 for False and 1 for True. a int32 scalar tensor, 0 for False and 1 for True.
""" """
op = builtin.CheckHasInf() op = builtin.CheckNonFinite()
(oup,) = apply(op, inp.reshape(-1).astype("float32")) (oup,) = apply(op, inp.reshape(-1).astype("float32"))
oup._setscalar() oup._setscalar()
return oup return oup
...@@ -185,14 +185,18 @@ def test_sum_neg_axis(): ...@@ -185,14 +185,18 @@ def test_sum_neg_axis():
F.sum(tensor(data), axis=(-1, 1)) F.sum(tensor(data), axis=(-1, 1))
def test_has_inf(): def test_non_finite():
shape = (32, 3, 32, 32) shape = (32, 3, 32, 32)
data = np.random.random(shape).astype(np.float32) data = np.random.random(shape).astype(np.float32)
rst = F.math._has_inf(tensor(data)) rst = F.math._check_non_finite(tensor(data))
np.testing.assert_equal(rst.numpy(), [0]) np.testing.assert_equal(rst.numpy(), [0])
data[0][0][0][0] = float("inf") data[0][0][0][0] = float("inf")
rst = F.math._has_inf(tensor(data)) rst = F.math._check_non_finite(tensor(data))
np.testing.assert_equal(rst.numpy(), [1])
data[0][0][0][0] = float("nan")
rst = F.math._check_non_finite(tensor(data))
np.testing.assert_equal(rst.numpy(), [1]) np.testing.assert_equal(rst.numpy(), [1])
......
...@@ -16,17 +16,17 @@ ...@@ -16,17 +16,17 @@
namespace mgb { namespace mgb {
namespace imperative { namespace imperative {
namespace check_has_inf { namespace check_non_finite {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<CheckHasInf>(); auto&& op = def.cast_final_safe<CheckNonFinite>();
mgb_assert(inputs.size() == 1); mgb_assert(inputs.size() == 1);
OperatorNodeConfig config{op.make_name()}; OperatorNodeConfig config{op.make_name()};
return opr::CheckHasInf::make(inputs[0], {}, config); return opr::CheckNonFinite::make(inputs[0], {}, config);
} }
OP_TRAIT_REG(CheckHasInf, CheckHasInf) OP_TRAIT_REG(CheckNonFinite, CheckNonFinite)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.fallback(); .fallback();
} // namespace check_has_inf } // namespace check_non_finite
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
......
...@@ -390,7 +390,7 @@ def CambriconRuntime: MgbHashableOp<"CambriconRuntime"> { ...@@ -390,7 +390,7 @@ def CambriconRuntime: MgbHashableOp<"CambriconRuntime"> {
def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>; def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>;
def CheckHasInf: MgbHashableOp<"CheckHasInf", [EmptyParam]>; def CheckNonFinite: MgbHashableOp<"CheckNonFinite", [EmptyParam]>;
def FastpathCopy: MgbHashableOp<"FastpathCopy">; def FastpathCopy: MgbHashableOp<"FastpathCopy">;
......
...@@ -491,12 +491,12 @@ MGB_IMPL_OPR_GRAD(TopK) { ...@@ -491,12 +491,12 @@ MGB_IMPL_OPR_GRAD(TopK) {
} }
#endif #endif
/* ================= CheckHasInf ================= */ /* ================= CheckNonFinite ================= */
namespace mgb { namespace mgb {
namespace opr { namespace opr {
namespace intl { namespace intl {
template<> template<>
struct MegDNNOprInitPostCtor<CheckHasInf> { struct MegDNNOprInitPostCtor<CheckNonFinite> {
static void apply(cg::OperatorNodeBase &opr) { static void apply(cg::OperatorNodeBase &opr) {
opr.output(0)->dtype(dtype::Int32()); opr.output(0)->dtype(dtype::Int32());
} }
...@@ -504,6 +504,6 @@ struct MegDNNOprInitPostCtor<CheckHasInf> { ...@@ -504,6 +504,6 @@ struct MegDNNOprInitPostCtor<CheckHasInf> {
} }
} }
} }
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CheckHasInf); MGB_DYN_TYPE_OBJ_FINAL_IMPL(CheckNonFinite);
MEGDNN_OPR_INIT1(CheckHasInf, "check_has_inf") MEGDNN_OPR_INIT1(CheckNonFinite, "check_non_finite")
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -72,7 +72,7 @@ namespace opr { ...@@ -72,7 +72,7 @@ namespace opr {
#if MGB_CUDA #if MGB_CUDA
MGB_SEREG_OPR(NvOf, 1); MGB_SEREG_OPR(NvOf, 1);
#endif #endif
MGB_SEREG_OPR(CheckHasInf, 1); MGB_SEREG_OPR(CheckNonFinite, 1);
} // namespace opr } // namespace opr
} // namespace mgb } // namespace mgb
......
...@@ -185,7 +185,7 @@ public: ...@@ -185,7 +185,7 @@ public:
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };
MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(CheckHasInf); MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(CheckNonFinite);
} // namespace opr } // namespace opr
} // namespace mgb } // namespace mgb
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册