io_remote.h 3.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
/**
 * \file src/opr-mm/include/megbrain/opr/io_remote.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once

#include "megbrain/graph.h"
#include "megbrain/opr/internal/mixin_base.h"
#include "megbrain/opr/group_manager.h"

#include "megray.h"

namespace mgb {
namespace opr {

/*!
 * \brief base class for remote I/O nodes
 */
MGB_DEFINE_CLS_WITH_SUPER(RemoteIOBase, cg::SingleCNOperatorNodeBase) // {
    public:
28
        const std::string& key() const { return m_key; }
29 30 31 32 33 34

        std::shared_ptr<GroupClient> group_client() const {
            return m_group_client;
        }

    protected:
35
        std::string m_key;
36 37 38 39 40 41 42 43 44
        std::shared_ptr<GroupClient> m_group_client;
        std::shared_ptr<MegRay::Communicator> m_megray_comm;
        std::shared_ptr<MegRay::Context> m_megray_ctx;
        bool m_init = false;
        using Super::Super;
};

/*!
 * \brief send a variable to remote address; a virtual output is produced
45
 *        for expressing dependency
46 47 48
 */
MGB_DEFINE_OPR_CLASS(RemoteSend, RemoteIOBase) // {
    public:
49
        RemoteSend(const std::string& key, VarNode* var,
50
                   std::shared_ptr<GroupClient> group_client,
51
                   bool is_grad, const OperatorNodeConfig& config);
52 53

        static SymbolVar make(
54
                const std::string& key, SymbolVar var,
55
                std::shared_ptr<GroupClient> group_client,
56 57 58
                bool is_grad, const OperatorNodeConfig& config = {});

        bool is_grad() const { return m_is_grad; }
59 60 61

    private:
        HostTensorND m_output_val;
62
        bool m_is_grad;
63 64 65 66 67 68 69

        void scn_do_execute() override;
        void init_output_static_infer_desc() override;
        NodeProp* do_make_node_prop() const override;
};

/*!
70 71
 * \brief receive a variable from remote address; target computing node
 *        of the var must be specified in config
72 73 74
 */
MGB_DEFINE_OPR_CLASS(RemoteRecv, RemoteIOBase) // {
    public:
75
        RemoteRecv(const std::string& key, cg::ComputingGraph& graph,
76 77 78 79
                   std::shared_ptr<GroupClient> group_client,
                   const OperatorNodeConfig& config, const TensorShape& shape,
                   DType dtype);

80 81 82 83 84
        RemoteRecv(const std::string& key, VarNode* var, cg::ComputingGraph& graph,
                   std::shared_ptr<GroupClient> group_client,
                   const OperatorNodeConfig& config, const TensorShape& shape,
                   DType dtype);

85
        static SymbolVar make(
86
                const std::string& key, cg::ComputingGraph& graph,
87 88 89 90
                std::shared_ptr<GroupClient> group_client,
                const OperatorNodeConfig& config, const TensorShape& shape,
                DType dtype);

91 92 93 94 95 96
        static SymbolVar make(
                const std::string& key, SymbolVar var, cg::ComputingGraph& graph,
                std::shared_ptr<GroupClient> group_client,
                const OperatorNodeConfig& config, const TensorShape& shape,
                DType dtype);

97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
    private:
        const TensorShape m_shape;
        const DType m_dtype;
        const CompNode m_comp_node;
        DeviceTensorND m_dev_buffer;

        void scn_do_execute() override;
        void init_output_static_infer_desc() override;
        NodeProp* do_make_node_prop() const override;
};

} // namespace opr
} // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}