提交 42ec91aa 编写于 作者: M Megvii Engine Team

feat(serialization): adapted to dummy extern opr loader for MegCC

GitOrigin-RevId: 9cb32215204b822d47a0a9d72bab314ad6fb5a6d
上级 2a2c1275
...@@ -13,6 +13,7 @@ using namespace opr; ...@@ -13,6 +13,7 @@ using namespace opr;
namespace { namespace {
const char PLACEHOLDER_TYPE_NAME[] = "placeholder"; const char PLACEHOLDER_TYPE_NAME[] = "placeholder";
const char* DUMMY_LOADER_NAME = "extern_opr_dummy";
typedef MGBOprDesc* (*opr_desc_transformer_t)(void* input); typedef MGBOprDesc* (*opr_desc_transformer_t)(void* input);
...@@ -437,6 +438,14 @@ cg::OperatorNodeBase* ExternCOprRunner::make_placeholder( ...@@ -437,6 +438,14 @@ cg::OperatorNodeBase* ExternCOprRunner::make_placeholder(
const SymbolVarArray& inputs, const TensorShapeArray& output_shapes, const SymbolVarArray& inputs, const TensorShapeArray& output_shapes,
const char* name, const void* data, size_t data_len, const char* name, const void* data, size_t data_len,
const OperatorNodeConfig& config, const SmallVector<DType>& output_dtypes) { const OperatorNodeConfig& config, const SmallVector<DType>& output_dtypes) {
std::string _name(name);
size_t pos = 0;
if ((pos = _name.find(':')) != std::string::npos) {
_name = _name.substr(0, pos);
}
mgb_throw_if(
_name.compare(DUMMY_LOADER_NAME) == 0, MegBrainError,
"The \"extern_opr_dummy[:xxx]\" is a reserved loader name.\n");
auto desc = PlaceholderMGBOprDesc::make( auto desc = PlaceholderMGBOprDesc::make(
inputs.size(), name, output_shapes, output_dtypes, data, data_len); inputs.size(), name, output_shapes, output_dtypes, data, data_len);
...@@ -492,11 +501,45 @@ cg::OperatorNodeBase* ExternCOprRunner::load( ...@@ -492,11 +501,45 @@ cg::OperatorNodeBase* ExternCOprRunner::load(
name = name.substr(0, index); name = name.substr(0, index);
auto&& map = loader_map(); auto&& map = loader_map();
auto iter = map.find(name); auto iter = map.find(name);
// !!For MegCC loader.
// If MegCC compiles a model containing an extern opr, it needs to use an extern
// opr loader. In order to make MegCC's compiler not dependent on the extern opr
// loader, MegCC registers a dummy loader with the name "extern_opr_dummy".
bool dummy_loader = false;
if (iter == map.end()) {
mgb_log_debug(
"Can NOT find loader '%s', try to find the dummy loader '%s'\n",
name.c_str(), DUMMY_LOADER_NAME);
iter = map.find(DUMMY_LOADER_NAME);
if (iter != map.end()) {
dummy_loader = true;
mgb_log_debug("Found the dummy extern opr loader.\n");
} else {
mgb_log_debug("Can NOT find the dummy loader '%s'\n", DUMMY_LOADER_NAME);
}
}
mgb_assert( mgb_assert(
iter != map.end(), "can not find loader for ExternCOprRunner `%s'", iter != map.end(), "can not find loader for ExternCOprRunner '%s'",
name.c_str()); name.c_str());
auto data = ctx.load_shared_buf_with_len(); auto data = ctx.load_shared_buf_with_len();
auto desc = iter->second.first.create_desc(inputs.size(), data.data(), data.size()); MGBOprDesc* desc = nullptr;
// For MegCC loader.
// If the loader is a MegCC dummy loader, copy the 'dump_name' in front of the
// 'data' to relate 'data' to 'name'. So multi loaders and multi extern oprs in
// single loader are supported.
if (dummy_loader) {
size_t buf_len = data.size() + sizeof(size_t) + dump_name.size();
std::shared_ptr<void> buf{malloc(buf_len), free};
char* buf_ptr = reinterpret_cast<char*>(buf.get());
*(size_t*)buf_ptr = dump_name.size();
buf_ptr += sizeof(size_t);
memmove(buf_ptr, dump_name.c_str(), dump_name.size());
buf_ptr += dump_name.size();
memmove(buf_ptr, data.data(), data.size());
desc = iter->second.first.create_desc(inputs.size(), buf.get(), buf_len);
} else
desc = iter->second.first.create_desc(inputs.size(), data.data(), data.size());
mgb_throw_if(nullptr == desc, MegBrainError, "loader create desc returns nullptr"); mgb_throw_if(nullptr == desc, MegBrainError, "loader create desc returns nullptr");
......
...@@ -497,4 +497,23 @@ TEST(TestExternCOpr, Dedup) { ...@@ -497,4 +497,23 @@ TEST(TestExternCOpr, Dedup) {
ASSERT_EQ(0, MGBOprDescImpl<>::nr_inst); ASSERT_EQ(0, MGBOprDescImpl<>::nr_inst);
} }
TEST(TestExternCOpr, DummyLoaderName) {
float bias = 0.0f;
ASSERT_THROW(
opr::ExternCOprRunner::make_placeholder(
{SymbolVar()}, {TensorShape{1}, TensorShape{1}}, "extern_opr_dummy",
&bias, sizeof(bias)),
MegBrainError);
ASSERT_THROW(
opr::ExternCOprRunner::make_placeholder(
{SymbolVar()}, {TensorShape{1}, TensorShape{1}},
"extern_opr_dummy:opr0", &bias, sizeof(bias)),
MegBrainError);
ASSERT_THROW(
opr::ExternCOprRunner::make_placeholder(
{SymbolVar()}, {TensorShape{1}, TensorShape{1}},
"extern_opr_dummy:opr1", &bias, sizeof(bias)),
MegBrainError);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册