opr_attr.cpp 4.3 KB
Newer Older
1
/**
M
Megvii Engine Team 已提交
2 3
 * \file imperative/src/impl/ops/opr_attr.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
4
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6
 *
M
Megvii Engine Team 已提交
7 8 9
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 11 12 13 14 15
 */

#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/serialization/opr_load_dump.h"

#include "../op_trait.h"
16
#include "megbrain/imperative/proxy_graph_detail.h"
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81

namespace mgb {
namespace imperative {

namespace {
class OprParamsLoadContext final: public serialization::OprLoadContextRawPOD {
    const OprAttr::Param& m_param;
    size_t m_pos = 0;
    ComputingGraph *m_graph;

    void read_raw(void *dest, size_t size) override final {
        mgb_assert(m_pos + size <= m_param.size(), "too many bytes requested");
        memcpy(dest, m_param.data() + m_pos, size);
        m_pos += size;
    }

    std::shared_ptr<HostTensorND> load_tensor() override {
        mgb_assert(0);
    }

    std::shared_ptr<DeviceTensorND> load_tensor_shared() override {
        mgb_assert(0);
    }

    const serialization::GraphLoadConfig& config() const override {
        mgb_assert(0);
    }

    public:
        OprParamsLoadContext(const OprAttr::Param& param,
                ComputingGraph *graph):
            serialization::OprLoadContextRawPOD(false), m_param(param), m_graph(graph)
        {}

        ~OprParamsLoadContext() {
            mgb_assert(m_pos == m_param.size(), "param not fully consumed");
        }

        ComputingGraph& graph() override {
            return *m_graph;
        }
};

class OprParamsDumpContext final: public serialization::OprDumpContextRawPOD {
public:
    OprAttr::Param m_param;
    OprParamsDumpContext() : serialization::OprDumpContextRawPOD(false) {}
    void write_raw(const void *data, size_t size) {
        const char* src = static_cast<const char*>(data);
        m_param.insert(m_param.end(), src, src + size);
    }
    void dump_tensor(
            const std::string &name,
            const HostTensorND &tensor,
            TensorWriteMethod method) {
        mgb_assert(0);
    }
    const serialization::GraphDumpConfig& config() const {
        mgb_assert(0);
    }
};

cg::OperatorNodeBase* apply_on_var_node(
        const OpDef& def, const VarNodeArray& inputs) {
    auto&& attr = def.cast_final_safe<OprAttr>();
82 83
    auto config = attr.config;
    config.name(attr.make_name());
84 85 86 87
    mgb_assert(!inputs.empty());
    auto registry = serialization::OprRegistry::find_by_name(attr.type);
    mgb_assert(registry, "operator %s not found", attr.type.c_str());
    OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()};
88
    return registry->loader(ctx, inputs, config);
89 90 91 92 93 94 95 96 97 98 99
}

std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) {
    OprParamsDumpContext ctx;
    auto registry = serialization::OprRegistry::find_by_type(opr->dyn_typeinfo());
    mgb_assert(registry, "operator %s not found", opr->dyn_typeinfo()->name);
    mgb_assert(registry->dumper, "operator %s cannot be serialized", opr->dyn_typeinfo()->name);
    registry->dumper(ctx, *opr);
    return OprAttr::make(registry->name, std::move(ctx.m_param), opr->config());
}

100 101 102 103
std::vector<std::pair<const char*, std::string>> props(const OpDef& def) {
    return {};
}

104 105 106 107
std::string make_name(const OpDef& def) {
    return "OprAttr";
}

108 109 110
OP_TRAIT_REG(OprAttr, OprAttr)
    .make_from_op_node(make_from_op_node)
    .apply_on_var_node(apply_on_var_node)
111
    .props(props)
112
    .make_name(make_name)
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
    .fallback();

} // anonymous namespace

bool OprAttr::is_same_st(const Hashable& rhs_) const {
    auto&& rhs = static_cast<const OprAttr&>(rhs_);
    return type == rhs.type && param == rhs.param
        && config.comp_node() == rhs.config.comp_node()
        && config.output_dtype() == rhs.config.output_dtype();
}

size_t OprAttr::hash() const {
    return hash_pair_combine(
            hash_pair_combine(
                mgb::hash(type),
                mgb::hash(static_cast<std::vector<char>>(param))),
            config.hash());
}

MGB_DYN_TYPE_OBJ_FINAL_IMPL(OprAttr);

} // namespace imperative
} // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}