diff --git a/src/serialization/impl/extern_c_opr.cpp b/src/serialization/impl/extern_c_opr.cpp index 141dcdf31e11102510770bfa109c06106306476e..f6e217dbab10f4349402fe0dd59836a26d2a990e 100644 --- a/src/serialization/impl/extern_c_opr.cpp +++ b/src/serialization/impl/extern_c_opr.cpp @@ -13,6 +13,7 @@ using namespace opr; namespace { const char PLACEHOLDER_TYPE_NAME[] = "placeholder"; +const char* DUMMY_LOADER_NAME = "extern_opr_dummy"; typedef MGBOprDesc* (*opr_desc_transformer_t)(void* input); @@ -437,6 +438,14 @@ cg::OperatorNodeBase* ExternCOprRunner::make_placeholder( const SymbolVarArray& inputs, const TensorShapeArray& output_shapes, const char* name, const void* data, size_t data_len, const OperatorNodeConfig& config, const SmallVector& output_dtypes) { + std::string _name(name); + size_t pos = 0; + if ((pos = _name.find(':')) != std::string::npos) { + _name = _name.substr(0, pos); + } + mgb_throw_if( + _name.compare(DUMMY_LOADER_NAME) == 0, MegBrainError, + "The \"extern_opr_dummy[:xxx]\" is a reserved loader name.\n"); auto desc = PlaceholderMGBOprDesc::make( inputs.size(), name, output_shapes, output_dtypes, data, data_len); @@ -492,11 +501,45 @@ cg::OperatorNodeBase* ExternCOprRunner::load( name = name.substr(0, index); auto&& map = loader_map(); auto iter = map.find(name); + // !!For MegCC loader. + // If MegCC compiles a model containing an extern opr, it needs to use an extern + // opr loader. In order to make MegCC's compiler not dependent on the extern opr + // loader, MegCC registers a dummy loader with the name "extern_opr_dummy". + bool dummy_loader = false; + if (iter == map.end()) { + mgb_log_debug( + "Can NOT find loader '%s', try to find the dummy loader '%s'\n", + name.c_str(), DUMMY_LOADER_NAME); + iter = map.find(DUMMY_LOADER_NAME); + if (iter != map.end()) { + dummy_loader = true; + mgb_log_debug("Found the dummy extern opr loader.\n"); + } else { + mgb_log_debug("Can NOT find the dummy loader '%s'\n", DUMMY_LOADER_NAME); + } + } mgb_assert( - iter != map.end(), "can not find loader for ExternCOprRunner `%s'", + iter != map.end(), "can not find loader for ExternCOprRunner '%s'", name.c_str()); auto data = ctx.load_shared_buf_with_len(); - auto desc = iter->second.first.create_desc(inputs.size(), data.data(), data.size()); + MGBOprDesc* desc = nullptr; + + // For MegCC loader. + // If the loader is a MegCC dummy loader, copy the 'dump_name' in front of the + // 'data' to relate 'data' to 'name'. So multi loaders and multi extern oprs in + // single loader are supported. + if (dummy_loader) { + size_t buf_len = data.size() + sizeof(size_t) + dump_name.size(); + std::shared_ptr buf{malloc(buf_len), free}; + char* buf_ptr = reinterpret_cast(buf.get()); + *(size_t*)buf_ptr = dump_name.size(); + buf_ptr += sizeof(size_t); + memmove(buf_ptr, dump_name.c_str(), dump_name.size()); + buf_ptr += dump_name.size(); + memmove(buf_ptr, data.data(), data.size()); + desc = iter->second.first.create_desc(inputs.size(), buf.get(), buf_len); + } else + desc = iter->second.first.create_desc(inputs.size(), data.data(), data.size()); mgb_throw_if(nullptr == desc, MegBrainError, "loader create desc returns nullptr"); diff --git a/src/serialization/test/extern_c_opr.cpp b/src/serialization/test/extern_c_opr.cpp index 191788e00d482a90e21fda599c1b8def82817212..64bcbad6ae747f3f747e447f6b28b146b52a6557 100644 --- a/src/serialization/test/extern_c_opr.cpp +++ b/src/serialization/test/extern_c_opr.cpp @@ -497,4 +497,23 @@ TEST(TestExternCOpr, Dedup) { ASSERT_EQ(0, MGBOprDescImpl<>::nr_inst); } +TEST(TestExternCOpr, DummyLoaderName) { + float bias = 0.0f; + ASSERT_THROW( + opr::ExternCOprRunner::make_placeholder( + {SymbolVar()}, {TensorShape{1}, TensorShape{1}}, "extern_opr_dummy", + &bias, sizeof(bias)), + MegBrainError); + ASSERT_THROW( + opr::ExternCOprRunner::make_placeholder( + {SymbolVar()}, {TensorShape{1}, TensorShape{1}}, + "extern_opr_dummy:opr0", &bias, sizeof(bias)), + MegBrainError); + ASSERT_THROW( + opr::ExternCOprRunner::make_placeholder( + {SymbolVar()}, {TensorShape{1}, TensorShape{1}}, + "extern_opr_dummy:opr1", &bias, sizeof(bias)), + MegBrainError); +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}