提交 2881934c 编写于 作者: M Megvii Engine Team

feat(dnn/check_non_finite): addmul scale to check_non_finite opr

GitOrigin-RevId: c35a219e52f852b37a311c6de505c6eb08e2ab8a
上级 b8ccc6a2
...@@ -1344,7 +1344,7 @@ protected: ...@@ -1344,7 +1344,7 @@ protected:
* \brief check whether input contains inf or nan value. * \brief check whether input contains inf or nan value.
*/ */
class CheckNonFinite : public OperatorBase { class CheckNonFinite : public OperatorBase {
DEF_OPR_PARAM(Empty); DEF_OPR_PARAM(CheckNonFinite);
DEF_OPR_IMPL(CheckNonFinite, OperatorBase, -1, 1); DEF_OPR_IMPL(CheckNonFinite, OperatorBase, -1, 1);
size_t m_size = 0; size_t m_size = 0;
......
...@@ -1176,6 +1176,8 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o ...@@ -1176,6 +1176,8 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o
) )
pdef('Fill').add_fields('float32', 'value', '0') pdef('Fill').add_fields('float32', 'value', '0')
pdef('CheckNonFinite').add_fields('float32', 'scale', '1.0')
PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'),
Doc('REFLECT = 1', 'fedcba|abcdefgh|hgfedcb'), Doc('REFLECT = 1', 'fedcba|abcdefgh|hgfedcb'),
......
...@@ -156,37 +156,6 @@ struct MaxOp<src_ctype, dst_ctype, dt_float32> { ...@@ -156,37 +156,6 @@ struct MaxOp<src_ctype, dst_ctype, dt_float32> {
: INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {} : INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {}
}; };
template <typename src_ctype, typename index_ctype, typename dst_ctype, typename wtype_>
struct CheckNonFiniteOp {
typedef wtype_ wtype;
const wtype INIT;
RefPtr* srcs;
RefPtr srcs_total_nr_elems;
RefPtr dst;
const size_t B;
wtype read(uint32_t idx) {
size_t x = idx / B;
size_t y = idx % B;
if (y < srcs_total_nr_elems.ptr<index_ctype>()[x]) {
RefPtr src = srcs[x];
return !std::isfinite(src.ptr<src_ctype>()[y]);
}
return 0;
}
void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; }
static wtype apply(wtype lhs, wtype rhs) { return lhs | rhs; }
CheckNonFiniteOp(
RefPtr* srcs, const RefPtr& srcs_total_nr_elems, const RefPtr& dst,
size_t B)
: INIT(wtype(0)),
srcs(srcs),
srcs_total_nr_elems(srcs_total_nr_elems),
dst(dst),
B(B) {}
};
void get_ABC(const TensorShape& shape, size_t& A, size_t& B, size_t& C, size_t axis); void get_ABC(const TensorShape& shape, size_t& A, size_t& B, size_t& C, size_t axis);
} // namespace reduce } // namespace reduce
......
...@@ -194,6 +194,7 @@ struct CheckNonFiniteOp { ...@@ -194,6 +194,7 @@ struct CheckNonFiniteOp {
index_ctype* srcs_total_nr_elems; index_ctype* srcs_total_nr_elems;
dst_ctype* dst; dst_ctype* dst;
const size_t B; const size_t B;
const src_ctype scale;
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) {
size_t x = idx / B; size_t x = idx / B;
...@@ -204,6 +205,8 @@ struct CheckNonFiniteOp { ...@@ -204,6 +205,8 @@ struct CheckNonFiniteOp {
#else #else
wtype val = std::isfinite(srcs[x][y]); wtype val = std::isfinite(srcs[x][y]);
#endif #endif
if (val)
srcs[x][y] *= scale;
return !val; return !val;
} }
return 0; return 0;
...@@ -214,12 +217,13 @@ struct CheckNonFiniteOp { ...@@ -214,12 +217,13 @@ struct CheckNonFiniteOp {
} }
MEGDNN_HOST MEGDNN_DEVICE CheckNonFiniteOp( MEGDNN_HOST MEGDNN_DEVICE CheckNonFiniteOp(
src_ctype** srcs, index_ctype* srcs_total_nr_elems, dst_ctype* dst, src_ctype** srcs, index_ctype* srcs_total_nr_elems, dst_ctype* dst,
size_t B) size_t B, src_ctype scale)
: INIT(wtype(0)), : INIT(wtype(0)),
srcs(srcs), srcs(srcs),
srcs_total_nr_elems(srcs_total_nr_elems), srcs_total_nr_elems(srcs_total_nr_elems),
dst(dst), dst(dst),
B(B) {} B(B),
scale(scale) {}
}; };
} // namespace device_reduce } // namespace device_reduce
......
...@@ -97,7 +97,7 @@ void CheckNonFiniteImpl::exec( ...@@ -97,7 +97,7 @@ void CheckNonFiniteImpl::exec(
workspace_gpu.total_size_in_bytes())), workspace_gpu.total_size_in_bytes())),
1, m_size * total_nr_elems_max, 1, stream, 1, m_size * total_nr_elems_max, 1, stream,
Op(srcs_gpu, srcs_total_nr_elems_gpu, dst.ptr<dt_int32>(), Op(srcs_gpu, srcs_total_nr_elems_gpu, dst.ptr<dt_int32>(),
total_nr_elems_max)); total_nr_elems_max, param().scale));
} }
} // namespace cuda } // namespace cuda
......
...@@ -19,7 +19,7 @@ using namespace megdnn; ...@@ -19,7 +19,7 @@ using namespace megdnn;
#define wtype dt_int32 #define wtype dt_int32
void reduce_fwd(const TensorNDArray& srcs, wtype* dptr) { void reduce_fwd(const TensorNDArray& srcs, wtype* dptr, dt_float32 scale) {
dptr[0] = 0; dptr[0] = 0;
for (auto src : srcs) { for (auto src : srcs) {
auto sptr = src.ptr<dt_float32>(); auto sptr = src.ptr<dt_float32>();
...@@ -31,6 +31,8 @@ void reduce_fwd(const TensorNDArray& srcs, wtype* dptr) { ...@@ -31,6 +31,8 @@ void reduce_fwd(const TensorNDArray& srcs, wtype* dptr) {
return func(l, mid) | func(mid, r); return func(l, mid) | func(mid, r);
} else { } else {
auto val = std::isfinite(sptr[l]); auto val = std::isfinite(sptr[l]);
if (val)
sptr[l] *= scale;
return static_cast<wtype>(!val); return static_cast<wtype>(!val);
} }
}; };
...@@ -47,9 +49,9 @@ void CheckNonFiniteImpl::exec( ...@@ -47,9 +49,9 @@ void CheckNonFiniteImpl::exec(
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst, _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst,
_megdnn_workspace workspace) { _megdnn_workspace workspace) {
check_exec(srcs, dst, workspace.size); check_exec(srcs, dst, workspace.size);
float scale = param().scale;
auto handle = static_cast<HandleImpl*>(this->handle()); auto handle = static_cast<HandleImpl*>(this->handle());
MEGDNN_DISPATCH_CPU_KERN(handle, reduce_fwd(srcs, dst.ptr<dt_int32>())); MEGDNN_DISPATCH_CPU_KERN(handle, reduce_fwd(srcs, dst.ptr<dt_int32>(), scale));
} }
} // namespace naive } // namespace naive
} // namespace megdnn } // namespace megdnn
......
...@@ -128,28 +128,28 @@ class GradScaler: ...@@ -128,28 +128,28 @@ class GradScaler:
grad_tensors: Tensors needed to unscale grads. Should be all tensors grad_tensors: Tensors needed to unscale grads. Should be all tensors
that are affected by ``target`` tensor in GradManager's backward. that are affected by ``target`` tensor in GradManager's backward.
""" """
# to support tracing, _check_gradients should be applied to every grad. if self.growth_interval == 0:
if self._check_gradients([x.grad for x in grad_tensors]):
self._found_non_finite = True
if self._found_non_finite:
for tensor in grad_tensors:
if tensor is None or getattr(tensor, "grad", None) is None:
continue
tensor.grad = None
else:
# use float64 for better precision # use float64 for better precision
inv_scale = Tensor(1.0 / self.scale_factor) inv_scale = Tensor(1.0 / self.scale_factor)
for tensor in grad_tensors: for tensor in grad_tensors:
if tensor is None or getattr(tensor, "grad", None) is None: if tensor is None or getattr(tensor, "grad", None) is None:
continue continue
tensor.grad *= inv_scale tensor.grad *= inv_scale
return self
# to support tracing, _check_gradients should be applied to every grad.
if self._check_gradients(
[x.grad for x in grad_tensors], 1.0 / self.scale_factor
):
self._found_non_finite = True
for tensor in grad_tensors:
if tensor is None or getattr(tensor, "grad", None) is None:
continue
tensor.grad = None
return self return self
def _check_gradients(self, grad): def _check_gradients(self, grad, scale):
if self.growth_interval == 0: return _check_non_finite(grad, scale)
return False
return _check_non_finite(grad)
def update(self, new_scale: float = None): def update(self, new_scale: float = None):
r"""Update the scale factor according to whether encountered overflow grad. r"""Update the scale factor according to whether encountered overflow grad.
......
...@@ -1183,7 +1183,7 @@ def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor: ...@@ -1183,7 +1183,7 @@ def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor:
return U, sigma, V return U, sigma, V
def _check_non_finite(inps: Iterable[Tensor]) -> Tensor: def _check_non_finite(inps: Iterable[Tensor], scale=1.0) -> Tensor:
r"""Check whether input contains infinite or nan value. r"""Check whether input contains infinite or nan value.
Args: Args:
...@@ -1192,7 +1192,11 @@ def _check_non_finite(inps: Iterable[Tensor]) -> Tensor: ...@@ -1192,7 +1192,11 @@ def _check_non_finite(inps: Iterable[Tensor]) -> Tensor:
Returns: Returns:
a int32 scalar tensor, 0 for False and 1 for True. a int32 scalar tensor, 0 for False and 1 for True.
""" """
op = builtin.CheckNonFinite() op = builtin.CheckNonFinite(scale=scale)
(oup,) = apply(op, *inps) oups = apply(op, *inps)
oup._setscalar() out = oups[-1]
return oup for i in range(len(inps)):
inps[i]._reset(oups[i])
out._setscalar()
return out
...@@ -191,17 +191,21 @@ def test_sum_neg_axis(): ...@@ -191,17 +191,21 @@ def test_sum_neg_axis():
def test_non_finite(): def test_non_finite():
shape = (32, 3, 32, 32) shape = (32, 3, 32, 32)
data1 = np.random.random(shape).astype(np.float32) data = []
data2 = np.random.random(shape).astype(np.float32) for i in range(2):
rst = F.math._check_non_finite([tensor(data1), tensor(data2)]) data.append(np.random.random(shape).astype(np.float32))
tensorList = [tensor(x) for x in data]
rst = F.math._check_non_finite(tensorList, 0.7)
np.testing.assert_equal(rst.numpy(), [0]) np.testing.assert_equal(rst.numpy(), [0])
for i in range(len(tensorList)):
np.testing.assert_allclose(tensorList[i].numpy() / 0.7, data[i], rtol=1e-6)
data2[0][0][0][0] = float("inf") data[1][0][0][0][0] = float("inf")
rst = F.math._check_non_finite([tensor(data1), tensor(data2)]) rst = F.math._check_non_finite([tensor(x) for x in data], 0.7)
np.testing.assert_equal(rst.numpy(), [1]) np.testing.assert_equal(rst.numpy(), [1])
data2[0][0][0][0] = float("nan") data[1][0][0][0][0] = float("nan")
rst = F.math._check_non_finite([tensor(data1), tensor(data2)]) rst = F.math._check_non_finite([tensor(x) for x in data], 0.7)
np.testing.assert_equal(rst.numpy(), [1]) np.testing.assert_equal(rst.numpy(), [1])
......
...@@ -17,44 +17,62 @@ namespace mgb { ...@@ -17,44 +17,62 @@ namespace mgb {
namespace imperative { namespace imperative {
namespace check_non_finite { namespace check_non_finite {
SymbolVar apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { SymbolVarArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<CheckNonFinite>(); auto&& op = def.cast_final_safe<CheckNonFinite>();
OperatorNodeConfig config{op.make_name()}; OperatorNodeConfig config{op.make_name()};
return opr::CheckNonFinite::make(inputs, {}, config); return opr::CheckNonFinite::make(inputs, op.param(), config);
} }
SmallVector<TensorPtr> apply_on_physical_tensor( SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs) { const OpDef& def, const SmallVector<TensorPtr>& inputs) {
size_t size = inputs.size(); size_t size = inputs.size();
auto&& op = def.cast_final_safe<CheckNonFinite>();
auto dest = Tensor::make( SmallVector<TensorPtr> outputs(size + 1);
outputs[size] = Tensor::make(
TensorLayout(TensorShape({1}), dtype::Int32()), inputs[0]->comp_node()); TensorLayout(TensorShape({1}), dtype::Int32()), inputs[0]->comp_node());
auto dest = outputs[size];
auto cn = dest->comp_node(); auto cn = dest->comp_node();
auto&& dnn_opr = opr::intl::create_megdnn_opr<megdnn::CheckNonFinite>(cn); auto&& dnn_opr = opr::intl::create_megdnn_opr<megdnn::CheckNonFinite>(cn);
size_t wk_size = 0; size_t wk_size = 0;
SmallVector<megdnn::TensorND> srcs(size); SmallVector<megdnn::TensorND> srcs(size);
// copy an outputs to the dnn for inplace
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
srcs[i] = inputs[i]->dev_tensor().as_megdnn(); outputs[i] = Tensor::make(inputs[i]->layout(), inputs[0]->comp_node());
outputs[i]->dev_tensor().copy_from_fixlayout(inputs[i]->dev_tensor());
srcs[i] = outputs[i]->dev_tensor().as_megdnn();
} }
megdnn::CheckNonFinite::Param param({op.scale});
dnn_opr->param() = param;
wk_size = dnn_opr->get_workspace_in_bytes(srcs, dest->layout()); wk_size = dnn_opr->get_workspace_in_bytes(srcs, dest->layout());
auto wk = Blob::make(cn, wk_size); auto wk = Blob::make(cn, wk_size);
megdnn::Workspace dnn_wk(wk->storage().get(), wk_size); megdnn::Workspace dnn_wk(wk->storage().get(), wk_size);
dnn_opr->exec(srcs, dest->dev_tensor().as_megdnn(), dnn_wk); dnn_opr->exec(srcs, dest->dev_tensor().as_megdnn(), dnn_wk);
return {dest}; return outputs;
} }
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
SmallVector<LogicalTensorDesc> dests(1); size_t size = inputs.size();
dests[0].comp_node = inputs[0].comp_node; SmallVector<LogicalTensorDesc> dests(size + 1);
dests[0].layout = TensorLayout(TensorShape({1}), dtype::Int32()); for (size_t i = 0; i < size; ++i) {
dests[i].comp_node = inputs[i].comp_node;
dests[i].layout = inputs[i].layout;
}
dests[size].comp_node = inputs[0].comp_node;
dests[size].layout = TensorLayout(TensorShape({1}), dtype::Int32());
return {dests, true}; return {dests, true};
} }
SmallVector<LogicalTensorDesc> infer_output_attrs( SmallVector<LogicalTensorDesc> infer_output_attrs(
const OpDef& def, const SmallVector<TensorPtr>& inputs) { const OpDef& def, const SmallVector<TensorPtr>& inputs) {
SmallVector<LogicalTensorDesc> dests(1); size_t size = inputs.size();
dests[0].comp_node = inputs[0]->comp_node(); SmallVector<LogicalTensorDesc> dests(size + 1);
dests[0].layout = TensorLayout(TensorShape({1}), dtype::Int32()); for (size_t i = 0; i < size; ++i) {
dests[i].comp_node = inputs[i]->comp_node();
dests[i].layout = inputs[i]->layout();
}
dests[size].comp_node = inputs[0]->comp_node();
dests[size].layout = TensorLayout(TensorShape({1}), dtype::Int32());
return dests; return dests;
} }
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
......
...@@ -397,7 +397,7 @@ def MagicMindRuntime: MgbHashableOp<"MagicMindRuntime"> { ...@@ -397,7 +397,7 @@ def MagicMindRuntime: MgbHashableOp<"MagicMindRuntime"> {
def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>; def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>;
def CheckNonFinite: MgbHashableOp<"CheckNonFinite", [EmptyParam]>; def CheckNonFinite: MgbHashableOp<"CheckNonFinite", [CheckNonFiniteParam]>;
def FastpathCopy: MgbHashableOp<"FastpathCopy">; def FastpathCopy: MgbHashableOp<"FastpathCopy">;
......
...@@ -487,39 +487,60 @@ CheckNonFinite::CheckNonFinite( ...@@ -487,39 +487,60 @@ CheckNonFinite::CheckNonFinite(
const VarNodeArrayView& inp, const Param& param, const VarNodeArrayView& inp, const Param& param,
const OperatorNodeConfig& config) const OperatorNodeConfig& config)
: Super(OperatorNodeBaseCtorParam{ : Super(OperatorNodeBaseCtorParam{
inp[0]->owner_graph(), config, "check_non_finite", inp}) { inp[0]->owner_graph(), config, "check_non_finite", inp}),
m_scale(param.scale) {
mgb_assert(!inp.empty()); mgb_assert(!inp.empty());
for (auto&& i : inp) { for (auto&& i : inp) {
add_input({i}); add_input({i});
add_output(None)
->dtype(dtype::Float32())
.add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
} }
add_output(None)->dtype(dtype::Int32()).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); add_output(None)->dtype(dtype::Int32()).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
cg::add_workspace_output(this); cg::add_workspace_output(this);
} }
SymbolVar CheckNonFinite::make( SymbolVarArray CheckNonFinite::make(
const VarNodeArrayView& inp, const Param& param, const VarNodeArrayView& inp, const Param& param,
const OperatorNodeConfig& config) { const OperatorNodeConfig& config) {
mgb_assert(!inp.empty()); mgb_assert(!inp.empty());
intl::BatchedDTypePromotion dtp{inp}; intl::BatchedDTypePromotion dtp{inp};
return SymbolVar{inp[0]}.insert_single_output_opr<CheckNonFinite>( auto outputs =
dtp.get_vars(), param, config); inp[0]->owner_graph()
->insert_opr(std::make_unique<CheckNonFinite>(inp, param, config))
->output();
mgb_assert(outputs.size() == inp.size() + 2);
SymbolVarArray ret(outputs.size() - 1);
for (size_t i = 0; i < ret.size(); ++i)
ret[i] = outputs[i];
return ret;
} }
void CheckNonFinite::scn_do_execute() { void CheckNonFinite::scn_do_execute() {
megdnn::TensorNDArray inp_arr(input().size()); size_t size = input().size();
for (size_t i = 0; i < input().size(); ++i) { megdnn::TensorNDArray oup_arr(size);
inp_arr[i] = input()[i]->dev_tensor().as_megdnn(); // copy an outputs to the dnn for inplace
for (size_t i = 0; i < size; ++i) {
oup_arr[i] = output(i)
->dev_tensor()
.copy_from_fixlayout(input(i)->dev_tensor())
.as_megdnn();
} }
megdnn_opr()->param().scale = m_scale;
megdnn_opr()->exec( megdnn_opr()->exec(
inp_arr, output(0)->dev_tensor().as_megdnn(), oup_arr, output(size)->dev_tensor().as_megdnn(),
intl::get_megdnn_workspace_from_var(output(1))); intl::get_megdnn_workspace_from_var(output(size + 1)));
} }
void CheckNonFinite::init_output_static_infer_desc() { void CheckNonFinite::init_output_static_infer_desc() {
using namespace cg::static_infer; using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager(); auto&& mgr = owner_graph()->static_infer_manager();
size_t size = input().size();
for (size_t i = 0; i < size; ++i) {
mgr.register_shape_infer(output(i), ShapeInferDesc::make_identity(input(i)));
}
auto infer_oshp = [](TensorShape& dest, const InpVal& iv) { auto infer_oshp = [](TensorShape& dest, const InpVal& iv) {
TensorLayout dst; TensorLayout dst;
dst.shape[0] = 1; dst.shape[0] = 1;
...@@ -532,7 +553,7 @@ void CheckNonFinite::init_output_static_infer_desc() { ...@@ -532,7 +553,7 @@ void CheckNonFinite::init_output_static_infer_desc() {
DepVal deps; DepVal deps;
for (auto i : input()) for (auto i : input())
deps.push_back({i, DepType::SHAPE}); deps.push_back({i, DepType::SHAPE});
mgr.register_shape_infer(output(0), {SourceType::DEP, deps, infer_oshp}); mgr.register_shape_infer(output(size), {SourceType::DEP, deps, infer_oshp});
auto infer_wk = [this](TensorShape& dest, const InpVal& inp) { auto infer_wk = [this](TensorShape& dest, const InpVal& inp) {
dest.ndim = 1; dest.ndim = 1;
...@@ -541,10 +562,11 @@ void CheckNonFinite::init_output_static_infer_desc() { ...@@ -541,10 +562,11 @@ void CheckNonFinite::init_output_static_infer_desc() {
inp_arr[i] = {NULL, {inp.val.at(i).shape(), input(0)->dtype()}}; inp_arr[i] = {NULL, {inp.val.at(i).shape(), input(0)->dtype()}};
} }
dest.shape[0] = megdnn_opr()->get_workspace_in_bytes( dest.shape[0] = megdnn_opr()->get_workspace_in_bytes(
inp_arr, {output(0)->shape(), output(0)->dtype()}); inp_arr, {output(input().size() + 1)->shape(),
output(input().size() + 1)->dtype()});
return true; return true;
}; };
mgr.register_shape_infer(output(1), {SourceType::DEP, deps, infer_wk}); mgr.register_shape_infer(output(size + 1), {SourceType::DEP, deps, infer_wk});
} }
void CheckNonFinite::add_input_layout_constraint() { void CheckNonFinite::add_input_layout_constraint() {
......
...@@ -56,7 +56,16 @@ struct OprMaker<opr::TopK, 2> { ...@@ -56,7 +56,16 @@ struct OprMaker<opr::TopK, 2> {
}; };
template <> template <>
struct OprMaker<opr::CheckNonFinite, 0> : public OprMakerVariadic<opr::CheckNonFinite> { struct OprMaker<opr::CheckNonFinite, 0> {
using Opr = opr::CheckNonFinite;
using Param = Opr::Param;
static cg::OperatorNodeBase* make(
const Param& param, const cg::VarNodeArray& inputs, ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
auto out = Opr::make(inputs, param, config);
return out[0].node()->owner_opr();
}
}; };
} // namespace serialization } // namespace serialization
......
...@@ -183,18 +183,19 @@ public: ...@@ -183,18 +183,19 @@ public:
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };
MGB_DEFINE_OPR_CLASS(CheckNonFinite, intl::CheckNonFiniteBase) //{ MGB_DEFINE_OPR_CLASS(CheckNonFinite, intl::CheckNonFiniteBase) // {
void scn_do_execute() override; void scn_do_execute() override;
void init_output_static_infer_desc() override; void init_output_static_infer_desc() override;
void add_input_layout_constraint() override; void add_input_layout_constraint() override;
float m_scale = 1;
public: public:
MGE_WIN_DECLSPEC_FUC CheckNonFinite( MGE_WIN_DECLSPEC_FUC CheckNonFinite(
const VarNodeArrayView& inp, const Param& param, const VarNodeArrayView& inp, const Param& param,
const OperatorNodeConfig& config); const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVar make( MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
const VarNodeArrayView& inp, const Param& param = {}, const VarNodeArrayView& inp, const Param& param = {},
const OperatorNodeConfig& config = {}); const OperatorNodeConfig& config = {});
}; };
} // namespace opr } // namespace opr
......
...@@ -115,6 +115,7 @@ union OperatorParam { ...@@ -115,6 +115,7 @@ union OperatorParam {
param.SlidingWindowTranspose = 81, param.SlidingWindowTranspose = 81,
param.Padding = 82, param.Padding = 82,
param.ShuffleRNG = 83, param.ShuffleRNG = 83,
param.CheckNonFinite = 84,
} }
table Operator { table Operator {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册