From 120e719e177c5608142edad0fadc5361df5b1a1b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 4 Nov 2020 14:47:31 +0800 Subject: [PATCH] refactor(imperative): refactor OpTrait methods and registration remove unused op methods and enable OpTrait registered from seperated files, also fix cond_take's infer_output_attrs GitOrigin-RevId: 134c8215ce1b8efaf6c14da00ab588bb336163ee --- imperative/src/impl/op_def.cpp | 13 --- imperative/src/impl/op_trait.cpp | 102 +++++------------- imperative/src/impl/op_trait.h | 86 +++++++-------- imperative/src/impl/ops/cond_take.cpp | 16 +-- imperative/src/impl/profiler.cpp | 8 +- imperative/src/impl/proxy_graph_detail.cpp | 17 ++- imperative/src/impl/proxy_graph_detail.h | 7 +- .../src/include/megbrain/imperative/op_def.h | 9 -- 8 files changed, 91 insertions(+), 167 deletions(-) diff --git a/imperative/src/impl/op_def.cpp b/imperative/src/impl/op_def.cpp index d3aaa8b20..770aab4f1 100644 --- a/imperative/src/impl/op_def.cpp +++ b/imperative/src/impl/op_def.cpp @@ -36,13 +36,6 @@ SmallVector OpDef::apply_on_physical_tensor( return def.trait()->apply_on_physical_tensor(def, inputs); } -void OpDef::exec( - const OpDef& def, - const SmallVector& inputs, - const SmallVector& outputs) { - def.trait()->exec(def, inputs, outputs); -} - cg::OperatorNodeBase* OpDef::apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { @@ -55,12 +48,6 @@ SmallVector OpDef::infer_output_attrs_fallible( return def.trait()->infer_output_attrs_fallible(def, inputs); } -SmallVector OpDef::infer_output_attrs( - const OpDef& def, - const SmallVector& inputs) { - return def.trait()->infer_output_attrs(def, inputs); -} - BackwardGraphResult OpDef::make_backward_graph( const OpDef& def, const SmallVector& inputs, diff --git a/imperative/src/impl/op_trait.cpp b/imperative/src/impl/op_trait.cpp index f36043ca9..a6a81ce17 100644 --- a/imperative/src/impl/op_trait.cpp +++ b/imperative/src/impl/op_trait.cpp @@ -34,16 +34,6 @@ StaticData& static_data() { return data; } -template -struct __not_implementation__; - -template -struct __not_implementation__ { - static RType raise(Args ...) { - mgb_throw(MegBrainError, "Not Implemented"); - } -}; - } // detail OpTrait::OpTrait(const char* name_): name(name_) {} @@ -72,89 +62,45 @@ void OpTrait::for_each_trait(thin_function visitor){ } } -OpTraitRegistry& OpTraitRegistry::finalize() { - std::ostringstream msg; - #define CHECK(field) if (!trait->field) { \ - msg << ", " #field; \ - trait->field = \ - detail::__not_implementation__::raise; \ - } - CHECK(make_from_op_node); - CHECK(apply_on_physical_tensor); - CHECK(exec); - CHECK(apply_on_var_node); - CHECK(infer_output_attrs_fallible); - CHECK(infer_output_attrs); - CHECK(make_backward_graph); - #undef CHECK - #ifdef DEBUG - if (msg.tellp() > 0) { - mgb_log_warn( - "%s op trait missing: %s", - trait->name ? trait->name : "(anonymous)", - msg.str().c_str() + 2 /* skip first ", " */); - } - #endif - return *this; -} - -SmallVector fallback_apply_on_physical_tensor( - const OpDef& def, - const SmallVector& inputs) { - auto desc = OpDef::infer_output_attrs(def, inputs); - SmallVector outputs; - for (auto&& i : desc) { - outputs.push_back(Tensor::make(i.layout, i.comp_node)); - } - OpDef::exec(def, inputs, outputs); - return outputs; -} - -SmallVector fallback_infer_output_attrs(const OpDef& def, - const SmallVector& inputs){ - SmallVector input_descs; - for(auto&& input: inputs){ - input_descs.push_back({input->layout(), input->comp_node()}); - } - return input_descs; -} - OpTraitRegistry& OpTraitRegistry::fallback() { - if (!trait->exec && trait->apply_on_var_node) { - trait->exec = proxy_graph_detail::exec; - } - if (!trait->infer_output_attrs && trait->apply_on_var_node) { - trait->infer_output_attrs = proxy_graph_detail::infer_output_attrs; - } - if (!trait->infer_output_attrs_fallible && trait->apply_on_var_node) { - trait->infer_output_attrs_fallible = proxy_graph_detail::infer_output_attrs_fallible; - } - if (!trait->make_backward_graph && trait->apply_on_var_node) { - trait->make_backward_graph = proxy_graph_detail::make_backward_graph; - } - if (!trait->apply_on_physical_tensor && trait->infer_output_attrs && trait->exec) { - trait->apply_on_physical_tensor = fallback_apply_on_physical_tensor; - } - if(!trait->infer_output_attrs && trait->infer_output_attrs_fallible){ - trait->infer_output_attrs = fallback_infer_output_attrs; + if (trait->apply_on_var_node) { + // fallback to proxy graph impl + if (!trait->apply_on_physical_tensor) { + trait->apply_on_physical_tensor = + proxy_graph_detail::apply_on_physical_tensor; + } + if (!trait->infer_output_attrs_fallible) { + trait->infer_output_attrs_fallible = + proxy_graph_detail::infer_output_attrs_fallible; + } + if (!trait->make_backward_graph) { + trait->make_backward_graph = + proxy_graph_detail::make_backward_graph; + } } return *this; } void OpTraitRegistry::do_insert(Typeinfo* type) { auto&& sd = detail::static_data(); - mgb_assert(sd.type2reg.emplace(type, trait).second); + auto ret = sd.type2reg.emplace(type, trait); + mgb_assert(ret.second || ret.first->second == trait, + "OpTrait for %s has already been registered", type->name); } OpTraitRegistry OpTraitRegistry::do_insert(const char* name) { auto&& sd = detail::static_data(); if (name) { - mgb_assert(!sd.name2reg.count(name), - "duplicated opr trait %s", name); + auto iter = sd.name2reg.find(name); + if (iter != sd.name2reg.end()) { + return {iter->second}; + } } sd.registries.emplace_back(name); auto ret = &sd.registries.back(); - sd.name2reg.emplace(name, ret); + if (name) { + sd.name2reg.emplace(name, ret); + } return {ret}; } diff --git a/imperative/src/impl/op_trait.h b/imperative/src/impl/op_trait.h index 3122494c9..a539c084a 100644 --- a/imperative/src/impl/op_trait.h +++ b/imperative/src/impl/op_trait.h @@ -16,29 +16,39 @@ namespace mgb { namespace imperative { -using OpDefMaker = thin_function< +namespace detail { +template +struct OpMeth; +template +struct OpMeth: public thin_function { + using Base = thin_function; + using Base::Base; + RType operator()(Args... args) const { + if (!this->Base::operator bool()) { + mgb_throw(MegBrainError, "Not Implemented"); + } + return this->Base::operator ()(args...); + } +}; +} // detail + +using OpDefMaker = detail::OpMeth< decltype(OpDef::make_from_op_node)>; -using ApplyOnPhysicalTensor = thin_function< +using ApplyOnPhysicalTensor = detail::OpMeth< decltype(OpDef::apply_on_physical_tensor)>; -using PhysicalTensorExecutor = thin_function< - decltype(OpDef::exec)>; -using ApplyOnVarNode = thin_function< +using ApplyOnVarNode = detail::OpMeth< decltype(OpDef::apply_on_var_node)>; -using InferOutputAttrsFallible = thin_function< +using InferOutputAttrsFallible = detail::OpMeth< decltype(OpDef::infer_output_attrs_fallible)>; -using InferOutputAttrs = thin_function< - decltype(OpDef::infer_output_attrs)>; -using GradMaker = thin_function< +using GradMaker = detail::OpMeth< decltype(OpDef::make_backward_graph)>; struct OpTrait { const char* name; OpDefMaker make_from_op_node; ApplyOnPhysicalTensor apply_on_physical_tensor; - PhysicalTensorExecutor exec; ApplyOnVarNode apply_on_var_node; InferOutputAttrsFallible infer_output_attrs_fallible; - InferOutputAttrs infer_output_attrs; GradMaker make_backward_graph; OpTrait(const char* name); static OpTrait* find_by_name(const char* name); @@ -46,38 +56,25 @@ struct OpTrait { static void for_each_trait(thin_function visitor); }; +#define FOR_EACH_OP_METH(cb) \ + cb(make_from_op_node) \ + cb(apply_on_physical_tensor) \ + cb(apply_on_var_node) \ + cb(infer_output_attrs_fallible) \ + cb(make_backward_graph) + struct OpTraitRegistry { OpTrait* trait; - OpTraitRegistry& make_from_op_node(OpDefMaker f) { - trait->make_from_op_node = f; - return *this; - } - OpTraitRegistry& apply_on_physical_tensor(ApplyOnPhysicalTensor f) { - trait->apply_on_physical_tensor = f; - return *this; - } - OpTraitRegistry& physical_tensor_executor(PhysicalTensorExecutor f) { - trait->exec = f; - return *this; - } - OpTraitRegistry& apply_on_var_node(ApplyOnVarNode f) { - trait->apply_on_var_node = f; - return *this; - } - OpTraitRegistry& infer_output_attrs_fallible(InferOutputAttrsFallible f) { - trait->infer_output_attrs_fallible = f; - return *this; - } - OpTraitRegistry& infer_output_attrs(InferOutputAttrs f) { - trait->infer_output_attrs = f; - return *this; - } - OpTraitRegistry& grad_maker(GradMaker f) { - trait->make_backward_graph = f; - return *this; +#define DECL(meth) \ + OpTraitRegistry& meth(decltype(OpTrait::meth) f) { \ + mgb_assert(!trait->meth, "op %s has duplicate method %s", trait->name, #meth); \ + trait->meth = f; \ + return *this; \ } + FOR_EACH_OP_METH(DECL) +#undef DECL + OpTraitRegistry& fallback(); - OpTraitRegistry& finalize(); template void insert() { @@ -102,20 +99,11 @@ struct OpTraitRegistry { static OpTraitRegistry do_insert(const char* name); }; -namespace detail { -struct _RegisterHelper { - OpTraitRegistry registry; - ~_RegisterHelper() { - registry.finalize(); - } -}; -} // namespace detail - } // namespace imperative } // namespace mgb #define OP_TRAIT_REG(name, ...) \ static OpTraitRegistry __##name##_global_registry__ = \ - detail::_RegisterHelper{OpTraitRegistry::insert<__VA_ARGS__>(#name)}.registry + OpTraitRegistry::insert<__VA_ARGS__>(#name) // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/imperative/src/impl/ops/cond_take.cpp b/imperative/src/impl/ops/cond_take.cpp index 74b3fc51b..133a9a933 100644 --- a/imperative/src/impl/ops/cond_take.cpp +++ b/imperative/src/impl/ops/cond_take.cpp @@ -110,20 +110,20 @@ SmallVector apply_on_physical_tensor( return out; } -SmallVector infer_output_attrs( +SmallVector infer_output_attrs_fallible( const OpDef& def, - const SmallVector& inputs) { - SmallVector out; - for (size_t i = 0; i < 2; ++ i) { - out.push_back({TensorLayout(), inputs[0]->comp_node()}); - } - return out; + const SmallVector& inputs) { + auto cn = inputs[0].comp_node; + return { + {TensorLayout(inputs[0].layout.dtype), cn}, + {TensorLayout(dtype::Int32()), cn} + }; } OP_TRAIT_REG(CondTake, CondTake, opr::CondTake) .apply_on_var_node(apply_on_var_node) .apply_on_physical_tensor(apply_on_physical_tensor) - .infer_output_attrs(infer_output_attrs) + .infer_output_attrs_fallible(infer_output_attrs_fallible) .fallback(); } // namespace diff --git a/imperative/src/impl/profiler.cpp b/imperative/src/impl/profiler.cpp index ccb98081d..ebd28c7fc 100644 --- a/imperative/src/impl/profiler.cpp +++ b/imperative/src/impl/profiler.cpp @@ -28,10 +28,12 @@ namespace { CompNode::UnorderedSet collect_comp_nodes( const OpDef& def, const SmallVector& inputs) { CompNode::UnorderedSet comp_nodes; - for (auto&& input : inputs) { - comp_nodes.insert(input->comp_node()); + SmallVector descs; + for (auto&& i : inputs) { + comp_nodes.insert(i->comp_node()); + descs.push_back({i->layout(), i->comp_node(), {}}); } - for (auto&& output_attr : def.infer_output_attrs(def, inputs)) { + for (auto&& output_attr : def.infer_output_attrs_fallible(def, descs)) { comp_nodes.insert(output_attr.comp_node); } return comp_nodes; diff --git a/imperative/src/impl/proxy_graph_detail.cpp b/imperative/src/impl/proxy_graph_detail.cpp index 6415e49a4..42a02a82d 100644 --- a/imperative/src/impl/proxy_graph_detail.cpp +++ b/imperative/src/impl/proxy_graph_detail.cpp @@ -31,7 +31,6 @@ SmallVector to_raw_ptr_array( } return ret; } -} // anonymous namespace void exec(const OpDef& def, const SmallVector& inputs_, @@ -61,11 +60,25 @@ void exec(const OpDef& def, } } -SmallVector infer_output_attrs(const OpDef& def, +SmallVector +infer_output_attrs(const OpDef& def, const SmallVector& inputs) { auto&& graph = ProxyGraph::get_default_graph(); return graph->infer_output_attrs(def, to_raw_ptr_array(inputs)); } +} // anonymous namespace + +SmallVector +apply_on_physical_tensor(const OpDef& def, + const SmallVector& inputs) { + auto desc = infer_output_attrs(def, inputs); + SmallVector outputs; + for (auto&& i : desc) { + outputs.push_back(Tensor::make(i.layout, i.comp_node)); + } + exec(def, inputs, outputs); + return outputs; +} SmallVector infer_output_attrs_fallible(const OpDef& def, diff --git a/imperative/src/impl/proxy_graph_detail.h b/imperative/src/impl/proxy_graph_detail.h index 3d0601e98..e148b0bba 100644 --- a/imperative/src/impl/proxy_graph_detail.h +++ b/imperative/src/impl/proxy_graph_detail.h @@ -17,11 +17,8 @@ namespace mgb { namespace imperative { namespace proxy_graph_detail { -void exec(const OpDef& def, - const SmallVector& inputs_, - const SmallVector& outputs_); - -SmallVector infer_output_attrs(const OpDef& def, +SmallVector +apply_on_physical_tensor(const OpDef& def, const SmallVector& inputs); SmallVector diff --git a/imperative/src/include/megbrain/imperative/op_def.h b/imperative/src/include/megbrain/imperative/op_def.h index cba1fad43..0aff1d53e 100644 --- a/imperative/src/include/megbrain/imperative/op_def.h +++ b/imperative/src/include/megbrain/imperative/op_def.h @@ -40,11 +40,6 @@ public: const OpDef& def, const SmallVector& inputs); - static void exec( - const OpDef& def, - const SmallVector& inputs, - const SmallVector& outputs); - static cg::OperatorNodeBase* apply_on_var_node( const OpDef& def, const VarNodeArray& inputs); @@ -53,10 +48,6 @@ public: const OpDef& def, const SmallVector& inputs); - static SmallVector infer_output_attrs( - const OpDef& def, - const SmallVector& inputs); - static BackwardGraphResult make_backward_graph( const OpDef& def, const SmallVector& inputs, -- GitLab