diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 06a78e73aed48c9570d97248c33f9b70a0e0797b..905a7ac50bd8f64dbaab33aca7eb289bbad852a1 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1579,3 +1579,9 @@ def batched_nms( indices = indices[0][: count.item()] keep_inds = sorted_idx[indices] return keep_inds + + + + +from .loss import * # isort:skip +from .quantized import conv_bias_activation # isort:skip diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index f3187ec4d0bf9c3bc36d7a24fa48727400ee550d..3dfce3fadbdead1fcea60870733fddf7b170f661 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -551,3 +551,5 @@ def test_nms_is_same(): assert op1 != op3 assert op1 != op4 assert op3 != op4 + + diff --git a/src/opr/impl/misc.cpp b/src/opr/impl/misc.cpp index 3f738d4eab5ccec96b2e419ff645ce2ad4198412..d4c1bb0bedb03c0b29e403ba0fdff1298376e77e 100644 --- a/src/opr/impl/misc.cpp +++ b/src/opr/impl/misc.cpp @@ -159,6 +159,7 @@ void Cumsum::init_output_static_infer_desc() { {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_workspace}); } + /* ================= CondTake ================= */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondTake); diff --git a/src/opr/impl/misc.oprdecl b/src/opr/impl/misc.oprdecl index b2444c650460182ab8c6e9a16f3087ce6d399d28..d76f473d4f93e752458956450169036c3e7c92e8 100644 --- a/src/opr/impl/misc.oprdecl +++ b/src/opr/impl/misc.oprdecl @@ -63,4 +63,5 @@ decl_opr('TopK', inputs=['data', 'k'], params='TopK', desc='Select the top k values from sorted result.') + # vim: ft=python diff --git a/src/opr/impl/misc.sereg.h b/src/opr/impl/misc.sereg.h index 7c5e7ea6e532543e7f4275c653112f0c5621095e..b8562ee5f9d19818e5b20dc5fd646bb541549a4f 100644 --- a/src/opr/impl/misc.sereg.h +++ b/src/opr/impl/misc.sereg.h @@ -70,6 +70,7 @@ namespace opr { using CumsumV1 = opr::Cumsum; MGB_SEREG_OPR(CumsumV1, 1); + } // namespace opr } // namespace mgb diff --git a/src/opr/include/megbrain/opr/misc.h b/src/opr/include/megbrain/opr/misc.h index e6285a41534a9ca289ebe208026ccb47a262e069..314adefc284fd583073f251b07b42a1b4808946f 100644 --- a/src/opr/include/megbrain/opr/misc.h +++ b/src/opr/include/megbrain/opr/misc.h @@ -94,6 +94,7 @@ MGB_DEFINE_OPR_CLASS(Cumsum, cg::SingleCNOperatorNodeBaseT< void init_output_static_infer_desc() override; }; + namespace intl { using CondTakeBase = cg::SingleCNOperatorNode