# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import multiprocessing as mp import subprocess import sys import numpy as np def worker(master_ip, master_port, world_size, rank, dev, trace): import megengine.distributed as dist import megengine.functional as F from megengine import is_cuda_available from megengine import jit from megengine.module import Linear, Module from megengine.optimizer import SGD if not is_cuda_available(): return class MLP(Module): def __init__(self): super().__init__() self.fc0 = Linear(3 * 224 * 224, 500) self.fc1 = Linear(500, 10) def forward(self, x): x = self.fc0(x) x = F.relu(x) x = self.fc1(x) return x dist.init_process_group( master_ip=master_ip, master_port=3456, world_size=world_size, rank=rank, dev=dev ) net = MLP() opt = SGD(net.parameters(requires_grad=True), lr=0.02) data = np.random.random((64, 3 * 224 * 224)).astype(np.float32) label = np.random.randint(0, 10, size=(64,)).astype(np.int32) jit.trace.enabled = trace @jit.trace() def train_func(data, label): pred = net(data) loss = F.cross_entropy_with_softmax(pred, label) opt.backward(loss) return loss for i in range(5): opt.zero_grad() loss = train_func(data, label) opt.step() def start_workers(worker, world_size, trace=False): def run_subproc(rank): cmd = "from test.integration.test_distributed import worker\n" cmd += "worker('localhost', 3456, {}, {}, {}, {})".format( world_size, rank, rank, "True" if trace else "False" ) cmd = ["python3", "-c", cmd] ret = subprocess.run( cmd, stdout=sys.stdout, stderr=sys.stderr, universal_newlines=True ) assert ret.returncode == 0, "subprocess failed" procs = [] for rank in range(world_size): p = mp.Process(target=run_subproc, args=(rank,)) p.start() procs.append(p) for p in procs: p.join() assert p.exitcode == 0 def test_distributed(): start_workers(worker, 2, trace=True) start_workers(worker, 2, trace=False)