diff --git a/imperative/python/test/unit/test_distributed.py b/imperative/python/test/unit/test_distributed.py index dd5add7e5c529fa87d492dbb5a735777b14c1ee5..f81b9f42203c00f8932eaab14feba8098a935e9b 100644 --- a/imperative/python/test/unit/test_distributed.py +++ b/imperative/python/test/unit/test_distributed.py @@ -14,6 +14,7 @@ import pytest import megengine as mge import megengine.distributed as dist +from megengine.core.ops.builtin import CollectiveComm, ParamPackConcat, ParamPackSplit from megengine.distributed.helper import get_device_count_by_fork @@ -187,3 +188,10 @@ def test_synchronized(): for p in procs: p.join(20) assert p.exitcode == 0 + + +def test_oprmm_hashable(): + lhs = (CollectiveComm(), ParamPackConcat(), ParamPackSplit()) + rhs = (CollectiveComm(), ParamPackConcat(), ParamPackSplit()) + assert lhs == rhs + assert hash(lhs) == hash(rhs) diff --git a/imperative/src/impl/ops/collective_comm.cpp b/imperative/src/impl/ops/collective_comm.cpp index 4c23be99fd309ac47346f686c14fa64c6018f821..68a97bbb5fbc98c20f2bcdfcb92f2ed92a5e4d60 100644 --- a/imperative/src/impl/ops/collective_comm.cpp +++ b/imperative/src/impl/ops/collective_comm.cpp @@ -71,7 +71,7 @@ OP_TRAIT_REG(CollectiveComm, CollectiveComm, opr::CollectiveComm) .make_from_op_node(make_from_op_node) .fallback(); } // anonymous namespace - +#endif // MGB_ENABLE_OPR_MM bool CollectiveComm::is_same_st(const Hashable& another) const{ auto* comm_opr = another.try_cast_final(); @@ -100,18 +100,6 @@ size_t CollectiveComm::hash() const{ return xxhash.digest(); } -#else - -bool CollectiveComm::is_same_st(const Hashable& another) const{ - return OpDef::is_same_st(another); -} - -size_t CollectiveComm::hash() const{ - return OpDef::hash(); -} - -#endif // MGB_ENABLE_OPR_MM - MGB_DYN_TYPE_OBJ_FINAL_IMPL(CollectiveComm); } // namespace imperative diff --git a/imperative/src/include/megbrain/imperative/ops/tensor_manip.h b/imperative/src/include/megbrain/imperative/ops/tensor_manip.h index 23a5bc31d20233b42d2a1830a62bef6d779df10f..8d3d44d38347de2b320e80d309fabbcb7cf14db4 100644 --- a/imperative/src/include/megbrain/imperative/ops/tensor_manip.h +++ b/imperative/src/include/megbrain/imperative/ops/tensor_manip.h @@ -64,11 +64,8 @@ public: } bool is_same_st(const Hashable& rhs) const override { - auto* pps = rhs.try_cast_final(); - if(pps == nullptr){ - return false; - } - return offsets == pps->offsets && shapes == pps->shapes; + auto&& pps = rhs.cast_final_safe(); + return offsets == pps.offsets && shapes == pps.shapes; } }; @@ -94,11 +91,8 @@ public: } bool is_same_st(const Hashable& rhs) const override { - auto* ppc = rhs.try_cast_final(); - if(ppc == nullptr){ - return false; - } - return offsets == ppc->offsets; + auto&& ppc = rhs.cast_final_safe(); + return offsets == ppc.offsets; } };