提交 11d75fec 编写于 作者: M Megvii Engine Team

feat(dnn/check_non_finite): add batch check_non_finite

GitOrigin-RevId: e108133282cb2c9129292715ae6eab1e396cd0bc
上级 7a023c05
...@@ -1345,22 +1345,23 @@ protected: ...@@ -1345,22 +1345,23 @@ protected:
*/ */
class CheckNonFinite : public OperatorBase { class CheckNonFinite : public OperatorBase {
DEF_OPR_PARAM(Empty); DEF_OPR_PARAM(Empty);
DEF_OPR_IMPL(CheckNonFinite, OperatorBase, 1, 1); DEF_OPR_IMPL(CheckNonFinite, OperatorBase, -1, 1);
size_t m_size = 0;
public: public:
virtual size_t get_workspace_in_bytes( virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) = 0; const TensorNDArray& srcs, const TensorLayout& dst) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& dst); void deduce_layout(const TensorLayoutArray& srcs, TensorLayout& dst);
virtual void exec( virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0; _megdnn_workspace workspace) = 0;
protected: protected:
void check_exec( void check_exec(
const TensorLayout& src, const TensorLayout& dst, const TensorNDArray& srcs, const TensorND& dst, size_t workspace_in_bytes);
size_t workspace_in_bytes); virtual size_t _get_workspace_in_bytes() = 0;
}; };
/*! /*!
......
...@@ -15,16 +15,15 @@ ...@@ -15,16 +15,15 @@
namespace megdnn { namespace megdnn {
void CheckNonFinite::check_exec( void CheckNonFinite::check_exec(
const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { const TensorNDArray& srcs, const TensorND& dst, size_t workspace_in_bytes) {
megdnn_assert_contiguous(src); megdnn_assert_contiguous(dst.layout);
megdnn_assert_contiguous(dst); megdnn_assert(srcs.size() > 0);
megdnn_assert(src.ndim == 1); megdnn_assert(srcs.begin()->layout.dtype == dtype::Float32());
megdnn_assert(src.dtype == dtype::Float32()); auto required_workspace_in_bytes = _get_workspace_in_bytes();
auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
} }
void CheckNonFinite::deduce_layout(const TensorLayout&, TensorLayout& dst) { void CheckNonFinite::deduce_layout(const TensorLayoutArray&, TensorLayout& dst) {
dst.shape[0] = 1; dst.shape[0] = 1;
dst.ndim = 1; dst.ndim = 1;
dst.dtype = dtype::Int32(); dst.dtype = dtype::Int32();
......
...@@ -156,21 +156,35 @@ struct MaxOp<src_ctype, dst_ctype, dt_float32> { ...@@ -156,21 +156,35 @@ struct MaxOp<src_ctype, dst_ctype, dt_float32> {
: INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {} : INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {}
}; };
template <typename src_ctype, typename dst_ctype, typename wtype_> template <typename src_ctype, typename index_ctype, typename dst_ctype, typename wtype_>
struct CheckNonFiniteOp { struct CheckNonFiniteOp {
typedef wtype_ wtype; typedef wtype_ wtype;
const wtype INIT; const wtype INIT;
RefPtr src; RefPtr* srcs;
RefPtr srcs_total_nr_elems;
RefPtr dst; RefPtr dst;
const size_t B; const size_t B;
wtype read(uint32_t idx) { return !std::isfinite(src.ptr<src_ctype>()[idx]); } wtype read(uint32_t idx) {
size_t x = idx / B;
size_t y = idx % B;
if (y < srcs_total_nr_elems.ptr<index_ctype>()[x]) {
RefPtr src = srcs[x];
return !std::isfinite(src.ptr<src_ctype>()[y]);
}
return 0;
}
void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; } void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; }
static wtype apply(wtype lhs, wtype rhs) { return lhs | rhs; } static wtype apply(wtype lhs, wtype rhs) { return lhs | rhs; }
MEGDNN_HOST MEGDNN_DEVICE CheckNonFiniteOp(
CheckNonFiniteOp(const RefPtr& src, const RefPtr& dst, size_t B) RefPtr* srcs, const RefPtr& srcs_total_nr_elems, const RefPtr& dst,
: INIT(wtype(0)), src(src), dst(dst), B(B) {} size_t B)
: INIT(wtype(0)),
srcs(srcs),
srcs_total_nr_elems(srcs_total_nr_elems),
dst(dst),
B(B) {}
}; };
void get_ABC(const TensorShape& shape, size_t& A, size_t& B, size_t& C, size_t axis); void get_ABC(const TensorShape& shape, size_t& A, size_t& B, size_t& C, size_t axis);
......
...@@ -185,28 +185,41 @@ struct MaxOp<src_ctype, dst_ctype, dt_float32> { ...@@ -185,28 +185,41 @@ struct MaxOp<src_ctype, dst_ctype, dt_float32> {
: INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {} : INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {}
}; };
template <typename src_ctype, typename dst_ctype, typename wtype_> template <typename src_ctype, typename index_ctype, typename dst_ctype, typename wtype_>
struct CheckNonFiniteOp { struct CheckNonFiniteOp {
typedef wtype_ wtype; typedef wtype_ wtype;
const wtype INIT; const wtype INIT;
src_ctype* src; src_ctype** srcs;
index_ctype* srcs_total_nr_elems;
dst_ctype* dst; dst_ctype* dst;
const size_t B; const size_t B;
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) {
size_t x = idx / B;
size_t y = idx % B;
if (y < srcs_total_nr_elems[x]) {
#if defined(__CUDA_ARCH__) #if defined(__CUDA_ARCH__)
return !isfinite(src[idx]); wtype val = isfinite(srcs[x][y]);
#else #else
return !std::isfinite(src[idx]); wtype val = std::isfinite(srcs[x][y]);
#endif #endif
return !val;
}
return 0;
} }
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; }
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) {
return lhs | rhs; return lhs | rhs;
} }
MEGDNN_HOST MEGDNN_DEVICE CheckNonFiniteOp(src_ctype* src, dst_ctype* dst, size_t B) MEGDNN_HOST MEGDNN_DEVICE CheckNonFiniteOp(
: INIT(wtype(0)), src(src), dst(dst), B(B) {} src_ctype** srcs, index_ctype* srcs_total_nr_elems, dst_ctype* dst,
size_t B)
: INIT(wtype(0)),
srcs(srcs),
srcs_total_nr_elems(srcs_total_nr_elems),
dst(dst),
B(B) {}
}; };
} // namespace device_reduce } // namespace device_reduce
......
...@@ -19,7 +19,8 @@ namespace cuda { ...@@ -19,7 +19,8 @@ namespace cuda {
#define COMMA , #define COMMA ,
INST_REDUCE( INST_REDUCE(
device_reduce::CheckNonFiniteOp<dt_float32 COMMA dt_int32 COMMA dt_int32>, device_reduce::CheckNonFiniteOp<
dt_float32 COMMA size_t COMMA dt_int32 COMMA dt_int32>,
false); false);
#undef COMMA #undef COMMA
......
...@@ -21,22 +21,83 @@ namespace megdnn { ...@@ -21,22 +21,83 @@ namespace megdnn {
namespace cuda { namespace cuda {
using device_reduce::CheckNonFiniteOp; using device_reduce::CheckNonFiniteOp;
#define total_nr_elems_max 2048
size_t CheckNonFiniteImpl::_get_workspace_in_bytes() {
// Call the _get_workspace_in_bytes to reduce the loop fetch workspace bytes
typedef CheckNonFiniteOp<dt_float32, size_t, dt_int32, dt_int32> Op;
megdnn_assert(m_size > 0);
WorkspaceBundle bundle(
nullptr, {
sizeof(dt_float32*) * m_size,
sizeof(size_t) * m_size,
});
return get_reduce_workspace_in_bytes<Op>(1, m_size * total_nr_elems_max, 1) +
bundle.total_size_in_bytes();
}
size_t CheckNonFiniteImpl::get_workspace_in_bytes( size_t CheckNonFiniteImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) { const TensorNDArray& srcs, const TensorLayout&) {
typedef CheckNonFiniteOp<dt_float32, dt_int32, dt_int32> Op; m_size = 0;
return get_reduce_workspace_in_bytes<Op>(1, src.total_nr_elems(), 1); for (const auto& src : srcs) {
m_size += DIVUP(src.layout.total_nr_elems(), total_nr_elems_max);
}
return _get_workspace_in_bytes();
} }
void CheckNonFiniteImpl::exec( void CheckNonFiniteImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst,
check_exec(src.layout, dst.layout, workspace.size); _megdnn_workspace workspace) {
typedef CheckNonFiniteOp<dt_float32, dt_int32, dt_int32> Op; check_exec(srcs, dst, workspace.size);
typedef CheckNonFiniteOp<dt_float32, size_t, dt_int32, dt_int32> Op;
auto stream = cuda_stream(this->handle()); auto stream = cuda_stream(this->handle());
auto B = src.layout.total_nr_elems(); SmallVector<size_t> workspace_sizes{
sizeof(dt_float32*) * m_size,
sizeof(size_t) * m_size,
};
WorkspaceBundle workspace_cpu(nullptr, workspace_sizes),
workspace_gpu(nullptr, workspace_sizes);
auto total_workspace_size = workspace_cpu.total_size_in_bytes();
void* workspace_cpu_raw = malloc(total_workspace_size);
megdnn_assert_internal(workspace_cpu_raw);
void* workspace_gpu_raw = workspace.raw_ptr;
workspace_cpu = WorkspaceBundle(workspace_cpu_raw, workspace_sizes);
workspace_gpu = WorkspaceBundle(workspace_gpu_raw, workspace_sizes);
auto srcs_cpu = static_cast<dt_float32**>(workspace_cpu.get(0));
auto srcs_gpu = static_cast<dt_float32**>(workspace_gpu.get(0));
auto srcs_total_nr_elems_cpu = static_cast<size_t*>(workspace_cpu.get(1));
auto srcs_total_nr_elems_gpu = static_cast<size_t*>(workspace_gpu.get(1));
// srcs
// cut the tensor to a fixed length of total_nr_elems_max
size_t i = 0;
for (const auto& src : srcs) {
size_t src_nr_elems = src.layout.total_nr_elems();
size_t nr_elems = DIVUP(src_nr_elems, total_nr_elems_max);
for (size_t j = 0; j < nr_elems; ++j, ++i) {
srcs_cpu[i] = src.ptr<dt_float32>() + j * total_nr_elems_max;
if (j + 1 == nr_elems && src_nr_elems % total_nr_elems_max) {
srcs_total_nr_elems_cpu[i] = src_nr_elems % total_nr_elems_max;
} else {
srcs_total_nr_elems_cpu[i] = total_nr_elems_max;
}
}
}
for (size_t i = 0; i < workspace_cpu.nr_workspace(); ++i) {
cuda_check(cudaMemcpyAsync(
workspace_gpu.get(i), workspace_cpu.get(i), workspace_cpu.get_size(i),
cudaMemcpyHostToDevice, stream));
}
cuda_check(cudaStreamAddCallback(
stream, callback_free, static_cast<void*>(workspace_cpu_raw), 0));
return run_reduce<Op, false>( return run_reduce<Op, false>(
workspace.ptr<dt_int32>(), 1, B, 1, stream, static_cast<dt_int32*>(
Op(src.ptr<dt_float32>(), dst.ptr<dt_int32>(), B)); (void*)((char*)workspace_gpu_raw +
workspace_gpu.total_size_in_bytes())),
1, m_size * total_nr_elems_max, 1, stream,
Op(srcs_gpu, srcs_total_nr_elems_gpu, dst.ptr<dt_int32>(),
total_nr_elems_max));
} }
} // namespace cuda } // namespace cuda
......
...@@ -18,16 +18,18 @@ namespace megdnn { ...@@ -18,16 +18,18 @@ namespace megdnn {
namespace cuda { namespace cuda {
class CheckNonFiniteImpl final : public CheckNonFinite { class CheckNonFiniteImpl final : public CheckNonFinite {
size_t _get_workspace_in_bytes() override;
public: public:
using CheckNonFinite::CheckNonFinite; using CheckNonFinite::CheckNonFinite;
size_t get_workspace_in_bytes( size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) override; const TensorNDArray& srcs, const TensorLayout& dst) override;
bool is_thread_safe() const override { return true; } bool is_thread_safe() const override { return true; }
void exec( void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override; _megdnn_workspace workspace) override;
}; };
......
...@@ -17,21 +17,25 @@ ...@@ -17,21 +17,25 @@
namespace { namespace {
using namespace megdnn; using namespace megdnn;
#define src_ctype dt_float32
#define wtype dt_int32 #define wtype dt_int32
void reduce_fwd(const src_ctype* sptr, wtype* dptr, size_t size) { void reduce_fwd(const TensorNDArray& srcs, wtype* dptr) {
std::function<wtype(size_t, size_t)> func; dptr[0] = 0;
func = [&](size_t l, size_t r) -> wtype { for (auto src : srcs) {
auto sptr = src.ptr<dt_float32>();
size_t size = src.layout.total_nr_elems();
std::function<wtype(wtype, wtype)> func;
func = [&](wtype l, wtype r) -> wtype {
if (l + 1 < r) { if (l + 1 < r) {
size_t mid = l + (r - l) / 2; wtype mid = l + (r - l) / 2;
return func(l, mid) | func(mid, r); return func(l, mid) | func(mid, r);
} else { } else {
return static_cast<wtype>(!std::isfinite(sptr[l])); auto val = std::isfinite(sptr[l]);
return static_cast<wtype>(!val);
} }
}; };
dptr[0] |= func(0, size);
dptr[0] = func(0, size); }
} }
} // namespace } // namespace
...@@ -39,20 +43,13 @@ void reduce_fwd(const src_ctype* sptr, wtype* dptr, size_t size) { ...@@ -39,20 +43,13 @@ void reduce_fwd(const src_ctype* sptr, wtype* dptr, size_t size) {
namespace megdnn { namespace megdnn {
namespace naive { namespace naive {
size_t CheckNonFiniteImpl::get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&) {
return 0;
}
void CheckNonFiniteImpl::exec( void CheckNonFiniteImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst,
check_exec(src.layout, dst.layout, workspace.size); _megdnn_workspace workspace) {
check_exec(srcs, dst, workspace.size);
auto handle = static_cast<HandleImpl*>(this->handle()); auto handle = static_cast<HandleImpl*>(this->handle());
MEGDNN_DISPATCH_CPU_KERN( MEGDNN_DISPATCH_CPU_KERN(handle, reduce_fwd(srcs, dst.ptr<dt_int32>()));
handle, reduce_fwd(
src.ptr<dt_float32>(), dst.ptr<dt_int32>(),
src.layout.total_nr_elems()));
} }
} // namespace naive } // namespace naive
} // namespace megdnn } // namespace megdnn
......
...@@ -17,16 +17,20 @@ namespace megdnn { ...@@ -17,16 +17,20 @@ namespace megdnn {
namespace naive { namespace naive {
class CheckNonFiniteImpl final : public CheckNonFinite { class CheckNonFiniteImpl final : public CheckNonFinite {
size_t _get_workspace_in_bytes() override { return 0; }
public: public:
using CheckNonFinite::CheckNonFinite; using CheckNonFinite::CheckNonFinite;
bool is_thread_safe() const override { return true; } bool is_thread_safe() const override { return true; }
size_t get_workspace_in_bytes( size_t get_workspace_in_bytes(const TensorNDArray&, const TensorLayout&) override {
const TensorLayout& src, const TensorLayout& dst) override; m_size = 0;
return _get_workspace_in_bytes();
}
void exec( void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override; _megdnn_workspace workspace) override;
}; };
......
...@@ -202,6 +202,27 @@ struct OprProxy<ConcatForward> { ...@@ -202,6 +202,27 @@ struct OprProxy<ConcatForward> {
} }
}; };
template <>
struct OprProxy<CheckNonFinite> {
static void deduce_layout(CheckNonFinite* opr, TensorLayoutArray& layouts) {
megdnn_assert(layouts.size() >= 2);
auto inp = layouts;
inp.pop_back();
opr->deduce_layout(inp, layouts.back());
}
static void exec(CheckNonFinite* opr, const TensorNDArray& tensors) {
megdnn_assert(tensors.size() >= 2);
auto inps = tensors;
inps.pop_back();
WorkspaceWrapper W(
opr->handle(),
opr->get_workspace_in_bytes(inps, tensors.back().layout));
opr->exec(inps, tensors.back(), W.workspace());
}
};
template <> template <>
struct OprProxy<SplitForward> : DeduceLayoutProxy<SplitForward, 0, false> { struct OprProxy<SplitForward> : DeduceLayoutProxy<SplitForward, 0, false> {
WorkspaceWrapper W; WorkspaceWrapper W;
......
...@@ -22,13 +22,16 @@ TEST_F(CUDA, CHECK_NON_FINITE_BASIC) { ...@@ -22,13 +22,16 @@ TEST_F(CUDA, CHECK_NON_FINITE_BASIC) {
const auto nan = std::numeric_limits<float>::quiet_NaN(); const auto nan = std::numeric_limits<float>::quiet_NaN();
UniformFloatWithValueRNG rng(-1.0f, 1.0f, 0.1f, inf); UniformFloatWithValueRNG rng(-1.0f, 1.0f, 0.1f, inf);
checker.set_rng(0, &rng); checker.set_rng(0, &rng);
checker.execs({{512 * 16}, {1}}); checker.execs({{512 * 4}, {4}, {1}});
rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 1.f, inf); rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 1.f, inf);
checker.set_rng(0, &rng); checker.set_rng(0, &rng);
checker.execs({{512 * 16}, {1}}); checker.execs({{4}, {512 * 4}, {1}});
rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 1.f, nan); rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 1.f, nan);
checker.set_rng(0, &rng); checker.set_rng(0, &rng);
checker.execs({{512 * 16}, {1}}); checker.execs({{32}, {256}, {1}});
rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 0.f, nan);
checker.set_rng(0, &rng);
checker.execs({{16}, {16}, {2}, {1}});
} }
} // namespace test } // namespace test
......
...@@ -20,23 +20,28 @@ namespace test { ...@@ -20,23 +20,28 @@ namespace test {
TEST_F(NAIVE, CHECK_NON_FINITE_BASIC) { TEST_F(NAIVE, CHECK_NON_FINITE_BASIC) {
Checker<CheckNonFinite> checker(handle(), false); Checker<CheckNonFinite> checker(handle(), false);
checker.exect( checker.exect(
Testcase{TensorValue({4}, dtype::Float32(), {1.1, 2.2, 3.3, 4.3}), {}}, Testcase{
Testcase{{}, TensorValue({1}, dtype::Int32(), {0})}); TensorValue({4}, dtype::Float32(), {1.1, 2.2, 3.3, 4.3}),
TensorValue({4}, dtype::Float32(), {1.1, 2.2, 3.3, 4.3}),
{}},
Testcase{{}, {}, TensorValue({1}, dtype::Int32(), {0})});
checker.exect( checker.exect(
Testcase{ Testcase{
TensorValue({4}, dtype::Float32(), {1.1, 2.2, 3.3, 4.3}),
TensorValue( TensorValue(
{4}, dtype::Float32(), {4}, dtype::Float32(),
{1.1f, 2.2f, 3.3f, std::numeric_limits<float>::infinity()}), {1.1f, 2.2f, 3.3f, std::numeric_limits<float>::infinity()}),
{}}, {}},
Testcase{{}, TensorValue({1}, dtype::Int32(), {1})}); Testcase{{}, {}, TensorValue({1}, dtype::Int32(), {1})});
checker.exect( checker.exect(
Testcase{ Testcase{
TensorValue({4}, dtype::Float32(), {1.1, 2.2, 3.3, 4.3}),
TensorValue( TensorValue(
{4}, dtype::Float32(), {4}, dtype::Float32(),
{1.1f, 2.2f, 3.3f, {1.1f, 2.2f, 3.3f,
std::numeric_limits<float>::quiet_NaN()}), std::numeric_limits<float>::quiet_NaN()}),
{}}, {}},
Testcase{{}, TensorValue({1}, dtype::Int32(), {1})}); Testcase{{}, {}, TensorValue({1}, dtype::Int32(), {1})});
} }
} // namespace test } // namespace test
......
...@@ -128,21 +128,22 @@ class GradScaler: ...@@ -128,21 +128,22 @@ class GradScaler:
grad_tensors: Tensors needed to unscale grads. Should be all tensors grad_tensors: Tensors needed to unscale grads. Should be all tensors
that are affected by ``target`` tensor in GradManager's backward. that are affected by ``target`` tensor in GradManager's backward.
""" """
# use float64 for better precision
inv_scale = Tensor(1.0 / self.scale_factor)
for tensor in grad_tensors:
if tensor is None or getattr(tensor, "grad", None) is None:
continue
# to support tracing, _check_gradients should be applied to every grad. # to support tracing, _check_gradients should be applied to every grad.
if self._check_gradients(tensor.grad): if self._check_gradients([x.grad for x in grad_tensors]):
self._found_non_finite = True self._found_non_finite = True
tensor.grad *= inv_scale
if self._found_non_finite: if self._found_non_finite:
for tensor in grad_tensors: for tensor in grad_tensors:
if tensor is None or getattr(tensor, "grad", None) is None: if tensor is None or getattr(tensor, "grad", None) is None:
continue continue
tensor.grad = None tensor.grad = None
else:
# use float64 for better precision
inv_scale = Tensor(1.0 / self.scale_factor)
for tensor in grad_tensors:
if tensor is None or getattr(tensor, "grad", None) is None:
continue
tensor.grad *= inv_scale
return self return self
def _check_gradients(self, grad): def _check_gradients(self, grad):
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
import collections import collections
import math import math
from functools import lru_cache from functools import lru_cache
from typing import Optional, Sequence, Tuple, Union from typing import Iterable, Optional, Sequence, Tuple, Union
from ..core import _config from ..core import _config
from ..core._imperative_rt.core2 import apply, dtype_promotion from ..core._imperative_rt.core2 import apply, dtype_promotion
...@@ -1183,7 +1183,7 @@ def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor: ...@@ -1183,7 +1183,7 @@ def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor:
return U, sigma, V return U, sigma, V
def _check_non_finite(inp: Tensor) -> Tensor: def _check_non_finite(inps: Iterable[Tensor]) -> Tensor:
r"""Check whether input contains infinite or nan value. r"""Check whether input contains infinite or nan value.
Args: Args:
...@@ -1193,6 +1193,6 @@ def _check_non_finite(inp: Tensor) -> Tensor: ...@@ -1193,6 +1193,6 @@ def _check_non_finite(inp: Tensor) -> Tensor:
a int32 scalar tensor, 0 for False and 1 for True. a int32 scalar tensor, 0 for False and 1 for True.
""" """
op = builtin.CheckNonFinite() op = builtin.CheckNonFinite()
(oup,) = apply(op, inp.reshape(-1).astype("float32")) (oup,) = apply(op, *inps)
oup._setscalar() oup._setscalar()
return oup return oup
...@@ -10,9 +10,11 @@ import numpy as np ...@@ -10,9 +10,11 @@ import numpy as np
import megengine as mge import megengine as mge
from megengine.amp import GradScaler from megengine.amp import GradScaler
from megengine.autodiff import GradManager from megengine.autodiff import GradManager
from megengine.jit import trace
def test_grad_scaler(): def test_grad_scaler():
def f():
gm = GradManager() gm = GradManager()
scaler = GradScaler() scaler = GradScaler()
...@@ -28,3 +30,6 @@ def test_grad_scaler(): ...@@ -28,3 +30,6 @@ def test_grad_scaler():
np.testing.assert_equal(y.grad.numpy(), 1) np.testing.assert_equal(y.grad.numpy(), 1)
# test handle None elements # test handle None elements
scaler.unscale(gm.attached_tensors()) scaler.unscale(gm.attached_tensors())
f()
trace(f)()
...@@ -191,16 +191,17 @@ def test_sum_neg_axis(): ...@@ -191,16 +191,17 @@ def test_sum_neg_axis():
def test_non_finite(): def test_non_finite():
shape = (32, 3, 32, 32) shape = (32, 3, 32, 32)
data = np.random.random(shape).astype(np.float32) data1 = np.random.random(shape).astype(np.float32)
rst = F.math._check_non_finite(tensor(data)) data2 = np.random.random(shape).astype(np.float32)
rst = F.math._check_non_finite([tensor(data1), tensor(data2)])
np.testing.assert_equal(rst.numpy(), [0]) np.testing.assert_equal(rst.numpy(), [0])
data[0][0][0][0] = float("inf") data2[0][0][0][0] = float("inf")
rst = F.math._check_non_finite(tensor(data)) rst = F.math._check_non_finite([tensor(data1), tensor(data2)])
np.testing.assert_equal(rst.numpy(), [1]) np.testing.assert_equal(rst.numpy(), [1])
data[0][0][0][0] = float("nan") data2[0][0][0][0] = float("nan")
rst = F.math._check_non_finite(tensor(data)) rst = F.math._check_non_finite([tensor(data1), tensor(data2)])
np.testing.assert_equal(rst.numpy(), [1]) np.testing.assert_equal(rst.numpy(), [1])
......
...@@ -17,14 +17,56 @@ namespace mgb { ...@@ -17,14 +17,56 @@ namespace mgb {
namespace imperative { namespace imperative {
namespace check_non_finite { namespace check_non_finite {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { SymbolVar apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<CheckNonFinite>(); auto&& op = def.cast_final_safe<CheckNonFinite>();
mgb_assert(inputs.size() == 1);
OperatorNodeConfig config{op.make_name()}; OperatorNodeConfig config{op.make_name()};
return opr::CheckNonFinite::make(inputs[0], {}, config); return opr::CheckNonFinite::make(inputs, {}, config);
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
size_t size = inputs.size();
auto dest = Tensor::make(
TensorLayout(TensorShape({1}), dtype::Int32()), inputs[0]->comp_node());
auto cn = dest->comp_node();
auto&& dnn_opr = opr::intl::create_megdnn_opr<megdnn::CheckNonFinite>(cn);
size_t wk_size = 0;
SmallVector<megdnn::TensorND> srcs(size);
for (size_t i = 0; i < size; ++i) {
srcs[i] = inputs[i]->dev_tensor().as_megdnn();
}
wk_size = dnn_opr->get_workspace_in_bytes(srcs, dest->layout());
auto wk = Blob::make(cn, wk_size);
megdnn::Workspace dnn_wk(wk->storage().get(), wk_size);
dnn_opr->exec(srcs, dest->dev_tensor().as_megdnn(), dnn_wk);
return {dest};
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
SmallVector<LogicalTensorDesc> dests(1);
dests[0].comp_node = inputs[0].comp_node;
dests[0].layout = TensorLayout(TensorShape({1}), dtype::Int32());
return {dests, true};
}
SmallVector<LogicalTensorDesc> infer_output_attrs(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
SmallVector<LogicalTensorDesc> dests(1);
dests[0].comp_node = inputs[0]->comp_node();
dests[0].layout = TensorLayout(TensorShape({1}), dtype::Int32());
return dests;
}
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
return {{}, {}};
} }
OP_TRAIT_REG(CheckNonFinite, CheckNonFinite) OP_TRAIT_REG(CheckNonFinite, CheckNonFinite)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.infer_output_mem_desc(infer_output_mem_desc)
.fallback(); .fallback();
} // namespace check_non_finite } // namespace check_non_finite
......
...@@ -482,18 +482,74 @@ MGB_IMPL_OPR_GRAD(TopK) { ...@@ -482,18 +482,74 @@ MGB_IMPL_OPR_GRAD(TopK) {
#endif #endif
/* ================= CheckNonFinite ================= */ /* ================= CheckNonFinite ================= */
namespace mgb {
namespace opr {
namespace intl {
template <>
struct MegDNNOprInitPostCtor<CheckNonFinite> {
static void apply(cg::OperatorNodeBase& opr) {
opr.output(0)->dtype(dtype::Int32());
}
};
} // namespace intl
} // namespace opr
} // namespace mgb
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CheckNonFinite); MGB_DYN_TYPE_OBJ_FINAL_IMPL(CheckNonFinite);
MEGDNN_OPR_INIT1(CheckNonFinite, "check_non_finite") CheckNonFinite::CheckNonFinite(
const VarNodeArrayView& inp, const Param& param,
const OperatorNodeConfig& config)
: Super(OperatorNodeBaseCtorParam{
inp[0]->owner_graph(), config, "check_non_finite", inp}) {
mgb_assert(!inp.empty());
for (auto&& i : inp) {
add_input({i});
}
add_output(None)->dtype(dtype::Int32()).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
cg::add_workspace_output(this);
}
SymbolVar CheckNonFinite::make(
const VarNodeArrayView& inp, const Param& param,
const OperatorNodeConfig& config) {
mgb_assert(!inp.empty());
intl::BatchedDTypePromotion dtp{inp};
return SymbolVar{inp[0]}.insert_single_output_opr<CheckNonFinite>(
dtp.get_vars(), param, config);
}
void CheckNonFinite::scn_do_execute() {
megdnn::TensorNDArray inp_arr(input().size());
for (size_t i = 0; i < input().size(); ++i) {
inp_arr[i] = input()[i]->dev_tensor().as_megdnn();
}
megdnn_opr()->exec(
inp_arr, output(0)->dev_tensor().as_megdnn(),
intl::get_megdnn_workspace_from_var(output(1)));
}
void CheckNonFinite::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager();
auto infer_oshp = [](TensorShape& dest, const InpVal& iv) {
TensorLayout dst;
dst.shape[0] = 1;
dst.ndim = 1;
dst.dtype = dtype::Int32();
dst.init_contiguous_stride();
dest = dst;
return true;
};
DepVal deps;
for (auto i : input())
deps.push_back({i, DepType::SHAPE});
mgr.register_shape_infer(output(0), {SourceType::DEP, deps, infer_oshp});
auto infer_wk = [this](TensorShape& dest, const InpVal& inp) {
dest.ndim = 1;
megdnn::TensorNDArray inp_arr(input().size());
for (size_t i = 0; i < input().size(); ++i) {
inp_arr[i] = {NULL, {inp.val.at(i).shape(), input(0)->dtype()}};
}
dest.shape[0] = megdnn_opr()->get_workspace_in_bytes(
inp_arr, {output(0)->shape(), output(0)->dtype()});
return true;
};
mgr.register_shape_infer(output(1), {SourceType::DEP, deps, infer_wk});
}
void CheckNonFinite::add_input_layout_constraint() {
for (auto i : input()) {
i->add_layout_constraint_contiguous();
}
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -55,6 +55,10 @@ struct OprMaker<opr::TopK, 2> { ...@@ -55,6 +55,10 @@ struct OprMaker<opr::TopK, 2> {
} }
}; };
template <>
struct OprMaker<opr::CheckNonFinite, 0> : public OprMakerVariadic<opr::CheckNonFinite> {
};
} // namespace serialization } // namespace serialization
namespace opr { namespace opr {
...@@ -72,7 +76,7 @@ MGB_SEREG_OPR(CumsumV1, 1); ...@@ -72,7 +76,7 @@ MGB_SEREG_OPR(CumsumV1, 1);
#if MGB_CUDA #if MGB_CUDA
MGB_SEREG_OPR(NvOf, 1); MGB_SEREG_OPR(NvOf, 1);
#endif #endif
MGB_SEREG_OPR(CheckNonFinite, 1); MGB_SEREG_OPR(CheckNonFinite, 0);
} // namespace opr } // namespace opr
} // namespace mgb } // namespace mgb
......
...@@ -142,6 +142,8 @@ using CondTakeBase = cg::SingleCNOperatorNode< ...@@ -142,6 +142,8 @@ using CondTakeBase = cg::SingleCNOperatorNode<
cg::OperatorNodeBase, mixin::MegDNNOprHolderImpl<megdnn::CondTake>>; cg::OperatorNodeBase, mixin::MegDNNOprHolderImpl<megdnn::CondTake>>;
using TopKBase = cg::SingleCNOperatorNode< using TopKBase = cg::SingleCNOperatorNode<
cg::OperatorNodeBase, mixin::MegDNNOprHolderImpl<megdnn::TopK>>; cg::OperatorNodeBase, mixin::MegDNNOprHolderImpl<megdnn::TopK>>;
using CheckNonFiniteBase = cg::SingleCNOperatorNode<
cg::OperatorNodeBase, mixin::MegDNNOprHolderImpl<megdnn::CheckNonFinite>>;
} // namespace intl } // namespace intl
/*! /*!
...@@ -181,7 +183,19 @@ public: ...@@ -181,7 +183,19 @@ public:
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };
MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(CheckNonFinite); MGB_DEFINE_OPR_CLASS(CheckNonFinite, intl::CheckNonFiniteBase) //{
void scn_do_execute() override;
void init_output_static_infer_desc() override;
void add_input_layout_constraint() override;
public:
MGE_WIN_DECLSPEC_FUC CheckNonFinite(
const VarNodeArrayView& inp, const Param& param,
const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
const VarNodeArrayView& inp, const Param& param = {},
const OperatorNodeConfig& config = {});
};
} // namespace opr } // namespace opr
} // namespace mgb } // namespace mgb
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册