From 9389a805d8d170ad62fc85e821c942b5f3861abd Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 9 Sep 2020 14:39:15 +0800 Subject: [PATCH] fix(mge/oprmm): fix grad for collective comm GitOrigin-RevId: e55dc47a6eb8f70955c94743ba274feb2c141624 --- imperative/python/test/unit/test_distributed.py | 8 ++++++++ imperative/src/impl/ops/collective_comm.cpp | 14 +------------- .../include/megbrain/imperative/ops/tensor_manip.h | 14 ++++---------- 3 files changed, 13 insertions(+), 23 deletions(-) diff --git a/imperative/python/test/unit/test_distributed.py b/imperative/python/test/unit/test_distributed.py index dd5add7e..f81b9f42 100644 --- a/imperative/python/test/unit/test_distributed.py +++ b/imperative/python/test/unit/test_distributed.py @@ -14,6 +14,7 @@ import pytest import megengine as mge import megengine.distributed as dist +from megengine.core.ops.builtin import CollectiveComm, ParamPackConcat, ParamPackSplit from megengine.distributed.helper import get_device_count_by_fork @@ -187,3 +188,10 @@ def test_synchronized(): for p in procs: p.join(20) assert p.exitcode == 0 + + +def test_oprmm_hashable(): + lhs = (CollectiveComm(), ParamPackConcat(), ParamPackSplit()) + rhs = (CollectiveComm(), ParamPackConcat(), ParamPackSplit()) + assert lhs == rhs + assert hash(lhs) == hash(rhs) diff --git a/imperative/src/impl/ops/collective_comm.cpp b/imperative/src/impl/ops/collective_comm.cpp index 4c23be99..68a97bbb 100644 --- a/imperative/src/impl/ops/collective_comm.cpp +++ b/imperative/src/impl/ops/collective_comm.cpp @@ -71,7 +71,7 @@ OP_TRAIT_REG(CollectiveComm, CollectiveComm, opr::CollectiveComm) .make_from_op_node(make_from_op_node) .fallback(); } // anonymous namespace - +#endif // MGB_ENABLE_OPR_MM bool CollectiveComm::is_same_st(const Hashable& another) const{ auto* comm_opr = another.try_cast_final(); @@ -100,18 +100,6 @@ size_t CollectiveComm::hash() const{ return xxhash.digest(); } -#else - -bool CollectiveComm::is_same_st(const Hashable& another) const{ - return OpDef::is_same_st(another); -} - -size_t CollectiveComm::hash() const{ - return OpDef::hash(); -} - -#endif // MGB_ENABLE_OPR_MM - MGB_DYN_TYPE_OBJ_FINAL_IMPL(CollectiveComm); } // namespace imperative diff --git a/imperative/src/include/megbrain/imperative/ops/tensor_manip.h b/imperative/src/include/megbrain/imperative/ops/tensor_manip.h index 23a5bc31..8d3d44d3 100644 --- a/imperative/src/include/megbrain/imperative/ops/tensor_manip.h +++ b/imperative/src/include/megbrain/imperative/ops/tensor_manip.h @@ -64,11 +64,8 @@ public: } bool is_same_st(const Hashable& rhs) const override { - auto* pps = rhs.try_cast_final(); - if(pps == nullptr){ - return false; - } - return offsets == pps->offsets && shapes == pps->shapes; + auto&& pps = rhs.cast_final_safe(); + return offsets == pps.offsets && shapes == pps.shapes; } }; @@ -94,11 +91,8 @@ public: } bool is_same_st(const Hashable& rhs) const override { - auto* ppc = rhs.try_cast_final(); - if(ppc == nullptr){ - return false; - } - return offsets == ppc->offsets; + auto&& ppc = rhs.cast_final_safe(); + return offsets == ppc.offsets; } }; -- GitLab