diff --git a/imperative/src/impl/ops/collective_comm.cpp b/imperative/src/impl/ops/collective_comm.cpp index 69ba2bddcd636f98a906bd5301aea3987349d540..7d23b764f42421453bbeb5836e6897e6af617e53 100644 --- a/imperative/src/impl/ops/collective_comm.cpp +++ b/imperative/src/impl/ops/collective_comm.cpp @@ -15,6 +15,7 @@ #include "../op_trait.h" #include "../proxy_graph_detail.h" #include "megbrain/opr/mm_handler.h" +#include "megbrain/utils/hash.h" #endif // MGB_ENABLE_OPR_MM #include "megbrain/imperative/ops/collective_comm.h" @@ -52,6 +53,45 @@ OP_TRAIT_REG(CollectiveComm, CollectiveComm, opr::CollectiveComm) .apply_on_var_node(apply_on_var_node) .fallback(); } // anonymous namespace + + +bool CollectiveComm::is_same_st(const Hashable& another) const{ + auto* comm_opr = another.try_cast_final(); + if(!comm_opr){ + return false; + } + return as_tuple() == comm_opr->as_tuple(); +} + +size_t CollectiveComm::hash() const{ + XXHash xxhash{}; + auto append = [&xxhash](auto field){ + auto hash_val = HashTrait::eval(field); + xxhash.update(reinterpret_cast(&hash_val), sizeof(hash_val)); + }; + append(key); + append(nr_devices); + append(rank); + append(is_root); + append(local_grad); + append(addr); + append(port); + append(mode); + append(backend); + append(comp_node); + return xxhash.digest(); +} + +#else + +bool CollectiveComm::is_same_st(const Hashable& another) const{ + return OpDef::is_same_st(another); +} + +size_t CollectiveComm::hash() const{ + return OpDef::hash(); +} + #endif // MGB_ENABLE_OPR_MM MGB_DYN_TYPE_OBJ_FINAL_IMPL(CollectiveComm); diff --git a/imperative/src/include/megbrain/imperative/ops/collective_comm.h b/imperative/src/include/megbrain/imperative/ops/collective_comm.h index 0bc4ab492e2634f8f4a639abf4c8403956d0c160..f45fff652fcc976b3ec6af2f7a14c0e040cbf415 100644 --- a/imperative/src/include/megbrain/imperative/ops/collective_comm.h +++ b/imperative/src/include/megbrain/imperative/ops/collective_comm.h @@ -52,6 +52,15 @@ public: DType dtype; std::string backend; std::string comp_node; + + size_t hash() const override; + + bool is_same_st(const Hashable& another) const override; + auto as_tuple() const{ + return std::tuple(key, nr_devices, rank, is_root, + local_grad, addr, port, mode, dtype, + backend, comp_node); + } }; } // namespace imperative diff --git a/imperative/src/include/megbrain/imperative/ops/opr_attr.h b/imperative/src/include/megbrain/imperative/ops/opr_attr.h index c08c76a3e0d05bf981032ae972efdd5a778bea3c..da011c7e501daaee08567b9f7f4880f7de508e78 100644 --- a/imperative/src/include/megbrain/imperative/ops/opr_attr.h +++ b/imperative/src/include/megbrain/imperative/ops/opr_attr.h @@ -45,8 +45,8 @@ public: std::string repr() const; - bool is_same_st(const Hashable& rhs) const; - size_t hash() const; + bool is_same_st(const Hashable& rhs) const override; + size_t hash() const override; }; } // namespace imperative