opr_shallow_copy.cpp 7.4 KB
Newer Older
1 2 3 4 5
/**
 * \file src/serialization/impl/opr_shallow_copy.cpp
 *
 * This file is part of MegBrain, a deep learning framework developed by Megvii.
 *
6
 * \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
 *
 */

#include "megbrain/serialization/opr_shallow_copy.h"

#include "megbrain/gopt/basic_arith.h"
#include "megbrain/serialization/opr_load_dump.h"
#include "megbrain/serialization/opr_registry.h"
#include "megbrain/utils/big_key_hashmap.h"

using namespace mgb;
using namespace serialization;

namespace {
//! dump single opr to memory for shallow copy
class OprDumpContextMemory final : public OprDumpContextRawPOD {
    std::vector<uint8_t> m_buf;

    void write_raw(const void* data, size_t size) override {
        auto pos = m_buf.size();
        auto end = pos + size;
        if (end > m_buf.capacity())
            m_buf.reserve(end * 2);
        m_buf.resize(end);
        memcpy(m_buf.data() + pos, data, size);
    }

M
Megvii Engine Team 已提交
34 35 36
    void dump_tensor(
            const std::string&, const HostTensorND&, TensorWriteMethod) override {
        mgb_throw(GraphError, "OprDumpContextMemory does not support dump tensor");
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
    }

    const GraphDumpConfig& config() const override {
        mgb_throw(GraphError, "OprDumpContextMemory has no associated config");
    }

public:
    OprDumpContextMemory() : OprDumpContextRawPOD(false) {}

    auto&& buf() const { return m_buf; }
};

//! load single opr from memory for shallow copy
class OprLoadContextMemory final : public OprLoadContextRawPOD {
    const uint8_t* m_ptr;
    size_t m_size, m_pos = 0;
    ComputingGraph* m_graph;

    void read_raw(void* dest, size_t size) override {
        auto end = m_pos + size;
        mgb_assert(end <= m_size);
        memcpy(dest, m_ptr + m_pos, size);
        m_pos = end;
    }

    ComputingGraph& graph() override { return *m_graph; }

    std::shared_ptr<HostTensorND> load_tensor() override { mgb_assert(0); }

M
Megvii Engine Team 已提交
66
    std::shared_ptr<DeviceTensorND> load_tensor_shared() override { mgb_assert(0); }
67 68 69 70 71 72

    const GraphLoadConfig& config() const override {
        mgb_throw(GraphError, "OprLoadContextMemory has no associated config");
    }

public:
M
Megvii Engine Team 已提交
73
    OprLoadContextMemory(ComputingGraph* graph, const OprDumpContextMemory& dumper)
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
            : OprLoadContextRawPOD(false),
              m_ptr{dumper.buf().data()},
              m_size{dumper.buf().size()},
              m_graph{graph} {}

    ~OprLoadContextMemory() { mgb_assert(m_pos == m_size); }
};

class ShallowCopyCacheContainer final : public UserDataContainer::UserData {
    MGB_TYPEINFO_OBJ_DECL;

    struct HashEq {
        template <typename T>
        static bool eq(const T& x, const T& y) {
            return x == y;
        }
M
Megvii Engine Team 已提交
90
        static bool eq(const OperatorNodeConfig& x, const OperatorNodeConfig& y) {
91 92
            return x.is_same(y);
        }
M
Megvii Engine Team 已提交
93
        static size_t hash(const void* ptr) { return std::hash<const void*>{}(ptr); }
94 95 96
        static size_t hash(const VarNodeArray& inputs) {
            return PODHash<VarNode*>::perform(inputs.data(), inputs.size());
        }
M
Megvii Engine Team 已提交
97
        static size_t hash(const OperatorNodeConfig& config) { return config.hash(); }
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
    };

public:
    big_key_hash_map::BigKeyHashMap<
            cg::OperatorNodeBase*, HashEq,
            big_key_hash_map::Copy<const cg::OperatorNodeBase*>,
            big_key_hash_map::Ref<VarNodeArray>,
            big_key_hash_map::Ref<OperatorNodeConfig>>
            cache;
};
MGB_TYPEINFO_OBJ_IMPL(ShallowCopyCacheContainer);

}  // anonymous namespace

ComputingGraph* serialization::OprShallowCopyContext::owner_graph(
        const cg::OperatorNodeBase& opr, const VarNodeArray& inputs) const {
    if (!m_owner_graph) {
        if (inputs.empty())
            return opr.owner_graph();
        return inputs[0]->owner_graph();
    }
    if (!inputs.empty())
        mgb_assert(m_owner_graph == inputs[0]->owner_graph());

    return m_owner_graph;
}

cg::OperatorNodeBase* serialization::copy_opr_shallow(
        const cg::OperatorNodeBase& opr, const VarNodeArray& inputs,
        const OperatorNodeConfig& config, const OprShallowCopyContext& ctx) {
128 129 130 131 132 133
    OprShallowCopy shallow_copy = nullptr;
    if (auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo())) {
        shallow_copy = registry->shallow_copy;
    } else {
        shallow_copy = intl::copy_opr_shallow_default_impl;
    }
134 135 136 137 138

    mgb_assert(inputs.size() == opr.input().size());
    auto dst_og = ctx.owner_graph(opr, inputs);
    auto do_copy = [&]() {
        auto nr_opr_before = opr.owner_graph()->nr_oprs_in_graph();
139
        auto ret = shallow_copy(ctx, opr, inputs, config);
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162

        if (dst_og != opr.owner_graph() ||
            opr.owner_graph()->nr_oprs_in_graph() != nr_opr_before) {
            auto&& attr = ret->node_prop().attribute();
            if (!attr.src_opr) {
                auto src = cg::get_opr_root_source_opr(
                        const_cast<cg::OperatorNodeBase*>(&opr));
                if (ret != src)
                    attr.src_opr = src;
            }
            if (!attr.priority) {
                // priority may have been changed by OprInserted event handlers
                // (like in python case)
                attr.priority = opr.node_prop().attribute().priority;
            }
        }
        return ret;
    };
    cg::OperatorNodeBase* ret;
    if (dst_og == opr.owner_graph()) {
        // use cache for copy in same graph
        auto&& cache =
                dst_og->options()
M
Megvii Engine Team 已提交
163
                        .user_data.get_user_data_or_create<ShallowCopyCacheContainer>()
164 165 166 167 168 169 170 171 172 173 174 175
                        ->cache;
        auto ins = cache.get(&opr, inputs, config);
        if (ins.first) {
            *ins.second = do_copy();
        } else {
            cg::update_output_var_shapes(*ins.second);
        }
        ret = *ins.second;
    } else {
        ret = do_copy();
    }

M
Megvii Engine Team 已提交
176 177 178 179 180 181 182 183 184
    mgb_assert(
            gopt::has_inplace_basic_arith_opt(opr) ||
                    ((  // outputs match
                             opr.usable_output().size() ==
                             ret->usable_output().size()) &&
                     (  // new opr is returned
                             (&opr != ret) || opr.input() == inputs)),
            "bad opr copy: src=%s{%s} dst=%s{%s}", opr.cname(),
            opr.dyn_typeinfo()->name, ret->cname(), ret->dyn_typeinfo()->name);
185 186 187 188 189 190 191 192

    return ret;
}

cg::OperatorNodeBase* serialization::intl::copy_opr_shallow_default_impl(
        const OprShallowCopyContext& ctx, const cg::OperatorNodeBase& opr,
        const VarNodeArray& inputs, const OperatorNodeConfig& config) {
    MGB_MARK_USED_VAR(ctx);
193 194
    OprDumper opr_dumper = nullptr;
    OprLoaderWrapper opr_loader = nullptr;
195

196 197 198 199 200 201 202 203 204
    if (auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo())) {
        opr_loader = registry->loader;
        opr_dumper = registry->dumper;
    } else {
        auto registryv2 = OprRegistryV2::versioned_find_by_typeinfo(
                opr.dyn_typeinfo(), CURRENT_VERSION);
        opr_loader = registryv2->loader;
        opr_dumper = registryv2->dumper;
    }
M
Megvii Engine Team 已提交
205
    mgb_assert(
206
            opr_dumper && opr_loader,
M
Megvii Engine Team 已提交
207 208 209
            "can not shallow_copy operator %s{%s}: "
            "no dumper/loader registered",
            opr.cname(), opr.dyn_typeinfo()->name);
210 211
    OprDumpContextMemory memory_dumper;
    opr_dumper(memory_dumper, opr);
212

213 214
    OprLoadContextMemory loader{opr.owner_graph(), memory_dumper};
    return opr_loader(loader, inputs, config).opr();
215 216 217
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}