From e9c036ccba3e7988ebb310db9d0ab421a4a8574e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sat, 19 Dec 2020 14:12:07 +0800 Subject: [PATCH] fix(mge/imperative): fix op specializations and tuple hash GitOrigin-RevId: e3df93fd5b1303aeab76019f4a75596a051a5bf1 --- imperative/src/impl/ops/autogen.cpp | 33 ++++++++++++--------- imperative/src/impl/ops/specializations.cpp | 22 +++++++++----- 2 files changed, 34 insertions(+), 21 deletions(-) diff --git a/imperative/src/impl/ops/autogen.cpp b/imperative/src/impl/ops/autogen.cpp index 75e00ff1c..4956915e4 100644 --- a/imperative/src/impl/ops/autogen.cpp +++ b/imperative/src/impl/ops/autogen.cpp @@ -18,23 +18,28 @@ using namespace megdnn; // FIXME: remove this when mgb::hash support tuple_hash namespace mgb { namespace { -template -auto tail(T t, std::index_sequence) { - return std::make_tuple(std::get(t)...); + +struct HashWrapper { + size_t hash; + constexpr operator size_t() {return hash;} + + constexpr HashWrapper operator+(HashWrapper rhs) { + // NOTE: use a + b + c + d, not a + (b + (c + d)) !!! + return {hash * 20141203 + rhs.hash}; + } +}; + +template +constexpr size_t hash_many(const Args&... args) { + return (... + HashWrapper{mgb::hash(args)}); } + } // anonymous namespace + template -class HashTrait> { - constexpr static size_t length = sizeof...(Args); -public: +struct HashTrait> { static size_t eval(const std::tuple &t) { - const T& val = std::get<0>(t); - if constexpr (!length) { - return mgb::hash(val); - } else { - return mgb::hash_pair_combine(mgb::hash(val), - mgb::hash(tail(t, std::make_index_sequence{}))); - } + return std::apply(hash_many, t); } }; } // namespace mgb @@ -43,4 +48,4 @@ namespace mgb::imperative { #include "./opdef.cpp.inl" -} // namespace mgb::imperative \ No newline at end of file +} // namespace mgb::imperative diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index 3d88826f3..c04fe09db 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -158,7 +158,13 @@ auto apply_on_var_node( } } -OP_TRAIT_REG(Reduce, Reduce) +std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { + auto* node = &node_->cast_final_safe(); + return Reduce::make(node->param()); +} + +OP_TRAIT_REG(Reduce, Reduce, opr::Reduce) + .make_from_op_node(make_from_op_node) .apply_on_var_node(apply_on_var_node) .fallback(); }} // reduce @@ -439,12 +445,13 @@ OP_TRAIT_REG(GaussianRNG, GaussianRNG) }} // gaussian_rng namespace { namespace roi_align { -auto apply_on_var_node( +VarNodeArray apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 2); - return opr::ROIAlign::make(inputs[0], inputs[1], op.param()); + auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param()).node()->owner_opr(); + return {opr->output(0), opr->output(1)}; } OP_TRAIT_REG(ROIAlign, ROIAlign) .apply_on_var_node(apply_on_var_node) @@ -496,12 +503,13 @@ OP_TRAIT_REG(Eye, Eye) }} // eye namespace { namespace roi_pooling { -auto apply_on_var_node( +VarNodeArray apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 3); - return opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param()); + auto* opr = opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], op.param()).node()->owner_opr(); + return {opr->output(0), opr->output(1)}; } OP_TRAIT_REG(ROIPooling, ROIPooling) .apply_on_var_node(apply_on_var_node) @@ -620,11 +628,11 @@ auto apply_on_var_node( const VarNodeArray& inputs) { auto&& op = static_cast(def); mgb_assert(inputs.size() == 1); - return opr::SVD::make(inputs[0], op.param()); + return opr::SVD::make(inputs[0], op.param())[0].node()->owner_opr()->usable_output(); } OP_TRAIT_REG(SVD, SVD) .apply_on_var_node(apply_on_var_node) .fallback(); }} // svd -} // namespace mgb::imperative \ No newline at end of file +} // namespace mgb::imperative -- GitLab