提交 2b8e7940 编写于 作者: M Megvii Engine Team 提交者: XindaH

fix(lite/cambricon): fix cambricon models which have multiple comp node

GitOrigin-RevId: 624fd7f0ce7cadcaabf5584b10619532a7f5a231
上级 cfad9a5d
......@@ -20,11 +20,20 @@ namespace {
namespace magicmind_runtime {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
#if CNRT_MAJOR_VERSION >= 5
auto&& op = static_cast<const MagicMindRuntime&>(def);
SymbolVarArray symbol_var_inputs(inputs.begin(), inputs.end());
OperatorNodeConfig config{op.make_name()};
return opr::MagicMindRuntimeOpr::make(
op.buf.c_str(), op.buf_size, symbol_var_inputs, config);
#else
mgb_assert(
false,
"Magicmind runtime opr is disabled at compile time, the reason of which is "
"the version of cnrt runtime is lower than 5.0. Please check the version "
"of your cambricon toolkit, and recompile megengine.");
return SymbolVar{};
#endif
}
OP_TRAIT_REG(MagicMindRuntime, MagicMindRuntime)
.apply_on_var_node(apply_on_var_node)
......
......@@ -129,6 +129,17 @@ void NetworkImplDft::application_config() {
loc.stream = m_nr_threads;
}
};
//! currently not set Locator type because a cambricon mgb model is a
//! cross-compnode graph
} else if (device_type == LiteDeviceType::LITE_CAMBRICON) {
m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) {
if (loc.type == mgb::CompNode::DeviceType::CAMBRICON) {
loc.device = m_compnode_locator.device;
loc.stream = m_compnode_locator.stream;
} else if (loc.type == mgb::CompNode::DeviceType::MULTITHREAD) {
loc.stream = m_nr_threads;
}
};
} else {
m_load_config.comp_node_mapper = [this](mgb::CompNode::Locator& loc) {
loc = m_compnode_locator;
......
......@@ -14,6 +14,7 @@
#include "megbrain/comp_node_env.h"
#if MGB_CAMBRICON
#if CNRT_MAJOR_VERSION >= 5
using namespace mgb;
using namespace opr;
......@@ -387,6 +388,7 @@ SymbolVarArray MagicMindRuntimeOpr::make(
return make(std::move(model), std::move(cambricon_allocator), src, config);
}
#endif // CNRT_MAJOR_VERSION
#endif // MGB_CAMBRICON
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -12,6 +12,8 @@
#include "megbrain/cambricon/magicmind_runtime_opr.h"
#include "megbrain/serialization/sereg.h"
#if CNRT_MAJOR_VERSION >= 5
namespace mgb {
namespace serialization {
......@@ -62,4 +64,6 @@ MGB_REG_OPR_SHALLOW_COPY(MagicMindRuntimeOpr, opr_shallow_copy_magicmind_runtime
} // namespace opr
} // namespace mgb
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -15,6 +15,7 @@
#include "megbrain/serialization/file.h"
#if MGB_CAMBRICON
#if CNRT_MAJOR_VERSION >= 5
#include <sstream>
#include "interface_runtime.h"
......@@ -99,6 +100,7 @@ private:
} // namespace opr
} // namespace mgb
#endif // CNRT_MAJOR_VERSION
#endif // MGB_CAMBRICON
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -17,6 +17,7 @@
#include "megbrain/test/helper.h"
#if MGB_CAMBRICON
#if CNRT_MAJOR_VERSION >= 5
#include "megbrain/cambricon/magicmind_runtime_opr.h"
......@@ -827,6 +828,7 @@ TEST(TestMagicMindRuntimeOpr, CrossCNCopy) {
MGB_ASSERT_TENSOR_NEAR(o2, o2_mm, 1e-4);
}
#endif
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -1097,7 +1097,8 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by(Impl* cn_imp
mgb_throw_if(
type != CompNode::DeviceType::CPU &&
type != CompNode::DeviceType::CUDA
&& type != CompNode::DeviceType::ATLAS
&& type != CompNode::DeviceType::ATLAS &&
type != CompNode::DeviceType::CAMBRICON
,
MegBrainError,
"currently CPU can only wait for CPU, CUDA, ATLAS"
......@@ -1116,7 +1117,7 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by(Impl* cn_imp
#else
mgb_throw(
MegBrainError,
"Cambricon comp_node used but MGB_CAMBRICON not enabled");
"Cambricon comp_node used but CAMBRICON BUILD not enabled");
#endif
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册