提交 8cf7150d 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(dnn/mge): add compnode multithread in python

GitOrigin-RevId: 47373d291d3649c8e012fa8823e3e43980c1f16d
上级 49972701
...@@ -38,9 +38,12 @@ def set_default_device(device: str = "xpux"): ...@@ -38,9 +38,12 @@ def set_default_device(device: str = "xpux"):
:param device: default device type. The type can be 'cpu0', 'cpu1', etc., :param device: default device type. The type can be 'cpu0', 'cpu1', etc.,
or 'gpu0', 'gpu1', etc., to specify the particular cpu or gpu to use. or 'gpu0', 'gpu1', etc., to specify the particular cpu or gpu to use.
To specify multiple devices, use cpu0:1 or gpu0:2.
'cpux' and 'gupx' can also be used to specify any number of cpu or gpu devices. 'cpux' and 'gupx' can also be used to specify any number of cpu or gpu devices.
'multithread' device type is avaliable when inference, which implements
multi-threading parallelism at the operator level. For example,
'multithread4' will compute with 4 threads. which implements
The default value is 'xpux' to specify any device available. The default value is 'xpux' to specify any device available.
It can also be set by environmental variable `MGE_DEFAULT_DEVICE`. It can also be set by environmental variable `MGE_DEFAULT_DEVICE`.
......
...@@ -603,11 +603,11 @@ Args Args::from_argv(int argc, char **argv) { ...@@ -603,11 +603,11 @@ Args Args::from_argv(int argc, char **argv) {
++ i; ++ i;
ret.multithread_number = std::stoi(argv[i]); ret.multithread_number = std::stoi(argv[i]);
ret.load_config.comp_node_mapper = ret.load_config.comp_node_mapper =
[nr_thread = [nr_threads =
ret.multithread_number](CompNode::Locator& loc) { ret.multithread_number](CompNode::Locator& loc) {
loc.type = CompNode::DeviceType::MULTITHREAD; loc.type = CompNode::DeviceType::MULTITHREAD;
loc.device = 0; loc.device = 0;
loc.stream = nr_thread; loc.nr_threads = nr_threads;
}; };
continue; continue;
} }
...@@ -615,11 +615,12 @@ Args Args::from_argv(int argc, char **argv) { ...@@ -615,11 +615,12 @@ Args Args::from_argv(int argc, char **argv) {
mgb_log_warn("use multithread:default mode"); mgb_log_warn("use multithread:default mode");
++i; ++i;
ret.multithread_number = std::stoi(argv[i]); ret.multithread_number = std::stoi(argv[i]);
ret.load_config.comp_node_mapper = [nr_thread = ret.load_config.comp_node_mapper = [nr_threads =
ret.multithread_number](CompNode::Locator& loc) { ret.multithread_number](
CompNode::Locator& loc) {
loc.type = CompNode::DeviceType::MULTITHREAD; loc.type = CompNode::DeviceType::MULTITHREAD;
loc.device = CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT; loc.device = CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT;
loc.stream = nr_thread; loc.nr_threads = nr_threads;
}; };
continue; continue;
} }
......
...@@ -127,13 +127,19 @@ CompNode::Locator CompNode::Locator::parse(const std::string &id) { ...@@ -127,13 +127,19 @@ CompNode::Locator CompNode::Locator::parse(const std::string &id) {
// current parsing location // current parsing location
const char *ptr = id.data(); const char *ptr = id.data();
if (id == "cpu:default") { if (id == "cpu:default") {
return {DeviceType::CPU, DEVICE_CPU_DEFAULT, 0}; return {DeviceType::CPU, DEVICE_CPU_DEFAULT, {0}};
} }
if (!strncmp(ptr, "multithread:default", 19)) { if (!strncmp(ptr, "multithread:default", 19)) {
//! the multithread default compnode string like "multithread:default:x" //! the multithread default compnode string like "multithread:default:x"
if (id.size() > 20) {
ptr += 20; ptr += 20;
int nr_thread =std::stoi(ptr); int nr_thread = std::stoi(ptr);
return {DeviceType::MULTITHREAD, DEVICE_MULTITHREAD_DEFAULT, nr_thread}; return {DeviceType::MULTITHREAD,
DEVICE_MULTITHREAD_DEFAULT,
{nr_thread}};
} else {
err();
}
} }
DeviceType dev_type; DeviceType dev_type;
...@@ -192,8 +198,16 @@ CompNode::Locator CompNode::Locator::parse(const std::string &id) { ...@@ -192,8 +198,16 @@ CompNode::Locator CompNode::Locator::parse(const std::string &id) {
int num_stream = parse_int(); int num_stream = parse_int();
if (*ptr) if (*ptr)
err(); err();
//! multi thread with thread number(num_stream) being zero is illegal
if (dev_type == DeviceType::MULTITHREAD) {
if (num_dev == 0) {
err();
}
//! num_steam store the nr_thread
std::swap(num_dev, num_stream);
}
return {dev_type, num_dev, num_stream}; return {dev_type, num_dev, {num_stream}};
} }
void CompNode::Locator::set_device_map(DeviceType type, int from, int to) { void CompNode::Locator::set_device_map(DeviceType type, int from, int to) {
...@@ -242,16 +256,22 @@ CompNode::Locator CompNode::Locator::to_physical() const { ...@@ -242,16 +256,22 @@ CompNode::Locator CompNode::Locator::to_physical() const {
stream_physical = 1023; stream_physical = 1023;
} }
} }
return {type_physical, device_physical, stream_physical}; return {type_physical, device_physical, {stream_physical}};
} }
std::string CompNode::Locator::to_string() const { std::string CompNode::Locator::to_string() const {
if (device == DEVICE_CPU_DEFAULT) { if (device == DEVICE_CPU_DEFAULT) {
return "cpu:default"; return "cpu:default";
} else if (device == DEVICE_MULTITHREAD_DEFAULT) { } else if (device == DEVICE_MULTITHREAD_DEFAULT) {
std::string ret="multithread:default:"; std::string ret = "multithread:default:";
ret.append(get_stream_str(stream)); ret.append(get_stream_str(stream));
return ret; return ret;
} else if (type == DeviceType::MULTITHREAD) {
std::string ret("multithread");
ret.append(get_stream_str(stream))
.append(":")
.append(get_stream_str(device));
return ret;
} }
char numstr[32]; char numstr[32];
if (device == -1) { if (device == -1) {
......
...@@ -380,9 +380,9 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase { ...@@ -380,9 +380,9 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase {
m_locator_logical(locator_logical) { m_locator_logical(locator_logical) {
auto cn = make_comp_node_from_impl(this); auto cn = make_comp_node_from_impl(this);
if (locator.type == DeviceType::MULTITHREAD) { if (locator.type == DeviceType::MULTITHREAD) {
//! When multi-thread the stream stand for thread number m_thread_pool = std::unique_ptr<ThreadPool>(new ThreadPool(
m_thread_pool = std::unique_ptr<ThreadPool>( static_cast<size_t>(locator.nr_threads)));
new ThreadPool(static_cast<size_t>(locator.stream))); mgb_assert(m_thread_pool, "ThradPool create failed");
} }
if (locator.type == DeviceType::CPU) { if (locator.type == DeviceType::CPU) {
...@@ -398,7 +398,6 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase { ...@@ -398,7 +398,6 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase {
cn); cn);
} }
} else if (locator.type == DeviceType::MULTITHREAD) { } else if (locator.type == DeviceType::MULTITHREAD) {
mgb_assert(m_thread_pool, "ThradPool create failed");
if (locator.device == Locator::DEVICE_MULTITHREAD_DEFAULT) { if (locator.device == Locator::DEVICE_MULTITHREAD_DEFAULT) {
m_env.init_cpu( m_env.init_cpu(
{std::make_shared<InplaceCPUDispatcher>( {std::make_shared<InplaceCPUDispatcher>(
...@@ -745,15 +744,14 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator, ...@@ -745,15 +744,14 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator,
} else { } else {
mgb_assert(locator.type == DeviceType::MULTITHREAD); mgb_assert(locator.type == DeviceType::MULTITHREAD);
auto&& pqueue_weak = sm_pool->physical2queue_multithead[{ auto&& pqueue_weak = sm_pool->physical2queue_multithead[{
locator.device, locator.stream}]; locator.device, locator.nr_threads}];
auto pqueue = pqueue_weak.lock(); auto pqueue = pqueue_weak.lock();
if (!pqueue) { if (!pqueue) {
pqueue = std::make_shared<WorkerQueue>(locator); pqueue = std::make_shared<WorkerQueue>(locator);
pqueue_weak = pqueue; pqueue_weak = pqueue;
} }
auto&& pimpl = sm_pool->logical2impl_multi_thread[{ auto&& pimpl = sm_pool->logical2impl_multi_thread[{
static_cast<int>(compact_logical_device), compact_logical_device, locator_logical.nr_threads}];
locator_logical.stream}];
if (!pimpl) { if (!pimpl) {
mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE, mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE,
"too many cpu multithread comp nodes; max %d allowed", "too many cpu multithread comp nodes; max %d allowed",
......
...@@ -153,8 +153,12 @@ class CompNode { ...@@ -153,8 +153,12 @@ class CompNode {
int device = -1; int device = -1;
//! multiple streams can execute on one computing device and share //! multiple streams can execute on one computing device and share
//! memory //! memory, when compnode type is multithread the field also stand
//! for nr_threads
union {
int stream = 0; int stream = 0;
int nr_threads;
};
/*! /*!
* \brief parse a string identifier * \brief parse a string identifier
...@@ -162,7 +166,7 @@ class CompNode { ...@@ -162,7 +166,7 @@ class CompNode {
* currently supported ID format: (gpu|cpu)<n>[:m] where n is the * currently supported ID format: (gpu|cpu)<n>[:m] where n is the
* device number, possibly with m as the stream id. * device number, possibly with m as the stream id.
*/ */
static Locator parse(const std::string &id); static Locator parse(const std::string& id);
/*! /*!
* \brief set mapping between device numbers of a device type * \brief set mapping between device numbers of a device type
......
...@@ -28,9 +28,7 @@ using namespace mgb; ...@@ -28,9 +28,7 @@ using namespace mgb;
TEST(TestCompNode, Parse) { TEST(TestCompNode, Parse) {
using L = CompNode::Locator; using L = CompNode::Locator;
using D = CompNode::DeviceType; using D = CompNode::DeviceType;
auto make_lc = [](D t, int dev, int s) -> L { auto make_lc = [](D t, int dev, int s) -> L { return {t, dev, {s}}; };
return {t, dev, s};
};
ASSERT_EQ(L::parse("xpux"), make_lc(D::UNSPEC, -1, 0)); ASSERT_EQ(L::parse("xpux"), make_lc(D::UNSPEC, -1, 0));
ASSERT_EQ(L::parse("xpux:23"), make_lc(D::UNSPEC, -1, 23)); ASSERT_EQ(L::parse("xpux:23"), make_lc(D::UNSPEC, -1, 23));
...@@ -47,10 +45,9 @@ TEST(TestCompNode, Parse) { ...@@ -47,10 +45,9 @@ TEST(TestCompNode, Parse) {
ASSERT_EQ(L::parse("xpu23"), make_lc(D::UNSPEC, 23, 0)); ASSERT_EQ(L::parse("xpu23"), make_lc(D::UNSPEC, 23, 0));
ASSERT_EQ(L::parse("xpu23:1"), make_lc(D::UNSPEC, 23, 1)); ASSERT_EQ(L::parse("xpu23:1"), make_lc(D::UNSPEC, 23, 1));
ASSERT_EQ(L::parse("cpu:default"), ASSERT_EQ(L::parse("cpu:default"), make_lc(D::CPU, L::DEVICE_CPU_DEFAULT, 0));
make_lc(D::CPU, L::DEVICE_CPU_DEFAULT, 0)); ASSERT_EQ(L::parse("multithread2:0"), make_lc(D::MULTITHREAD, 0, 2));
ASSERT_EQ(L::parse("multithread0:2"), make_lc(D::MULTITHREAD, 0, 2)); ASSERT_EQ(L::parse("multithread1:3"), make_lc(D::MULTITHREAD, 3, 1));
ASSERT_EQ(L::parse("multithread1:3"), make_lc(D::MULTITHREAD, 1, 3));
ASSERT_EQ(L::parse("multithread:default:2"), ASSERT_EQ(L::parse("multithread:default:2"),
make_lc(D::MULTITHREAD, L::DEVICE_MULTITHREAD_DEFAULT, 2)); make_lc(D::MULTITHREAD, L::DEVICE_MULTITHREAD_DEFAULT, 2));
...@@ -65,6 +62,10 @@ TEST(TestCompNode, Parse) { ...@@ -65,6 +62,10 @@ TEST(TestCompNode, Parse) {
ASSERT_THROW(L::parse("heaxgon0"), MegBrainError); ASSERT_THROW(L::parse("heaxgon0"), MegBrainError);
ASSERT_THROW(L::parse("rcom0"), MegBrainError); ASSERT_THROW(L::parse("rcom0"), MegBrainError);
ASSERT_THROW(L::parse("cmabricon0"), MegBrainError); ASSERT_THROW(L::parse("cmabricon0"), MegBrainError);
ASSERT_THROW(L::parse("multithread"), MegBrainError);
ASSERT_THROW(L::parse("multithread1:"), MegBrainError);
ASSERT_THROW(L::parse("multithread1:default"), MegBrainError);
ASSERT_THROW(L::parse("multithread1:default:0"), MegBrainError);
} }
TEST(TestCompNode, SetDefaultDev) { TEST(TestCompNode, SetDefaultDev) {
...@@ -107,12 +108,12 @@ TEST(TestCompNode, Load) { ...@@ -107,12 +108,12 @@ TEST(TestCompNode, Load) {
#endif #endif
#if MGB_HAVE_THREAD #if MGB_HAVE_THREAD
auto cn_multi_thread0 = CompNode::load("multithread0:2"); auto cn_multi_thread0 = CompNode::load("multithread2:0");
auto cn_multi_thread1 = CompNode::load("multithread1:2"); auto cn_multi_thread1 = CompNode::load("multithread2:1");
ASSERT_EQ(CompNode::load("multithread0:2"), cn_multi_thread0); ASSERT_EQ(CompNode::load("multithread2:0"), cn_multi_thread0);
ASSERT_EQ(CompNode::load("multithread1:2"), cn_multi_thread1); ASSERT_EQ(CompNode::load("multithread2:1"), cn_multi_thread1);
ASSERT_NE(CompNode::load("multithread0:4"), cn_multi_thread0); ASSERT_NE(CompNode::load("multithread4:0"), cn_multi_thread0);
ASSERT_NE(CompNode::load("multithread1:4"), cn_multi_thread1); ASSERT_NE(CompNode::load("multithread4:1"), cn_multi_thread1);
auto cn_multi_default0 = CompNode::load("multithread:default:2"); auto cn_multi_default0 = CompNode::load("multithread:default:2");
auto cn_multi_default1 = CompNode::load("multithread:default:4"); auto cn_multi_default1 = CompNode::load("multithread:default:4");
...@@ -139,7 +140,7 @@ TEST(TestCompNode, FreeAfterFinalize) { ...@@ -139,7 +140,7 @@ TEST(TestCompNode, FreeAfterFinalize) {
auto type = static_cast<CompNode::DeviceType>(i); auto type = static_cast<CompNode::DeviceType>(i);
if (!CompNode::get_device_count(type)) if (!CompNode::get_device_count(type))
continue; continue;
auto cn = CompNode::load(CompNode::Locator{type}); auto cn = CompNode::load(CompNode::Locator{type, -1, {0}});
auto ptr = cn.alloc_device(123); auto ptr = cn.alloc_device(123);
CompNode::finalize(); CompNode::finalize();
cn.free_device(ptr); cn.free_device(ptr);
...@@ -190,13 +191,13 @@ TEST(TestCompNodeCPU, CoreAffinity) { ...@@ -190,13 +191,13 @@ TEST(TestCompNodeCPU, CoreAffinity) {
size_t data0, data1 = 0; size_t data0, data1 = 0;
auto empty_task = []() {}; auto empty_task = []() {};
auto cn0 = CompNode::load("cpu:default"), cn1 = CompNode::load("cpu0"), auto cn0 = CompNode::load("cpu:default"), cn1 = CompNode::load("cpu0"),
cn2 = CompNode::load("multithread0:2"); cn2 = CompNode::load("multithread2:0");
auto binding0 = [&](size_t) { data0 = 10; }; auto binding0 = [&](size_t) { data0 = 10; };
CompNodeEnv::from_comp_node(cn0).cpu_env().set_affinity(binding0); CompNodeEnv::from_comp_node(cn0).cpu_env().set_affinity(binding0);
CompNodeEnv::from_comp_node(cn0).cpu_env().dispatch(empty_task); CompNodeEnv::from_comp_node(cn0).cpu_env().dispatch(empty_task);
cn0.sync(); cn0.sync();
auto binding1 = [&](size_t) { data1 = 20; }; auto binding1 = [&](size_t ) { data1 = 20; };
CompNodeEnv::from_comp_node(cn1).cpu_env().set_affinity(binding1); CompNodeEnv::from_comp_node(cn1).cpu_env().set_affinity(binding1);
CompNodeEnv::from_comp_node(cn1).cpu_env().dispatch(empty_task); CompNodeEnv::from_comp_node(cn1).cpu_env().dispatch(empty_task);
cn1.sync(); cn1.sync();
...@@ -238,7 +239,7 @@ TEST(TestCompNode, CPU_MULTI_THREAD) { ...@@ -238,7 +239,7 @@ TEST(TestCompNode, CPU_MULTI_THREAD) {
}; };
for (auto&& str : std::vector<std::string>{ for (auto&& str : std::vector<std::string>{
"multithread0:2", "multithread0:4", "multithread:default:4"}) { "multithread2:0", "multithread4:0", "multithread:default:4"}) {
auto cn0 = CompNode::load("cpu0"), cn1 = CompNode::load(str); auto cn0 = CompNode::load("cpu0"), cn1 = CompNode::load(str);
std::thread wk_thread0{std::ref(worker), std::ref(dst0), std::ref(cn0)}; std::thread wk_thread0{std::ref(worker), std::ref(dst0), std::ref(cn0)};
std::thread wk_thread1{std::ref(worker), std::ref(dst1), std::ref(cn1)}; std::thread wk_thread1{std::ref(worker), std::ref(dst1), std::ref(cn1)};
...@@ -271,9 +272,9 @@ TEST(TestCompNodeCPU, PhysicalDispatch) { ...@@ -271,9 +272,9 @@ TEST(TestCompNodeCPU, PhysicalDispatch) {
L::set_device_map(DT, ID, 0); L::set_device_map(DT, ID, 0);
L::set_device_map(DT, ID + 1, 0); L::set_device_map(DT, ID + 1, 0);
L::set_device_map(DT, ID + 2, 1); L::set_device_map(DT, ID + 2, 1);
auto cn0 = CompNode::load({DT, ID, 0}), auto cn0 = CompNode::load({DT, ID, {0}}),
cn1 = CompNode::load({DT, ID + 1, 0}), cn1 = CompNode::load({DT, ID + 1, {0}}),
cn2 = CompNode::load({DT, ID + 2, 0}); cn2 = CompNode::load({DT, ID + 2, {0}});
#if MGB_HAVE_THREAD #if MGB_HAVE_THREAD
ASSERT_NE(cn0, cn1); ASSERT_NE(cn0, cn1);
#else #else
...@@ -532,10 +533,10 @@ TEST(TestCompNode, MultipleLoad) { ...@@ -532,10 +533,10 @@ TEST(TestCompNode, MultipleLoad) {
for (size_t i = 1; i < CompNode::NR_DEVICE_TYPE; ++i) { for (size_t i = 1; i < CompNode::NR_DEVICE_TYPE; ++i) {
auto dt = static_cast<CompNode::DeviceType>(i); auto dt = static_cast<CompNode::DeviceType>(i);
if (CompNode::get_device_count(dt)) { if (CompNode::get_device_count(dt)) {
auto cn = CompNode::load({dt}); auto cn = CompNode::load({dt, 0, {0}});
mgb_log("comp node %s is available", cn.to_string().c_str()); mgb_log("comp node %s is available", cn.to_string().c_str());
run(cn); run(cn);
cn = CompNode::load({dt}); cn = CompNode::load({dt, 0, {0}});
run(cn); run(cn);
} }
} }
...@@ -591,7 +592,7 @@ TYPED_TEST(TestCPUCompSeqRec, run_default_cpu) { ...@@ -591,7 +592,7 @@ TYPED_TEST(TestCPUCompSeqRec, run_default_cpu) {
comp_node_test::seq_rec::run<TypeParam>(CompNode::load("cpu:default")); comp_node_test::seq_rec::run<TypeParam>(CompNode::load("cpu:default"));
} }
TYPED_TEST(TestCPUCompSeqRec, run_multi_thread) { TYPED_TEST(TestCPUCompSeqRec, run_multi_thread) {
auto cn = CompNode::load("multithread0:4"); auto cn = CompNode::load("multithread4:0");
comp_node_test::seq_rec::run<TypeParam>(cn); comp_node_test::seq_rec::run<TypeParam>(cn);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册