提交 11d75fec 编写于 作者: M Megvii Engine Team

feat(dnn/check_non_finite): add batch check_non_finite

GitOrigin-RevId: e108133282cb2c9129292715ae6eab1e396cd0bc
上级 7a023c05
......@@ -1345,22 +1345,23 @@ protected:
*/
class CheckNonFinite : public OperatorBase {
DEF_OPR_PARAM(Empty);
DEF_OPR_IMPL(CheckNonFinite, OperatorBase, 1, 1);
DEF_OPR_IMPL(CheckNonFinite, OperatorBase, -1, 1);
size_t m_size = 0;
public:
virtual size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) = 0;
const TensorNDArray& srcs, const TensorLayout& dst) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& dst);
void deduce_layout(const TensorLayoutArray& srcs, TensorLayout& dst);
virtual void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
protected:
void check_exec(
const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes);
const TensorNDArray& srcs, const TensorND& dst, size_t workspace_in_bytes);
virtual size_t _get_workspace_in_bytes() = 0;
};
/*!
......
......@@ -15,16 +15,15 @@
namespace megdnn {
void CheckNonFinite::check_exec(
const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) {
megdnn_assert_contiguous(src);
megdnn_assert_contiguous(dst);
megdnn_assert(src.ndim == 1);
megdnn_assert(src.dtype == dtype::Float32());
auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst);
const TensorNDArray& srcs, const TensorND& dst, size_t workspace_in_bytes) {
megdnn_assert_contiguous(dst.layout);
megdnn_assert(srcs.size() > 0);
megdnn_assert(srcs.begin()->layout.dtype == dtype::Float32());
auto required_workspace_in_bytes = _get_workspace_in_bytes();
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
void CheckNonFinite::deduce_layout(const TensorLayout&, TensorLayout& dst) {
void CheckNonFinite::deduce_layout(const TensorLayoutArray&, TensorLayout& dst) {
dst.shape[0] = 1;
dst.ndim = 1;
dst.dtype = dtype::Int32();
......
......@@ -156,21 +156,35 @@ struct MaxOp<src_ctype, dst_ctype, dt_float32> {
: INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {}
};
template <typename src_ctype, typename dst_ctype, typename wtype_>
template <typename src_ctype, typename index_ctype, typename dst_ctype, typename wtype_>
struct CheckNonFiniteOp {
typedef wtype_ wtype;
const wtype INIT;
RefPtr src;
RefPtr* srcs;
RefPtr srcs_total_nr_elems;
RefPtr dst;
const size_t B;
wtype read(uint32_t idx) { return !std::isfinite(src.ptr<src_ctype>()[idx]); }
wtype read(uint32_t idx) {
size_t x = idx / B;
size_t y = idx % B;
if (y < srcs_total_nr_elems.ptr<index_ctype>()[x]) {
RefPtr src = srcs[x];
return !std::isfinite(src.ptr<src_ctype>()[y]);
}
return 0;
}
void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; }
static wtype apply(wtype lhs, wtype rhs) { return lhs | rhs; }
MEGDNN_HOST MEGDNN_DEVICE
CheckNonFiniteOp(const RefPtr& src, const RefPtr& dst, size_t B)
: INIT(wtype(0)), src(src), dst(dst), B(B) {}
CheckNonFiniteOp(
RefPtr* srcs, const RefPtr& srcs_total_nr_elems, const RefPtr& dst,
size_t B)
: INIT(wtype(0)),
srcs(srcs),
srcs_total_nr_elems(srcs_total_nr_elems),
dst(dst),
B(B) {}
};
void get_ABC(const TensorShape& shape, size_t& A, size_t& B, size_t& C, size_t axis);
......
......@@ -185,28 +185,41 @@ struct MaxOp<src_ctype, dst_ctype, dt_float32> {
: INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {}
};
template <typename src_ctype, typename dst_ctype, typename wtype_>
template <typename src_ctype, typename index_ctype, typename dst_ctype, typename wtype_>
struct CheckNonFiniteOp {
typedef wtype_ wtype;
const wtype INIT;
src_ctype* src;
src_ctype** srcs;
index_ctype* srcs_total_nr_elems;
dst_ctype* dst;
const size_t B;
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) {
size_t x = idx / B;
size_t y = idx % B;
if (y < srcs_total_nr_elems[x]) {
#if defined(__CUDA_ARCH__)
return !isfinite(src[idx]);
wtype val = isfinite(srcs[x][y]);
#else
return !std::isfinite(src[idx]);
wtype val = std::isfinite(srcs[x][y]);
#endif
return !val;
}
return 0;
}
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; }
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) {
return lhs | rhs;
}
MEGDNN_HOST MEGDNN_DEVICE CheckNonFiniteOp(src_ctype* src, dst_ctype* dst, size_t B)
: INIT(wtype(0)), src(src), dst(dst), B(B) {}
MEGDNN_HOST MEGDNN_DEVICE CheckNonFiniteOp(
src_ctype** srcs, index_ctype* srcs_total_nr_elems, dst_ctype* dst,
size_t B)
: INIT(wtype(0)),
srcs(srcs),
srcs_total_nr_elems(srcs_total_nr_elems),
dst(dst),
B(B) {}
};
} // namespace device_reduce
......
......@@ -19,7 +19,8 @@ namespace cuda {
#define COMMA ,
INST_REDUCE(
device_reduce::CheckNonFiniteOp<dt_float32 COMMA dt_int32 COMMA dt_int32>,
device_reduce::CheckNonFiniteOp<
dt_float32 COMMA size_t COMMA dt_int32 COMMA dt_int32>,
false);
#undef COMMA
......
......@@ -21,22 +21,83 @@ namespace megdnn {
namespace cuda {
using device_reduce::CheckNonFiniteOp;
#define total_nr_elems_max 2048
size_t CheckNonFiniteImpl::_get_workspace_in_bytes() {
// Call the _get_workspace_in_bytes to reduce the loop fetch workspace bytes
typedef CheckNonFiniteOp<dt_float32, size_t, dt_int32, dt_int32> Op;
megdnn_assert(m_size > 0);
WorkspaceBundle bundle(
nullptr, {
sizeof(dt_float32*) * m_size,
sizeof(size_t) * m_size,
});
return get_reduce_workspace_in_bytes<Op>(1, m_size * total_nr_elems_max, 1) +
bundle.total_size_in_bytes();
}
size_t CheckNonFiniteImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) {
typedef CheckNonFiniteOp<dt_float32, dt_int32, dt_int32> Op;
return get_reduce_workspace_in_bytes<Op>(1, src.total_nr_elems(), 1);
const TensorNDArray& srcs, const TensorLayout&) {
m_size = 0;
for (const auto& src : srcs) {
m_size += DIVUP(src.layout.total_nr_elems(), total_nr_elems_max);
}
return _get_workspace_in_bytes();
}
void CheckNonFiniteImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size);
typedef CheckNonFiniteOp<dt_float32, dt_int32, dt_int32> Op;
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
check_exec(srcs, dst, workspace.size);
typedef CheckNonFiniteOp<dt_float32, size_t, dt_int32, dt_int32> Op;
auto stream = cuda_stream(this->handle());
auto B = src.layout.total_nr_elems();
SmallVector<size_t> workspace_sizes{
sizeof(dt_float32*) * m_size,
sizeof(size_t) * m_size,
};
WorkspaceBundle workspace_cpu(nullptr, workspace_sizes),
workspace_gpu(nullptr, workspace_sizes);
auto total_workspace_size = workspace_cpu.total_size_in_bytes();
void* workspace_cpu_raw = malloc(total_workspace_size);
megdnn_assert_internal(workspace_cpu_raw);
void* workspace_gpu_raw = workspace.raw_ptr;
workspace_cpu = WorkspaceBundle(workspace_cpu_raw, workspace_sizes);
workspace_gpu = WorkspaceBundle(workspace_gpu_raw, workspace_sizes);
auto srcs_cpu = static_cast<dt_float32**>(workspace_cpu.get(0));
auto srcs_gpu = static_cast<dt_float32**>(workspace_gpu.get(0));
auto srcs_total_nr_elems_cpu = static_cast<size_t*>(workspace_cpu.get(1));
auto srcs_total_nr_elems_gpu = static_cast<size_t*>(workspace_gpu.get(1));
// srcs
// cut the tensor to a fixed length of total_nr_elems_max
size_t i = 0;
for (const auto& src : srcs) {
size_t src_nr_elems = src.layout.total_nr_elems();
size_t nr_elems = DIVUP(src_nr_elems, total_nr_elems_max);
for (size_t j = 0; j < nr_elems; ++j, ++i) {
srcs_cpu[i] = src.ptr<dt_float32>() + j * total_nr_elems_max;
if (j + 1 == nr_elems && src_nr_elems % total_nr_elems_max) {
srcs_total_nr_elems_cpu[i] = src_nr_elems % total_nr_elems_max;
} else {
srcs_total_nr_elems_cpu[i] = total_nr_elems_max;
}
}
}
for (size_t i = 0; i < workspace_cpu.nr_workspace(); ++i) {
cuda_check(cudaMemcpyAsync(
workspace_gpu.get(i), workspace_cpu.get(i), workspace_cpu.get_size(i),
cudaMemcpyHostToDevice, stream));
}
cuda_check(cudaStreamAddCallback(
stream, callback_free, static_cast<void*>(workspace_cpu_raw), 0));
return run_reduce<Op, false>(
workspace.ptr<dt_int32>(), 1, B, 1, stream,
Op(src.ptr<dt_float32>(), dst.ptr<dt_int32>(), B));
static_cast<dt_int32*>(
(void*)((char*)workspace_gpu_raw +
workspace_gpu.total_size_in_bytes())),
1, m_size * total_nr_elems_max, 1, stream,
Op(srcs_gpu, srcs_total_nr_elems_gpu, dst.ptr<dt_int32>(),
total_nr_elems_max));
}
} // namespace cuda
......
......@@ -18,16 +18,18 @@ namespace megdnn {
namespace cuda {
class CheckNonFiniteImpl final : public CheckNonFinite {
size_t _get_workspace_in_bytes() override;
public:
using CheckNonFinite::CheckNonFinite;
size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) override;
const TensorNDArray& srcs, const TensorLayout& dst) override;
bool is_thread_safe() const override { return true; }
void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
};
......
......@@ -17,21 +17,25 @@
namespace {
using namespace megdnn;
#define src_ctype dt_float32
#define wtype dt_int32
void reduce_fwd(const src_ctype* sptr, wtype* dptr, size_t size) {
std::function<wtype(size_t, size_t)> func;
func = [&](size_t l, size_t r) -> wtype {
if (l + 1 < r) {
size_t mid = l + (r - l) / 2;
return func(l, mid) | func(mid, r);
} else {
return static_cast<wtype>(!std::isfinite(sptr[l]));
}
};
dptr[0] = func(0, size);
#define wtype dt_int32
void reduce_fwd(const TensorNDArray& srcs, wtype* dptr) {
dptr[0] = 0;
for (auto src : srcs) {
auto sptr = src.ptr<dt_float32>();
size_t size = src.layout.total_nr_elems();
std::function<wtype(wtype, wtype)> func;
func = [&](wtype l, wtype r) -> wtype {
if (l + 1 < r) {
wtype mid = l + (r - l) / 2;
return func(l, mid) | func(mid, r);
} else {
auto val = std::isfinite(sptr[l]);
return static_cast<wtype>(!val);
}
};
dptr[0] |= func(0, size);
}
}
} // namespace
......@@ -39,20 +43,13 @@ void reduce_fwd(const src_ctype* sptr, wtype* dptr, size_t size) {
namespace megdnn {
namespace naive {
size_t CheckNonFiniteImpl::get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&) {
return 0;
}
void CheckNonFiniteImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size);
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
check_exec(srcs, dst, workspace.size);
auto handle = static_cast<HandleImpl*>(this->handle());
MEGDNN_DISPATCH_CPU_KERN(
handle, reduce_fwd(
src.ptr<dt_float32>(), dst.ptr<dt_int32>(),
src.layout.total_nr_elems()));
MEGDNN_DISPATCH_CPU_KERN(handle, reduce_fwd(srcs, dst.ptr<dt_int32>()));
}
} // namespace naive
} // namespace megdnn
......
......@@ -17,16 +17,20 @@ namespace megdnn {
namespace naive {
class CheckNonFiniteImpl final : public CheckNonFinite {
size_t _get_workspace_in_bytes() override { return 0; }
public:
using CheckNonFinite::CheckNonFinite;
bool is_thread_safe() const override { return true; }
size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) override;
size_t get_workspace_in_bytes(const TensorNDArray&, const TensorLayout&) override {
m_size = 0;
return _get_workspace_in_bytes();
}
void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
};
......
......@@ -202,6 +202,27 @@ struct OprProxy<ConcatForward> {
}
};
template <>
struct OprProxy<CheckNonFinite> {
static void deduce_layout(CheckNonFinite* opr, TensorLayoutArray& layouts) {
megdnn_assert(layouts.size() >= 2);
auto inp = layouts;
inp.pop_back();
opr->deduce_layout(inp, layouts.back());
}
static void exec(CheckNonFinite* opr, const TensorNDArray& tensors) {
megdnn_assert(tensors.size() >= 2);
auto inps = tensors;
inps.pop_back();
WorkspaceWrapper W(
opr->handle(),
opr->get_workspace_in_bytes(inps, tensors.back().layout));
opr->exec(inps, tensors.back(), W.workspace());
}
};
template <>
struct OprProxy<SplitForward> : DeduceLayoutProxy<SplitForward, 0, false> {
WorkspaceWrapper W;
......
......@@ -22,13 +22,16 @@ TEST_F(CUDA, CHECK_NON_FINITE_BASIC) {
const auto nan = std::numeric_limits<float>::quiet_NaN();
UniformFloatWithValueRNG rng(-1.0f, 1.0f, 0.1f, inf);
checker.set_rng(0, &rng);
checker.execs({{512 * 16}, {1}});
checker.execs({{512 * 4}, {4}, {1}});
rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 1.f, inf);
checker.set_rng(0, &rng);
checker.execs({{512 * 16}, {1}});
checker.execs({{4}, {512 * 4}, {1}});
rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 1.f, nan);
checker.set_rng(0, &rng);
checker.execs({{512 * 16}, {1}});
checker.execs({{32}, {256}, {1}});
rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 0.f, nan);
checker.set_rng(0, &rng);
checker.execs({{16}, {16}, {2}, {1}});
}
} // namespace test
......
......@@ -20,23 +20,28 @@ namespace test {
TEST_F(NAIVE, CHECK_NON_FINITE_BASIC) {
Checker<CheckNonFinite> checker(handle(), false);
checker.exect(
Testcase{TensorValue({4}, dtype::Float32(), {1.1, 2.2, 3.3, 4.3}), {}},
Testcase{{}, TensorValue({1}, dtype::Int32(), {0})});
Testcase{
TensorValue({4}, dtype::Float32(), {1.1, 2.2, 3.3, 4.3}),
TensorValue({4}, dtype::Float32(), {1.1, 2.2, 3.3, 4.3}),
{}},
Testcase{{}, {}, TensorValue({1}, dtype::Int32(), {0})});
checker.exect(
Testcase{
TensorValue({4}, dtype::Float32(), {1.1, 2.2, 3.3, 4.3}),
TensorValue(
{4}, dtype::Float32(),
{1.1f, 2.2f, 3.3f, std::numeric_limits<float>::infinity()}),
{}},
Testcase{{}, TensorValue({1}, dtype::Int32(), {1})});
Testcase{{}, {}, TensorValue({1}, dtype::Int32(), {1})});
checker.exect(
Testcase{
TensorValue({4}, dtype::Float32(), {1.1, 2.2, 3.3, 4.3}),
TensorValue(
{4}, dtype::Float32(),
{1.1f, 2.2f, 3.3f,
std::numeric_limits<float>::quiet_NaN()}),
{}},
Testcase{{}, TensorValue({1}, dtype::Int32(), {1})});
Testcase{{}, {}, TensorValue({1}, dtype::Int32(), {1})});
}
} // namespace test
......
......@@ -128,21 +128,22 @@ class GradScaler:
grad_tensors: Tensors needed to unscale grads. Should be all tensors
that are affected by ``target`` tensor in GradManager's backward.
"""
# use float64 for better precision
inv_scale = Tensor(1.0 / self.scale_factor)
for tensor in grad_tensors:
if tensor is None or getattr(tensor, "grad", None) is None:
continue
# to support tracing, _check_gradients should be applied to every grad.
if self._check_gradients(tensor.grad):
self._found_non_finite = True
tensor.grad *= inv_scale
# to support tracing, _check_gradients should be applied to every grad.
if self._check_gradients([x.grad for x in grad_tensors]):
self._found_non_finite = True
if self._found_non_finite:
for tensor in grad_tensors:
if tensor is None or getattr(tensor, "grad", None) is None:
continue
tensor.grad = None
else:
# use float64 for better precision
inv_scale = Tensor(1.0 / self.scale_factor)
for tensor in grad_tensors:
if tensor is None or getattr(tensor, "grad", None) is None:
continue
tensor.grad *= inv_scale
return self
def _check_gradients(self, grad):
......
......@@ -9,7 +9,7 @@
import collections
import math
from functools import lru_cache
from typing import Optional, Sequence, Tuple, Union
from typing import Iterable, Optional, Sequence, Tuple, Union
from ..core import _config
from ..core._imperative_rt.core2 import apply, dtype_promotion
......@@ -1183,7 +1183,7 @@ def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor:
return U, sigma, V
def _check_non_finite(inp: Tensor) -> Tensor:
def _check_non_finite(inps: Iterable[Tensor]) -> Tensor:
r"""Check whether input contains infinite or nan value.
Args:
......@@ -1193,6 +1193,6 @@ def _check_non_finite(inp: Tensor) -> Tensor:
a int32 scalar tensor, 0 for False and 1 for True.
"""
op = builtin.CheckNonFinite()
(oup,) = apply(op, inp.reshape(-1).astype("float32"))
(oup,) = apply(op, *inps)
oup._setscalar()
return oup
......@@ -10,21 +10,26 @@ import numpy as np
import megengine as mge
from megengine.amp import GradScaler
from megengine.autodiff import GradManager
from megengine.jit import trace
def test_grad_scaler():
gm = GradManager()
scaler = GradScaler()
def f():
gm = GradManager()
scaler = GradScaler()
x = mge.tensor(1.0)
for _ in range(3):
with gm:
y = x + 1
gm.attach(y)
loss = y + 1
scaler.backward(gm, loss, unscale_grad=False)
np.testing.assert_equal(y.grad.numpy(), scaler.scale_factor)
x = mge.tensor(1.0)
for _ in range(3):
with gm:
y = x + 1
gm.attach(y)
loss = y + 1
scaler.backward(gm, loss, unscale_grad=False)
np.testing.assert_equal(y.grad.numpy(), scaler.scale_factor)
scaler.unscale(gm.attached_tensors())
np.testing.assert_equal(y.grad.numpy(), 1)
# test handle None elements
scaler.unscale(gm.attached_tensors())
np.testing.assert_equal(y.grad.numpy(), 1)
# test handle None elements
scaler.unscale(gm.attached_tensors())
f()
trace(f)()
......@@ -191,16 +191,17 @@ def test_sum_neg_axis():
def test_non_finite():
shape = (32, 3, 32, 32)
data = np.random.random(shape).astype(np.float32)
rst = F.math._check_non_finite(tensor(data))
data1 = np.random.random(shape).astype(np.float32)
data2 = np.random.random(shape).astype(np.float32)
rst = F.math._check_non_finite([tensor(data1), tensor(data2)])
np.testing.assert_equal(rst.numpy(), [0])
data[0][0][0][0] = float("inf")
rst = F.math._check_non_finite(tensor(data))
data2[0][0][0][0] = float("inf")
rst = F.math._check_non_finite([tensor(data1), tensor(data2)])
np.testing.assert_equal(rst.numpy(), [1])
data[0][0][0][0] = float("nan")
rst = F.math._check_non_finite(tensor(data))
data2[0][0][0][0] = float("nan")
rst = F.math._check_non_finite([tensor(data1), tensor(data2)])
np.testing.assert_equal(rst.numpy(), [1])
......
......@@ -17,14 +17,56 @@ namespace mgb {
namespace imperative {
namespace check_non_finite {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
SymbolVar apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<CheckNonFinite>();
mgb_assert(inputs.size() == 1);
OperatorNodeConfig config{op.make_name()};
return opr::CheckNonFinite::make(inputs[0], {}, config);
return opr::CheckNonFinite::make(inputs, {}, config);
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
size_t size = inputs.size();
auto dest = Tensor::make(
TensorLayout(TensorShape({1}), dtype::Int32()), inputs[0]->comp_node());
auto cn = dest->comp_node();
auto&& dnn_opr = opr::intl::create_megdnn_opr<megdnn::CheckNonFinite>(cn);
size_t wk_size = 0;
SmallVector<megdnn::TensorND> srcs(size);
for (size_t i = 0; i < size; ++i) {
srcs[i] = inputs[i]->dev_tensor().as_megdnn();
}
wk_size = dnn_opr->get_workspace_in_bytes(srcs, dest->layout());
auto wk = Blob::make(cn, wk_size);
megdnn::Workspace dnn_wk(wk->storage().get(), wk_size);
dnn_opr->exec(srcs, dest->dev_tensor().as_megdnn(), dnn_wk);
return {dest};
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
SmallVector<LogicalTensorDesc> dests(1);
dests[0].comp_node = inputs[0].comp_node;
dests[0].layout = TensorLayout(TensorShape({1}), dtype::Int32());
return {dests, true};
}
SmallVector<LogicalTensorDesc> infer_output_attrs(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
SmallVector<LogicalTensorDesc> dests(1);
dests[0].comp_node = inputs[0]->comp_node();
dests[0].layout = TensorLayout(TensorShape({1}), dtype::Int32());
return dests;
}
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
return {{}, {}};
}
OP_TRAIT_REG(CheckNonFinite, CheckNonFinite)
.apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.infer_output_mem_desc(infer_output_mem_desc)
.fallback();
} // namespace check_non_finite
......
......@@ -482,18 +482,74 @@ MGB_IMPL_OPR_GRAD(TopK) {
#endif
/* ================= CheckNonFinite ================= */
namespace mgb {
namespace opr {
namespace intl {
template <>
struct MegDNNOprInitPostCtor<CheckNonFinite> {
static void apply(cg::OperatorNodeBase& opr) {
opr.output(0)->dtype(dtype::Int32());
}
};
} // namespace intl
} // namespace opr
} // namespace mgb
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CheckNonFinite);
MEGDNN_OPR_INIT1(CheckNonFinite, "check_non_finite")
CheckNonFinite::CheckNonFinite(
const VarNodeArrayView& inp, const Param& param,
const OperatorNodeConfig& config)
: Super(OperatorNodeBaseCtorParam{
inp[0]->owner_graph(), config, "check_non_finite", inp}) {
mgb_assert(!inp.empty());
for (auto&& i : inp) {
add_input({i});
}
add_output(None)->dtype(dtype::Int32()).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
cg::add_workspace_output(this);
}
SymbolVar CheckNonFinite::make(
const VarNodeArrayView& inp, const Param& param,
const OperatorNodeConfig& config) {
mgb_assert(!inp.empty());
intl::BatchedDTypePromotion dtp{inp};
return SymbolVar{inp[0]}.insert_single_output_opr<CheckNonFinite>(
dtp.get_vars(), param, config);
}
void CheckNonFinite::scn_do_execute() {
megdnn::TensorNDArray inp_arr(input().size());
for (size_t i = 0; i < input().size(); ++i) {
inp_arr[i] = input()[i]->dev_tensor().as_megdnn();
}
megdnn_opr()->exec(
inp_arr, output(0)->dev_tensor().as_megdnn(),
intl::get_megdnn_workspace_from_var(output(1)));
}
void CheckNonFinite::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager();
auto infer_oshp = [](TensorShape& dest, const InpVal& iv) {
TensorLayout dst;
dst.shape[0] = 1;
dst.ndim = 1;
dst.dtype = dtype::Int32();
dst.init_contiguous_stride();
dest = dst;
return true;
};
DepVal deps;
for (auto i : input())
deps.push_back({i, DepType::SHAPE});
mgr.register_shape_infer(output(0), {SourceType::DEP, deps, infer_oshp});
auto infer_wk = [this](TensorShape& dest, const InpVal& inp) {
dest.ndim = 1;
megdnn::TensorNDArray inp_arr(input().size());
for (size_t i = 0; i < input().size(); ++i) {
inp_arr[i] = {NULL, {inp.val.at(i).shape(), input(0)->dtype()}};
}
dest.shape[0] = megdnn_opr()->get_workspace_in_bytes(
inp_arr, {output(0)->shape(), output(0)->dtype()});
return true;
};
mgr.register_shape_infer(output(1), {SourceType::DEP, deps, infer_wk});
}
void CheckNonFinite::add_input_layout_constraint() {
for (auto i : input()) {
i->add_layout_constraint_contiguous();
}
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -55,6 +55,10 @@ struct OprMaker<opr::TopK, 2> {
}
};
template <>
struct OprMaker<opr::CheckNonFinite, 0> : public OprMakerVariadic<opr::CheckNonFinite> {
};
} // namespace serialization
namespace opr {
......@@ -72,7 +76,7 @@ MGB_SEREG_OPR(CumsumV1, 1);
#if MGB_CUDA
MGB_SEREG_OPR(NvOf, 1);
#endif
MGB_SEREG_OPR(CheckNonFinite, 1);
MGB_SEREG_OPR(CheckNonFinite, 0);
} // namespace opr
} // namespace mgb
......
......@@ -142,6 +142,8 @@ using CondTakeBase = cg::SingleCNOperatorNode<
cg::OperatorNodeBase, mixin::MegDNNOprHolderImpl<megdnn::CondTake>>;
using TopKBase = cg::SingleCNOperatorNode<
cg::OperatorNodeBase, mixin::MegDNNOprHolderImpl<megdnn::TopK>>;
using CheckNonFiniteBase = cg::SingleCNOperatorNode<
cg::OperatorNodeBase, mixin::MegDNNOprHolderImpl<megdnn::CheckNonFinite>>;
} // namespace intl
/*!
......@@ -181,7 +183,19 @@ public:
const OperatorNodeConfig& config = {});
};
MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(CheckNonFinite);
MGB_DEFINE_OPR_CLASS(CheckNonFinite, intl::CheckNonFiniteBase) //{
void scn_do_execute() override;
void init_output_static_infer_desc() override;
void add_input_layout_constraint() override;
public:
MGE_WIN_DECLSPEC_FUC CheckNonFinite(
const VarNodeArrayView& inp, const Param& param,
const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVar make(
const VarNodeArrayView& inp, const Param& param = {},
const OperatorNodeConfig& config = {});
};
} // namespace opr
} // namespace mgb
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册