提交 aeb7980b 编写于 作者: M Megvii Engine Team

perf(mgb): outputs of the same opr and same compnode share the same callbackcaller

GitOrigin-RevId: 59b8e3bcbe0dd76f80f85bc1f46733364df769d1
上级 89b6dbc7
...@@ -148,13 +148,15 @@ size_t ComputingGraph::prealloc_static_storage(size_t size) { ...@@ -148,13 +148,15 @@ size_t ComputingGraph::prealloc_static_storage(size_t size) {
/* ========================== CallbackCaller ========================== */ /* ========================== CallbackCaller ========================== */
MGB_DEFINE_OPR_CLASS(ComputingGraphImpl::CallbackCaller, MGB_DEFINE_OPR_CLASS(ComputingGraphImpl::CallbackCaller,
SingleCNOperatorNodeBase) // { SingleCNOperatorNodeBase) // {
std::vector<ComputingGraph::Callback> m_cb; std::vector<std::vector<ComputingGraph::Callback>> m_cb;
void scn_do_execute() override { void scn_do_execute() override {
auto&& dv = input(0)->dev_tensor(); for (size_t i = 0; i < input().size(); ++i) {
for (auto&& i : m_cb) { auto&& in = input(i)->dev_tensor();
for (auto&& callback : m_cb[i]) {
// const cast for backward API compatibility // const cast for backward API compatibility
i(const_cast<DeviceTensorND&>(dv)); callback(const_cast<DeviceTensorND&>(in));
}
} }
} }
...@@ -168,14 +170,29 @@ MGB_DEFINE_OPR_CLASS(ComputingGraphImpl::CallbackCaller, ...@@ -168,14 +170,29 @@ MGB_DEFINE_OPR_CLASS(ComputingGraphImpl::CallbackCaller,
if (owner_graph()->options().comp_node_seq_record_level) { if (owner_graph()->options().comp_node_seq_record_level) {
// the user callback usually copies from device to host, which // the user callback usually copies from device to host, which
// involves tmp alloc if input is not contiguous // involves tmp alloc if input is not contiguous
input(0)->add_layout_constraint_contiguous(); for (auto&& inp : input()) {
inp->add_layout_constraint_contiguous();
}
}
}
void init_output_dtype() override {
if (output(0)->dtype().valid()) {
return;
} }
mgb_assert(!input().empty());
DType dtype = input(0)->dtype();
mgb_assert(dtype.valid() && dtype != dtype::Byte());
output(0)->dtype(dtype);
} }
NodeProp* do_make_node_prop() const override { NodeProp* do_make_node_prop() const override {
auto ret = Super::do_make_node_prop(); auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(0), for (auto&& inp : input()) {
NodeProp::DepType::VALUE_ALLOW_EMPTY); ret->add_dep_type_existing_var(
inp, NodeProp::DepType::VALUE_ALLOW_EMPTY);
}
return ret; return ret;
} }
...@@ -185,25 +202,38 @@ MGB_DEFINE_OPR_CLASS(ComputingGraphImpl::CallbackCaller, ...@@ -185,25 +202,38 @@ MGB_DEFINE_OPR_CLASS(ComputingGraphImpl::CallbackCaller,
} }
public: public:
CallbackCaller(VarNode* inp) CallbackCaller(const VarNodeArrayView& inp)
: Super{inp->owner_graph(), {}, "callback", {inp}} { : Super{inp[0]->owner_graph(), {}, "callback", inp} {
add_input({inp}); mgb_assert(!inp.empty());
m_cb.resize(inp.size());
for (auto&& i : inp) {
add_input({i});
}
using F = VarNode::Flag; using F = VarNode::Flag;
add_output(None) add_output(None)
->add_flag(F::ALLOW_EMPTY_SHAPE) ->add_flag(F::ALLOW_EMPTY_SHAPE)
.add_flag(F::VOLATILE_CONTENT); .add_flag(F::VOLATILE_CONTENT);
} }
static SymbolVar make(SymbolVar inp) { static SymbolVar make(const VarNodeArrayView& inp) {
return inp.insert_single_output_opr<CallbackCaller>(inp.node()); mgb_assert(!inp.empty());
return SymbolVar{inp[0]}
.node()
->owner_graph()
->insert_opr(std::make_unique<CallbackCaller>(inp))
->output(0);
} }
void add_callback(const ComputingGraph::Callback& cb) { void add_callback(const ComputingGraph::Callback& cb, size_t i = 0) {
mgb_assert(cb); mgb_assert(cb && i < m_cb.size());
m_cb.push_back(cb); m_cb[i].push_back(cb);
} }
void clear_callback() { m_cb.clear(); } void clear_callback() {
for (size_t i = 0; i < m_cb.size(); ++i) {
m_cb[i].clear();
}
}
}; };
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ComputingGraphImpl::CallbackCaller); MGB_DYN_TYPE_OBJ_FINAL_IMPL(ComputingGraphImpl::CallbackCaller);
...@@ -529,22 +559,39 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( ...@@ -529,22 +559,39 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
cmpnt.seq_comp_node_opt.optimize_comp_nodes(dest_vars); cmpnt.seq_comp_node_opt.optimize_comp_nodes(dest_vars);
auto init_opr_seq = [&]() { auto init_opr_seq = [&]() {
ThinHashMap<VarNode*, CallbackCaller*> var2cb_caller; ThinHashMap<VarNode*, size_t> var2idx;
std::unordered_map<CallbackCallerKey, CallbackCallerVal,
CallbackCallerKey::Hash>
opr2vars;
for (size_t i = 0; i < out_spec.size(); ++i) { for (size_t i = 0; i < out_spec.size(); ++i) {
auto&& cb = out_spec[i].second; auto&& cb = out_spec[i].second;
if (cb) { if (cb) {
auto var = dest_vars[i]; auto var = dest_vars[i];
auto&& cb_caller = var2cb_caller[var]; CallbackCallerKey key{var->owner_opr(), var->comp_node()};
if (!cb_caller) { auto&& vals = opr2vars[key];
auto dvar = CallbackCaller::make(var); auto&& var2idx_iter = var2idx.find(var);
cb_caller = &dvar.node() if ( var2idx_iter == var2idx.end()) {
vals.vars.push_back(var);
vals.indexs.push_back({i});
var2idx[var] = vals.vars.size() - 1;
} else {
vals.indexs[var2idx_iter->second].push_back(i);
}
}
}
for (auto& item : opr2vars) {
auto&& val = item.second;
auto dvar = CallbackCaller::make(val.vars);
CallbackCaller* cb_caller = &dvar.node()
->owner_opr() ->owner_opr()
->cast_final_safe<CallbackCaller>(); ->cast_final_safe<CallbackCaller>();
++extra_info.var2recvinfo[dvar.node()].nr_direct_comp_req; ++extra_info.var2recvinfo[dvar.node()].nr_direct_comp_req;
cb_caller->clear_callback(); cb_caller->clear_callback();
for (size_t i=0;i<val.vars.size(); ++i) {
for (auto&& idx : val.indexs[i]) {
cb_caller->add_callback(out_spec[idx].second, i);
dest_vars[idx] = cb_caller->output(0);
} }
cb_caller->add_callback(cb);
dest_vars[i] = cb_caller->output(0);
} }
} }
opr_seq = topo_sorter().get_comp_seq(extra_info, dest_vars); opr_seq = topo_sorter().get_comp_seq(extra_info, dest_vars);
......
...@@ -40,6 +40,28 @@ class ComputingGraphImpl final : public ComputingGraph { ...@@ -40,6 +40,28 @@ class ComputingGraphImpl final : public ComputingGraph {
const OprNodeArray* opr_seq = nullptr; const OprNodeArray* opr_seq = nullptr;
}; };
struct CallbackCallerKey {
OperatorNodeBase* opr;
CompNode comp_node;
bool operator==(const CallbackCallerKey& rhs) const {
return opr == rhs.opr && comp_node == rhs.comp_node;
}
struct Hash {
size_t operator()(const CallbackCallerKey& b) const {
return hash_pair_combine(mgb::hash(b.opr),
mgb::hash(b.comp_node));
}
};
};
struct CallbackCallerVal {
SmallVector<VarNode*> vars;
//! indexs of vars in out_spec.
SmallVector<SmallVector<size_t>> indexs;
};
/*! /*!
* Components for implementing algorithms on a computing graph. * Components for implementing algorithms on a computing graph.
* *
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/misc.h" #include "megbrain/opr/misc.h"
#include "megbrain/opr/indexing.h" #include "megbrain/opr/indexing.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/graph/helper.h" #include "megbrain/graph/helper.h"
#include "megbrain/graph/grad_impl.h" #include "megbrain/graph/grad_impl.h"
#include "megbrain/graph/event.h" #include "megbrain/graph/event.h"
...@@ -30,6 +31,7 @@ ...@@ -30,6 +31,7 @@
#include <atomic> #include <atomic>
#include <chrono> #include <chrono>
#include <array> #include <array>
#include <memory>
using namespace mgb; using namespace mgb;
...@@ -2236,4 +2238,52 @@ TEST(TestGraph, FreeBias) { ...@@ -2236,4 +2238,52 @@ TEST(TestGraph, FreeBias) {
} }
} }
TEST(TestGraph, CallbackCaller) {
using namespace opr;
auto cns = load_multiple_xpus(3);
constexpr size_t C1 = 20, C2 = 30, C3 = 10, C4 = 40;
constexpr size_t N = 2, C = C1 + C2;
HostTensorGenerator<> gen;
auto host_opr0 = gen({N, C}, cns[0]);
auto graph = ComputingGraph::make();
SymbolVar opr0 = opr::Host2DeviceCopy::make(*graph, host_opr0, {"opr0"});
auto spl0 = opr::Split::make(
opr0, Split::Options::make_partition(opr0, 1, {C1, C2}),
OperatorNodeConfig("split0").comp_node_arr({cns[1], cns[2]}));
auto spl1 = opr::Split::make(
opr0, Split::Options::make_partition(opr0, 1, {C3, C4}),
OperatorNodeConfig("split1"));
HostTensorND host_spl00, host_spl01, host_spl10, host_spl11;
auto func = graph->compile({make_callback_copy(spl0[0], host_spl00),
make_callback_copy(spl0[1], host_spl01),
make_callback_copy(spl1[0], host_spl10),
make_callback_copy(spl1[1], host_spl11)});
func->execute();
auto o00 = host_spl00.ptr<float>(),
o01 = host_spl01.ptr<float>(),
o10 = host_spl10.ptr<float>(),
o11 = host_spl11.ptr<float>(), c = host_opr0->ptr<float>();
for (size_t i = 0, it = host_opr0->layout().total_nr_elems(); i < it; i++) {
auto ch = i % C;
auto n = i / C;
if (ch < C1) {
MGB_ASSERT_FLOAT_EQ(o00[n * C1 + ch], c[i])
<< ssprintf("failed at %zd", i);
} else {
MGB_ASSERT_FLOAT_EQ(o01[n * C2 + ch - C1], c[i])
<< ssprintf("failed at %zd", i);
}
if (ch < C3) {
MGB_ASSERT_FLOAT_EQ(o10[n * C3 + ch], c[i])
<< ssprintf("failed at %zd", i);
} else {
MGB_ASSERT_FLOAT_EQ(o11[n * C4 + ch - C3], c[i])
<< ssprintf("failed at %zd", i);
}
}
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册