/** * \file src/serialization/impl/opr_shallow_copy.cpp * * This file is part of MegBrain, a deep learning framework developed by Megvii. * * \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * */ #include "megbrain/serialization/opr_shallow_copy.h" #include "megbrain/gopt/basic_arith.h" #include "megbrain/serialization/opr_load_dump.h" #include "megbrain/serialization/opr_registry.h" #include "megbrain/utils/big_key_hashmap.h" using namespace mgb; using namespace serialization; namespace { //! dump single opr to memory for shallow copy class OprDumpContextMemory final : public OprDumpContextRawPOD { std::vector m_buf; void write_raw(const void* data, size_t size) override { auto pos = m_buf.size(); auto end = pos + size; if (end > m_buf.capacity()) m_buf.reserve(end * 2); m_buf.resize(end); memcpy(m_buf.data() + pos, data, size); } void dump_tensor( const std::string&, const HostTensorND&, TensorWriteMethod) override { mgb_throw(GraphError, "OprDumpContextMemory does not support dump tensor"); } const GraphDumpConfig& config() const override { mgb_throw(GraphError, "OprDumpContextMemory has no associated config"); } public: OprDumpContextMemory() : OprDumpContextRawPOD(false) {} auto&& buf() const { return m_buf; } }; //! load single opr from memory for shallow copy class OprLoadContextMemory final : public OprLoadContextRawPOD { const uint8_t* m_ptr; size_t m_size, m_pos = 0; ComputingGraph* m_graph; void read_raw(void* dest, size_t size) override { auto end = m_pos + size; mgb_assert(end <= m_size); memcpy(dest, m_ptr + m_pos, size); m_pos = end; } ComputingGraph& graph() override { return *m_graph; } std::shared_ptr load_tensor() override { mgb_assert(0); } std::shared_ptr load_tensor_shared() override { mgb_assert(0); } const GraphLoadConfig& config() const override { mgb_throw(GraphError, "OprLoadContextMemory has no associated config"); } public: OprLoadContextMemory(ComputingGraph* graph, const OprDumpContextMemory& dumper) : OprLoadContextRawPOD(false), m_ptr{dumper.buf().data()}, m_size{dumper.buf().size()}, m_graph{graph} {} ~OprLoadContextMemory() { mgb_assert(m_pos == m_size); } }; class ShallowCopyCacheContainer final : public UserDataContainer::UserData { MGB_TYPEINFO_OBJ_DECL; struct HashEq { template static bool eq(const T& x, const T& y) { return x == y; } static bool eq(const OperatorNodeConfig& x, const OperatorNodeConfig& y) { return x.is_same(y); } static size_t hash(const void* ptr) { return std::hash{}(ptr); } static size_t hash(const VarNodeArray& inputs) { return PODHash::perform(inputs.data(), inputs.size()); } static size_t hash(const OperatorNodeConfig& config) { return config.hash(); } }; public: big_key_hash_map::BigKeyHashMap< cg::OperatorNodeBase*, HashEq, big_key_hash_map::Copy, big_key_hash_map::Ref, big_key_hash_map::Ref> cache; }; MGB_TYPEINFO_OBJ_IMPL(ShallowCopyCacheContainer); } // anonymous namespace ComputingGraph* serialization::OprShallowCopyContext::owner_graph( const cg::OperatorNodeBase& opr, const VarNodeArray& inputs) const { if (!m_owner_graph) { if (inputs.empty()) return opr.owner_graph(); return inputs[0]->owner_graph(); } if (!inputs.empty()) mgb_assert(m_owner_graph == inputs[0]->owner_graph()); return m_owner_graph; } cg::OperatorNodeBase* serialization::copy_opr_shallow( const cg::OperatorNodeBase& opr, const VarNodeArray& inputs, const OperatorNodeConfig& config, const OprShallowCopyContext& ctx) { OprShallowCopy shallow_copy = nullptr; if (auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo())) { shallow_copy = registry->shallow_copy; } else { shallow_copy = intl::copy_opr_shallow_default_impl; } mgb_assert(inputs.size() == opr.input().size()); auto dst_og = ctx.owner_graph(opr, inputs); auto do_copy = [&]() { auto nr_opr_before = opr.owner_graph()->nr_oprs_in_graph(); auto ret = shallow_copy(ctx, opr, inputs, config); if (dst_og != opr.owner_graph() || opr.owner_graph()->nr_oprs_in_graph() != nr_opr_before) { auto&& attr = ret->node_prop().attribute(); if (!attr.src_opr) { auto src = cg::get_opr_root_source_opr( const_cast(&opr)); if (ret != src) attr.src_opr = src; } if (!attr.priority) { // priority may have been changed by OprInserted event handlers // (like in python case) attr.priority = opr.node_prop().attribute().priority; } } return ret; }; cg::OperatorNodeBase* ret; if (dst_og == opr.owner_graph()) { // use cache for copy in same graph auto&& cache = dst_og->options() .user_data.get_user_data_or_create() ->cache; auto ins = cache.get(&opr, inputs, config); if (ins.first) { *ins.second = do_copy(); } else { cg::update_output_var_shapes(*ins.second); } ret = *ins.second; } else { ret = do_copy(); } mgb_assert( gopt::has_inplace_basic_arith_opt(opr) || (( // outputs match opr.usable_output().size() == ret->usable_output().size()) && ( // new opr is returned (&opr != ret) || opr.input() == inputs)), "bad opr copy: src=%s{%s} dst=%s{%s}", opr.cname(), opr.dyn_typeinfo()->name, ret->cname(), ret->dyn_typeinfo()->name); return ret; } cg::OperatorNodeBase* serialization::intl::copy_opr_shallow_default_impl( const OprShallowCopyContext& ctx, const cg::OperatorNodeBase& opr, const VarNodeArray& inputs, const OperatorNodeConfig& config) { MGB_MARK_USED_VAR(ctx); OprDumper opr_dumper = nullptr; OprLoaderWrapper opr_loader = nullptr; if (auto registry = OprRegistry::find_by_type(opr.dyn_typeinfo())) { opr_loader = registry->loader; opr_dumper = registry->dumper; } else { auto registryv2 = OprRegistryV2::versioned_find_by_typeinfo( opr.dyn_typeinfo(), CURRENT_VERSION); opr_loader = registryv2->loader; opr_dumper = registryv2->dumper; } mgb_assert( opr_dumper && opr_loader, "can not shallow_copy operator %s{%s}: " "no dumper/loader registered", opr.cname(), opr.dyn_typeinfo()->name); OprDumpContextMemory memory_dumper; opr_dumper(memory_dumper, opr); OprLoadContextMemory loader{opr.owner_graph(), memory_dumper}; return opr_loader(loader, inputs, config).opr(); } // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}