/** * \file imperative/src/impl/op_trait.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once #include "megbrain/imperative/op_def.h" namespace mgb { namespace imperative { namespace detail { template struct OpMeth; template struct OpMeth : public thin_function { using Base = thin_function; using Base::Base; RType operator()(Args... args) const { if (!this->Base::operator bool()) { mgb_throw(MegBrainError, "Not Implemented"); } return this->Base::operator()(std::forward(args)...); } }; template struct ToVarNodeArray: std::false_type {}; template<> struct ToVarNodeArray: std::true_type { VarNodeArray operator()(const SymbolVar& inp) { return {inp.node()}; } }; template<> struct ToVarNodeArray: std::true_type { VarNodeArray operator()(const SymbolVarArray& inputs) { return cg::to_var_node_array(inputs); } }; template struct ToVarNodeArray>: std::true_type { VarNodeArray operator()(const std::array& inp) { return cg::to_var_node_array({inp.begin(), inp.end()}); } }; template<> struct ToVarNodeArray: std::true_type { VarNodeArray operator()(const cg::OperatorNodeBase* opr) { return opr->usable_output(); } }; } // namespace detail using OpDefMaker = detail::OpMeth< decltype(OpDef::make_from_op_node)>; using ApplyOnPhysicalTensor = detail::OpMeth< decltype(OpDef::apply_on_physical_tensor)>; using ApplyOnVarNode = detail::OpMeth< decltype(OpDef::apply_on_var_node)>; using InferOutputAttrsFallible = detail::OpMeth< decltype(OpDef::infer_output_attrs_fallible)>; using GradMaker = detail::OpMeth< decltype(OpDef::make_backward_graph)>; using HashFunc = detail::OpMeth; using IsSame = detail::OpMeth; struct OpTrait { const char* name; OpDefMaker make_from_op_node; ApplyOnPhysicalTensor apply_on_physical_tensor; ApplyOnVarNode apply_on_var_node; InferOutputAttrsFallible infer_output_attrs_fallible; GradMaker make_backward_graph; HashFunc hash; IsSame is_same_st; OpTrait(const char* name); static OpTrait* find_by_name(const char* name); static OpTrait* find_by_typeinfo(Typeinfo* type); static void for_each_trait(thin_function visitor); }; #define FOR_EACH_OP_METH(cb) \ cb(make_from_op_node) \ cb(apply_on_physical_tensor) \ cb(apply_on_var_node) \ cb(infer_output_attrs_fallible) \ cb(make_backward_graph) \ cb(hash) \ cb(is_same_st) struct OpTraitRegistry { OpTrait* trait; #define DECL(meth) \ OpTraitRegistry& meth(decltype(OpTrait::meth) f) { \ mgb_assert(!trait->meth, "op %s has duplicate method %s", trait->name, #meth); \ trait->meth = f; \ return *this; \ } FOR_EACH_OP_METH(DECL) #undef DECL OpTraitRegistry& fallback(); template void insert() { do_insert(T::typeinfo()); } template void insert() { insert(); insert(); } template static OpTraitRegistry insert(const char* name) { auto&& ret = do_insert(name); ret.insert(); return ret; } void do_insert(Typeinfo* type); static OpTraitRegistry do_insert(const char* name); template, typename = std::enable_if_t> OpTraitRegistry& apply_on_var_node(T (*f)(const OpDef&, const VarNodeArray&)) { return apply_on_var_node([=](const OpDef& opdef, const VarNodeArray& inputs) { return To()(f(opdef, inputs)); }); } }; } // namespace imperative } // namespace mgb #define OP_TRAIT_REG(name, ...) \ static OpTraitRegistry __##name##_global_registry__ = \ OpTraitRegistry::insert<__VA_ARGS__>(#name) // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}