提交 cfad9a5d 编写于 作者: M Megvii Engine Team 提交者: XindaH

fix(mgb/cambricon): fix magicmind runtime opr when set workspace point second time

GitOrigin-RevId: 1ac9d0eabad312dcbcffacd7f35522a37884ddb3
上级 262e124b
......@@ -168,7 +168,8 @@ MagicMindRuntimeOpr::MagicMindRuntimeOpr(
m_allocator{std::move(allocator)},
m_engine{nullptr},
m_context{nullptr},
m_model{std::move(model)} {
m_model{std::move(model)},
m_current_ptr{nullptr} {
mgb_assert(
inputs[0]->comp_node().device_type() == CompNode::DeviceType::CAMBRICON,
"MagicMindRuntimeOpr can only be used on cambricon comp node; "
......@@ -230,8 +231,18 @@ void MagicMindRuntimeOpr::scn_do_execute() {
MM_CHECK(tensor->SetDimensions(mgb_shape_to_mm_dims(output(i)->shape())));
MM_CHECK(tensor->SetData(output(i)->dev_tensor().raw_ptr()));
}
auto size = output().back()->dev_tensor().layout().span().dist_byte();
MM_CHECK(m_context->SetWorkspace(output().back()->dev_tensor().raw_ptr(), size));
if (m_current_ptr == nullptr) {
auto size = output().back()->dev_tensor().layout().span().dist_byte();
m_current_ptr = output().back()->dev_tensor().raw_ptr();
MM_CHECK(m_context->SetWorkspace(m_current_ptr, size));
} else {
auto current_ptr = output().back()->dev_tensor().raw_ptr();
mgb_assert(
current_ptr == m_current_ptr,
"workspace has been changed, the execution context should be "
"reconstructed, but now this is not supported (got:%p,prev:%p)",
current_ptr, m_current_ptr);
}
MM_CHECK(m_context->Enqueue(inputs, outputs, cnrt_env.queue));
for (auto&& i : inputs) {
i->Destroy();
......@@ -293,7 +304,6 @@ void MagicMindRuntimeOpr::get_output_var_shape(
false, "static shape infer for MagicMindRuntimeOpr(%s) failed",
cname());
}
return;
for (auto&& i : inputs) {
i->Destroy();
}
......
......@@ -93,6 +93,7 @@ private:
IEnginePtr m_engine;
mutable IContextPtr m_context;
IModelPtr m_model;
dt_byte* m_current_ptr;
};
} // namespace opr
......
......@@ -642,6 +642,7 @@ TEST(TestMagicMindRuntimeOpr, GraphShapeMutable) {
auto func = graph->compile(
{make_callback_copy(out1, o1), make_callback_copy(out2, o2)});
func->execute();
func->execute();
HostTensorND o1_mm(cn, mkshp(no, co, ho, wo), dtype::Float32()),
o2_mm(cn, mkshp(no, co, ho, wo), dtype::Float32());
std::memcpy(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册