提交 8cf7150d 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(dnn/mge): add compnode multithread in python

GitOrigin-RevId: 47373d291d3649c8e012fa8823e3e43980c1f16d
上级 49972701
......@@ -38,9 +38,12 @@ def set_default_device(device: str = "xpux"):
:param device: default device type. The type can be 'cpu0', 'cpu1', etc.,
or 'gpu0', 'gpu1', etc., to specify the particular cpu or gpu to use.
To specify multiple devices, use cpu0:1 or gpu0:2.
'cpux' and 'gupx' can also be used to specify any number of cpu or gpu devices.
'multithread' device type is avaliable when inference, which implements
multi-threading parallelism at the operator level. For example,
'multithread4' will compute with 4 threads. which implements
The default value is 'xpux' to specify any device available.
It can also be set by environmental variable `MGE_DEFAULT_DEVICE`.
......
......@@ -603,11 +603,11 @@ Args Args::from_argv(int argc, char **argv) {
++ i;
ret.multithread_number = std::stoi(argv[i]);
ret.load_config.comp_node_mapper =
[nr_thread =
[nr_threads =
ret.multithread_number](CompNode::Locator& loc) {
loc.type = CompNode::DeviceType::MULTITHREAD;
loc.device = 0;
loc.stream = nr_thread;
loc.nr_threads = nr_threads;
};
continue;
}
......@@ -615,11 +615,12 @@ Args Args::from_argv(int argc, char **argv) {
mgb_log_warn("use multithread:default mode");
++i;
ret.multithread_number = std::stoi(argv[i]);
ret.load_config.comp_node_mapper = [nr_thread =
ret.multithread_number](CompNode::Locator& loc) {
ret.load_config.comp_node_mapper = [nr_threads =
ret.multithread_number](
CompNode::Locator& loc) {
loc.type = CompNode::DeviceType::MULTITHREAD;
loc.device = CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT;
loc.stream = nr_thread;
loc.nr_threads = nr_threads;
};
continue;
}
......
......@@ -127,13 +127,19 @@ CompNode::Locator CompNode::Locator::parse(const std::string &id) {
// current parsing location
const char *ptr = id.data();
if (id == "cpu:default") {
return {DeviceType::CPU, DEVICE_CPU_DEFAULT, 0};
return {DeviceType::CPU, DEVICE_CPU_DEFAULT, {0}};
}
if (!strncmp(ptr, "multithread:default", 19)) {
//! the multithread default compnode string like "multithread:default:x"
ptr += 20;
int nr_thread =std::stoi(ptr);
return {DeviceType::MULTITHREAD, DEVICE_MULTITHREAD_DEFAULT, nr_thread};
if (id.size() > 20) {
ptr += 20;
int nr_thread = std::stoi(ptr);
return {DeviceType::MULTITHREAD,
DEVICE_MULTITHREAD_DEFAULT,
{nr_thread}};
} else {
err();
}
}
DeviceType dev_type;
......@@ -192,8 +198,16 @@ CompNode::Locator CompNode::Locator::parse(const std::string &id) {
int num_stream = parse_int();
if (*ptr)
err();
//! multi thread with thread number(num_stream) being zero is illegal
if (dev_type == DeviceType::MULTITHREAD) {
if (num_dev == 0) {
err();
}
//! num_steam store the nr_thread
std::swap(num_dev, num_stream);
}
return {dev_type, num_dev, num_stream};
return {dev_type, num_dev, {num_stream}};
}
void CompNode::Locator::set_device_map(DeviceType type, int from, int to) {
......@@ -242,16 +256,22 @@ CompNode::Locator CompNode::Locator::to_physical() const {
stream_physical = 1023;
}
}
return {type_physical, device_physical, stream_physical};
return {type_physical, device_physical, {stream_physical}};
}
std::string CompNode::Locator::to_string() const {
if (device == DEVICE_CPU_DEFAULT) {
return "cpu:default";
} else if (device == DEVICE_MULTITHREAD_DEFAULT) {
std::string ret="multithread:default:";
std::string ret = "multithread:default:";
ret.append(get_stream_str(stream));
return ret;
} else if (type == DeviceType::MULTITHREAD) {
std::string ret("multithread");
ret.append(get_stream_str(stream))
.append(":")
.append(get_stream_str(device));
return ret;
}
char numstr[32];
if (device == -1) {
......
......@@ -380,9 +380,9 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase {
m_locator_logical(locator_logical) {
auto cn = make_comp_node_from_impl(this);
if (locator.type == DeviceType::MULTITHREAD) {
//! When multi-thread the stream stand for thread number
m_thread_pool = std::unique_ptr<ThreadPool>(
new ThreadPool(static_cast<size_t>(locator.stream)));
m_thread_pool = std::unique_ptr<ThreadPool>(new ThreadPool(
static_cast<size_t>(locator.nr_threads)));
mgb_assert(m_thread_pool, "ThradPool create failed");
}
if (locator.type == DeviceType::CPU) {
......@@ -398,7 +398,6 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase {
cn);
}
} else if (locator.type == DeviceType::MULTITHREAD) {
mgb_assert(m_thread_pool, "ThradPool create failed");
if (locator.device == Locator::DEVICE_MULTITHREAD_DEFAULT) {
m_env.init_cpu(
{std::make_shared<InplaceCPUDispatcher>(
......@@ -745,15 +744,14 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator,
} else {
mgb_assert(locator.type == DeviceType::MULTITHREAD);
auto&& pqueue_weak = sm_pool->physical2queue_multithead[{
locator.device, locator.stream}];
locator.device, locator.nr_threads}];
auto pqueue = pqueue_weak.lock();
if (!pqueue) {
pqueue = std::make_shared<WorkerQueue>(locator);
pqueue_weak = pqueue;
}
auto&& pimpl = sm_pool->logical2impl_multi_thread[{
static_cast<int>(compact_logical_device),
locator_logical.stream}];
compact_logical_device, locator_logical.nr_threads}];
if (!pimpl) {
mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE,
"too many cpu multithread comp nodes; max %d allowed",
......
......@@ -153,8 +153,12 @@ class CompNode {
int device = -1;
//! multiple streams can execute on one computing device and share
//! memory
int stream = 0;
//! memory, when compnode type is multithread the field also stand
//! for nr_threads
union {
int stream = 0;
int nr_threads;
};
/*!
* \brief parse a string identifier
......@@ -162,7 +166,7 @@ class CompNode {
* currently supported ID format: (gpu|cpu)<n>[:m] where n is the
* device number, possibly with m as the stream id.
*/
static Locator parse(const std::string &id);
static Locator parse(const std::string& id);
/*!
* \brief set mapping between device numbers of a device type
......
......@@ -28,9 +28,7 @@ using namespace mgb;
TEST(TestCompNode, Parse) {
using L = CompNode::Locator;
using D = CompNode::DeviceType;
auto make_lc = [](D t, int dev, int s) -> L {
return {t, dev, s};
};
auto make_lc = [](D t, int dev, int s) -> L { return {t, dev, {s}}; };
ASSERT_EQ(L::parse("xpux"), make_lc(D::UNSPEC, -1, 0));
ASSERT_EQ(L::parse("xpux:23"), make_lc(D::UNSPEC, -1, 23));
......@@ -47,10 +45,9 @@ TEST(TestCompNode, Parse) {
ASSERT_EQ(L::parse("xpu23"), make_lc(D::UNSPEC, 23, 0));
ASSERT_EQ(L::parse("xpu23:1"), make_lc(D::UNSPEC, 23, 1));
ASSERT_EQ(L::parse("cpu:default"),
make_lc(D::CPU, L::DEVICE_CPU_DEFAULT, 0));
ASSERT_EQ(L::parse("multithread0:2"), make_lc(D::MULTITHREAD, 0, 2));
ASSERT_EQ(L::parse("multithread1:3"), make_lc(D::MULTITHREAD, 1, 3));
ASSERT_EQ(L::parse("cpu:default"), make_lc(D::CPU, L::DEVICE_CPU_DEFAULT, 0));
ASSERT_EQ(L::parse("multithread2:0"), make_lc(D::MULTITHREAD, 0, 2));
ASSERT_EQ(L::parse("multithread1:3"), make_lc(D::MULTITHREAD, 3, 1));
ASSERT_EQ(L::parse("multithread:default:2"),
make_lc(D::MULTITHREAD, L::DEVICE_MULTITHREAD_DEFAULT, 2));
......@@ -65,6 +62,10 @@ TEST(TestCompNode, Parse) {
ASSERT_THROW(L::parse("heaxgon0"), MegBrainError);
ASSERT_THROW(L::parse("rcom0"), MegBrainError);
ASSERT_THROW(L::parse("cmabricon0"), MegBrainError);
ASSERT_THROW(L::parse("multithread"), MegBrainError);
ASSERT_THROW(L::parse("multithread1:"), MegBrainError);
ASSERT_THROW(L::parse("multithread1:default"), MegBrainError);
ASSERT_THROW(L::parse("multithread1:default:0"), MegBrainError);
}
TEST(TestCompNode, SetDefaultDev) {
......@@ -107,12 +108,12 @@ TEST(TestCompNode, Load) {
#endif
#if MGB_HAVE_THREAD
auto cn_multi_thread0 = CompNode::load("multithread0:2");
auto cn_multi_thread1 = CompNode::load("multithread1:2");
ASSERT_EQ(CompNode::load("multithread0:2"), cn_multi_thread0);
ASSERT_EQ(CompNode::load("multithread1:2"), cn_multi_thread1);
ASSERT_NE(CompNode::load("multithread0:4"), cn_multi_thread0);
ASSERT_NE(CompNode::load("multithread1:4"), cn_multi_thread1);
auto cn_multi_thread0 = CompNode::load("multithread2:0");
auto cn_multi_thread1 = CompNode::load("multithread2:1");
ASSERT_EQ(CompNode::load("multithread2:0"), cn_multi_thread0);
ASSERT_EQ(CompNode::load("multithread2:1"), cn_multi_thread1);
ASSERT_NE(CompNode::load("multithread4:0"), cn_multi_thread0);
ASSERT_NE(CompNode::load("multithread4:1"), cn_multi_thread1);
auto cn_multi_default0 = CompNode::load("multithread:default:2");
auto cn_multi_default1 = CompNode::load("multithread:default:4");
......@@ -139,7 +140,7 @@ TEST(TestCompNode, FreeAfterFinalize) {
auto type = static_cast<CompNode::DeviceType>(i);
if (!CompNode::get_device_count(type))
continue;
auto cn = CompNode::load(CompNode::Locator{type});
auto cn = CompNode::load(CompNode::Locator{type, -1, {0}});
auto ptr = cn.alloc_device(123);
CompNode::finalize();
cn.free_device(ptr);
......@@ -190,13 +191,13 @@ TEST(TestCompNodeCPU, CoreAffinity) {
size_t data0, data1 = 0;
auto empty_task = []() {};
auto cn0 = CompNode::load("cpu:default"), cn1 = CompNode::load("cpu0"),
cn2 = CompNode::load("multithread0:2");
cn2 = CompNode::load("multithread2:0");
auto binding0 = [&](size_t) { data0 = 10; };
CompNodeEnv::from_comp_node(cn0).cpu_env().set_affinity(binding0);
CompNodeEnv::from_comp_node(cn0).cpu_env().dispatch(empty_task);
cn0.sync();
auto binding1 = [&](size_t) { data1 = 20; };
auto binding1 = [&](size_t ) { data1 = 20; };
CompNodeEnv::from_comp_node(cn1).cpu_env().set_affinity(binding1);
CompNodeEnv::from_comp_node(cn1).cpu_env().dispatch(empty_task);
cn1.sync();
......@@ -238,7 +239,7 @@ TEST(TestCompNode, CPU_MULTI_THREAD) {
};
for (auto&& str : std::vector<std::string>{
"multithread0:2", "multithread0:4", "multithread:default:4"}) {
"multithread2:0", "multithread4:0", "multithread:default:4"}) {
auto cn0 = CompNode::load("cpu0"), cn1 = CompNode::load(str);
std::thread wk_thread0{std::ref(worker), std::ref(dst0), std::ref(cn0)};
std::thread wk_thread1{std::ref(worker), std::ref(dst1), std::ref(cn1)};
......@@ -271,9 +272,9 @@ TEST(TestCompNodeCPU, PhysicalDispatch) {
L::set_device_map(DT, ID, 0);
L::set_device_map(DT, ID + 1, 0);
L::set_device_map(DT, ID + 2, 1);
auto cn0 = CompNode::load({DT, ID, 0}),
cn1 = CompNode::load({DT, ID + 1, 0}),
cn2 = CompNode::load({DT, ID + 2, 0});
auto cn0 = CompNode::load({DT, ID, {0}}),
cn1 = CompNode::load({DT, ID + 1, {0}}),
cn2 = CompNode::load({DT, ID + 2, {0}});
#if MGB_HAVE_THREAD
ASSERT_NE(cn0, cn1);
#else
......@@ -532,10 +533,10 @@ TEST(TestCompNode, MultipleLoad) {
for (size_t i = 1; i < CompNode::NR_DEVICE_TYPE; ++i) {
auto dt = static_cast<CompNode::DeviceType>(i);
if (CompNode::get_device_count(dt)) {
auto cn = CompNode::load({dt});
auto cn = CompNode::load({dt, 0, {0}});
mgb_log("comp node %s is available", cn.to_string().c_str());
run(cn);
cn = CompNode::load({dt});
cn = CompNode::load({dt, 0, {0}});
run(cn);
}
}
......@@ -591,7 +592,7 @@ TYPED_TEST(TestCPUCompSeqRec, run_default_cpu) {
comp_node_test::seq_rec::run<TypeParam>(CompNode::load("cpu:default"));
}
TYPED_TEST(TestCPUCompSeqRec, run_multi_thread) {
auto cn = CompNode::load("multithread0:4");
auto cn = CompNode::load("multithread4:0");
comp_node_test::seq_rec::run<TypeParam>(cn);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册