op_trait.cpp 3.9 KB
Newer Older
1
/**
M
Megvii Engine Team 已提交
2 3
 * \file imperative/src/impl/op_trait.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
4
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6
 *
M
Megvii Engine Team 已提交
7 8 9
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 11
 */

12
#include <exception>
13
#include <sstream>
14
#include <stdexcept>
15

16
#include "megbrain/imperative/op_def.h"
17
#include "megbrain/imperative/ops/opr_attr.h"
18 19
#include "megbrain/imperative/proxy_graph_detail.h"
#include "megbrain/tensor.h"
20 21 22 23 24 25 26 27 28 29

#include "./op_trait.h"

namespace mgb {
namespace imperative {

namespace detail {

struct StaticData {
    std::list<OpTrait> registries;
30
    std::unordered_map<std::string, OpTrait*> name2reg;
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
    std::unordered_map<Typeinfo*, OpTrait*> type2reg;
};

// use "Construct On First Use" to prevent "static initialization order fiasco"
// (i.e., ensure global registry was initialized before calling opr registration)
StaticData& static_data() {
    static StaticData data;
    return data;
}

} // detail

OpTrait::OpTrait(const char* name_): name(name_) {}

OpTrait* OpTrait::find_by_typeinfo(Typeinfo* type) {
    auto&& type2reg = detail::static_data().type2reg;
    auto iter = type2reg.find(type);
    if (iter == type2reg.end()) {
        return nullptr;
    }
    return iter->second;
}

OpTrait* OpTrait::find_by_name(const char* name) {
    auto&& name2reg = detail::static_data().name2reg;
    auto iter = name2reg.find(name);
    if (iter == name2reg.find(name)) {
        return nullptr;
    }
    return iter->second;
}

void OpTrait::for_each_trait(thin_function<void(OpTrait&)> visitor){
    for(auto& trait: detail::static_data().registries){
        visitor(trait);
    }
}

69 70 71 72 73 74
DispatchMode fallback_decide_dispatch_mode(
    const OpDef& def,
    const SmallVector<LogicalTensorDesc>& inputs) {
    return KERNEL;
}

75
OpTraitRegistry& OpTraitRegistry::fallback() {
76 77 78 79 80 81
    if (trait->apply_on_var_node) {
        // fallback to proxy graph impl
        if (!trait->apply_on_physical_tensor) {
            trait->apply_on_physical_tensor =
                    proxy_graph_detail::apply_on_physical_tensor;
        }
82 83 84 85 86 87 88
        if (!trait->execute) {
            trait->execute = proxy_graph_detail::execute;
        }
        if (!trait->infer_output_mem_desc) {
            trait->infer_output_mem_desc =
                    proxy_graph_detail::infer_output_mem_desc;
        }
89 90 91 92 93 94 95 96
        if (!trait->infer_output_attrs_fallible) {
            trait->infer_output_attrs_fallible =
                    proxy_graph_detail::infer_output_attrs_fallible;
        }
        if (!trait->make_backward_graph) {
            trait->make_backward_graph =
                    proxy_graph_detail::make_backward_graph;
        }
97
    }
98 99 100
    if (!trait->decide_dispatch_mode) {
        trait->decide_dispatch_mode = fallback_decide_dispatch_mode;
    }
101 102 103 104 105 106
    if (!trait->make_name) {
        static auto make_name = [](const OpDef& def) -> std::string {
            return def.trait()->name;
        };
        trait->make_name = make_name;
    }
107 108 109 110 111
    return *this;
}

void OpTraitRegistry::do_insert(Typeinfo* type) {
    auto&& sd = detail::static_data();
112 113 114
    auto ret = sd.type2reg.emplace(type, trait);
    mgb_assert(ret.second || ret.first->second == trait,
            "OpTrait for %s has already been registered", type->name);
115 116 117 118 119
}

OpTraitRegistry OpTraitRegistry::do_insert(const char* name) {
    auto&& sd = detail::static_data();
    if (name) {
120 121 122 123
        auto iter = sd.name2reg.find(name);
        if (iter != sd.name2reg.end()) {
            return {iter->second};
        }
124 125 126
    }
    sd.registries.emplace_back(name);
    auto ret = &sd.registries.back();
127 128 129
    if (name) {
        sd.name2reg.emplace(name, ret);
    }
130 131 132 133 134 135 136
    return {ret};
}

} // namespace imperative
} // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}