io_remote.cpp 2.2 KB
Newer Older
1
/**
M
Megvii Engine Team 已提交
2 3
 * \file imperative/src/impl/ops/io_remote.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
4
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6
 *
M
Megvii Engine Team 已提交
7 8 9
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
 */
M
Megvii Engine Team 已提交
11

12 13 14 15
#include "megbrain_build_config.h"

#if MGB_ENABLE_OPR_MM
#include "../op_trait.h"
16
#include "megbrain/imperative/proxy_graph_detail.h"
17 18 19 20
#include "megbrain/opr/io_remote.h"
#include "megbrain/opr/mm_handler.h"
#endif // MGB_ENABLE_OPR_MM

21
#include "megbrain/imperative/ops/autogen.h"
22 23 24 25 26 27 28 29 30 31 32 33 34

namespace mgb {
namespace imperative {

#if MGB_ENABLE_OPR_MM
namespace {
cg::OperatorNodeBase* apply_on_var_node_remote_send(
        const OpDef& def, const VarNodeArray& inputs) {
    auto&& send = def.cast_final_safe<RemoteSend>();
    auto group_client = std::make_shared<GroupClientProxy>(
            ssprintf("%s:%d", send.addr.data(), send.port));
    auto&& graph = inputs[0]->owner_graph();

35
    OperatorNodeConfig config{send.make_name()};
36 37 38 39 40 41 42 43 44
    cg::OperatorNodeBase* opr =
            graph->insert_opr(std::make_unique<mgb::opr::RemoteSend>(
                    send.key, inputs[0], group_client, true, config));
    return opr;
}

cg::OperatorNodeBase* apply_on_var_node_remote_recv(
        const OpDef& def, const VarNodeArray& inputs) {
    auto&& recv = def.cast_final_safe<RemoteRecv>();
45 46
    OperatorNodeConfig config{recv.cn};
    config.name(recv.make_name());
47 48 49 50
    auto group_client = std::make_shared<GroupClientProxy>(
            ssprintf("%s:%d", recv.addr.data(), recv.port));
    auto&& graph = inputs[0]->owner_graph();
    return graph->insert_opr(std::make_unique<mgb::opr::RemoteRecv>(
51
            recv.key, inputs[0], *graph, group_client, config,
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
            recv.shape, recv.dtype));
}

OP_TRAIT_REG(RemoteSend, RemoteSend, mgb::opr::RemoteSend)
        .apply_on_var_node(apply_on_var_node_remote_send)
        .fallback();

OP_TRAIT_REG(RemoteRecv, RemoteRecv, mgb::opr::RemoteRecv)
        .apply_on_var_node(apply_on_var_node_remote_recv)
        .fallback();
}  // anonymous namespace
#endif // MGB_ENABLE_OPR_MM

}  // namespace imperative
}  // namespace mgb