提交 120e719e 编写于 作者: M Megvii Engine Team

refactor(imperative): refactor OpTrait methods and registration

remove unused op methods and enable OpTrait registered from seperated files, also fix cond_take's infer_output_attrs

GitOrigin-RevId: 134c8215ce1b8efaf6c14da00ab588bb336163ee
上级 aea829c9
...@@ -36,13 +36,6 @@ SmallVector<TensorPtr> OpDef::apply_on_physical_tensor( ...@@ -36,13 +36,6 @@ SmallVector<TensorPtr> OpDef::apply_on_physical_tensor(
return def.trait()->apply_on_physical_tensor(def, inputs); return def.trait()->apply_on_physical_tensor(def, inputs);
} }
void OpDef::exec(
const OpDef& def,
const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs) {
def.trait()->exec(def, inputs, outputs);
}
cg::OperatorNodeBase* OpDef::apply_on_var_node( cg::OperatorNodeBase* OpDef::apply_on_var_node(
const OpDef& def, const OpDef& def,
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
...@@ -55,12 +48,6 @@ SmallVector<LogicalTensorDesc> OpDef::infer_output_attrs_fallible( ...@@ -55,12 +48,6 @@ SmallVector<LogicalTensorDesc> OpDef::infer_output_attrs_fallible(
return def.trait()->infer_output_attrs_fallible(def, inputs); return def.trait()->infer_output_attrs_fallible(def, inputs);
} }
SmallVector<LogicalTensorDesc> OpDef::infer_output_attrs(
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
return def.trait()->infer_output_attrs(def, inputs);
}
BackwardGraphResult OpDef::make_backward_graph( BackwardGraphResult OpDef::make_backward_graph(
const OpDef& def, const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs, const SmallVector<LogicalTensorDesc>& inputs,
......
...@@ -34,16 +34,6 @@ StaticData& static_data() { ...@@ -34,16 +34,6 @@ StaticData& static_data() {
return data; return data;
} }
template<typename T>
struct __not_implementation__;
template<typename RType, typename ...Args>
struct __not_implementation__<RType(Args...)> {
static RType raise(Args ...) {
mgb_throw(MegBrainError, "Not Implemented");
}
};
} // detail } // detail
OpTrait::OpTrait(const char* name_): name(name_) {} OpTrait::OpTrait(const char* name_): name(name_) {}
...@@ -72,89 +62,45 @@ void OpTrait::for_each_trait(thin_function<void(OpTrait&)> visitor){ ...@@ -72,89 +62,45 @@ void OpTrait::for_each_trait(thin_function<void(OpTrait&)> visitor){
} }
} }
OpTraitRegistry& OpTraitRegistry::finalize() {
std::ostringstream msg;
#define CHECK(field) if (!trait->field) { \
msg << ", " #field; \
trait->field = \
detail::__not_implementation__<decltype(OpDef::field)>::raise; \
}
CHECK(make_from_op_node);
CHECK(apply_on_physical_tensor);
CHECK(exec);
CHECK(apply_on_var_node);
CHECK(infer_output_attrs_fallible);
CHECK(infer_output_attrs);
CHECK(make_backward_graph);
#undef CHECK
#ifdef DEBUG
if (msg.tellp() > 0) {
mgb_log_warn(
"%s op trait missing: %s",
trait->name ? trait->name : "(anonymous)",
msg.str().c_str() + 2 /* skip first ", " */);
}
#endif
return *this;
}
SmallVector<TensorPtr> fallback_apply_on_physical_tensor(
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
auto desc = OpDef::infer_output_attrs(def, inputs);
SmallVector<TensorPtr> outputs;
for (auto&& i : desc) {
outputs.push_back(Tensor::make(i.layout, i.comp_node));
}
OpDef::exec(def, inputs, outputs);
return outputs;
}
SmallVector<LogicalTensorDesc> fallback_infer_output_attrs(const OpDef& def,
const SmallVector<TensorPtr>& inputs){
SmallVector<LogicalTensorDesc> input_descs;
for(auto&& input: inputs){
input_descs.push_back({input->layout(), input->comp_node()});
}
return input_descs;
}
OpTraitRegistry& OpTraitRegistry::fallback() { OpTraitRegistry& OpTraitRegistry::fallback() {
if (!trait->exec && trait->apply_on_var_node) { if (trait->apply_on_var_node) {
trait->exec = proxy_graph_detail::exec; // fallback to proxy graph impl
} if (!trait->apply_on_physical_tensor) {
if (!trait->infer_output_attrs && trait->apply_on_var_node) { trait->apply_on_physical_tensor =
trait->infer_output_attrs = proxy_graph_detail::infer_output_attrs; proxy_graph_detail::apply_on_physical_tensor;
} }
if (!trait->infer_output_attrs_fallible && trait->apply_on_var_node) { if (!trait->infer_output_attrs_fallible) {
trait->infer_output_attrs_fallible = proxy_graph_detail::infer_output_attrs_fallible; trait->infer_output_attrs_fallible =
} proxy_graph_detail::infer_output_attrs_fallible;
if (!trait->make_backward_graph && trait->apply_on_var_node) { }
trait->make_backward_graph = proxy_graph_detail::make_backward_graph; if (!trait->make_backward_graph) {
} trait->make_backward_graph =
if (!trait->apply_on_physical_tensor && trait->infer_output_attrs && trait->exec) { proxy_graph_detail::make_backward_graph;
trait->apply_on_physical_tensor = fallback_apply_on_physical_tensor; }
}
if(!trait->infer_output_attrs && trait->infer_output_attrs_fallible){
trait->infer_output_attrs = fallback_infer_output_attrs;
} }
return *this; return *this;
} }
void OpTraitRegistry::do_insert(Typeinfo* type) { void OpTraitRegistry::do_insert(Typeinfo* type) {
auto&& sd = detail::static_data(); auto&& sd = detail::static_data();
mgb_assert(sd.type2reg.emplace(type, trait).second); auto ret = sd.type2reg.emplace(type, trait);
mgb_assert(ret.second || ret.first->second == trait,
"OpTrait for %s has already been registered", type->name);
} }
OpTraitRegistry OpTraitRegistry::do_insert(const char* name) { OpTraitRegistry OpTraitRegistry::do_insert(const char* name) {
auto&& sd = detail::static_data(); auto&& sd = detail::static_data();
if (name) { if (name) {
mgb_assert(!sd.name2reg.count(name), auto iter = sd.name2reg.find(name);
"duplicated opr trait %s", name); if (iter != sd.name2reg.end()) {
return {iter->second};
}
} }
sd.registries.emplace_back(name); sd.registries.emplace_back(name);
auto ret = &sd.registries.back(); auto ret = &sd.registries.back();
sd.name2reg.emplace(name, ret); if (name) {
sd.name2reg.emplace(name, ret);
}
return {ret}; return {ret};
} }
......
...@@ -16,29 +16,39 @@ ...@@ -16,29 +16,39 @@
namespace mgb { namespace mgb {
namespace imperative { namespace imperative {
using OpDefMaker = thin_function< namespace detail {
template<typename Signature>
struct OpMeth;
template<typename RType, typename ...Args>
struct OpMeth<RType(Args...)>: public thin_function<RType(Args...)> {
using Base = thin_function<RType(Args...)>;
using Base::Base;
RType operator()(Args... args) const {
if (!this->Base::operator bool()) {
mgb_throw(MegBrainError, "Not Implemented");
}
return this->Base::operator ()(args...);
}
};
} // detail
using OpDefMaker = detail::OpMeth<
decltype(OpDef::make_from_op_node)>; decltype(OpDef::make_from_op_node)>;
using ApplyOnPhysicalTensor = thin_function< using ApplyOnPhysicalTensor = detail::OpMeth<
decltype(OpDef::apply_on_physical_tensor)>; decltype(OpDef::apply_on_physical_tensor)>;
using PhysicalTensorExecutor = thin_function< using ApplyOnVarNode = detail::OpMeth<
decltype(OpDef::exec)>;
using ApplyOnVarNode = thin_function<
decltype(OpDef::apply_on_var_node)>; decltype(OpDef::apply_on_var_node)>;
using InferOutputAttrsFallible = thin_function< using InferOutputAttrsFallible = detail::OpMeth<
decltype(OpDef::infer_output_attrs_fallible)>; decltype(OpDef::infer_output_attrs_fallible)>;
using InferOutputAttrs = thin_function< using GradMaker = detail::OpMeth<
decltype(OpDef::infer_output_attrs)>;
using GradMaker = thin_function<
decltype(OpDef::make_backward_graph)>; decltype(OpDef::make_backward_graph)>;
struct OpTrait { struct OpTrait {
const char* name; const char* name;
OpDefMaker make_from_op_node; OpDefMaker make_from_op_node;
ApplyOnPhysicalTensor apply_on_physical_tensor; ApplyOnPhysicalTensor apply_on_physical_tensor;
PhysicalTensorExecutor exec;
ApplyOnVarNode apply_on_var_node; ApplyOnVarNode apply_on_var_node;
InferOutputAttrsFallible infer_output_attrs_fallible; InferOutputAttrsFallible infer_output_attrs_fallible;
InferOutputAttrs infer_output_attrs;
GradMaker make_backward_graph; GradMaker make_backward_graph;
OpTrait(const char* name); OpTrait(const char* name);
static OpTrait* find_by_name(const char* name); static OpTrait* find_by_name(const char* name);
...@@ -46,38 +56,25 @@ struct OpTrait { ...@@ -46,38 +56,25 @@ struct OpTrait {
static void for_each_trait(thin_function<void(OpTrait&)> visitor); static void for_each_trait(thin_function<void(OpTrait&)> visitor);
}; };
#define FOR_EACH_OP_METH(cb) \
cb(make_from_op_node) \
cb(apply_on_physical_tensor) \
cb(apply_on_var_node) \
cb(infer_output_attrs_fallible) \
cb(make_backward_graph)
struct OpTraitRegistry { struct OpTraitRegistry {
OpTrait* trait; OpTrait* trait;
OpTraitRegistry& make_from_op_node(OpDefMaker f) { #define DECL(meth) \
trait->make_from_op_node = f; OpTraitRegistry& meth(decltype(OpTrait::meth) f) { \
return *this; mgb_assert(!trait->meth, "op %s has duplicate method %s", trait->name, #meth); \
} trait->meth = f; \
OpTraitRegistry& apply_on_physical_tensor(ApplyOnPhysicalTensor f) { return *this; \
trait->apply_on_physical_tensor = f;
return *this;
}
OpTraitRegistry& physical_tensor_executor(PhysicalTensorExecutor f) {
trait->exec = f;
return *this;
}
OpTraitRegistry& apply_on_var_node(ApplyOnVarNode f) {
trait->apply_on_var_node = f;
return *this;
}
OpTraitRegistry& infer_output_attrs_fallible(InferOutputAttrsFallible f) {
trait->infer_output_attrs_fallible = f;
return *this;
}
OpTraitRegistry& infer_output_attrs(InferOutputAttrs f) {
trait->infer_output_attrs = f;
return *this;
}
OpTraitRegistry& grad_maker(GradMaker f) {
trait->make_backward_graph = f;
return *this;
} }
FOR_EACH_OP_METH(DECL)
#undef DECL
OpTraitRegistry& fallback(); OpTraitRegistry& fallback();
OpTraitRegistry& finalize();
template<typename T> template<typename T>
void insert() { void insert() {
...@@ -102,20 +99,11 @@ struct OpTraitRegistry { ...@@ -102,20 +99,11 @@ struct OpTraitRegistry {
static OpTraitRegistry do_insert(const char* name); static OpTraitRegistry do_insert(const char* name);
}; };
namespace detail {
struct _RegisterHelper {
OpTraitRegistry registry;
~_RegisterHelper() {
registry.finalize();
}
};
} // namespace detail
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
#define OP_TRAIT_REG(name, ...) \ #define OP_TRAIT_REG(name, ...) \
static OpTraitRegistry __##name##_global_registry__ = \ static OpTraitRegistry __##name##_global_registry__ = \
detail::_RegisterHelper{OpTraitRegistry::insert<__VA_ARGS__>(#name)}.registry OpTraitRegistry::insert<__VA_ARGS__>(#name)
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -110,20 +110,20 @@ SmallVector<TensorPtr> apply_on_physical_tensor( ...@@ -110,20 +110,20 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
return out; return out;
} }
SmallVector<LogicalTensorDesc> infer_output_attrs( SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
const OpDef& def, const OpDef& def,
const SmallVector<TensorPtr>& inputs) { const SmallVector<LogicalTensorDesc>& inputs) {
SmallVector<LogicalTensorDesc> out; auto cn = inputs[0].comp_node;
for (size_t i = 0; i < 2; ++ i) { return {
out.push_back({TensorLayout(), inputs[0]->comp_node()}); {TensorLayout(inputs[0].layout.dtype), cn},
} {TensorLayout(dtype::Int32()), cn}
return out; };
} }
OP_TRAIT_REG(CondTake, CondTake, opr::CondTake) OP_TRAIT_REG(CondTake, CondTake, opr::CondTake)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor) .apply_on_physical_tensor(apply_on_physical_tensor)
.infer_output_attrs(infer_output_attrs) .infer_output_attrs_fallible(infer_output_attrs_fallible)
.fallback(); .fallback();
} // namespace } // namespace
......
...@@ -28,10 +28,12 @@ namespace { ...@@ -28,10 +28,12 @@ namespace {
CompNode::UnorderedSet collect_comp_nodes( CompNode::UnorderedSet collect_comp_nodes(
const OpDef& def, const SmallVector<TensorPtr>& inputs) { const OpDef& def, const SmallVector<TensorPtr>& inputs) {
CompNode::UnorderedSet comp_nodes; CompNode::UnorderedSet comp_nodes;
for (auto&& input : inputs) { SmallVector<LogicalTensorDesc> descs;
comp_nodes.insert(input->comp_node()); for (auto&& i : inputs) {
comp_nodes.insert(i->comp_node());
descs.push_back({i->layout(), i->comp_node(), {}});
} }
for (auto&& output_attr : def.infer_output_attrs(def, inputs)) { for (auto&& output_attr : def.infer_output_attrs_fallible(def, descs)) {
comp_nodes.insert(output_attr.comp_node); comp_nodes.insert(output_attr.comp_node);
} }
return comp_nodes; return comp_nodes;
......
...@@ -31,7 +31,6 @@ SmallVector<Tensor*> to_raw_ptr_array( ...@@ -31,7 +31,6 @@ SmallVector<Tensor*> to_raw_ptr_array(
} }
return ret; return ret;
} }
} // anonymous namespace
void exec(const OpDef& def, void exec(const OpDef& def,
const SmallVector<TensorPtr>& inputs_, const SmallVector<TensorPtr>& inputs_,
...@@ -61,11 +60,25 @@ void exec(const OpDef& def, ...@@ -61,11 +60,25 @@ void exec(const OpDef& def,
} }
} }
SmallVector<LogicalTensorDesc> infer_output_attrs(const OpDef& def, SmallVector<LogicalTensorDesc>
infer_output_attrs(const OpDef& def,
const SmallVector<TensorPtr>& inputs) { const SmallVector<TensorPtr>& inputs) {
auto&& graph = ProxyGraph::get_default_graph(); auto&& graph = ProxyGraph::get_default_graph();
return graph->infer_output_attrs(def, to_raw_ptr_array(inputs)); return graph->infer_output_attrs(def, to_raw_ptr_array(inputs));
} }
} // anonymous namespace
SmallVector<TensorPtr>
apply_on_physical_tensor(const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
auto desc = infer_output_attrs(def, inputs);
SmallVector<TensorPtr> outputs;
for (auto&& i : desc) {
outputs.push_back(Tensor::make(i.layout, i.comp_node));
}
exec(def, inputs, outputs);
return outputs;
}
SmallVector<LogicalTensorDesc> SmallVector<LogicalTensorDesc>
infer_output_attrs_fallible(const OpDef& def, infer_output_attrs_fallible(const OpDef& def,
......
...@@ -17,11 +17,8 @@ namespace mgb { ...@@ -17,11 +17,8 @@ namespace mgb {
namespace imperative { namespace imperative {
namespace proxy_graph_detail { namespace proxy_graph_detail {
void exec(const OpDef& def, SmallVector<TensorPtr>
const SmallVector<TensorPtr>& inputs_, apply_on_physical_tensor(const OpDef& def,
const SmallVector<TensorPtr>& outputs_);
SmallVector<LogicalTensorDesc> infer_output_attrs(const OpDef& def,
const SmallVector<TensorPtr>& inputs); const SmallVector<TensorPtr>& inputs);
SmallVector<LogicalTensorDesc> SmallVector<LogicalTensorDesc>
......
...@@ -40,11 +40,6 @@ public: ...@@ -40,11 +40,6 @@ public:
const OpDef& def, const OpDef& def,
const SmallVector<TensorPtr>& inputs); const SmallVector<TensorPtr>& inputs);
static void exec(
const OpDef& def,
const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs);
static cg::OperatorNodeBase* apply_on_var_node( static cg::OperatorNodeBase* apply_on_var_node(
const OpDef& def, const OpDef& def,
const VarNodeArray& inputs); const VarNodeArray& inputs);
...@@ -53,10 +48,6 @@ public: ...@@ -53,10 +48,6 @@ public:
const OpDef& def, const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs); const SmallVector<LogicalTensorDesc>& inputs);
static SmallVector<LogicalTensorDesc> infer_output_attrs(
const OpDef& def,
const SmallVector<TensorPtr>& inputs);
static BackwardGraphResult make_backward_graph( static BackwardGraphResult make_backward_graph(
const OpDef& def, const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs, const SmallVector<LogicalTensorDesc>& inputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册