提交 2b8e7940 编写于 作者: M Megvii Engine Team 提交者: XindaH

fix(lite/cambricon): fix cambricon models which have multiple comp node

GitOrigin-RevId: 624fd7f0ce7cadcaabf5584b10619532a7f5a231
上级 cfad9a5d
...@@ -20,11 +20,20 @@ namespace { ...@@ -20,11 +20,20 @@ namespace {
namespace magicmind_runtime { namespace magicmind_runtime {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
#if CNRT_MAJOR_VERSION >= 5
auto&& op = static_cast<const MagicMindRuntime&>(def); auto&& op = static_cast<const MagicMindRuntime&>(def);
SymbolVarArray symbol_var_inputs(inputs.begin(), inputs.end()); SymbolVarArray symbol_var_inputs(inputs.begin(), inputs.end());
OperatorNodeConfig config{op.make_name()}; OperatorNodeConfig config{op.make_name()};
return opr::MagicMindRuntimeOpr::make( return opr::MagicMindRuntimeOpr::make(
op.buf.c_str(), op.buf_size, symbol_var_inputs, config); op.buf.c_str(), op.buf_size, symbol_var_inputs, config);
#else
mgb_assert(
false,
"Magicmind runtime opr is disabled at compile time, the reason of which is "
"the version of cnrt runtime is lower than 5.0. Please check the version "
"of your cambricon toolkit, and recompile megengine.");
return SymbolVar{};
#endif
} }
OP_TRAIT_REG(MagicMindRuntime, MagicMindRuntime) OP_TRAIT_REG(MagicMindRuntime, MagicMindRuntime)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
......
...@@ -129,6 +129,17 @@ void NetworkImplDft::application_config() { ...@@ -129,6 +129,17 @@ void NetworkImplDft::application_config() {
loc.stream = m_nr_threads; loc.stream = m_nr_threads;
} }
}; };
//! currently not set Locator type because a cambricon mgb model is a
//! cross-compnode graph
} else if (device_type == LiteDeviceType::LITE_CAMBRICON) {
m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) {
if (loc.type == mgb::CompNode::DeviceType::CAMBRICON) {
loc.device = m_compnode_locator.device;
loc.stream = m_compnode_locator.stream;
} else if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD) {
loc.stream = m_nr_threads;
}
};
} else { } else {
m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) { m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) {
loc = m_compnode_locator; loc = m_compnode_locator;
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "megbrain/comp_node_env.h" #include "megbrain/comp_node_env.h"
#if MGB_CAMBRICON #if MGB_CAMBRICON
#if CNRT_MAJOR_VERSION >= 5
using namespace mgb; using namespace mgb;
using namespace opr; using namespace opr;
...@@ -168,7 +169,7 @@ MagicMindRuntimeOpr::MagicMindRuntimeOpr( ...@@ -168,7 +169,7 @@ MagicMindRuntimeOpr::MagicMindRuntimeOpr(
m_allocator{std::move(allocator)}, m_allocator{std::move(allocator)},
m_engine{nullptr}, m_engine{nullptr},
m_context{nullptr}, m_context{nullptr},
m_model{std::move(model)}, m_model{std::move(model)},
m_current_ptr{nullptr} { m_current_ptr{nullptr} {
mgb_assert( mgb_assert(
inputs[0]->comp_node().device_type() == CompNode::DeviceType::CAMBRICON, inputs[0]->comp_node().device_type() == CompNode::DeviceType::CAMBRICON,
...@@ -387,6 +388,7 @@ SymbolVarArray MagicMindRuntimeOpr::make( ...@@ -387,6 +388,7 @@ SymbolVarArray MagicMindRuntimeOpr::make(
return make(std::move(model), std::move(cambricon_allocator), src, config); return make(std::move(model), std::move(cambricon_allocator), src, config);
} }
#endif // CNRT_MAJOR_VERSION
#endif // MGB_CAMBRICON #endif // MGB_CAMBRICON
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
#include "megbrain/cambricon/magicmind_runtime_opr.h" #include "megbrain/cambricon/magicmind_runtime_opr.h"
#include "megbrain/serialization/sereg.h" #include "megbrain/serialization/sereg.h"
#if CNRT_MAJOR_VERSION >= 5
namespace mgb { namespace mgb {
namespace serialization { namespace serialization {
...@@ -62,4 +64,6 @@ MGB_REG_OPR_SHALLOW_COPY(MagicMindRuntimeOpr, opr_shallow_copy_magicmind_runtime ...@@ -62,4 +64,6 @@ MGB_REG_OPR_SHALLOW_COPY(MagicMindRuntimeOpr, opr_shallow_copy_magicmind_runtime
} // namespace opr } // namespace opr
} // namespace mgb } // namespace mgb
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "megbrain/serialization/file.h" #include "megbrain/serialization/file.h"
#if MGB_CAMBRICON #if MGB_CAMBRICON
#if CNRT_MAJOR_VERSION >= 5
#include <sstream> #include <sstream>
#include "interface_runtime.h" #include "interface_runtime.h"
...@@ -99,6 +100,7 @@ private: ...@@ -99,6 +100,7 @@ private:
} // namespace opr } // namespace opr
} // namespace mgb } // namespace mgb
#endif // CNRT_MAJOR_VERSION
#endif // MGB_CAMBRICON #endif // MGB_CAMBRICON
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "megbrain/test/helper.h" #include "megbrain/test/helper.h"
#if MGB_CAMBRICON #if MGB_CAMBRICON
#if CNRT_MAJOR_VERSION >= 5
#include "megbrain/cambricon/magicmind_runtime_opr.h" #include "megbrain/cambricon/magicmind_runtime_opr.h"
...@@ -827,6 +828,7 @@ TEST(TestMagicMindRuntimeOpr, CrossCNCopy) { ...@@ -827,6 +828,7 @@ TEST(TestMagicMindRuntimeOpr, CrossCNCopy) {
MGB_ASSERT_TENSOR_NEAR(o2, o2_mm, 1e-4); MGB_ASSERT_TENSOR_NEAR(o2, o2_mm, 1e-4);
} }
#endif
#endif #endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -1097,7 +1097,8 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by(Impl* cn_imp ...@@ -1097,7 +1097,8 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by(Impl* cn_imp
mgb_throw_if( mgb_throw_if(
type != CompNode::DeviceType::CPU && type != CompNode::DeviceType::CPU &&
type != CompNode::DeviceType::CUDA type != CompNode::DeviceType::CUDA
&& type != CompNode::DeviceType::ATLAS && type != CompNode::DeviceType::ATLAS &&
type != CompNode::DeviceType::CAMBRICON
, ,
MegBrainError, MegBrainError,
"currently CPU can only wait for CPU, CUDA, ATLAS" "currently CPU can only wait for CPU, CUDA, ATLAS"
...@@ -1116,7 +1117,7 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by(Impl* cn_imp ...@@ -1116,7 +1117,7 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by(Impl* cn_imp
#else #else
mgb_throw( mgb_throw(
MegBrainError, MegBrainError,
"Cambricon comp_node used but MGB_CAMBRICON not enabled"); "Cambricon comp_node used but CAMBRICON BUILD not enabled");
#endif #endif
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册