From aeb7980b294633318b0328651470170c7fa45241 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 21 Dec 2020 18:26:26 +0800 Subject: [PATCH] perf(mgb): outputs of the same opr and same compnode share the same callbackcaller GitOrigin-RevId: 59b8e3bcbe0dd76f80f85bc1f46733364df769d1 --- src/core/impl/graph/cg_impl.cpp | 103 +++++++++++++++++++++++--------- src/core/impl/graph/cg_impl.h | 22 +++++++ src/core/test/graph/misc.cpp | 50 ++++++++++++++++ 3 files changed, 147 insertions(+), 28 deletions(-) diff --git a/src/core/impl/graph/cg_impl.cpp b/src/core/impl/graph/cg_impl.cpp index dbead12e2..2a091f19c 100644 --- a/src/core/impl/graph/cg_impl.cpp +++ b/src/core/impl/graph/cg_impl.cpp @@ -148,13 +148,15 @@ size_t ComputingGraph::prealloc_static_storage(size_t size) { /* ========================== CallbackCaller ========================== */ MGB_DEFINE_OPR_CLASS(ComputingGraphImpl::CallbackCaller, SingleCNOperatorNodeBase) // { - std::vector m_cb; + std::vector> m_cb; void scn_do_execute() override { - auto&& dv = input(0)->dev_tensor(); - for (auto&& i : m_cb) { - // const cast for backward API compatibility - i(const_cast(dv)); + for (size_t i = 0; i < input().size(); ++i) { + auto&& in = input(i)->dev_tensor(); + for (auto&& callback : m_cb[i]) { + // const cast for backward API compatibility + callback(const_cast(in)); + } } } @@ -168,14 +170,29 @@ MGB_DEFINE_OPR_CLASS(ComputingGraphImpl::CallbackCaller, if (owner_graph()->options().comp_node_seq_record_level) { // the user callback usually copies from device to host, which // involves tmp alloc if input is not contiguous - input(0)->add_layout_constraint_contiguous(); + for (auto&& inp : input()) { + inp->add_layout_constraint_contiguous(); + } } } + void init_output_dtype() override { + if (output(0)->dtype().valid()) { + return; + } + + mgb_assert(!input().empty()); + DType dtype = input(0)->dtype(); + mgb_assert(dtype.valid() && dtype != dtype::Byte()); + output(0)->dtype(dtype); + } + NodeProp* do_make_node_prop() const override { auto ret = Super::do_make_node_prop(); - ret->add_dep_type_existing_var(input(0), - NodeProp::DepType::VALUE_ALLOW_EMPTY); + for (auto&& inp : input()) { + ret->add_dep_type_existing_var( + inp, NodeProp::DepType::VALUE_ALLOW_EMPTY); + } return ret; } @@ -185,25 +202,38 @@ MGB_DEFINE_OPR_CLASS(ComputingGraphImpl::CallbackCaller, } public: - CallbackCaller(VarNode* inp) - : Super{inp->owner_graph(), {}, "callback", {inp}} { - add_input({inp}); + CallbackCaller(const VarNodeArrayView& inp) + : Super{inp[0]->owner_graph(), {}, "callback", inp} { + mgb_assert(!inp.empty()); + m_cb.resize(inp.size()); + for (auto&& i : inp) { + add_input({i}); + } using F = VarNode::Flag; add_output(None) ->add_flag(F::ALLOW_EMPTY_SHAPE) .add_flag(F::VOLATILE_CONTENT); } - static SymbolVar make(SymbolVar inp) { - return inp.insert_single_output_opr(inp.node()); + static SymbolVar make(const VarNodeArrayView& inp) { + mgb_assert(!inp.empty()); + return SymbolVar{inp[0]} + .node() + ->owner_graph() + ->insert_opr(std::make_unique(inp)) + ->output(0); } - void add_callback(const ComputingGraph::Callback& cb) { - mgb_assert(cb); - m_cb.push_back(cb); + void add_callback(const ComputingGraph::Callback& cb, size_t i = 0) { + mgb_assert(cb && i < m_cb.size()); + m_cb[i].push_back(cb); } - void clear_callback() { m_cb.clear(); } + void clear_callback() { + for (size_t i = 0; i < m_cb.size(); ++i) { + m_cb[i].clear(); + } + } }; MGB_DYN_TYPE_OBJ_FINAL_IMPL(ComputingGraphImpl::CallbackCaller); @@ -529,22 +559,39 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( cmpnt.seq_comp_node_opt.optimize_comp_nodes(dest_vars); auto init_opr_seq = [&]() { - ThinHashMap var2cb_caller; + ThinHashMap var2idx; + std::unordered_map + opr2vars; for (size_t i = 0; i < out_spec.size(); ++i) { auto&& cb = out_spec[i].second; if (cb) { auto var = dest_vars[i]; - auto&& cb_caller = var2cb_caller[var]; - if (!cb_caller) { - auto dvar = CallbackCaller::make(var); - cb_caller = &dvar.node() - ->owner_opr() - ->cast_final_safe(); - ++extra_info.var2recvinfo[dvar.node()].nr_direct_comp_req; - cb_caller->clear_callback(); + CallbackCallerKey key{var->owner_opr(), var->comp_node()}; + auto&& vals = opr2vars[key]; + auto&& var2idx_iter = var2idx.find(var); + if ( var2idx_iter == var2idx.end()) { + vals.vars.push_back(var); + vals.indexs.push_back({i}); + var2idx[var] = vals.vars.size() - 1; + } else { + vals.indexs[var2idx_iter->second].push_back(i); + } + } + } + for (auto& item : opr2vars) { + auto&& val = item.second; + auto dvar = CallbackCaller::make(val.vars); + CallbackCaller* cb_caller = &dvar.node() + ->owner_opr() + ->cast_final_safe(); + ++extra_info.var2recvinfo[dvar.node()].nr_direct_comp_req; + cb_caller->clear_callback(); + for (size_t i=0;iadd_callback(out_spec[idx].second, i); + dest_vars[idx] = cb_caller->output(0); } - cb_caller->add_callback(cb); - dest_vars[i] = cb_caller->output(0); } } opr_seq = topo_sorter().get_comp_seq(extra_info, dest_vars); diff --git a/src/core/impl/graph/cg_impl.h b/src/core/impl/graph/cg_impl.h index 8c24785f0..e3969e655 100644 --- a/src/core/impl/graph/cg_impl.h +++ b/src/core/impl/graph/cg_impl.h @@ -40,6 +40,28 @@ class ComputingGraphImpl final : public ComputingGraph { const OprNodeArray* opr_seq = nullptr; }; + struct CallbackCallerKey { + OperatorNodeBase* opr; + CompNode comp_node; + + bool operator==(const CallbackCallerKey& rhs) const { + return opr == rhs.opr && comp_node == rhs.comp_node; + } + + struct Hash { + size_t operator()(const CallbackCallerKey& b) const { + return hash_pair_combine(mgb::hash(b.opr), + mgb::hash(b.comp_node)); + } + }; + }; + + struct CallbackCallerVal { + SmallVector vars; + //! indexs of vars in out_spec. + SmallVector> indexs; + }; + /*! * Components for implementing algorithms on a computing graph. * diff --git a/src/core/test/graph/misc.cpp b/src/core/test/graph/misc.cpp index 830ff21ab..a7895243f 100644 --- a/src/core/test/graph/misc.cpp +++ b/src/core/test/graph/misc.cpp @@ -17,6 +17,7 @@ #include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/misc.h" #include "megbrain/opr/indexing.h" +#include "megbrain/opr/tensor_manip.h" #include "megbrain/graph/helper.h" #include "megbrain/graph/grad_impl.h" #include "megbrain/graph/event.h" @@ -30,6 +31,7 @@ #include #include #include +#include using namespace mgb; @@ -2236,4 +2238,52 @@ TEST(TestGraph, FreeBias) { } } +TEST(TestGraph, CallbackCaller) { + using namespace opr; + auto cns = load_multiple_xpus(3); + constexpr size_t C1 = 20, C2 = 30, C3 = 10, C4 = 40; + constexpr size_t N = 2, C = C1 + C2; + HostTensorGenerator<> gen; + auto host_opr0 = gen({N, C}, cns[0]); + auto graph = ComputingGraph::make(); + SymbolVar opr0 = opr::Host2DeviceCopy::make(*graph, host_opr0, {"opr0"}); + + auto spl0 = opr::Split::make( + opr0, Split::Options::make_partition(opr0, 1, {C1, C2}), + OperatorNodeConfig("split0").comp_node_arr({cns[1], cns[2]})); + + auto spl1 = opr::Split::make( + opr0, Split::Options::make_partition(opr0, 1, {C3, C4}), + OperatorNodeConfig("split1")); + + HostTensorND host_spl00, host_spl01, host_spl10, host_spl11; + auto func = graph->compile({make_callback_copy(spl0[0], host_spl00), + make_callback_copy(spl0[1], host_spl01), + make_callback_copy(spl1[0], host_spl10), + make_callback_copy(spl1[1], host_spl11)}); + func->execute(); + auto o00 = host_spl00.ptr(), + o01 = host_spl01.ptr(), + o10 = host_spl10.ptr(), + o11 = host_spl11.ptr(), c = host_opr0->ptr(); + for (size_t i = 0, it = host_opr0->layout().total_nr_elems(); i < it; i++) { + auto ch = i % C; + auto n = i / C; + if (ch < C1) { + MGB_ASSERT_FLOAT_EQ(o00[n * C1 + ch], c[i]) + << ssprintf("failed at %zd", i); + } else { + MGB_ASSERT_FLOAT_EQ(o01[n * C2 + ch - C1], c[i]) + << ssprintf("failed at %zd", i); + } + if (ch < C3) { + MGB_ASSERT_FLOAT_EQ(o10[n * C3 + ch], c[i]) + << ssprintf("failed at %zd", i); + } else { + MGB_ASSERT_FLOAT_EQ(o11[n * C4 + ch - C3], c[i]) + << ssprintf("failed at %zd", i); + } + } +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} -- GitLab