提交 a599725c 编写于 作者: M Megvii Engine Team

fix(mgb/core): fix operator input waiting spec

GitOrigin-RevId: 5906dd7ee6c0bdb3c8187681253ee80f50945132
上级 528ef930
...@@ -45,6 +45,10 @@ class PostExecActions { ...@@ -45,6 +45,10 @@ class PostExecActions {
} }
}; };
CompNode m_comp_node; CompNode m_comp_node;
// VarNodes in m_items should be listed in the same order as in the
// output of the owner_opr, because opr would generate input_wating_spec()
// according to this order
// see `SeqCompNodeOptimizerImpl::init_ready_event()` for more details
SmallVector<Item> m_items; SmallVector<Item> m_items;
MGB_IF_COND_EXEC(ExecutionMask* m_mask = nullptr); MGB_IF_COND_EXEC(ExecutionMask* m_mask = nullptr);
......
...@@ -216,16 +216,19 @@ void SeqCompNodeOptimizerImpl::init_ready_event( ...@@ -216,16 +216,19 @@ void SeqCompNodeOptimizerImpl::init_ready_event(
} }
m_cnpair2opr_step.clear(); m_cnpair2opr_step.clear();
// opr step, idx of output
using VarStep = std::pair<size_t, size_t>;
// cn0 -> (cn1 -> step): step on cn1 is known to have finished for current // cn0 -> (cn1 -> step): step on cn1 is known to have finished for current
// opr on cn0 // opr on cn0
CompNode::UnorderedMap<CompNode::UnorderedMap<size_t>> cnpair2step; CompNode::UnorderedMap<CompNode::UnorderedMap<VarStep>> cnpair2step;
// vars to be waited on for current opr; only the latest var needs to be // vars to be waited on for current opr; only the latest var needs to be
// waited for each comp node // waited for each comp node
CompNode::UnorderedMap<VarNode*> vars_to_wait; CompNode::UnorderedMap<VarNode*> vars_to_wait;
CompNode::UnorderedSet cur_used_cn; CompNode::UnorderedSet cur_used_cn;
ThinHashMap<OperatorNodeBase*, size_t> opr2step; ThinHashMap<VarNode*, VarStep> var2step;
size_t cur_step = 0; size_t cur_step = 0;
using OprNodeProp = OperatorNodeBase::NodeProp; using OprNodeProp = OperatorNodeBase::NodeProp;
...@@ -266,7 +269,7 @@ void SeqCompNodeOptimizerImpl::init_ready_event( ...@@ -266,7 +269,7 @@ void SeqCompNodeOptimizerImpl::init_ready_event(
} }
if ((OprNodeProp::is_device_comp_order_dep(i.second) && if ((OprNodeProp::is_device_comp_order_dep(i.second) &&
i.first->comp_node() != cn) || pdv_need_sync_host) { i.first->comp_node() != cn) || pdv_need_sync_host) {
auto step = opr2step.at(i.first->owner_opr()); auto step = var2step.at(i.first);
auto ins = dep2step.insert({i.first->comp_node(), step}); auto ins = dep2step.insert({i.first->comp_node(), step});
// only wait for var if it is beyond currently known // only wait for var if it is beyond currently known
// synchronized step // synchronized step
...@@ -290,16 +293,25 @@ void SeqCompNodeOptimizerImpl::init_ready_event( ...@@ -290,16 +293,25 @@ void SeqCompNodeOptimizerImpl::init_ready_event(
auto&& record = m_cnpair2opr_step[cn]; auto&& record = m_cnpair2opr_step[cn];
for (auto&& i : vars_to_wait) { for (auto&& i : vars_to_wait) {
auto step_done = opr2step.at(i.second->owner_opr()); auto step_done = var2step.at(i.second).first;
auto&& seq = record[i.first]; auto&& seq = record[i.first];
mgb_assert(seq.empty() || step_done > seq.back().second); // for multi-output operator, there might be multiple other
seq.emplace_back(cur_step, step_done); // operators which depand on different output varnodes, and
// those output vars share the same opr step number
mgb_assert(seq.empty() || step_done >= seq.back().second);
if (seq.empty() || step_done > seq.back().second) {
seq.emplace_back(cur_step, step_done);
}
} }
} }
} }
opr->input_waiting_spec(std::move(waiting_spec)); opr->input_waiting_spec(std::move(waiting_spec));
opr2step[opr] = cur_step ++; auto&& usable_output = opr->usable_output();
for (size_t i = 0; i < usable_output.size(); ++ i) {
var2step[usable_output[i]] = {cur_step, i};
}
cur_step ++;
} }
mgb_assert(cur_step == seq.size()); mgb_assert(cur_step == seq.size());
} }
......
...@@ -1085,6 +1085,22 @@ TEST(TestGraph, DynShapeDepCrossCN) { ...@@ -1085,6 +1085,22 @@ TEST(TestGraph, DynShapeDepCrossCN) {
ASSERT_EQ(24.f, host_b.ptr<int>()[0]); ASSERT_EQ(24.f, host_b.ptr<int>()[0]);
} }
namespace {
void check_wait(SymbolVar dest, SymbolVar dep) {
if (!dep.node()) {
ASSERT_EQ(0u,
dest.node()->owner_opr()->input_waiting_spec().size());
return;
}
cg::OperatorNodeBase::InputWaitingSpecElem ws;
unpack_vector(dest.node()->owner_opr()->input_waiting_spec(), ws);
ASSERT_EQ(ws.comp_node, dest.node()->comp_node());
VarNode *get;
unpack_vector(ws.dev_ready, get);
ASSERT_EQ(dep, get);
};
}
TEST(TestGraph, InputWaitingSpec) { TEST(TestGraph, InputWaitingSpec) {
auto cns = load_multiple_xpus(2); auto cns = load_multiple_xpus(2);
constexpr size_t SIZE = 12345; constexpr size_t SIZE = 12345;
...@@ -1115,26 +1131,40 @@ TEST(TestGraph, InputWaitingSpec) { ...@@ -1115,26 +1131,40 @@ TEST(TestGraph, InputWaitingSpec) {
MGB_ASSERT_FLOAT_EQ(px[i] + 1, pz0[i]); MGB_ASSERT_FLOAT_EQ(px[i] + 1, pz0[i]);
MGB_ASSERT_FLOAT_EQ(px[i] + 2, pz1[i]); MGB_ASSERT_FLOAT_EQ(px[i] + 2, pz1[i]);
} }
auto check_wait = [](SymbolVar dest, SymbolVar dep) {
if (!dep.node()) {
ASSERT_EQ(0u,
dest.node()->owner_opr()->input_waiting_spec().size());
return;
}
cg::OperatorNodeBase::InputWaitingSpecElem ws;
unpack_vector(dest.node()->owner_opr()->input_waiting_spec(), ws);
ASSERT_EQ(ws.comp_node, dest.node()->comp_node());
VarNode *get;
unpack_vector(ws.dev_ready, get);
ASSERT_EQ(dep, get);
};
check_wait(y0, x); check_wait(y0, x);
check_wait(y1, x + 1); check_wait(y1, x + 1);
check_wait(z1, y1 + 1); check_wait(z1, y1 + 1);
check_wait(z0, {}); check_wait(z0, {});
} }
TEST(TestGraph, InputWaitingSpecMultiOut) {
auto cn0 = CompNode::load("xpu0:0"), cn1 = CompNode::load("xpu0:1");
HostTensorGenerator<> gen;
auto graph = cg::ComputingGraph::make();
graph->options().graph_opt_level = 0;
graph->options().var_sanity_check_first_run = 0;
graph->options().async_exec_level = 0b100;
graph->options().seq_opt.enable_seq_comp_node_opt = false;
size_t nr_out = 1024, length = 32;
auto hv = gen({nr_out * length}, cn0);
auto x = opr::Host2DeviceCopy::make(*graph, hv);
auto outs = opr::Split::make(x, opr::Split::Options::make_average(0, nr_out));
cg::ComputingGraph::OutputSpec output_spec;
for (size_t i = 0; i < nr_out; ++ i) {
auto y = opr::Copy::make(outs[i], cn1);
y.node()->owner_opr()->node_prop().attribute().priority = i ? nr_out - i : 0;
output_spec.push_back({y, {}});
}
auto func = graph->compile(output_spec);
func->execute().wait();
check_wait(output_spec[0].first, outs[0]);
check_wait(output_spec[nr_out - 1].first, outs[nr_out - 1]);
for (size_t i = 1; i < nr_out - 1; ++ i) {
check_wait(output_spec[i].first, {});
}
}
TEST(TestGraph, GradStaticShape) { TEST(TestGraph, GradStaticShape) {
for (bool enable: {false, true}) { for (bool enable: {false, true}) {
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册