提交 e0da7485 编写于 作者: M Megvii Engine Team

feat(opr): add confidential operator

GitOrigin-RevId: 53c2d4bc4510cebacade1be42c4e49ee47c04716
上级 20cdd11d
......@@ -1579,3 +1579,9 @@ def batched_nms(
indices = indices[0][: count.item()]
keep_inds = sorted_idx[indices]
return keep_inds
from .loss import * # isort:skip
from .quantized import conv_bias_activation # isort:skip
......@@ -551,3 +551,5 @@ def test_nms_is_same():
assert op1 != op3
assert op1 != op4
assert op3 != op4
......@@ -159,6 +159,7 @@ void Cumsum::init_output_static_infer_desc() {
{SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_workspace});
}
/* ================= CondTake ================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondTake);
......
......@@ -63,4 +63,5 @@ decl_opr('TopK',
inputs=['data', 'k'], params='TopK',
desc='Select the top k values from sorted result.')
# vim: ft=python
......@@ -70,6 +70,7 @@ namespace opr {
using CumsumV1 = opr::Cumsum;
MGB_SEREG_OPR(CumsumV1, 1);
} // namespace opr
} // namespace mgb
......
......@@ -94,6 +94,7 @@ MGB_DEFINE_OPR_CLASS(Cumsum, cg::SingleCNOperatorNodeBaseT<
void init_output_static_infer_desc() override;
};
namespace intl {
using CondTakeBase =
cg::SingleCNOperatorNode<cg::OperatorNodeBase,
......
......@@ -28,6 +28,7 @@ table Blob {
}
table Reserved0 {}
table Reserved1 {}
union OperatorParam {
param.Empty = 1,
......@@ -100,6 +101,7 @@ union OperatorParam {
param.Remap = 68,
param.NMSKeep = 69,
param.AdaptivePooling = 70,
Reserved1 = 71,
}
table Operator {
......
......@@ -143,3 +143,4 @@ pdef('PersistentOutputStorage').add_fields(
' no branch is taken')
)
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册