diff --git a/src/cambricon/impl/magicmind_runtime_opr.cpp b/src/cambricon/impl/magicmind_runtime_opr.cpp index c25b0b8d0e71c2e86d1a88d27bc045e88448419a..7bf5c534232279b0cf4d47f119fd5a26ccec42a9 100644 --- a/src/cambricon/impl/magicmind_runtime_opr.cpp +++ b/src/cambricon/impl/magicmind_runtime_opr.cpp @@ -168,7 +168,8 @@ MagicMindRuntimeOpr::MagicMindRuntimeOpr( m_allocator{std::move(allocator)}, m_engine{nullptr}, m_context{nullptr}, - m_model{std::move(model)} { + m_model{std::move(model)}, + m_current_ptr{nullptr} { mgb_assert( inputs[0]->comp_node().device_type() == CompNode::DeviceType::CAMBRICON, "MagicMindRuntimeOpr can only be used on cambricon comp node; " @@ -230,8 +231,18 @@ void MagicMindRuntimeOpr::scn_do_execute() { MM_CHECK(tensor->SetDimensions(mgb_shape_to_mm_dims(output(i)->shape()))); MM_CHECK(tensor->SetData(output(i)->dev_tensor().raw_ptr())); } - auto size = output().back()->dev_tensor().layout().span().dist_byte(); - MM_CHECK(m_context->SetWorkspace(output().back()->dev_tensor().raw_ptr(), size)); + if (m_current_ptr == nullptr) { + auto size = output().back()->dev_tensor().layout().span().dist_byte(); + m_current_ptr = output().back()->dev_tensor().raw_ptr(); + MM_CHECK(m_context->SetWorkspace(m_current_ptr, size)); + } else { + auto current_ptr = output().back()->dev_tensor().raw_ptr(); + mgb_assert( + current_ptr == m_current_ptr, + "workspace has been changed, the execution context should be " + "reconstructed, but now this is not supported (got:%p,prev:%p)", + current_ptr, m_current_ptr); + } MM_CHECK(m_context->Enqueue(inputs, outputs, cnrt_env.queue)); for (auto&& i : inputs) { i->Destroy(); @@ -293,7 +304,6 @@ void MagicMindRuntimeOpr::get_output_var_shape( false, "static shape infer for MagicMindRuntimeOpr(%s) failed", cname()); } - return; for (auto&& i : inputs) { i->Destroy(); } diff --git a/src/cambricon/include/megbrain/cambricon/magicmind_runtime_opr.h b/src/cambricon/include/megbrain/cambricon/magicmind_runtime_opr.h index b72cf87162679f36dcabe81ef029aa50eb9b14fc..0fd8a318f0f904e4a03cd12e9873d10f8d3483b7 100644 --- a/src/cambricon/include/megbrain/cambricon/magicmind_runtime_opr.h +++ b/src/cambricon/include/megbrain/cambricon/magicmind_runtime_opr.h @@ -93,6 +93,7 @@ private: IEnginePtr m_engine; mutable IContextPtr m_context; IModelPtr m_model; + dt_byte* m_current_ptr; }; } // namespace opr diff --git a/src/cambricon/test/magicmind_runtime_opr.cpp b/src/cambricon/test/magicmind_runtime_opr.cpp index 4a7a446041f39f6e6bfc4b473ceaf1f3246e9f50..7ddeb1117e4340bcd377850db3452d58ee267b07 100644 --- a/src/cambricon/test/magicmind_runtime_opr.cpp +++ b/src/cambricon/test/magicmind_runtime_opr.cpp @@ -642,6 +642,7 @@ TEST(TestMagicMindRuntimeOpr, GraphShapeMutable) { auto func = graph->compile( {make_callback_copy(out1, o1), make_callback_copy(out2, o2)}); func->execute(); + func->execute(); HostTensorND o1_mm(cn, mkshp(no, co, ho, wo), dtype::Float32()), o2_mm(cn, mkshp(no, co, ho, wo), dtype::Float32()); std::memcpy(