提交 262e124b 编写于 作者: M Megvii Engine Team 提交者: XindaH

feat(lite): add lite interface for cambricon models

GitOrigin-RevId: cd5feafd7be88c4ef6153400976795978abf7e87
上级 0c3699ff
......@@ -31,9 +31,10 @@ typedef enum {
LITE_CUDA = 1,
LITE_ATLAS = 3,
LITE_NPU = 4,
LITE_CAMBRICON = 5,
//! when the device information is set in model, so set LITE_DEVICE_DEFAULT
//! in lite
LITE_DEVICE_DEFAULT = 5,
LITE_DEVICE_DEFAULT = 6,
} LiteDeviceType;
typedef enum {
......
......@@ -21,7 +21,8 @@ class LiteDeviceType(IntEnum):
LITE_CUDA = 1
LITE_ATLAS = 3
LITE_NPU = 4
LITE_DEVICE_DEFAULT = 5
LITE_CAMBRICON = 5
LITE_DEVICE_DEFAULT = 6
class LiteDataType(IntEnum):
......
......@@ -164,6 +164,9 @@ mgb::CompNode::Locator lite::to_compnode_locator(const LiteDeviceType& device) {
case LiteDeviceType::LITE_ATLAS:
loc.type = mgb::CompNode::DeviceType::ATLAS;
break;
case LiteDeviceType::LITE_CAMBRICON:
loc.type = mgb::CompNode::DeviceType::CAMBRICON;
break;
case LiteDeviceType::LITE_DEVICE_DEFAULT:
loc.type = mgb::CompNode::DeviceType::UNSPEC;
break;
......@@ -183,6 +186,8 @@ LiteDeviceType lite::get_device_from_locator(const mgb::CompNode::Locator& locat
return LiteDeviceType::LITE_CUDA;
case mgb::CompNode::DeviceType::ATLAS:
return LiteDeviceType::LITE_ATLAS;
case mgb::CompNode::DeviceType::CAMBRICON:
return LiteDeviceType::LITE_CAMBRICON;
case mgb::CompNode::DeviceType::UNSPEC:
return LiteDeviceType::LITE_DEVICE_DEFAULT;
default:
......
......@@ -69,6 +69,8 @@ bool default_parse_info(
return LiteDeviceType::LITE_CUDA;
if (type == "ATLAS")
return LiteDeviceType::LITE_ATLAS;
if (type == "CAMBRICON")
return LiteDeviceType::LITE_CAMBRICON;
if (type == "NPU")
return LiteDeviceType::LITE_NPU;
else {
......
......@@ -1232,72 +1232,102 @@ TEST(TestNetWork, DeviceAsyncExec) {
#endif
#endif
#if MGB_ATLAS
TEST(TestNetWork, AtlasLoadNoDevice) {
#if MGB_ATLAS || MGB_CAMBRICON
namespace {
void load_no_device(LiteDeviceType device_type, const std::string& model_path) {
lite::Config config;
config.device_type = LiteDeviceType::LITE_DEVICE_DEFAULT;
config.device_type = device_type;
auto network = std::make_shared<lite::Network>(config);
network->load_model("./model_atlas.mgb");
network->load_model(model_path);
network->forward();
network->wait();
}
TEST(TestNetWork, AtlasLoadDeviceInput) {
void load_device_input(
LiteDeviceType device_type, const std::string& model_path,
const std::vector<std::string>& inputs) {
lite::NetworkIO networkio;
lite::IO input_data_io = {};
input_data_io.name = "data";
input_data_io.name = inputs[0];
input_data_io.is_host = false;
networkio.inputs.emplace_back(input_data_io);
lite::IO input_input0_io = {};
input_input0_io.name = "input0";
input_input0_io.name = inputs[1];
input_input0_io.is_host = false;
networkio.inputs.emplace_back(input_input0_io);
lite::Config config;
config.device_type = LiteDeviceType::LITE_DEVICE_DEFAULT;
config.device_type = device_type;
auto network = std::make_shared<lite::Network>(config, networkio);
network->load_model("./model_atlas.mgb");
network->load_model(model_path);
network->forward();
network->wait();
}
TEST(TestNetWork, AtlasLoadAtlas) {
void load_device_id(
LiteDeviceType device_type, int device_id, const std::string& model_path) {
lite::Config config;
config.device_type = LiteDeviceType::LITE_ATLAS;
config.device_type = device_type;
auto network = std::make_shared<lite::Network>(config);
network->load_model("./model_atlas.mgb");
network->set_device_id(device_id);
network->load_model(model_path);
std::shared_ptr<Tensor> input_tensor = network->get_input_tensor(0);
std::shared_ptr<Tensor> output_tensor = network->get_output_tensor(0);
network->forward();
network->wait();
ASSERT_EQ(output_tensor->get_device_id(), device_id);
}
} // namespace
#endif
#if MGB_ATLAS
TEST(TestNetWork, AtlasLoadNoDevice) {
load_no_device(LiteDeviceType::LITE_DEVICE_DEFAULT, "./model_atlas.mgb");
}
TEST(TestNetWork, AtlasLoadDeviceInput) {
load_device_input(
LiteDeviceType::LITE_DEVICE_DEFAULT, "./model_atlas.mgb",
{"data", "input0"});
}
TEST(TestNetWork, AtlasLoadAtlas) {
load_no_device(LiteDeviceType::LITE_ATLAS, "./model_atlas.mgb");
}
TEST(TestNetWork, AtlasLoadAtlasDeviceInput) {
lite::NetworkIO networkio;
lite::IO input_data_io = {};
input_data_io.name = "data";
input_data_io.is_host = false;
networkio.inputs.emplace_back(input_data_io);
lite::IO input_input0_io = {};
input_input0_io.name = "input0";
input_input0_io.is_host = false;
networkio.inputs.emplace_back(input_input0_io);
lite::Config config;
config.device_type = LiteDeviceType::LITE_ATLAS;
auto network = std::make_shared<lite::Network>(config, networkio);
network->load_model("./model_atlas.mgb");
network->forward();
network->wait();
load_device_input(
LiteDeviceType::LITE_ATLAS, "./model_atlas.mgb", {"data", "input0"});
}
TEST(TestNetWork, AtlasDeviceID) {
lite::Config config;
config.device_type = LiteDeviceType::LITE_ATLAS;
auto network = std::make_shared<lite::Network>(config);
network->set_device_id(1);
network->load_model("./model_atlas.mgb");
std::shared_ptr<Tensor> input_tensor = network->get_input_tensor(0);
std::shared_ptr<Tensor> output_tensor = network->get_output_tensor(0);
network->forward();
network->wait();
ASSERT_EQ(output_tensor->get_device_id(), 1);
load_device_id(LiteDeviceType::LITE_ATLAS, 1, "./model_atlas.mgb");
}
#endif
#if MGB_CAMBRICON
TEST(TestNetWork, CambriconLoadNoDevice) {
load_no_device(LiteDeviceType::LITE_DEVICE_DEFAULT, "./model_magicmind.mgb");
}
TEST(TestNetWork, CambriconLoadDeviceInput) {
load_device_input(
LiteDeviceType::LITE_DEVICE_DEFAULT, "./model_magicmind.mgb",
{"data", "input0"});
}
TEST(TestNetWork, CambriconLoadCambricon) {
load_no_device(LiteDeviceType::LITE_CAMBRICON, "./model_magicmind.mgb");
}
TEST(TestNetWork, CambriconLoadCambriconDeviceInput) {
load_device_input(
LiteDeviceType::LITE_CAMBRICON, "./model_magicmind.mgb",
{"data", "input0"});
}
TEST(TestNetWork, CambriconDeviceID) {
load_device_id(LiteDeviceType::LITE_CAMBRICON, 0, "./model_magicmind.mgb");
}
#endif
#endif
......
......@@ -679,7 +679,7 @@ TEST(TestMagicMindRuntimeOpr, Serialization) {
reinterpret_cast<const void*>(buf.data()), buf.size(), {x_, add_});
auto out1 = outs[0];
auto out2 = outs[1];
auto fname = output_file("MagicMindRuntimeOprTest");
auto fname = output_file("model_magicmind.mgb");
auto dump = [&]() {
auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str()));
auto rst = dumper->dump({out1, out2});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册