forward_sereg.cpp 9.8 KB
Newer Older
1
#include "./forward_sereg.h"
M
Megvii Engine Team 已提交
2 3
#include "./impl.h"
#include "megbrain/opr/internal/param_tag_defs.h"
4 5
#include "megbrain/serialization/opr_load_dump.h"
#include "megbrain/serialization/opr_shallow_copy.h"
M
Megvii Engine Team 已提交
6
#include "megbrain/serialization/serializer.h"
7 8 9 10 11 12 13

using namespace mgb;
using namespace mgb::opr::intl;
using namespace mgb::serialization;

namespace {

M
Megvii Engine Team 已提交
14 15
class LoopDumpContext : public UserDataContainer::UserData {
    MGB_TYPEINFO_OBJ_DECL;
16

M
Megvii Engine Team 已提交
17 18
public:
    ThinHashMap<VarNode*, size_t> ogvar2inpidx;
19

M
Megvii Engine Team 已提交
20 21 22 23 24 25 26 27
    static LoopDumpContext& from_dump_ctx(OprDumpContext& ctx) {
        auto ret = ctx.config().user_data->get_user_data<LoopDumpContext>();
        mgb_assert(ret.second);
        return *ret.first[ret.second - 1];
    }
};
class LoopLoadContext : public UserDataContainer::UserData {
    MGB_TYPEINFO_OBJ_DECL;
28

M
Megvii Engine Team 已提交
29 30 31
public:
    const VarNodeArray& input_vars;
    opr::Loop::Desc& desc;
32

M
Megvii Engine Team 已提交
33 34
    LoopLoadContext(const VarNodeArray& input_vars_, opr::Loop::Desc& desc_)
            : input_vars{input_vars_}, desc{desc_} {}
35

M
Megvii Engine Team 已提交
36 37 38 39 40 41
    static LoopLoadContext& from_load_ctx(OprLoadContext& ctx) {
        auto ret = ctx.config().user_data->get_user_data<LoopLoadContext>();
        mgb_assert(ret.second);
        return *ret.first[ret.second - 1];
    }
};
42

M
Megvii Engine Team 已提交
43 44
MGB_TYPEINFO_OBJ_IMPL(LoopDumpContext);
MGB_TYPEINFO_OBJ_IMPL(LoopLoadContext);
45

M
Megvii Engine Team 已提交
46
}  // anonymous namespace
47 48 49 50 51

namespace mgb {
namespace opr {
namespace intl {

M
Megvii Engine Team 已提交
52 53 54 55
//! use LoopSerializer because it is friend of LoopImpl
class LoopSerializer {
    using InputMaker = LoopImpl::InputMaker;
    using CounterProvider = LoopImpl::DescImplBase::CounterProvider;
56

M
Megvii Engine Team 已提交
57 58 59 60 61
    struct LoopParam {
        static constexpr uint32_t TAG = opr::param_tag::LOOP;
        Loop::Param opr_param;
        uint64_t cond_var_id;
    };
62

M
Megvii Engine Team 已提交
63 64 65 66 67
    struct InputMakerParam {
        static constexpr uint32_t TAG = opr::param_tag::LOOP_INPUT_MAKER;
        bool has_assign;
        uint64_t ogvar_id;  //! id of proxied var in owner graph
    };
68

M
Megvii Engine Team 已提交
69 70 71 72
    struct OutputListEntry {
        uint64_t subvar_id;
        LoopImpl::Desc::OutputMode mode;
    } MGB_PACKED;
73

M
Megvii Engine Team 已提交
74 75 76
    struct AssignListEntry {
        uint64_t dst_id, src_id;
    };
77

M
Megvii Engine Team 已提交
78
    static void dump_loop(OprDumpContext& ctx, const cg::OperatorNodeBase& opr);
79

M
Megvii Engine Team 已提交
80
    static void dump_input_maker(OprDumpContext& ctx, const cg::OperatorNodeBase& opr);
81

M
Megvii Engine Team 已提交
82 83
    static void dump_counter_provider(
            OprDumpContext& ctx, const cg::OperatorNodeBase& opr);
84

M
Megvii Engine Team 已提交
85 86 87
    static cg::OperatorNodeBase* load_loop(
            OprLoadContext& ctx, const cg::VarNodeArray& inputs,
            const OperatorNodeConfig& config);
88

M
Megvii Engine Team 已提交
89 90 91
    static cg::OperatorNodeBase* load_input_maker(
            OprLoadContext& ctx, const cg::VarNodeArray& inputs,
            const OperatorNodeConfig& config);
92

M
Megvii Engine Team 已提交
93 94 95
    static cg::OperatorNodeBase* load_counter_provider(
            OprLoadContext& ctx, const cg::VarNodeArray& inputs,
            const OperatorNodeConfig& config);
96

M
Megvii Engine Team 已提交
97 98
public:
    static void reg_all();
99

M
Megvii Engine Team 已提交
100 101 102 103 104 105
    // we need dedicated shallow_copy because some oprs can be copied
    // but can not be dumped; also record InterGraphVarTransformer
    static cg::OperatorNodeBase* shallow_copy(
            const OprShallowCopyContext& orig_ctx, const Loop& opr,
            const VarNodeArray& inputs, const OperatorNodeConfig& config);
};
106

M
Megvii Engine Team 已提交
107 108 109
}  // namespace intl
}  // namespace opr
}  // namespace mgb
110 111 112 113 114 115

namespace mgb {
namespace serialization {
namespace fbs {

template <>
M
Megvii Engine Team 已提交
116
struct SupportFlatBuffersSerialization<opr::intl::LoopSerializer::LoopParam> : No {};
117 118

template <>
M
Megvii Engine Team 已提交
119 120
struct SupportFlatBuffersSerialization<opr::intl::LoopSerializer::InputMakerParam>
        : No {};
121 122 123 124 125 126

}  // namespace fbs
}  // namespace serialization
}  // namespace mgb

cg::OperatorNodeBase* serialization::opr_shallow_copy_loop(
M
Megvii Engine Team 已提交
127 128
        const OprShallowCopyContext& ctx, const cg::OperatorNodeBase& opr,
        const VarNodeArray& inputs, const OperatorNodeConfig& config) {
129
    return opr::intl::LoopSerializer::shallow_copy(
M
Megvii Engine Team 已提交
130
            ctx, opr.cast_final_safe<opr::Loop>(), inputs, config);
131 132 133
}

void LoopSerializer::reg_all() {
134 135
    MGB_SEREG_OPR_INTL_CALL_ADD(opr::Loop, dump_loop, load_loop, true);
    MGB_SEREG_OPR_INTL_CALL_ADD(InputMaker, dump_input_maker, load_input_maker, true);
M
Megvii Engine Team 已提交
136
    MGB_SEREG_OPR_INTL_CALL_ADD(
137
            CounterProvider, dump_counter_provider, load_counter_provider, true);
138 139

    MGB_SEREG_OPR_INTL_CALL_ADD_V2(
140
            opr::Loop, dump_loop, load_loop, nullptr, 2, CURRENT_VERSION);
141 142 143 144 145 146
    MGB_SEREG_OPR_INTL_CALL_ADD_V2(
            InputMaker, dump_input_maker, load_input_maker, nullptr, 2,
            CURRENT_VERSION);
    MGB_SEREG_OPR_INTL_CALL_ADD_V2(
            CounterProvider, dump_counter_provider, load_counter_provider, nullptr, 2,
            CURRENT_VERSION);
147 148
}

M
Megvii Engine Team 已提交
149
void LoopSerializer::dump_loop(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) {
150
    bool dump_implemented = false;
M
Megvii Engine Team 已提交
151 152 153
    mgb_throw_if(
            !dump_implemented, SerializationError,
            "Serialization of Loop opr not implemented");
154 155 156
}

void LoopSerializer::dump_input_maker(
M
Megvii Engine Team 已提交
157 158 159 160 161
        OprDumpContext& ctx, const cg::OperatorNodeBase& opr) {
    auto&& ogvar2inpidx = LoopDumpContext::from_dump_ctx(ctx).ogvar2inpidx;
    auto&& opr_im = opr.cast_final_safe<InputMaker>();
    ctx.write_param<InputMakerParam>(
            {opr_im.param().has_assign, ogvar2inpidx.at(opr_im.orig_var())});
162 163 164
}

void LoopSerializer::dump_counter_provider(
M
Megvii Engine Team 已提交
165
        OprDumpContext& ctx, const cg::OperatorNodeBase& opr) {
166 167 168 169 170 171
    // there is nothing needs to do
    MGB_MARK_USED_VAR(ctx);
    MGB_MARK_USED_VAR(opr);
}

cg::OperatorNodeBase* LoopSerializer::load_loop(
M
Megvii Engine Team 已提交
172 173
        OprLoadContext& ctx, const cg::VarNodeArray& inputs,
        const OperatorNodeConfig& config) {
174 175
    bool load_implemented = false;
    cg::OperatorNodeBase* load_result = nullptr;
M
Megvii Engine Team 已提交
176 177 178
    mgb_throw_if(
            !load_implemented, SerializationError,
            "Serialization of Loop opr not implemented");
179 180 181 182
    return load_result;
}

cg::OperatorNodeBase* LoopSerializer::load_input_maker(
M
Megvii Engine Team 已提交
183 184
        OprLoadContext& ctx, const cg::VarNodeArray& inputs,
        const OperatorNodeConfig& config) {
185
    MGB_MARK_USED_VAR(config);
M
Megvii Engine Team 已提交
186
    auto&& loop_load_ctx = LoopLoadContext::from_load_ctx(ctx);
187
    auto param = ctx.read_param<InputMakerParam>();
M
Megvii Engine Team 已提交
188 189 190 191
    return loop_load_ctx.desc
            .add_input(loop_load_ctx.input_vars.at(param.ogvar_id), param.has_assign)
            .node()
            ->owner_opr();
192 193 194
}

cg::OperatorNodeBase* LoopSerializer::load_counter_provider(
M
Megvii Engine Team 已提交
195 196
        OprLoadContext& ctx, const cg::VarNodeArray& inputs,
        const OperatorNodeConfig& config) {
197 198
    MGB_MARK_USED_VAR(inputs);
    mgb_assert(inputs.empty());
M
Megvii Engine Team 已提交
199
    auto&& loop_load_ctx = LoopLoadContext::from_load_ctx(ctx);
200 201 202 203
    return loop_load_ctx.desc.get_counter_var().node()->owner_opr();
}

cg::OperatorNodeBase* LoopSerializer::shallow_copy(
M
Megvii Engine Team 已提交
204 205
        const OprShallowCopyContext& orig_ctx, const Loop& opr,
        const VarNodeArray& inputs, const OperatorNodeConfig& config) {
206 207 208
    auto orig_desc = static_cast<LoopImpl::FwdDesc*>(opr.m_desc.get());
    ThinHashMap<VarNode*, size_t> ogvar2inpidx;

M
Megvii Engine Team 已提交
209 210
    mgb_assert(inputs.size() == opr.input().size());
    for (size_t i = 0; i < inputs.size(); ++i)
211 212 213 214
        ogvar2inpidx[opr.input(i)] = i;

    VarNodeArray cur_opr_inputs;
    auto varmap_buf = std::make_shared<ThinHashMap<VarNode*, VarNode*>>();
M
Megvii Engine Team 已提交
215
    auto desc_maker = [&](Loop::Desc& desc) {
216
        ThinHashMap<VarNode*, LoopImpl::InputMaker*> assignee2orig_im;
M
Megvii Engine Team 已提交
217
        auto&& varmap = *varmap_buf;
218 219 220

        // add inputs
        OprShallowCopyContext ctx{orig_ctx};
M
Megvii Engine Team 已提交
221
        for (auto inp : orig_desc->all_inputs()) {
222 223 224 225 226 227 228 229 230 231
            auto ogvar = inputs.at(ogvar2inpidx.at(inp->orig_var()));
            auto subvar = desc.add_input(ogvar, inp->param().has_assign);
            varmap[inp->output(0)] = subvar.node();
            if (inp->param().has_assign) {
                assignee2orig_im[subvar.node()] = inp;
            }
            ctx.owner_graph(subvar.node()->owner_graph());
        }

        // copy oprs
M
Megvii Engine Team 已提交
232
        for (auto opr : orig_desc->sub_graph_oprs()) {
233 234 235 236
            if (opr->same_type<LoopImpl::InputMaker>()) {
                continue;
            }

M
Megvii Engine Team 已提交
237
            if (opr->same_type<LoopImpl::DescImplBase::CounterProvider>()) {
238 239 240
                varmap[opr->output(0)] = desc.get_counter_var().node();
            } else {
                cur_opr_inputs.clear();
M
Megvii Engine Team 已提交
241
                for (auto i : opr->input())
242
                    cur_opr_inputs.push_back(varmap.at(i));
M
Megvii Engine Team 已提交
243 244
                auto new_opr =
                        copy_opr_shallow(*opr, cur_opr_inputs, opr->config(), ctx);
245
                mgb_assert(new_opr->output().size() == opr->output().size());
M
Megvii Engine Team 已提交
246
                for (size_t i = 0; i < new_opr->output().size(); ++i)
247 248 249 250
                    varmap[opr->output(i)] = new_opr->output(i);
            }
        }
        // add outputs in original order
M
Megvii Engine Team 已提交
251
        for (auto&& i : orig_desc->output_record_spec_no_dedup()) {
252 253 254
            desc.add_output(varmap.at(i->var_sub()), i->output_mode());
        }
        // add assignments
M
Megvii Engine Team 已提交
255
        for (auto&& i : assignee2orig_im) {
256 257
            desc.assign(i.first, varmap.at(i.second->assignor()));
        }
M
Megvii Engine Team 已提交
258
        desc.set_loop_condition(varmap.at(orig_desc->loop_cond_manager().var().node()));
259 260
    };

M
Megvii Engine Team 已提交
261 262
    auto&& ret =
            opr::Loop::make(desc_maker)[0].node()->owner_opr()->cast_final_safe<Loop>();
263 264
    mgb_assert(ret.output().size() == opr.output().size());

M
Megvii Engine Team 已提交
265
    auto trans_src_var = [varmap_buf](VarNode* src) -> VarNode* {
266
        auto iter = varmap_buf->find(src);
M
Megvii Engine Team 已提交
267 268
        mgb_throw_if(
                iter == varmap_buf->end(), GraphError,
269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
                "loop fwd shallow copy: "
                "can not to get copied var from unused src var: %s",
                cg::dump_var_info({src}).c_str());
        return iter->second;
    };
    cg::InterGraphVarTransformer::register_to(
            ret.m_desc->sub_graph(), opr.m_desc->sub_graph(), trans_src_var);

    return &ret;
}

void LoopSerializerReg::entry() {
    LoopSerializer::reg_all();
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}