diff --git a/python_module/CMakeLists.txt b/python_module/CMakeLists.txt index f4142ee0adb9dc083a80afb7026e65006a8334ee..28b0d1f99cd911f7fce6fa57b75804ee2fc46e03 100644 --- a/python_module/CMakeLists.txt +++ b/python_module/CMakeLists.txt @@ -55,10 +55,10 @@ add_custom_command( add_custom_target(mgb_opr_py DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/megengine/_internal/opr.py) -set(SRCS src/cpp/craniotome.cpp src/cpp/function_replace.cpp src/cpp/intbx.cpp src/cpp/megbrain_config.cpp src/cpp/megbrain_pubapi.cpp src/cpp/megbrain_serialize.cpp src/cpp/megbrain_wrap.cpp src/cpp/opr_defs.cpp src/cpp/opr_helper.cpp src/cpp/plugin.cpp src/cpp/python_helper.cpp) +set(SRCS src/cpp/craniotome.cpp src/cpp/function_replace.cpp src/cpp/intbx.cpp src/cpp/megbrain_config.cpp src/cpp/megbrain_pubapi.cpp src/cpp/megbrain_serialize.cpp src/cpp/megbrain_wrap.cpp src/cpp/mm_handler.cpp src/cpp/opr_defs.cpp src/cpp/opr_helper.cpp src/cpp/plugin.cpp src/cpp/python_helper.cpp) if(MGE_WITH_DISTRIBUTED) - list(APPEND SRCS src/cpp/mm_handler.cpp src/cpp/zmq_rpc.cpp) + list(APPEND SRCS src/cpp/zmq_rpc.cpp) endif() include(UseSWIG) diff --git a/python_module/src/cpp/megbrain_config.h b/python_module/src/cpp/megbrain_config.h index 65a8e0aa58a4b5b6603eb40c7ccc6d16d7b270df..e6d89646e2242a3da1819dc1eac87d6778b3420d 100644 --- a/python_module/src/cpp/megbrain_config.h +++ b/python_module/src/cpp/megbrain_config.h @@ -65,12 +65,10 @@ class _config { static std::vector> dump_registered_oprs(); -#if MGB_ENABLE_OPR_MM static int create_mm_server(const std::string& server_addr, int port); static void group_barrier(const std::string& server_addr, int port, uint32_t size, uint32_t rank); -#endif }; // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/python_module/src/cpp/mm_handler.cpp b/python_module/src/cpp/mm_handler.cpp index 54aa795828a9f5aa26cd6e9a2dbe7816630c4d6d..b9da63c0ba24a7b5ed5f13ecbb8fc8bbfb8df73a 100644 --- a/python_module/src/cpp/mm_handler.cpp +++ b/python_module/src/cpp/mm_handler.cpp @@ -12,7 +12,7 @@ #include "megbrain/exception.h" #include "megbrain_config.h" -#if MGB_CUDA +#if MGB_ENABLE_OPR_MM #include "zmq_rpc.h" #include @@ -242,17 +242,11 @@ int _config::create_mm_server(const std::string& server_addr, int port) { server_addr, port, std::make_unique()); } -#else - -int _config::create_mm_server(const std::string& server_addr, int port) { - mgb_throw(mgb::MegBrainError, "CUDA suppport disable at compile time"); - return 0; -} - -#endif - /* ======================== Group Barrier ========================== */ +/*! see definition : src/cpp/megbrain_config.h. + * Block until all ranks in the group reach this barrier + */ void _config::group_barrier(const std::string& server_addr, int port, uint32_t size, uint32_t rank) { mgb_assert(rank < size, "invalid rank %d", rank); @@ -263,4 +257,18 @@ void _config::group_barrier(const std::string& server_addr, mgb_assert(size == rsp, "inconsistent size: %d, expect %d", size, rsp); } +#else + +int _config::create_mm_server(const std::string& server_addr, int port) { + mgb_throw(mgb::MegBrainError, "distributed mode disabled at compile time"); + return 0; +} + +void _config::group_barrier(const std::string& server_addr, + int port, uint32_t size, uint32_t rank) { + mgb_throw(mgb::MegBrainError, "distributed mode disabled at compile time"); +} + +#endif + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/python_module/src/cpp/mm_handler.h b/python_module/src/cpp/mm_handler.h index 3d4ec403a8b1a4e72041eb2d1136dd5937acdf6b..338ea36c14129038f49d6ed68656aac5f67af752 100644 --- a/python_module/src/cpp/mm_handler.h +++ b/python_module/src/cpp/mm_handler.h @@ -11,7 +11,7 @@ #include "megbrain_build_config.h" -#if MGB_CUDA +#if MGB_ENABLE_OPR_MM #include "zmq_rpc.h" diff --git a/python_module/test/integration/test_distributed.py b/python_module/test/integration/test_distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..c816e5e9f32e0eb909571d2dce3eb183f33b381e --- /dev/null +++ b/python_module/test/integration/test_distributed.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import multiprocessing as mp +import subprocess +import sys + +import numpy as np + + +def worker(master_ip, master_port, world_size, rank, dev, trace): + import megengine.distributed as dist + import megengine.functional as F + from megengine import is_cuda_available + from megengine import jit + from megengine.module import Linear, Module + from megengine.optimizer import SGD + + if not is_cuda_available(): + return + + class MLP(Module): + def __init__(self): + super().__init__() + self.fc0 = Linear(3 * 224 * 224, 500) + self.fc1 = Linear(500, 10) + + def forward(self, x): + x = self.fc0(x) + x = F.relu(x) + x = self.fc1(x) + return x + + dist.init_process_group( + master_ip=master_ip, master_port=3456, world_size=world_size, rank=rank, dev=dev + ) + net = MLP() + + opt = SGD(net.parameters(requires_grad=True), lr=0.02) + + data = np.random.random((64, 3 * 224 * 224)).astype(np.float32) + label = np.random.randint(0, 10, size=(64,)).astype(np.int32) + + jit.trace.enabled = trace + + @jit.trace() + def train_func(data, label): + pred = net(data) + loss = F.cross_entropy_with_softmax(pred, label) + opt.backward(loss) + return loss + + for i in range(5): + opt.zero_grad() + loss = train_func(data, label) + opt.step() + + +def start_workers(worker, world_size, trace=False): + def run_subproc(rank): + cmd = "from test.integration.test_distributed import worker\n" + cmd += "worker('localhost', 3456, {}, {}, {}, {})".format( + world_size, rank, rank, "True" if trace else "False" + ) + cmd = ["python3", "-c", cmd] + ret = subprocess.run( + cmd, stdout=sys.stdout, stderr=sys.stderr, universal_newlines=True + ) + assert ret.returncode == 0, "subprocess failed" + + procs = [] + for rank in range(world_size): + p = mp.Process(target=run_subproc, args=(rank,)) + p.start() + procs.append(p) + + for p in procs: + p.join() + assert p.exitcode == 0 + + +def test_distributed(): + start_workers(worker, 2, trace=True) + start_workers(worker, 2, trace=False)