提交 905fab56 编写于 作者: M Megvii Engine Team

fix(core): fix multithread load graph

GitOrigin-RevId: 696ec86d6776a470c3eb5424bc356a5664f4df65
上级 0257b2e6
#include <thread>
#include "lite_build_config.h"
#if LITE_BUILD_WITH_MGE
......@@ -87,6 +88,68 @@ TEST(TestNetWork, SetDeviceId) {
ASSERT_EQ(output_tensor->get_device_id(), 4);
}
TEST(TestNetWork, MutliThreadLoad) {
Config config;
std::string model_path = "./shufflenet.mge";
std::shared_ptr<Network> network0 = std::make_shared<Network>(config);
std::shared_ptr<Network> network1 = std::make_shared<Network>(config);
std::shared_ptr<Network> network2 = std::make_shared<Network>(config);
std::shared_ptr<Network> network3 = std::make_shared<Network>(config);
auto func0 = [&] {
network0->set_device_id(0);
network0->load_model(model_path);
std::shared_ptr<Tensor> input_tensor = network0->get_input_tensor(0);
std::shared_ptr<Tensor> output_tensor = network0->get_output_tensor(0);
network0->forward();
network0->wait();
ASSERT_EQ(input_tensor->get_device_id(), 0);
ASSERT_EQ(output_tensor->get_device_id(), 0);
};
auto func1 = [&] {
network1->set_device_id(0);
network1->load_model(model_path);
std::shared_ptr<Tensor> input_tensor = network1->get_input_tensor(0);
std::shared_ptr<Tensor> output_tensor = network1->get_output_tensor(0);
network1->forward();
network1->wait();
ASSERT_EQ(input_tensor->get_device_id(), 0);
ASSERT_EQ(output_tensor->get_device_id(), 0);
};
auto func2 = [&] {
network2->set_device_id(1);
network2->load_model(model_path);
std::shared_ptr<Tensor> input_tensor = network2->get_input_tensor(0);
std::shared_ptr<Tensor> output_tensor = network2->get_output_tensor(0);
network2->forward();
network2->wait();
ASSERT_EQ(input_tensor->get_device_id(), 1);
ASSERT_EQ(output_tensor->get_device_id(), 1);
};
auto func3 = [&] {
network3->set_device_id(1);
network3->load_model(model_path);
std::shared_ptr<Tensor> input_tensor = network3->get_input_tensor(0);
std::shared_ptr<Tensor> output_tensor = network3->get_output_tensor(0);
network3->forward();
network3->wait();
ASSERT_EQ(input_tensor->get_device_id(), 1);
ASSERT_EQ(output_tensor->get_device_id(), 1);
};
std::thread t0(func0);
std::thread t1(func1);
std::thread t2(func2);
std::thread t3(func3);
t0.join();
t1.join();
t2.join();
t3.join();
}
TEST(TestNetWork, GetAllName) {
Config config;
auto lite_tensor = get_input_data("./input_data.npy");
......
......@@ -266,7 +266,7 @@ SubTensorSpec FancyIndexingHelper::fancy_indexing_make_sub_spec(
static DeviceTensorND fake_val;
static MGB_MUTEX fake_val_mtx;
if (mgb_unlikely(fake_val.empty())) {
{
MGB_LOCK_GUARD(fake_val_mtx);
if (fake_val.empty()) {
fake_val.comp_node(CompNode::default_cpu())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册