提交 a5060a2b 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(mgb/opr): add check_has_inf kernel and opr

GitOrigin-RevId: 0d042dbfce8baa51245f4189e197bf800347c6b9
上级 3597a6db
......@@ -1317,6 +1317,27 @@ protected:
TensorLayout& exec_workspace,
TensorLayout& exec_src, TensorLayout& exec_dst);
};
/*!
* \brief check whether input contains inf value.
*/
class CheckHasInf: public OperatorBase {
DEF_OPR_PARAM(Empty);
DEF_OPR_IMPL(CheckHasInf, OperatorBase, 1, 1);
public:
virtual size_t get_workspace_in_bytes(const TensorLayout &src,
const TensorLayout &dst) = 0;
void deduce_layout(const TensorLayout &src, TensorLayout &dst);
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) = 0;
protected:
void check_exec(const TensorLayout &src, const TensorLayout &dst,
size_t workspace_in_bytes);
};
} // namespace megdnn
#include "megdnn/internal/opr_header_epilogue.h"
......
/**
* \file dnn/src/common/check_has_inf.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megdnn/oprs.h"
#include "src/common/utils.h"
namespace megdnn {
void CheckHasInf::check_exec(const TensorLayout& src, const TensorLayout& dst,
size_t workspace_in_bytes) {
megdnn_assert_contiguous(src);
megdnn_assert_contiguous(dst);
megdnn_assert(src.ndim == 1);
megdnn_assert(src.dtype == dtype::Float32());
auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
void CheckHasInf::deduce_layout(const TensorLayout&, TensorLayout& dst) {
dst.shape[0] = 1;
dst.ndim = 1;
dst.dtype = dtype::Int32();
dst.init_contiguous_stride();
}
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -207,7 +207,8 @@ private:
cb(FakeQuantForward) \
cb(FakeQuantBackward) \
cb(TQTForward) \
cb(TQTBackward)
cb(TQTBackward) \
cb(CheckHasInf)
/*!
* \brief specialize HandleImpl::create_operator for a single opr type;
......
......@@ -120,6 +120,7 @@ DEF(PowC, 2, false, true);
DEF(UniformRNG, 1, true, true);
DEF(GaussianRNG, 1, true, true);
DEF(ChecksumForward, 1, true, false);
DEF(CheckHasInf, 2, true, true);
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -4,9 +4,9 @@
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/dtype.h"
......@@ -151,6 +151,33 @@ struct MaxOp {
: INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {}
};
template <typename src_ctype, typename dst_ctype, typename wtype_>
struct CheckHasInfOp {
typedef wtype_ wtype;
const wtype INIT;
src_ctype* src;
dst_ctype* dst;
const size_t B;
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) {
#if defined(__CUDA_ARCH__)
return isinf(src[idx]);
#else
return std::isinf(src[idx]);
#endif
}
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) {
dst[idx] = val;
}
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) {
return lhs | rhs;
}
MEGDNN_HOST MEGDNN_DEVICE CheckHasInfOp(src_ctype* src, dst_ctype* dst,
size_t B)
: INIT(wtype(0)), src(src), dst(dst), B(B) {}
};
#if MEGDNN_CC_HOST
void get_ABC(const TensorShape& shape, size_t& A, size_t& B, size_t& C,
size_t axis);
......
/**
* \file dnn/src/cuda/check_has_inf/kern.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/common/reduce_helper.h"
#include "megdnn/dtype.h"
#include "src/cuda/reduce_helper.cuh"
namespace megdnn {
namespace cuda {
#define COMMA ,
INST_REDUCE(reduce::CheckHasInfOp<dt_float32 COMMA dt_int32 COMMA dt_int32>, false);
#undef COMMA
} // namespace cuda
} // namespace megdnn
// vim: ft=cpp syntax=cpp.doxygen
/**
* \file dnn/src/cuda/check_has_inf/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cuda/check_has_inf/opr_impl.h"
#include "src/cuda/reduce_helper.cuh"
#include "src/cuda/handle.h"
#include "src/cuda/utils.h"
#include "src/common/reduce_helper.h"
namespace megdnn {
namespace cuda {
using reduce::CheckHasInfOp;
size_t CheckHasInfImpl::get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) {
typedef CheckHasInfOp<dt_float32, dt_int32, dt_int32> Op;
return get_reduce_workspace_in_bytes<Op>(1, src.total_nr_elems(), 1);
}
void CheckHasInfImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size);
typedef CheckHasInfOp<dt_float32, dt_int32, dt_int32> Op;
auto stream = cuda_stream(this->handle());
auto B = src.layout.total_nr_elems();
return run_reduce<Op, false>(
workspace.ptr<dt_int32>(), 1, B, 1, stream,
Op(src.ptr<dt_float32>(), dst.ptr<dt_int32>(), B));
}
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cuda/check_has_inf/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs/utils.h"
#include "src/cuda/utils.h"
namespace megdnn {
namespace cuda {
class CheckHasInfImpl final : public CheckHasInf {
public:
using CheckHasInf::CheckHasInf;
size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) override;
bool is_thread_safe() const override { return true; }
void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
};
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -17,6 +17,7 @@
#include "src/cuda/argsort/opr_impl.h"
#include "src/cuda/batch_normalization/opr_impl.h"
#include "src/cuda/batched_matrix_mul/opr_impl.h"
#include "src/cuda/check_has_inf/opr_impl.h"
#include "src/cuda/checksum/opr_impl.h"
#include "src/cuda/concat/opr_impl.h"
#include "src/cuda/cond_take/opr_impl.h"
......
......@@ -18,15 +18,15 @@ namespace cuda {
using namespace reduce;
#define COMMOA ,
#define COMMA ,
#define INST(sctype, dctype, wtype) \
INST_REDUCE(SumOp<sctype COMMOA dctype COMMOA wtype>, false); \
INST_REDUCE(SumSqrOp<sctype COMMOA dctype COMMOA wtype>, false); \
INST_REDUCE(ProdOp<sctype COMMOA dctype COMMOA wtype>, false); \
INST_REDUCE(MinOp<sctype COMMOA dctype COMMOA wtype>, false); \
INST_REDUCE(MaxOp<sctype COMMOA dctype COMMOA wtype>, false); \
INST_REDUCE(MeanOp<sctype COMMOA dctype COMMOA wtype>, false);
INST_REDUCE(SumOp<sctype COMMA dctype COMMA wtype>, false); \
INST_REDUCE(SumSqrOp<sctype COMMA dctype COMMA wtype>, false); \
INST_REDUCE(ProdOp<sctype COMMA dctype COMMA wtype>, false); \
INST_REDUCE(MinOp<sctype COMMA dctype COMMA wtype>, false); \
INST_REDUCE(MaxOp<sctype COMMA dctype COMMA wtype>, false); \
INST_REDUCE(MeanOp<sctype COMMA dctype COMMA wtype>, false);
#define cb(_dt) \
INST(DTypeTrait<_dt>::ctype, DTypeTrait<_dt>::ctype, DTypeTrait<_dt>::ctype)
......@@ -40,6 +40,7 @@ INST(int, float, float)
#undef cb
#undef INST
#undef COMMA
} // namespace cuda
} // namespace megdnn
......
/**
* \file dnn/src/naive/check_has_inf/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/naive/check_has_inf/opr_impl.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
namespace {
using namespace megdnn;
#define src_ctype dt_float32
#define wtype dt_int32
void reduce_fwd(const src_ctype* sptr, wtype* dptr, size_t size) {
std::function<wtype(size_t, size_t)> func;
func = [&](size_t l, size_t r) -> wtype {
if (l + 1 < r) {
size_t mid = l + (r - l) / 2;
return func(l, mid) | func(mid, r);
} else {
return static_cast<wtype>(std::isinf(sptr[l]));
}
};
dptr[0] = func(0, size);
}
} // namespace
namespace megdnn {
namespace naive {
size_t CheckHasInfImpl::get_workspace_in_bytes(const TensorLayout&,
const TensorLayout&) {
return 0;
}
void CheckHasInfImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size);
auto handle = static_cast<HandleImpl*>(this->handle());
MEGDNN_DISPATCH_CPU_KERN(
handle, reduce_fwd(src.ptr<dt_float32>(), dst.ptr<dt_int32>(),
src.layout.total_nr_elems()));
}
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/naive/check_has_inf/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megdnn/oprs.h"
namespace megdnn {
namespace naive {
class CheckHasInfImpl final : public CheckHasInf {
public:
using CheckHasInf::CheckHasInf;
bool is_thread_safe() const override { return true; }
size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst) override;
void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
};
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -21,6 +21,7 @@
#include "src/naive/batch_conv_bias/opr_impl.h"
#include "src/naive/batch_normalization/opr_impl.h"
#include "src/naive/batched_matrix_mul/opr_impl.h"
#include "src/naive/check_has_inf/opr_impl.h"
#include "src/naive/checksum/opr_impl.h"
#include "src/naive/concat/opr_impl.h"
#include "src/naive/cond_take/opr_impl.h"
......
......@@ -18,15 +18,15 @@ namespace rocm {
using namespace reduce;
#define COMMOA ,
#define COMMA ,
#define INST(sctype, dctype, wtype) \
INST_REDUCE(SumOp<sctype COMMOA dctype COMMOA wtype>, false); \
INST_REDUCE(SumSqrOp<sctype COMMOA dctype COMMOA wtype>, false); \
INST_REDUCE(ProdOp<sctype COMMOA dctype COMMOA wtype>, false); \
INST_REDUCE(MinOp<sctype COMMOA dctype COMMOA wtype>, false); \
INST_REDUCE(MaxOp<sctype COMMOA dctype COMMOA wtype>, false); \
INST_REDUCE(MeanOp<sctype COMMOA dctype COMMOA wtype>, false);
INST_REDUCE(SumOp<sctype COMMA dctype COMMA wtype>, false); \
INST_REDUCE(SumSqrOp<sctype COMMA dctype COMMA wtype>, false); \
INST_REDUCE(ProdOp<sctype COMMA dctype COMMA wtype>, false); \
INST_REDUCE(MinOp<sctype COMMA dctype COMMA wtype>, false); \
INST_REDUCE(MaxOp<sctype COMMA dctype COMMA wtype>, false); \
INST_REDUCE(MeanOp<sctype COMMA dctype COMMA wtype>, false);
#define cb(_dt) \
INST(DTypeTrait<_dt>::ctype, DTypeTrait<_dt>::ctype, DTypeTrait<_dt>::ctype)
......@@ -39,6 +39,7 @@ INST(float, dt_float16, float)
INST(int, float, float)
#undef cb
#undef INST
#undef COMMA
} // namespace rocm
} // namespace megdnn
......
......@@ -23,7 +23,7 @@ namespace {
::testing::AssertionResult assert_tensor_eq_with_iter(
const char *expr0, const char *expr1,
Iter it0, Iter it1, const TensorLayout &layout,
float maxerr, float maxerr_avg, float maxerr_avg_biased) {
float maxerr, float maxerr_avg, float maxerr_avg_biased, bool allow_invalid) {
auto nr_elem = layout.total_nr_elems();
double error_sum = 0;
......@@ -33,8 +33,8 @@ namespace {
float err = diff(iv0, iv1);
error_sum += std::abs(err);
error_sum_biased += err;
if (!good_float(iv0) || !good_float(iv1) ||
std::abs(err) > maxerr) {
if (!allow_invalid && (!good_float(iv0) || !good_float(iv1) ||
std::abs(err) > maxerr)) {
Index index(layout, i);
return ::testing::AssertionFailure()
<< "Unequal value\n"
......@@ -82,14 +82,14 @@ namespace {
::testing::AssertionResult assert_tensor_eq_with_dtype(
const char *expr0, const char *expr1,
const TensorND &v0, const TensorND &v1,
float maxerr, float maxerr_avg, float maxerr_avg_biased) {
float maxerr, float maxerr_avg, float maxerr_avg_biased, bool allow_invalid) {
if (!std::is_same<ctype, dt_qint4>::value &&
!std::is_same<ctype, dt_quint4>::value) {
if (v0.layout.is_physical_contiguous() &&
v1.layout.is_physical_contiguous()) {
return assert_tensor_eq_with_iter<ctype>(
expr0, expr1, v0.ptr<ctype>(), v1.ptr<ctype>(),
v0.layout, maxerr, maxerr_avg, maxerr_avg_biased);
v0.layout, maxerr, maxerr_avg, maxerr_avg_biased, allow_invalid);
}
}
......@@ -98,7 +98,7 @@ namespace {
return assert_tensor_eq_with_iter<ctype>(expr0, expr1, it0, it1,
v0.layout, maxerr, maxerr_avg,
maxerr_avg_biased);
maxerr_avg_biased, allow_invalid);
}
template<class Impl>
......@@ -136,7 +136,7 @@ namespace {
const char* /*expr_maxerr_avg*/,
const char* /*expr_maxerr_avg*/,
const TensorND &v0, const TensorND &v1,
float maxerr, float maxerr_avg, float maxerr_avg_biased) {
float maxerr, float maxerr_avg, float maxerr_avg_biased, bool allow_invalid) {
if (!v0.layout.eq_shape(v1.layout)) {
return ::testing::AssertionFailure()
......@@ -160,7 +160,7 @@ namespace {
#define cb(_dt) \
case DTypeTrait<_dt>::enumv: \
return assert_tensor_eq_with_dtype<DTypeTrait<_dt>::ctype>( \
expr0, expr1, v0, v1, maxerr, maxerr_avg, maxerr_avg_biased);
expr0, expr1, v0, v1, maxerr, maxerr_avg, maxerr_avg_biased, allow_invalid);
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
//! In order to avoid an unnecessary increase in binary size, we just
......@@ -174,6 +174,17 @@ namespace {
}
::testing::AssertionResult test::__assert_tensor_eq_allow_invalid(
const char* expr0, const char* expr1, const char* expr_maxerr,
const char* expr_maxerr_avg, const char* expr_maxerr_avg_biased,
const TensorND& v0, const TensorND& v1, float maxerr, float maxerr_avg,
float maxerr_avg_biased) {
return __assert_tensor_eq(expr0, expr1, expr_maxerr, expr_maxerr_avg,
expr_maxerr_avg_biased, v0, v1, maxerr,
maxerr_avg, maxerr_avg_biased, true);
};
CheckerHelper::CheckerHelper(Handle *handle, bool check_dispatch):
m_handle_cur(handle),
m_default_rng(new NormalRNG())
......@@ -411,9 +422,15 @@ void CheckerHelper::check_tensors(const TensorValueArray& expected,
for (size_t i = 0; i < expected.size(); ++i) {
if (expected[i].layout.ndim == 0)
continue;
MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG(expected[i], computed[i], m_epsilon,
m_max_avg_error,
m_max_avg_biased_error);
if (m_allow_invalid_check) {
MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG_ALLOW_INVALID(
expected[i], computed[i], m_epsilon, m_max_avg_error,
m_max_avg_biased_error);
} else {
MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG(expected[i], computed[i], m_epsilon,
m_max_avg_error,
m_max_avg_biased_error);
}
}
}
......
......@@ -79,6 +79,7 @@ protected:
bool m_no_naive_and_check = false;
bool m_stable_check = false;
bool m_force_deduce_dst = true;
bool m_allow_invalid_check = false;
/**
* the offset from the start of malloc memory
*
......@@ -248,6 +249,11 @@ public:
return *this;
}
Checker& set_allow_invalid_check(bool allow_invalid_check) {
m_allow_invalid_check = allow_invalid_check;
return *this;
}
//! load input tensors from file for next run
Checker& load_input_tensors(const char* fpath) {
m_input_tensors_fpath = fpath;
......@@ -326,6 +332,12 @@ private:
};
::testing::AssertionResult __assert_tensor_eq(
const char* expr0, const char* expr1, const char* expr_maxerr,
const char* expr_maxerr_avg, const char* expr_maxerr_avg_biased,
const TensorND& v0, const TensorND& v1, float maxerr, float maxerr_avg,
float maxerr_avg_biased, bool allow_invalid = false);
::testing::AssertionResult __assert_tensor_eq_allow_invalid(
const char* expr0, const char* expr1, const char* expr_maxerr,
const char* expr_maxerr_avg, const char* expr_maxerr_avg_biased,
const TensorND& v0, const TensorND& v1, float maxerr, float maxerr_avg,
......@@ -336,6 +348,11 @@ private:
ASSERT_PRED_FORMAT5(::megdnn::test::__assert_tensor_eq, v0, v1, maxerr, \
maxerr_avg, maxerr_avg_biased)
#define MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG_ALLOW_INVALID( \
v0, v1, maxerr, maxerr_avg, maxerr_avg_biased) \
ASSERT_PRED_FORMAT5(::megdnn::test::__assert_tensor_eq_allow_invalid, v0, \
v1, maxerr, maxerr_avg, maxerr_avg_biased)
#define MEGDNN_ASSERT_TENSOR_EQ_EPS(v0, v1, maxerr) \
MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG(v0, v1, maxerr, maxerr, maxerr)
......@@ -435,7 +452,7 @@ TensorND TensorValue(const TensorShape& shape, T dtype,
template <typename T, typename U>
TensorND TensorValueLowbit4(const TensorShape& shape, T dtype,
std::vector<U> values) {
std::vector<U> values) {
TensorND tensor;
tensor.layout = {shape, dtype};
tensor.raw_ptr =
......
......@@ -38,6 +38,22 @@ struct ExecProxy<Opr, 8, true> {
}
};
template <typename Opr>
struct ExecProxy<Opr, 7, true> {
WorkspaceWrapper W;
void exec(Opr* opr, const TensorNDArray& tensors) {
if (!W.valid()) {
W = WorkspaceWrapper(opr->handle(), 0);
}
W.update(opr->get_workspace_in_bytes(
tensors[0].layout, tensors[1].layout, tensors[2].layout,
tensors[3].layout, tensors[4].layout, tensors[5].layout,
tensors[6].layout));
opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
tensors[5], tensors[6], W.workspace());
}
};
template <typename Opr>
struct ExecProxy<Opr, 6, true> {
WorkspaceWrapper W;
......@@ -149,24 +165,6 @@ struct ExecProxy<Opr, 2, false> {
}
};
template <typename Opr>
struct ExecProxy<Opr, 7, true> {
WorkspaceWrapper W;
void exec(Opr* opr, const TensorNDArray& tensors) {
if (!W.valid()) {
W = WorkspaceWrapper(opr->handle(), 0);
}
W.update(opr->get_workspace_in_bytes(
tensors[0].layout, tensors[1].layout, tensors[2].layout,
tensors[3].layout, tensors[4].layout, tensors[5].layout,
tensors[6].layout));
opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4],
tensors[5], tensors[6], W.workspace());
}
};
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -301,9 +301,8 @@ void UniformFloatNonZeroRNG::fill_fast_float32(dt_float32* dest, size_t size) {
}
}
void UniformFloatWithZeroRNG::fill_fast_float32(dt_float32 *dest, size_t size) {
void UniformFloatWithValueRNG::fill_fast_float32(dt_float32 *dest, size_t size) {
RNGxorshf gen{RandomState::generator()};
printf("a %f, b %f \n", m_dist.a(), m_dist.b());
auto k = double(m_dist.b() - m_dist.a()) /
double(RNGxorshf::max() - RNGxorshf::min() + 1.0);
auto b = m_dist.a() - RNGxorshf::min() * k;
......@@ -312,9 +311,8 @@ void UniformFloatWithZeroRNG::fill_fast_float32(dt_float32 *dest, size_t size) {
auto pb = 0.f - RNGxorshf::min() * p;
for (size_t i = 0; i < size; ++ i) {
float rnd = gen() * p + pb;
//printf("%.3f \n", rnd);
if(rnd < zero_val_proportion_) {
dest[i] = 0.f;
if(rnd < val_proportion_) {
dest[i] = val_;
} else {
dest[i] = gen() * k + b;
}
......
......@@ -11,10 +11,10 @@
#pragma once
#include "megdnn/dtype.h"
#include "test/common/utils.h"
#include "test/common/random_state.h"
#include <random>
#include <set>
#include "test/common/random_state.h"
#include "test/common/utils.h"
namespace megdnn {
namespace test {
......@@ -80,7 +80,8 @@ public:
}
void gen(const TensorND& tensor) override {
megdnn_assert(tensor.layout.dtype.enumv() == DTypeTrait<dt_bfloat16>::enumv);
megdnn_assert(tensor.layout.dtype.enumv() ==
DTypeTrait<dt_bfloat16>::enumv);
size_t nr_elems = tensor.layout.span().dist_elem();
auto offset = tensor.layout.span().low_elem;
for (size_t i = 0; i < nr_elems; ++i) {
......@@ -185,24 +186,31 @@ public:
void fill_fast_float32(dt_float32* dest, size_t size) override;
};
class UniformFloatWithZeroRNG final : public UniformFloatRNG {
class UniformFloatWithValueRNG : public UniformFloatRNG {
public:
UniformFloatWithZeroRNG(dt_float32 a, dt_float32 b,
float zero_val_proportion)
: UniformFloatRNG(a, b) {
if (zero_val_proportion < 0.f)
zero_val_proportion_ = 0.f;
else if (zero_val_proportion > 1.f)
zero_val_proportion_ = 1.f;
UniformFloatWithValueRNG(dt_float32 a, dt_float32 b, float val_proportion,
float val)
: UniformFloatRNG(a, b), val_(val) {
if (val_proportion < 0.f)
val_proportion_ = 0.f;
else if (val_proportion > 1.f)
val_proportion_ = 1.f;
else
zero_val_proportion_ = zero_val_proportion;
val_proportion_ = val_proportion;
}
private:
float zero_val_proportion_;
float val_proportion_, val_;
void fill_fast_float32(dt_float32* dest, size_t size) override;
};
class UniformFloatWithZeroRNG final : public UniformFloatWithValueRNG {
public:
UniformFloatWithZeroRNG(dt_float32 a, dt_float32 b,
float zero_val_proportion)
: UniformFloatWithValueRNG(a, b, zero_val_proportion, 0.f) {}
};
class BernoulliRNG final : public IIDRNG {
public:
BernoulliRNG(dt_float32 probability_);
......
/**
* \file dnn/test/cuda/check_has_inf.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megdnn/oprs.h"
#include "test/common/checker.h"
#include "test/cuda/fixture.h"
namespace megdnn {
namespace test {
TEST_F(CUDA, CHECK_HAS_INF_BASIC) {
Checker<CheckHasInf> checker(handle_cuda());
checker.set_allow_invalid_check(true);
const auto inf = std::numeric_limits<float>::infinity();
UniformFloatWithValueRNG rng(-1.0f, 1.0f, 0.1f, inf);
checker.set_rng(0, &rng);
checker.execs({{512*16}, {1}});
rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 1.f, inf);
checker.set_rng(0, &rng);
checker.execs({{512*16}, {1}});
}
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file test/naive/check_has_inf.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "test/naive/fixture.h"
#include "megdnn/oprs.h"
#include "test/common/checker.h"
namespace megdnn {
namespace test {
TEST_F(NAIVE, CHECK_HAS_INF_BASIC) {
Checker<CheckHasInf> checker(handle(), false);
checker.exect(Testcase{TensorValue({4}, dtype::Float32(),
{1.1, 2.2, 3.3, 4.3}),
{}},
Testcase{{}, TensorValue({1}, dtype::Int32(), {0})});
checker.exect(
Testcase{TensorValue({4}, dtype::Float32(),
{1.1f, 2.2f, 3.3f,
std::numeric_limits<float>::infinity()}),
{}},
Testcase{{}, TensorValue({1}, dtype::Int32(), {1})});
}
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -959,3 +959,16 @@ def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor:
op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv)
U, sigma, V = apply(op, inp)
return U, sigma, V
def _has_inf(inp: Tensor) -> Tensor:
"""
Check whether input contains infinite value.
:param inp: a tensor to be checked.
:return: a int32 scalar tensor, 0 for False and 1 for True.
"""
op = builtin.CheckHasInf()
(oup,) = apply(op, inp.reshape(-1).astype("float32"))
oup._setscalar()
return oup
......@@ -157,3 +157,14 @@ def test_sum_neg_axis():
np.testing.assert_allclose(get.numpy(), ref, rtol=1e-6)
with pytest.raises(AssertionError):
F.sum(tensor(data), axis=(-1, 1))
def test_has_inf():
shape = (32, 3, 32, 32)
data = np.random.random(shape).astype(np.float32)
rst = F.math._has_inf(tensor(data))
np.testing.assert_equal(rst.numpy(), [0])
data[0][0][0][0] = float("inf")
rst = F.math._has_inf(tensor(data))
np.testing.assert_equal(rst.numpy(), [1])
/**
* \file imperative/src/impl/ops/tensor_manip.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "../op_trait.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/misc.h"
namespace mgb {
namespace imperative {
namespace check_has_inf {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<CheckHasInf>();
mgb_assert(inputs.size() == 1);
OperatorNodeConfig config{op.make_name()};
return opr::CheckHasInf::make(inputs[0], {}, config);
}
OP_TRAIT_REG(CheckHasInf, CheckHasInf)
.apply_on_var_node(apply_on_var_node)
.fallback();
} // namespace check_has_inf
} // namespace imperative
} // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -307,4 +307,6 @@ def CambriconRuntime: MgbHashableOp<"CambriconRuntime"> {
def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>;
def CheckHasInf: MgbHashableOp<"CheckHasInf", [EmptyParam]>;
#endif // MGB_OPS
......@@ -437,4 +437,19 @@ MGB_IMPL_OPR_GRAD(TopK) {
}
#endif
/* ================= CheckHasInf ================= */
namespace mgb {
namespace opr {
namespace intl {
template<>
struct MegDNNOprInitPostCtor<CheckHasInf> {
static void apply(cg::OperatorNodeBase &opr) {
opr.output(0)->dtype(dtype::Int32());
}
};
}
}
}
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CheckHasInf);
MEGDNN_OPR_INIT1(CheckHasInf, "check_has_inf")
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -73,6 +73,7 @@ namespace opr {
#if MGB_CUDA
MGB_SEREG_OPR(NvOf, 1);
#endif
MGB_SEREG_OPR(CheckHasInf, 1);
} // namespace opr
} // namespace mgb
......
......@@ -178,6 +178,8 @@ public:
const OperatorNodeConfig& config = {});
};
MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(CheckHasInf);
} // namespace opr
} // namespace mgb
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册