op_trait.h 4.5 KB
Newer Older
1
/**
M
Megvii Engine Team 已提交
2 3
 * \file imperative/src/impl/op_trait.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
4
 *
M
Megvii Engine Team 已提交
5
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
6
 *
M
Megvii Engine Team 已提交
7 8 9
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 11 12 13 14 15 16 17 18
 */

#pragma once

#include "megbrain/imperative/op_def.h"

namespace mgb {
namespace imperative {

19
namespace detail {
20
template <typename Signature>
21
struct OpMeth;
22 23
template <typename RType, typename... Args>
struct OpMeth<RType(Args...)> : public thin_function<RType(Args...)> {
24 25 26 27 28 29
    using Base = thin_function<RType(Args...)>;
    using Base::Base;
    RType operator()(Args... args) const {
        if (!this->Base::operator bool()) {
            mgb_throw(MegBrainError, "Not Implemented");
        }
30
        return this->Base::operator()(std::forward<Args>(args)...);
31 32
    }
};
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
template<typename T>
struct ToVarNodeArray: std::false_type {};
template<>
struct ToVarNodeArray<SymbolVar>: std::true_type {
    VarNodeArray operator()(const SymbolVar& inp) {
        return {inp.node()};
    }
};
template<>
struct ToVarNodeArray<SymbolVarArray>: std::true_type {
    VarNodeArray operator()(const SymbolVarArray& inputs) {
        return cg::to_var_node_array(inputs);
    }
};
template<size_t N>
struct ToVarNodeArray<std::array<SymbolVar, N>>: std::true_type {
    VarNodeArray operator()(const std::array<SymbolVar, N>& inp) {
        return cg::to_var_node_array({inp.begin(), inp.end()});
    }
};
template<>
struct ToVarNodeArray<cg::OperatorNodeBase*>: std::true_type {
    VarNodeArray operator()(const cg::OperatorNodeBase* opr) {
        return opr->usable_output();
    }
};
59
}  // namespace detail
60 61

using OpDefMaker = detail::OpMeth<
62
        decltype(OpDef::make_from_op_node)>;
63
using ApplyOnPhysicalTensor = detail::OpMeth<
64
        decltype(OpDef::apply_on_physical_tensor)>;
65
using ApplyOnVarNode = detail::OpMeth<
66
        decltype(OpDef::apply_on_var_node)>;
67
using InferOutputAttrsFallible = detail::OpMeth<
68
        decltype(OpDef::infer_output_attrs_fallible)>;
69
using GradMaker = detail::OpMeth<
70
        decltype(OpDef::make_backward_graph)>;
71 72
using HashFunc = detail::OpMeth<size_t(const OpDef&)>;
using IsSame = detail::OpMeth<bool(const OpDef&, const OpDef&)>;
73 74 75 76 77 78 79 80

struct OpTrait {
    const char* name;
    OpDefMaker make_from_op_node;
    ApplyOnPhysicalTensor apply_on_physical_tensor;
    ApplyOnVarNode apply_on_var_node;
    InferOutputAttrsFallible infer_output_attrs_fallible;
    GradMaker make_backward_graph;
81 82
    HashFunc hash;
    IsSame is_same_st;
83 84 85 86 87 88
    OpTrait(const char* name);
    static OpTrait* find_by_name(const char* name);
    static OpTrait* find_by_typeinfo(Typeinfo* type);
    static void for_each_trait(thin_function<void(OpTrait&)> visitor);
};

89 90 91 92 93
#define FOR_EACH_OP_METH(cb) \
    cb(make_from_op_node) \
    cb(apply_on_physical_tensor) \
    cb(apply_on_var_node) \
    cb(infer_output_attrs_fallible) \
94 95 96
    cb(make_backward_graph) \
    cb(hash) \
    cb(is_same_st)
97

98 99
struct OpTraitRegistry {
    OpTrait* trait;
100 101 102 103 104
#define DECL(meth) \
    OpTraitRegistry& meth(decltype(OpTrait::meth) f) { \
        mgb_assert(!trait->meth, "op %s has duplicate method %s", trait->name, #meth); \
        trait->meth = f; \
        return *this; \
105
    }
106 107 108
    FOR_EACH_OP_METH(DECL)
#undef DECL

109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
    OpTraitRegistry& fallback();

    template<typename T>
    void insert() {
        do_insert(T::typeinfo());
    }

    template<typename T0, typename T1, typename ...Ts>
    void insert() {
        insert<T0>();
        insert<T1, Ts...>();
    }

    template<typename ...Args>
    static OpTraitRegistry insert(const char* name) {
        auto&& ret = do_insert(name);
        ret.insert<Args...>();
        return ret;
    }

    void do_insert(Typeinfo* type);

    static OpTraitRegistry do_insert(const char* name);
132 133 134 135 136 137 138 139 140

    template<typename T,
        typename To = detail::ToVarNodeArray<T>,
        typename = std::enable_if_t<To::value>>
    OpTraitRegistry& apply_on_var_node(T (*f)(const OpDef&, const VarNodeArray&)) {
        return apply_on_var_node([=](const OpDef& opdef, const VarNodeArray& inputs) {
            return To()(f(opdef, inputs));
        });
    }
141 142 143 144 145 146 147
};

} // namespace imperative
} // namespace mgb

#define OP_TRAIT_REG(name, ...) \
    static OpTraitRegistry __##name##_global_registry__ = \
148
        OpTraitRegistry::insert<__VA_ARGS__>(#name)
149 150

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}