diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 3401ad102bd2c15d8fb6b325b0e9afeb8745a7a2..4aec162711f6178661b58755a27b297312bfb134 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -11,6 +11,7 @@ import itertools import numpy as np import pytest +import megengine.core.ops.builtin as builtin import megengine.core.tensor.dtype as dtype import megengine.functional as F from megengine import Buffer, Parameter, is_cuda_available, tensor @@ -631,3 +632,20 @@ def test_condtake(): val, idx = F.cond_take(yy, xx) np.testing.assert_equal(val.numpy(), x[y]) np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0]) + + +def test_condtake_is_same(): + op1 = builtin.CondTake() + op2 = builtin.CondTake() + assert op1 == op2 + + +def test_nms_is_same(): + op1 = builtin.NMSKeep(0.7, 100) + op2 = builtin.NMSKeep(0.7, 100) + op3 = builtin.NMSKeep(0.8, 100) + op4 = builtin.NMSKeep(0.7, 200) + assert op1 == op2 + assert op1 != op3 + assert op1 != op4 + assert op3 != op4 diff --git a/imperative/src/include/megbrain/imperative/ops/cond_take.h b/imperative/src/include/megbrain/imperative/ops/cond_take.h index 42a1c1a4892f6665adb56bf912f2da03fef13f3a..bed3465cecd016a2fb7bb3b9d991edbf1460a3a7 100644 --- a/imperative/src/include/megbrain/imperative/ops/cond_take.h +++ b/imperative/src/include/megbrain/imperative/ops/cond_take.h @@ -19,6 +19,15 @@ class CondTake : public OpDefImplBase { MGB_DYN_TYPE_OBJ_FINAL_DECL; public: CondTake() = default; + + size_t hash() const override { + return reinterpret_cast(dyn_typeinfo()); + } + + bool is_same_st(const Hashable& rhs) const override { + return rhs.dyn_typeinfo() == dyn_typeinfo(); + } + }; } // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/ops/nms.h b/imperative/src/include/megbrain/imperative/ops/nms.h index 4e0c21ec0b29a93e11aa2d676025453ea5e20dd8..ad1bd96bec474373fd0bae92d5ac6aea2e0e214a 100644 --- a/imperative/src/include/megbrain/imperative/ops/nms.h +++ b/imperative/src/include/megbrain/imperative/ops/nms.h @@ -23,6 +23,20 @@ public: NMSKeep() = default; NMSKeep(float iou_thresh_, uint32_t max_output_): iou_thresh(iou_thresh_), max_output(max_output_) {} + + size_t hash() const override { + return hash_pair_combine( + hash_pair_combine(mgb::hash(iou_thresh), mgb::hash(max_output)), + reinterpret_cast(dyn_typeinfo())); + } + + bool is_same_st(const Hashable& rhs_) const override { + auto&& rhs = static_cast(rhs_); + return rhs.dyn_typeinfo() == dyn_typeinfo() + && rhs.iou_thresh == iou_thresh + && rhs.max_output == max_output; + } + }; } // namespace mgb::imperative