diff --git a/python_module/megengine/core/device.py b/python_module/megengine/core/device.py index 863386cda1d784404c986a8e65c31d3a9da95a65..ebc34750cd2b35d191203ac7dff416c946cd9876 100644 --- a/python_module/megengine/core/device.py +++ b/python_module/megengine/core/device.py @@ -38,9 +38,12 @@ def set_default_device(device: str = "xpux"): :param device: default device type. The type can be 'cpu0', 'cpu1', etc., or 'gpu0', 'gpu1', etc., to specify the particular cpu or gpu to use. - To specify multiple devices, use cpu0:1 or gpu0:2. 'cpux' and 'gupx' can also be used to specify any number of cpu or gpu devices. + 'multithread' device type is avaliable when inference, which implements + multi-threading parallelism at the operator level. For example, + 'multithread4' will compute with 4 threads. which implements + The default value is 'xpux' to specify any device available. It can also be set by environmental variable `MGE_DEFAULT_DEVICE`. diff --git a/sdk/load-and-run/src/mgblar.cpp b/sdk/load-and-run/src/mgblar.cpp index 002d281f95e26d8427ee69f5577cd53f88cb43ea..13ccd9e638615b53a1b7b1b2c06b58cd15ced4b0 100644 --- a/sdk/load-and-run/src/mgblar.cpp +++ b/sdk/load-and-run/src/mgblar.cpp @@ -603,11 +603,11 @@ Args Args::from_argv(int argc, char **argv) { ++ i; ret.multithread_number = std::stoi(argv[i]); ret.load_config.comp_node_mapper = - [nr_thread = + [nr_threads = ret.multithread_number](CompNode::Locator& loc) { loc.type = CompNode::DeviceType::MULTITHREAD; loc.device = 0; - loc.stream = nr_thread; + loc.nr_threads = nr_threads; }; continue; } @@ -615,11 +615,12 @@ Args Args::from_argv(int argc, char **argv) { mgb_log_warn("use multithread:default mode"); ++i; ret.multithread_number = std::stoi(argv[i]); - ret.load_config.comp_node_mapper = [nr_thread = - ret.multithread_number](CompNode::Locator& loc) { + ret.load_config.comp_node_mapper = [nr_threads = + ret.multithread_number]( + CompNode::Locator& loc) { loc.type = CompNode::DeviceType::MULTITHREAD; loc.device = CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT; - loc.stream = nr_thread; + loc.nr_threads = nr_threads; }; continue; } diff --git a/src/core/impl/comp_node/comp_node.cpp b/src/core/impl/comp_node/comp_node.cpp index a9f256888ad147a6d9e0074c1a6685d33ac21094..ae0401907cc792a10678aafa954352593987e343 100644 --- a/src/core/impl/comp_node/comp_node.cpp +++ b/src/core/impl/comp_node/comp_node.cpp @@ -127,13 +127,19 @@ CompNode::Locator CompNode::Locator::parse(const std::string &id) { // current parsing location const char *ptr = id.data(); if (id == "cpu:default") { - return {DeviceType::CPU, DEVICE_CPU_DEFAULT, 0}; + return {DeviceType::CPU, DEVICE_CPU_DEFAULT, {0}}; } if (!strncmp(ptr, "multithread:default", 19)) { //! the multithread default compnode string like "multithread:default:x" - ptr += 20; - int nr_thread =std::stoi(ptr); - return {DeviceType::MULTITHREAD, DEVICE_MULTITHREAD_DEFAULT, nr_thread}; + if (id.size() > 20) { + ptr += 20; + int nr_thread = std::stoi(ptr); + return {DeviceType::MULTITHREAD, + DEVICE_MULTITHREAD_DEFAULT, + {nr_thread}}; + } else { + err(); + } } DeviceType dev_type; @@ -192,8 +198,16 @@ CompNode::Locator CompNode::Locator::parse(const std::string &id) { int num_stream = parse_int(); if (*ptr) err(); + //! multi thread with thread number(num_stream) being zero is illegal + if (dev_type == DeviceType::MULTITHREAD) { + if (num_dev == 0) { + err(); + } + //! num_steam store the nr_thread + std::swap(num_dev, num_stream); + } - return {dev_type, num_dev, num_stream}; + return {dev_type, num_dev, {num_stream}}; } void CompNode::Locator::set_device_map(DeviceType type, int from, int to) { @@ -242,16 +256,22 @@ CompNode::Locator CompNode::Locator::to_physical() const { stream_physical = 1023; } } - return {type_physical, device_physical, stream_physical}; + return {type_physical, device_physical, {stream_physical}}; } std::string CompNode::Locator::to_string() const { if (device == DEVICE_CPU_DEFAULT) { return "cpu:default"; } else if (device == DEVICE_MULTITHREAD_DEFAULT) { - std::string ret="multithread:default:"; + std::string ret = "multithread:default:"; ret.append(get_stream_str(stream)); return ret; + } else if (type == DeviceType::MULTITHREAD) { + std::string ret("multithread"); + ret.append(get_stream_str(stream)) + .append(":") + .append(get_stream_str(device)); + return ret; } char numstr[32]; if (device == -1) { diff --git a/src/core/impl/comp_node/cpu/comp_node.cpp b/src/core/impl/comp_node/cpu/comp_node.cpp index b317d1904efcca267b983715c659882d5791210b..3c5375c8401cf0cdbcf43f40ef659057ce526567 100644 --- a/src/core/impl/comp_node/cpu/comp_node.cpp +++ b/src/core/impl/comp_node/cpu/comp_node.cpp @@ -380,9 +380,9 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase { m_locator_logical(locator_logical) { auto cn = make_comp_node_from_impl(this); if (locator.type == DeviceType::MULTITHREAD) { - //! When multi-thread the stream stand for thread number - m_thread_pool = std::unique_ptr( - new ThreadPool(static_cast(locator.stream))); + m_thread_pool = std::unique_ptr(new ThreadPool( + static_cast(locator.nr_threads))); + mgb_assert(m_thread_pool, "ThradPool create failed"); } if (locator.type == DeviceType::CPU) { @@ -398,7 +398,6 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase { cn); } } else if (locator.type == DeviceType::MULTITHREAD) { - mgb_assert(m_thread_pool, "ThradPool create failed"); if (locator.device == Locator::DEVICE_MULTITHREAD_DEFAULT) { m_env.init_cpu( {std::make_shared( @@ -745,15 +744,14 @@ CpuCompNode::Impl* CpuCompNode::load_cpu(Locator locator, } else { mgb_assert(locator.type == DeviceType::MULTITHREAD); auto&& pqueue_weak = sm_pool->physical2queue_multithead[{ - locator.device, locator.stream}]; + locator.device, locator.nr_threads}]; auto pqueue = pqueue_weak.lock(); if (!pqueue) { pqueue = std::make_shared(locator); pqueue_weak = pqueue; } auto&& pimpl = sm_pool->logical2impl_multi_thread[{ - static_cast(compact_logical_device), - locator_logical.stream}]; + compact_logical_device, locator_logical.nr_threads}]; if (!pimpl) { mgb_assert(sm_pool->nr_used_impl_storage < Pool::MAX_NR_COMP_NODE, "too many cpu multithread comp nodes; max %d allowed", diff --git a/src/core/include/megbrain/comp_node.h b/src/core/include/megbrain/comp_node.h index 1bd56c662d368f588372b5ff29700a2f758cc5ad..17921d35dc50db273b0e7f878548f9a3468d03a7 100644 --- a/src/core/include/megbrain/comp_node.h +++ b/src/core/include/megbrain/comp_node.h @@ -153,8 +153,12 @@ class CompNode { int device = -1; //! multiple streams can execute on one computing device and share - //! memory - int stream = 0; + //! memory, when compnode type is multithread the field also stand + //! for nr_threads + union { + int stream = 0; + int nr_threads; + }; /*! * \brief parse a string identifier @@ -162,7 +166,7 @@ class CompNode { * currently supported ID format: (gpu|cpu)[:m] where n is the * device number, possibly with m as the stream id. */ - static Locator parse(const std::string &id); + static Locator parse(const std::string& id); /*! * \brief set mapping between device numbers of a device type diff --git a/src/core/test/comp_node.cpp b/src/core/test/comp_node.cpp index d25901000a166c7dc45390655e4ba7f6cb305ab0..7947d6773ba2231ab8aa32c808da52a0b58453cd 100644 --- a/src/core/test/comp_node.cpp +++ b/src/core/test/comp_node.cpp @@ -28,9 +28,7 @@ using namespace mgb; TEST(TestCompNode, Parse) { using L = CompNode::Locator; using D = CompNode::DeviceType; - auto make_lc = [](D t, int dev, int s) -> L { - return {t, dev, s}; - }; + auto make_lc = [](D t, int dev, int s) -> L { return {t, dev, {s}}; }; ASSERT_EQ(L::parse("xpux"), make_lc(D::UNSPEC, -1, 0)); ASSERT_EQ(L::parse("xpux:23"), make_lc(D::UNSPEC, -1, 23)); @@ -47,10 +45,9 @@ TEST(TestCompNode, Parse) { ASSERT_EQ(L::parse("xpu23"), make_lc(D::UNSPEC, 23, 0)); ASSERT_EQ(L::parse("xpu23:1"), make_lc(D::UNSPEC, 23, 1)); - ASSERT_EQ(L::parse("cpu:default"), - make_lc(D::CPU, L::DEVICE_CPU_DEFAULT, 0)); - ASSERT_EQ(L::parse("multithread0:2"), make_lc(D::MULTITHREAD, 0, 2)); - ASSERT_EQ(L::parse("multithread1:3"), make_lc(D::MULTITHREAD, 1, 3)); + ASSERT_EQ(L::parse("cpu:default"), make_lc(D::CPU, L::DEVICE_CPU_DEFAULT, 0)); + ASSERT_EQ(L::parse("multithread2:0"), make_lc(D::MULTITHREAD, 0, 2)); + ASSERT_EQ(L::parse("multithread1:3"), make_lc(D::MULTITHREAD, 3, 1)); ASSERT_EQ(L::parse("multithread:default:2"), make_lc(D::MULTITHREAD, L::DEVICE_MULTITHREAD_DEFAULT, 2)); @@ -65,6 +62,10 @@ TEST(TestCompNode, Parse) { ASSERT_THROW(L::parse("heaxgon0"), MegBrainError); ASSERT_THROW(L::parse("rcom0"), MegBrainError); ASSERT_THROW(L::parse("cmabricon0"), MegBrainError); + ASSERT_THROW(L::parse("multithread"), MegBrainError); + ASSERT_THROW(L::parse("multithread1:"), MegBrainError); + ASSERT_THROW(L::parse("multithread1:default"), MegBrainError); + ASSERT_THROW(L::parse("multithread1:default:0"), MegBrainError); } TEST(TestCompNode, SetDefaultDev) { @@ -107,12 +108,12 @@ TEST(TestCompNode, Load) { #endif #if MGB_HAVE_THREAD - auto cn_multi_thread0 = CompNode::load("multithread0:2"); - auto cn_multi_thread1 = CompNode::load("multithread1:2"); - ASSERT_EQ(CompNode::load("multithread0:2"), cn_multi_thread0); - ASSERT_EQ(CompNode::load("multithread1:2"), cn_multi_thread1); - ASSERT_NE(CompNode::load("multithread0:4"), cn_multi_thread0); - ASSERT_NE(CompNode::load("multithread1:4"), cn_multi_thread1); + auto cn_multi_thread0 = CompNode::load("multithread2:0"); + auto cn_multi_thread1 = CompNode::load("multithread2:1"); + ASSERT_EQ(CompNode::load("multithread2:0"), cn_multi_thread0); + ASSERT_EQ(CompNode::load("multithread2:1"), cn_multi_thread1); + ASSERT_NE(CompNode::load("multithread4:0"), cn_multi_thread0); + ASSERT_NE(CompNode::load("multithread4:1"), cn_multi_thread1); auto cn_multi_default0 = CompNode::load("multithread:default:2"); auto cn_multi_default1 = CompNode::load("multithread:default:4"); @@ -139,7 +140,7 @@ TEST(TestCompNode, FreeAfterFinalize) { auto type = static_cast(i); if (!CompNode::get_device_count(type)) continue; - auto cn = CompNode::load(CompNode::Locator{type}); + auto cn = CompNode::load(CompNode::Locator{type, -1, {0}}); auto ptr = cn.alloc_device(123); CompNode::finalize(); cn.free_device(ptr); @@ -190,13 +191,13 @@ TEST(TestCompNodeCPU, CoreAffinity) { size_t data0, data1 = 0; auto empty_task = []() {}; auto cn0 = CompNode::load("cpu:default"), cn1 = CompNode::load("cpu0"), - cn2 = CompNode::load("multithread0:2"); + cn2 = CompNode::load("multithread2:0"); auto binding0 = [&](size_t) { data0 = 10; }; CompNodeEnv::from_comp_node(cn0).cpu_env().set_affinity(binding0); CompNodeEnv::from_comp_node(cn0).cpu_env().dispatch(empty_task); cn0.sync(); - auto binding1 = [&](size_t) { data1 = 20; }; + auto binding1 = [&](size_t ) { data1 = 20; }; CompNodeEnv::from_comp_node(cn1).cpu_env().set_affinity(binding1); CompNodeEnv::from_comp_node(cn1).cpu_env().dispatch(empty_task); cn1.sync(); @@ -238,7 +239,7 @@ TEST(TestCompNode, CPU_MULTI_THREAD) { }; for (auto&& str : std::vector{ - "multithread0:2", "multithread0:4", "multithread:default:4"}) { + "multithread2:0", "multithread4:0", "multithread:default:4"}) { auto cn0 = CompNode::load("cpu0"), cn1 = CompNode::load(str); std::thread wk_thread0{std::ref(worker), std::ref(dst0), std::ref(cn0)}; std::thread wk_thread1{std::ref(worker), std::ref(dst1), std::ref(cn1)}; @@ -271,9 +272,9 @@ TEST(TestCompNodeCPU, PhysicalDispatch) { L::set_device_map(DT, ID, 0); L::set_device_map(DT, ID + 1, 0); L::set_device_map(DT, ID + 2, 1); - auto cn0 = CompNode::load({DT, ID, 0}), - cn1 = CompNode::load({DT, ID + 1, 0}), - cn2 = CompNode::load({DT, ID + 2, 0}); + auto cn0 = CompNode::load({DT, ID, {0}}), + cn1 = CompNode::load({DT, ID + 1, {0}}), + cn2 = CompNode::load({DT, ID + 2, {0}}); #if MGB_HAVE_THREAD ASSERT_NE(cn0, cn1); #else @@ -532,10 +533,10 @@ TEST(TestCompNode, MultipleLoad) { for (size_t i = 1; i < CompNode::NR_DEVICE_TYPE; ++i) { auto dt = static_cast(i); if (CompNode::get_device_count(dt)) { - auto cn = CompNode::load({dt}); + auto cn = CompNode::load({dt, 0, {0}}); mgb_log("comp node %s is available", cn.to_string().c_str()); run(cn); - cn = CompNode::load({dt}); + cn = CompNode::load({dt, 0, {0}}); run(cn); } } @@ -591,7 +592,7 @@ TYPED_TEST(TestCPUCompSeqRec, run_default_cpu) { comp_node_test::seq_rec::run(CompNode::load("cpu:default")); } TYPED_TEST(TestCPUCompSeqRec, run_multi_thread) { - auto cn = CompNode::load("multithread0:4"); + auto cn = CompNode::load("multithread4:0"); comp_node_test::seq_rec::run(cn); }