提交 e0da7485 编写于 作者: M Megvii Engine Team

feat(opr): add confidential operator

GitOrigin-RevId: 53c2d4bc4510cebacade1be42c4e49ee47c04716
上级 20cdd11d
...@@ -1579,3 +1579,9 @@ def batched_nms( ...@@ -1579,3 +1579,9 @@ def batched_nms(
indices = indices[0][: count.item()] indices = indices[0][: count.item()]
keep_inds = sorted_idx[indices] keep_inds = sorted_idx[indices]
return keep_inds return keep_inds
from .loss import * # isort:skip
from .quantized import conv_bias_activation # isort:skip
...@@ -551,3 +551,5 @@ def test_nms_is_same(): ...@@ -551,3 +551,5 @@ def test_nms_is_same():
assert op1 != op3 assert op1 != op3
assert op1 != op4 assert op1 != op4
assert op3 != op4 assert op3 != op4
...@@ -159,6 +159,7 @@ void Cumsum::init_output_static_infer_desc() { ...@@ -159,6 +159,7 @@ void Cumsum::init_output_static_infer_desc() {
{SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_workspace}); {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_workspace});
} }
/* ================= CondTake ================= */ /* ================= CondTake ================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondTake); MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondTake);
......
...@@ -63,4 +63,5 @@ decl_opr('TopK', ...@@ -63,4 +63,5 @@ decl_opr('TopK',
inputs=['data', 'k'], params='TopK', inputs=['data', 'k'], params='TopK',
desc='Select the top k values from sorted result.') desc='Select the top k values from sorted result.')
# vim: ft=python # vim: ft=python
...@@ -70,6 +70,7 @@ namespace opr { ...@@ -70,6 +70,7 @@ namespace opr {
using CumsumV1 = opr::Cumsum; using CumsumV1 = opr::Cumsum;
MGB_SEREG_OPR(CumsumV1, 1); MGB_SEREG_OPR(CumsumV1, 1);
} // namespace opr } // namespace opr
} // namespace mgb } // namespace mgb
......
...@@ -94,6 +94,7 @@ MGB_DEFINE_OPR_CLASS(Cumsum, cg::SingleCNOperatorNodeBaseT< ...@@ -94,6 +94,7 @@ MGB_DEFINE_OPR_CLASS(Cumsum, cg::SingleCNOperatorNodeBaseT<
void init_output_static_infer_desc() override; void init_output_static_infer_desc() override;
}; };
namespace intl { namespace intl {
using CondTakeBase = using CondTakeBase =
cg::SingleCNOperatorNode<cg::OperatorNodeBase, cg::SingleCNOperatorNode<cg::OperatorNodeBase,
......
...@@ -28,6 +28,7 @@ table Blob { ...@@ -28,6 +28,7 @@ table Blob {
} }
table Reserved0 {} table Reserved0 {}
table Reserved1 {}
union OperatorParam { union OperatorParam {
param.Empty = 1, param.Empty = 1,
...@@ -100,6 +101,7 @@ union OperatorParam { ...@@ -100,6 +101,7 @@ union OperatorParam {
param.Remap = 68, param.Remap = 68,
param.NMSKeep = 69, param.NMSKeep = 69,
param.AdaptivePooling = 70, param.AdaptivePooling = 70,
Reserved1 = 71,
} }
table Operator { table Operator {
......
...@@ -143,3 +143,4 @@ pdef('PersistentOutputStorage').add_fields( ...@@ -143,3 +143,4 @@ pdef('PersistentOutputStorage').add_fields(
' no branch is taken') ' no branch is taken')
) )
) )
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册