/** * \file imperative/src/impl/ops/opr_attr.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/serialization/opr_load_dump.h" #include "../op_trait.h" #include "megbrain/imperative/proxy_graph_detail.h" namespace mgb { namespace imperative { namespace { class OprParamsLoadContext final : public serialization::OprLoadContextRawPOD { const OprAttr::Param& m_param; size_t m_pos = 0; ComputingGraph* m_graph; void read_raw(void* dest, size_t size) override final { mgb_assert(m_pos + size <= m_param.size(), "too many bytes requested"); memcpy(dest, m_param.data() + m_pos, size); m_pos += size; } std::shared_ptr load_tensor() override { mgb_assert(0); } std::shared_ptr load_tensor_shared() override { mgb_assert(0); } const serialization::GraphLoadConfig& config() const override { mgb_assert(0); } public: OprParamsLoadContext(const OprAttr::Param& param, ComputingGraph* graph) : serialization::OprLoadContextRawPOD(false), m_param(param), m_graph(graph) {} ~OprParamsLoadContext() { mgb_assert(m_pos == m_param.size(), "param not fully consumed"); } ComputingGraph& graph() override { return *m_graph; } }; class OprParamsDumpContext final : public serialization::OprDumpContextRawPOD { public: OprAttr::Param m_param; OprParamsDumpContext() : serialization::OprDumpContextRawPOD(false) {} void write_raw(const void* data, size_t size) { const char* src = static_cast(data); m_param.insert(m_param.end(), src, src + size); } void dump_tensor( const std::string& name, const HostTensorND& tensor, TensorWriteMethod method) { mgb_assert(0); } const serialization::GraphDumpConfig& config() const { mgb_assert(0); } }; VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { auto&& attr = def.cast_final_safe(); auto config = attr.config; config.name(attr.make_name()); mgb_assert(!inputs.empty()); auto registry = serialization::OprRegistry::find_by_name(attr.type); mgb_assert(registry, "operator %s not found", attr.type.c_str()); OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()}; return registry->loader(ctx, inputs, config).usable_output(); } std::shared_ptr make_from_op_node(cg::OperatorNodeBase* opr) { OprParamsDumpContext ctx; auto registry = serialization::OprRegistry::find_by_type(opr->dyn_typeinfo()); mgb_assert(registry, "operator %s not found", opr->dyn_typeinfo()->name); mgb_assert( registry->dumper, "operator %s cannot be serialized", opr->dyn_typeinfo()->name); registry->dumper(ctx, *opr); return OprAttr::make(registry->name, std::move(ctx.m_param), opr->config()); } std::vector> props(const OpDef& def) { return {}; } std::string make_name(const OpDef& def) { auto&& attr = def.cast_final_safe(); return attr.type; } OP_TRAIT_REG(OprAttr, OprAttr) .make_from_op_node(make_from_op_node) .apply_on_var_node(apply_on_var_node) .props(props) .make_name(make_name) .fallback(); } // anonymous namespace bool OprAttr::is_same_st(const Hashable& rhs_) const { auto&& rhs = static_cast(rhs_); return type == rhs.type && param == rhs.param && config.comp_node() == rhs.config.comp_node() && config.output_dtype() == rhs.config.output_dtype(); } size_t OprAttr::hash() const { return hash_pair_combine( hash_pair_combine( mgb::hash(type), mgb::hash(static_cast>(param))), config.hash()); } MGB_DYN_TYPE_OBJ_FINAL_IMPL(OprAttr); } // namespace imperative } // namespace mgb // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}