提交 939bf3b2 编写于 作者: M Megvii Engine Team

fix(mge/imperative): impl hashable for nms and cond_take

GitOrigin-RevId: 56918db014326fb8c317119039585f2f3707e8e3
上级 8dc23e0f
......@@ -11,6 +11,7 @@ import itertools
import numpy as np
import pytest
import megengine.core.ops.builtin as builtin
import megengine.core.tensor.dtype as dtype
import megengine.functional as F
from megengine import Buffer, Parameter, is_cuda_available, tensor
......@@ -631,3 +632,20 @@ def test_condtake():
val, idx = F.cond_take(yy, xx)
np.testing.assert_equal(val.numpy(), x[y])
np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0])
def test_condtake_is_same():
op1 = builtin.CondTake()
op2 = builtin.CondTake()
assert op1 == op2
def test_nms_is_same():
op1 = builtin.NMSKeep(0.7, 100)
op2 = builtin.NMSKeep(0.7, 100)
op3 = builtin.NMSKeep(0.8, 100)
op4 = builtin.NMSKeep(0.7, 200)
assert op1 == op2
assert op1 != op3
assert op1 != op4
assert op3 != op4
......@@ -19,6 +19,15 @@ class CondTake : public OpDefImplBase<CondTake> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
CondTake() = default;
size_t hash() const override {
return reinterpret_cast<std::uintptr_t>(dyn_typeinfo());
}
bool is_same_st(const Hashable& rhs) const override {
return rhs.dyn_typeinfo() == dyn_typeinfo();
}
};
} // namespace mgb::imperative
......@@ -23,6 +23,20 @@ public:
NMSKeep() = default;
NMSKeep(float iou_thresh_, uint32_t max_output_):
iou_thresh(iou_thresh_), max_output(max_output_) {}
size_t hash() const override {
return hash_pair_combine(
hash_pair_combine(mgb::hash(iou_thresh), mgb::hash(max_output)),
reinterpret_cast<std::uintptr_t>(dyn_typeinfo()));
}
bool is_same_st(const Hashable& rhs_) const override {
auto&& rhs = static_cast<const NMSKeep&>(rhs_);
return rhs.dyn_typeinfo() == dyn_typeinfo()
&& rhs.iou_thresh == iou_thresh
&& rhs.max_output == max_output;
}
};
} // namespace mgb::imperative
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册