提交 aeb7980b 编写于 作者: M Megvii Engine Team

perf(mgb): outputs of the same opr and same compnode share the same callbackcaller

GitOrigin-RevId: 59b8e3bcbe0dd76f80f85bc1f46733364df769d1
上级 89b6dbc7
......@@ -148,13 +148,15 @@ size_t ComputingGraph::prealloc_static_storage(size_t size) {
/* ========================== CallbackCaller ========================== */
MGB_DEFINE_OPR_CLASS(ComputingGraphImpl::CallbackCaller,
SingleCNOperatorNodeBase) // {
std::vector<ComputingGraph::Callback> m_cb;
std::vector<std::vector<ComputingGraph::Callback>> m_cb;
void scn_do_execute() override {
auto&& dv = input(0)->dev_tensor();
for (auto&& i : m_cb) {
// const cast for backward API compatibility
i(const_cast<DeviceTensorND&>(dv));
for (size_t i = 0; i < input().size(); ++i) {
auto&& in = input(i)->dev_tensor();
for (auto&& callback : m_cb[i]) {
// const cast for backward API compatibility
callback(const_cast<DeviceTensorND&>(in));
}
}
}
......@@ -168,14 +170,29 @@ MGB_DEFINE_OPR_CLASS(ComputingGraphImpl::CallbackCaller,
if (owner_graph()->options().comp_node_seq_record_level) {
// the user callback usually copies from device to host, which
// involves tmp alloc if input is not contiguous
input(0)->add_layout_constraint_contiguous();
for (auto&& inp : input()) {
inp->add_layout_constraint_contiguous();
}
}
}
void init_output_dtype() override {
if (output(0)->dtype().valid()) {
return;
}
mgb_assert(!input().empty());
DType dtype = input(0)->dtype();
mgb_assert(dtype.valid() && dtype != dtype::Byte());
output(0)->dtype(dtype);
}
NodeProp* do_make_node_prop() const override {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(0),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
for (auto&& inp : input()) {
ret->add_dep_type_existing_var(
inp, NodeProp::DepType::VALUE_ALLOW_EMPTY);
}
return ret;
}
......@@ -185,25 +202,38 @@ MGB_DEFINE_OPR_CLASS(ComputingGraphImpl::CallbackCaller,
}
public:
CallbackCaller(VarNode* inp)
: Super{inp->owner_graph(), {}, "callback", {inp}} {
add_input({inp});
CallbackCaller(const VarNodeArrayView& inp)
: Super{inp[0]->owner_graph(), {}, "callback", inp} {
mgb_assert(!inp.empty());
m_cb.resize(inp.size());
for (auto&& i : inp) {
add_input({i});
}
using F = VarNode::Flag;
add_output(None)
->add_flag(F::ALLOW_EMPTY_SHAPE)
.add_flag(F::VOLATILE_CONTENT);
}
static SymbolVar make(SymbolVar inp) {
return inp.insert_single_output_opr<CallbackCaller>(inp.node());
static SymbolVar make(const VarNodeArrayView& inp) {
mgb_assert(!inp.empty());
return SymbolVar{inp[0]}
.node()
->owner_graph()
->insert_opr(std::make_unique<CallbackCaller>(inp))
->output(0);
}
void add_callback(const ComputingGraph::Callback& cb) {
mgb_assert(cb);
m_cb.push_back(cb);
void add_callback(const ComputingGraph::Callback& cb, size_t i = 0) {
mgb_assert(cb && i < m_cb.size());
m_cb[i].push_back(cb);
}
void clear_callback() { m_cb.clear(); }
void clear_callback() {
for (size_t i = 0; i < m_cb.size(); ++i) {
m_cb[i].clear();
}
}
};
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ComputingGraphImpl::CallbackCaller);
......@@ -529,22 +559,39 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
cmpnt.seq_comp_node_opt.optimize_comp_nodes(dest_vars);
auto init_opr_seq = [&]() {
ThinHashMap<VarNode*, CallbackCaller*> var2cb_caller;
ThinHashMap<VarNode*, size_t> var2idx;
std::unordered_map<CallbackCallerKey, CallbackCallerVal,
CallbackCallerKey::Hash>
opr2vars;
for (size_t i = 0; i < out_spec.size(); ++i) {
auto&& cb = out_spec[i].second;
if (cb) {
auto var = dest_vars[i];
auto&& cb_caller = var2cb_caller[var];
if (!cb_caller) {
auto dvar = CallbackCaller::make(var);
cb_caller = &dvar.node()
->owner_opr()
->cast_final_safe<CallbackCaller>();
++extra_info.var2recvinfo[dvar.node()].nr_direct_comp_req;
cb_caller->clear_callback();
CallbackCallerKey key{var->owner_opr(), var->comp_node()};
auto&& vals = opr2vars[key];
auto&& var2idx_iter = var2idx.find(var);
if ( var2idx_iter == var2idx.end()) {
vals.vars.push_back(var);
vals.indexs.push_back({i});
var2idx[var] = vals.vars.size() - 1;
} else {
vals.indexs[var2idx_iter->second].push_back(i);
}
}
}
for (auto& item : opr2vars) {
auto&& val = item.second;
auto dvar = CallbackCaller::make(val.vars);
CallbackCaller* cb_caller = &dvar.node()
->owner_opr()
->cast_final_safe<CallbackCaller>();
++extra_info.var2recvinfo[dvar.node()].nr_direct_comp_req;
cb_caller->clear_callback();
for (size_t i=0;i<val.vars.size(); ++i) {
for (auto&& idx : val.indexs[i]) {
cb_caller->add_callback(out_spec[idx].second, i);
dest_vars[idx] = cb_caller->output(0);
}
cb_caller->add_callback(cb);
dest_vars[i] = cb_caller->output(0);
}
}
opr_seq = topo_sorter().get_comp_seq(extra_info, dest_vars);
......
......@@ -40,6 +40,28 @@ class ComputingGraphImpl final : public ComputingGraph {
const OprNodeArray* opr_seq = nullptr;
};
struct CallbackCallerKey {
OperatorNodeBase* opr;
CompNode comp_node;
bool operator==(const CallbackCallerKey& rhs) const {
return opr == rhs.opr && comp_node == rhs.comp_node;
}
struct Hash {
size_t operator()(const CallbackCallerKey& b) const {
return hash_pair_combine(mgb::hash(b.opr),
mgb::hash(b.comp_node));
}
};
};
struct CallbackCallerVal {
SmallVector<VarNode*> vars;
//! indexs of vars in out_spec.
SmallVector<SmallVector<size_t>> indexs;
};
/*!
* Components for implementing algorithms on a computing graph.
*
......
......@@ -17,6 +17,7 @@
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/misc.h"
#include "megbrain/opr/indexing.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/graph/helper.h"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/graph/event.h"
......@@ -30,6 +31,7 @@
#include <atomic>
#include <chrono>
#include <array>
#include <memory>
using namespace mgb;
......@@ -2236,4 +2238,52 @@ TEST(TestGraph, FreeBias) {
}
}
TEST(TestGraph, CallbackCaller) {
using namespace opr;
auto cns = load_multiple_xpus(3);
constexpr size_t C1 = 20, C2 = 30, C3 = 10, C4 = 40;
constexpr size_t N = 2, C = C1 + C2;
HostTensorGenerator<> gen;
auto host_opr0 = gen({N, C}, cns[0]);
auto graph = ComputingGraph::make();
SymbolVar opr0 = opr::Host2DeviceCopy::make(*graph, host_opr0, {"opr0"});
auto spl0 = opr::Split::make(
opr0, Split::Options::make_partition(opr0, 1, {C1, C2}),
OperatorNodeConfig("split0").comp_node_arr({cns[1], cns[2]}));
auto spl1 = opr::Split::make(
opr0, Split::Options::make_partition(opr0, 1, {C3, C4}),
OperatorNodeConfig("split1"));
HostTensorND host_spl00, host_spl01, host_spl10, host_spl11;
auto func = graph->compile({make_callback_copy(spl0[0], host_spl00),
make_callback_copy(spl0[1], host_spl01),
make_callback_copy(spl1[0], host_spl10),
make_callback_copy(spl1[1], host_spl11)});
func->execute();
auto o00 = host_spl00.ptr<float>(),
o01 = host_spl01.ptr<float>(),
o10 = host_spl10.ptr<float>(),
o11 = host_spl11.ptr<float>(), c = host_opr0->ptr<float>();
for (size_t i = 0, it = host_opr0->layout().total_nr_elems(); i < it; i++) {
auto ch = i % C;
auto n = i / C;
if (ch < C1) {
MGB_ASSERT_FLOAT_EQ(o00[n * C1 + ch], c[i])
<< ssprintf("failed at %zd", i);
} else {
MGB_ASSERT_FLOAT_EQ(o01[n * C2 + ch - C1], c[i])
<< ssprintf("failed at %zd", i);
}
if (ch < C3) {
MGB_ASSERT_FLOAT_EQ(o10[n * C3 + ch], c[i])
<< ssprintf("failed at %zd", i);
} else {
MGB_ASSERT_FLOAT_EQ(o11[n * C4 + ch - C3], c[i])
<< ssprintf("failed at %zd", i);
}
}
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册