From 3bbfef300921a083ea29f2033120bdefc8b0c16d Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 4 Sep 2020 16:29:48 +0800 Subject: [PATCH] fix(mge/imperative): impl hashable for coll-comm GitOrigin-RevId: 76ab16a89bf9519c7749192ebdf831a9a7c2950c --- imperative/src/impl/ops/collective_comm.cpp | 40 +++++++++++++++++++ .../megbrain/imperative/ops/collective_comm.h | 9 +++++ .../megbrain/imperative/ops/opr_attr.h | 4 +- 3 files changed, 51 insertions(+), 2 deletions(-) diff --git a/imperative/src/impl/ops/collective_comm.cpp b/imperative/src/impl/ops/collective_comm.cpp index 69ba2bdd..7d23b764 100644 --- a/imperative/src/impl/ops/collective_comm.cpp +++ b/imperative/src/impl/ops/collective_comm.cpp @@ -15,6 +15,7 @@ #include "../op_trait.h" #include "../proxy_graph_detail.h" #include "megbrain/opr/mm_handler.h" +#include "megbrain/utils/hash.h" #endif // MGB_ENABLE_OPR_MM #include "megbrain/imperative/ops/collective_comm.h" @@ -52,6 +53,45 @@ OP_TRAIT_REG(CollectiveComm, CollectiveComm, opr::CollectiveComm) .apply_on_var_node(apply_on_var_node) .fallback(); } // anonymous namespace + + +bool CollectiveComm::is_same_st(const Hashable& another) const{ + auto* comm_opr = another.try_cast_final(); + if(!comm_opr){ + return false; + } + return as_tuple() == comm_opr->as_tuple(); +} + +size_t CollectiveComm::hash() const{ + XXHash xxhash{}; + auto append = [&xxhash](auto field){ + auto hash_val = HashTrait::eval(field); + xxhash.update(reinterpret_cast(&hash_val), sizeof(hash_val)); + }; + append(key); + append(nr_devices); + append(rank); + append(is_root); + append(local_grad); + append(addr); + append(port); + append(mode); + append(backend); + append(comp_node); + return xxhash.digest(); +} + +#else + +bool CollectiveComm::is_same_st(const Hashable& another) const{ + return OpDef::is_same_st(another); +} + +size_t CollectiveComm::hash() const{ + return OpDef::hash(); +} + #endif // MGB_ENABLE_OPR_MM MGB_DYN_TYPE_OBJ_FINAL_IMPL(CollectiveComm); diff --git a/imperative/src/include/megbrain/imperative/ops/collective_comm.h b/imperative/src/include/megbrain/imperative/ops/collective_comm.h index 0bc4ab49..f45fff65 100644 --- a/imperative/src/include/megbrain/imperative/ops/collective_comm.h +++ b/imperative/src/include/megbrain/imperative/ops/collective_comm.h @@ -52,6 +52,15 @@ public: DType dtype; std::string backend; std::string comp_node; + + size_t hash() const override; + + bool is_same_st(const Hashable& another) const override; + auto as_tuple() const{ + return std::tuple(key, nr_devices, rank, is_root, + local_grad, addr, port, mode, dtype, + backend, comp_node); + } }; } // namespace imperative diff --git a/imperative/src/include/megbrain/imperative/ops/opr_attr.h b/imperative/src/include/megbrain/imperative/ops/opr_attr.h index c08c76a3..da011c7e 100644 --- a/imperative/src/include/megbrain/imperative/ops/opr_attr.h +++ b/imperative/src/include/megbrain/imperative/ops/opr_attr.h @@ -45,8 +45,8 @@ public: std::string repr() const; - bool is_same_st(const Hashable& rhs) const; - size_t hash() const; + bool is_same_st(const Hashable& rhs) const override; + size_t hash() const override; }; } // namespace imperative -- GitLab