collective_comm.cpp 1.8 KB
Newer Older
1
/**
M
Megvii Engine Team 已提交
2 3
 * \file imperative/src/test/collective_comm.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
4
 *
M
Megvii Engine Team 已提交
5
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
6
 *
M
Megvii Engine Team 已提交
7 8 9
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 11 12
 */

#include "./helper.h"
13
#include "megbrain/imperative/ops/autogen.h"
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
#include "megbrain/opr/mm_handler.h"

using namespace mgb;
using namespace imperative;

TEST(TestImperative, AllReduceBasic) {
    REQUIRE_GPU(2);
    const char* server_addr = "127.0.0.1";
    uint32_t port = 3456;
    mgb_assert(create_zmqrpc_server(server_addr, port) > 0);
    HostTensorGenerator<> gen;
    CompNode cn0 = CompNode::load("gpu0"),
             cn1 = CompNode::load("gpu1");

    auto host_x = gen({233}, cn0), host_y = gen({233}, cn1);
    auto expect = gen({233});
    for (size_t i = 0; i < 233; ++ i) {
        expect->ptr<float>()[i] = host_x->ptr<float>()[i] + host_y->ptr<float>()[i];
    }

    auto run = [&](std::shared_ptr<HostTensorND> hnd, uint32_t idx) {
35 36
        auto def =
            imperative::CollectiveComm::make(
37
                megdnn::param::CollectiveComm::Mode::ALL_REDUCE_SUM,
38 39
                "all_reduce", 2, idx, idx==0, false, server_addr, port,
                dtype::Float32(), "nccl", "");
40
        auto inp = Tensor::make(*hnd);
41
        auto oup = OpDef::apply_on_physical_tensor(*def, {inp});
42 43 44 45 46 47 48 49 50 51 52 53 54
        HostTensorND host_v;
        host_v.copy_from(oup[0]->dev_tensor()).sync();
        MGB_ASSERT_TENSOR_NEAR(*expect, host_v, 1e-6);
    };

    std::thread t0(std::bind(run, host_x, 0));
    std::thread t1(std::bind(run, host_y, 1));

    t0.join();
    t1.join();
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}