test_distributed.py 2.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import multiprocessing as mp
import subprocess
import sys

import numpy as np


def worker(master_ip, master_port, world_size, rank, dev, trace):
    import megengine.distributed as dist
    import megengine.functional as F
    from megengine import is_cuda_available
    from megengine import jit
    from megengine.module import Linear, Module
    from megengine.optimizer import SGD

    if not is_cuda_available():
        return

    class MLP(Module):
        def __init__(self):
            super().__init__()
            self.fc0 = Linear(3 * 224 * 224, 500)
            self.fc1 = Linear(500, 10)

        def forward(self, x):
            x = self.fc0(x)
            x = F.relu(x)
            x = self.fc1(x)
            return x

    dist.init_process_group(
        master_ip=master_ip, master_port=3456, world_size=world_size, rank=rank, dev=dev
    )
    net = MLP()

    opt = SGD(net.parameters(requires_grad=True), lr=0.02)

    data = np.random.random((64, 3 * 224 * 224)).astype(np.float32)
    label = np.random.randint(0, 10, size=(64,)).astype(np.int32)

    jit.trace.enabled = trace

    @jit.trace()
    def train_func(data, label):
        pred = net(data)
        loss = F.cross_entropy_with_softmax(pred, label)
        opt.backward(loss)
        return loss

    for i in range(5):
        opt.zero_grad()
        loss = train_func(data, label)
        opt.step()


def start_workers(worker, world_size, trace=False):
    def run_subproc(rank):
        cmd = "from test.integration.test_distributed import worker\n"
        cmd += "worker('localhost', 3456, {}, {}, {}, {})".format(
            world_size, rank, rank, "True" if trace else "False"
        )
70
        cmd = [sys.executable, "-c", cmd]
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
        ret = subprocess.run(
            cmd, stdout=sys.stdout, stderr=sys.stderr, universal_newlines=True
        )
        assert ret.returncode == 0, "subprocess failed"

    procs = []
    for rank in range(world_size):
        p = mp.Process(target=run_subproc, args=(rank,))
        p.start()
        procs.append(p)

    for p in procs:
        p.join()
        assert p.exitcode == 0


def test_distributed():
    start_workers(worker, 2, trace=True)
    start_workers(worker, 2, trace=False)