提交 120e719e 编写于 作者: M Megvii Engine Team

refactor(imperative): refactor OpTrait methods and registration

remove unused op methods and enable OpTrait registered from seperated files, also fix cond_take's infer_output_attrs

GitOrigin-RevId: 134c8215ce1b8efaf6c14da00ab588bb336163ee
上级 aea829c9
......@@ -36,13 +36,6 @@ SmallVector<TensorPtr> OpDef::apply_on_physical_tensor(
return def.trait()->apply_on_physical_tensor(def, inputs);
}
void OpDef::exec(
const OpDef& def,
const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs) {
def.trait()->exec(def, inputs, outputs);
}
cg::OperatorNodeBase* OpDef::apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
......@@ -55,12 +48,6 @@ SmallVector<LogicalTensorDesc> OpDef::infer_output_attrs_fallible(
return def.trait()->infer_output_attrs_fallible(def, inputs);
}
SmallVector<LogicalTensorDesc> OpDef::infer_output_attrs(
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
return def.trait()->infer_output_attrs(def, inputs);
}
BackwardGraphResult OpDef::make_backward_graph(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs,
......
......@@ -34,16 +34,6 @@ StaticData& static_data() {
return data;
}
template<typename T>
struct __not_implementation__;
template<typename RType, typename ...Args>
struct __not_implementation__<RType(Args...)> {
static RType raise(Args ...) {
mgb_throw(MegBrainError, "Not Implemented");
}
};
} // detail
OpTrait::OpTrait(const char* name_): name(name_) {}
......@@ -72,89 +62,45 @@ void OpTrait::for_each_trait(thin_function<void(OpTrait&)> visitor){
}
}
OpTraitRegistry& OpTraitRegistry::finalize() {
std::ostringstream msg;
#define CHECK(field) if (!trait->field) { \
msg << ", " #field; \
trait->field = \
detail::__not_implementation__<decltype(OpDef::field)>::raise; \
}
CHECK(make_from_op_node);
CHECK(apply_on_physical_tensor);
CHECK(exec);
CHECK(apply_on_var_node);
CHECK(infer_output_attrs_fallible);
CHECK(infer_output_attrs);
CHECK(make_backward_graph);
#undef CHECK
#ifdef DEBUG
if (msg.tellp() > 0) {
mgb_log_warn(
"%s op trait missing: %s",
trait->name ? trait->name : "(anonymous)",
msg.str().c_str() + 2 /* skip first ", " */);
}
#endif
return *this;
}
SmallVector<TensorPtr> fallback_apply_on_physical_tensor(
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
auto desc = OpDef::infer_output_attrs(def, inputs);
SmallVector<TensorPtr> outputs;
for (auto&& i : desc) {
outputs.push_back(Tensor::make(i.layout, i.comp_node));
}
OpDef::exec(def, inputs, outputs);
return outputs;
}
SmallVector<LogicalTensorDesc> fallback_infer_output_attrs(const OpDef& def,
const SmallVector<TensorPtr>& inputs){
SmallVector<LogicalTensorDesc> input_descs;
for(auto&& input: inputs){
input_descs.push_back({input->layout(), input->comp_node()});
}
return input_descs;
}
OpTraitRegistry& OpTraitRegistry::fallback() {
if (!trait->exec && trait->apply_on_var_node) {
trait->exec = proxy_graph_detail::exec;
if (trait->apply_on_var_node) {
// fallback to proxy graph impl
if (!trait->apply_on_physical_tensor) {
trait->apply_on_physical_tensor =
proxy_graph_detail::apply_on_physical_tensor;
}
if (!trait->infer_output_attrs && trait->apply_on_var_node) {
trait->infer_output_attrs = proxy_graph_detail::infer_output_attrs;
if (!trait->infer_output_attrs_fallible) {
trait->infer_output_attrs_fallible =
proxy_graph_detail::infer_output_attrs_fallible;
}
if (!trait->infer_output_attrs_fallible && trait->apply_on_var_node) {
trait->infer_output_attrs_fallible = proxy_graph_detail::infer_output_attrs_fallible;
if (!trait->make_backward_graph) {
trait->make_backward_graph =
proxy_graph_detail::make_backward_graph;
}
if (!trait->make_backward_graph && trait->apply_on_var_node) {
trait->make_backward_graph = proxy_graph_detail::make_backward_graph;
}
if (!trait->apply_on_physical_tensor && trait->infer_output_attrs && trait->exec) {
trait->apply_on_physical_tensor = fallback_apply_on_physical_tensor;
}
if(!trait->infer_output_attrs && trait->infer_output_attrs_fallible){
trait->infer_output_attrs = fallback_infer_output_attrs;
}
return *this;
}
void OpTraitRegistry::do_insert(Typeinfo* type) {
auto&& sd = detail::static_data();
mgb_assert(sd.type2reg.emplace(type, trait).second);
auto ret = sd.type2reg.emplace(type, trait);
mgb_assert(ret.second || ret.first->second == trait,
"OpTrait for %s has already been registered", type->name);
}
OpTraitRegistry OpTraitRegistry::do_insert(const char* name) {
auto&& sd = detail::static_data();
if (name) {
mgb_assert(!sd.name2reg.count(name),
"duplicated opr trait %s", name);
auto iter = sd.name2reg.find(name);
if (iter != sd.name2reg.end()) {
return {iter->second};
}
}
sd.registries.emplace_back(name);
auto ret = &sd.registries.back();
if (name) {
sd.name2reg.emplace(name, ret);
}
return {ret};
}
......
......@@ -16,29 +16,39 @@
namespace mgb {
namespace imperative {
using OpDefMaker = thin_function<
namespace detail {
template<typename Signature>
struct OpMeth;
template<typename RType, typename ...Args>
struct OpMeth<RType(Args...)>: public thin_function<RType(Args...)> {
using Base = thin_function<RType(Args...)>;
using Base::Base;
RType operator()(Args... args) const {
if (!this->Base::operator bool()) {
mgb_throw(MegBrainError, "Not Implemented");
}
return this->Base::operator ()(args...);
}
};
} // detail
using OpDefMaker = detail::OpMeth<
decltype(OpDef::make_from_op_node)>;
using ApplyOnPhysicalTensor = thin_function<
using ApplyOnPhysicalTensor = detail::OpMeth<
decltype(OpDef::apply_on_physical_tensor)>;
using PhysicalTensorExecutor = thin_function<
decltype(OpDef::exec)>;
using ApplyOnVarNode = thin_function<
using ApplyOnVarNode = detail::OpMeth<
decltype(OpDef::apply_on_var_node)>;
using InferOutputAttrsFallible = thin_function<
using InferOutputAttrsFallible = detail::OpMeth<
decltype(OpDef::infer_output_attrs_fallible)>;
using InferOutputAttrs = thin_function<
decltype(OpDef::infer_output_attrs)>;
using GradMaker = thin_function<
using GradMaker = detail::OpMeth<
decltype(OpDef::make_backward_graph)>;
struct OpTrait {
const char* name;
OpDefMaker make_from_op_node;
ApplyOnPhysicalTensor apply_on_physical_tensor;
PhysicalTensorExecutor exec;
ApplyOnVarNode apply_on_var_node;
InferOutputAttrsFallible infer_output_attrs_fallible;
InferOutputAttrs infer_output_attrs;
GradMaker make_backward_graph;
OpTrait(const char* name);
static OpTrait* find_by_name(const char* name);
......@@ -46,38 +56,25 @@ struct OpTrait {
static void for_each_trait(thin_function<void(OpTrait&)> visitor);
};
#define FOR_EACH_OP_METH(cb) \
cb(make_from_op_node) \
cb(apply_on_physical_tensor) \
cb(apply_on_var_node) \
cb(infer_output_attrs_fallible) \
cb(make_backward_graph)
struct OpTraitRegistry {
OpTrait* trait;
OpTraitRegistry& make_from_op_node(OpDefMaker f) {
trait->make_from_op_node = f;
return *this;
}
OpTraitRegistry& apply_on_physical_tensor(ApplyOnPhysicalTensor f) {
trait->apply_on_physical_tensor = f;
return *this;
}
OpTraitRegistry& physical_tensor_executor(PhysicalTensorExecutor f) {
trait->exec = f;
return *this;
}
OpTraitRegistry& apply_on_var_node(ApplyOnVarNode f) {
trait->apply_on_var_node = f;
return *this;
}
OpTraitRegistry& infer_output_attrs_fallible(InferOutputAttrsFallible f) {
trait->infer_output_attrs_fallible = f;
return *this;
}
OpTraitRegistry& infer_output_attrs(InferOutputAttrs f) {
trait->infer_output_attrs = f;
return *this;
}
OpTraitRegistry& grad_maker(GradMaker f) {
trait->make_backward_graph = f;
return *this;
#define DECL(meth) \
OpTraitRegistry& meth(decltype(OpTrait::meth) f) { \
mgb_assert(!trait->meth, "op %s has duplicate method %s", trait->name, #meth); \
trait->meth = f; \
return *this; \
}
FOR_EACH_OP_METH(DECL)
#undef DECL
OpTraitRegistry& fallback();
OpTraitRegistry& finalize();
template<typename T>
void insert() {
......@@ -102,20 +99,11 @@ struct OpTraitRegistry {
static OpTraitRegistry do_insert(const char* name);
};
namespace detail {
struct _RegisterHelper {
OpTraitRegistry registry;
~_RegisterHelper() {
registry.finalize();
}
};
} // namespace detail
} // namespace imperative
} // namespace mgb
#define OP_TRAIT_REG(name, ...) \
static OpTraitRegistry __##name##_global_registry__ = \
detail::_RegisterHelper{OpTraitRegistry::insert<__VA_ARGS__>(#name)}.registry
OpTraitRegistry::insert<__VA_ARGS__>(#name)
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -110,20 +110,20 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
return out;
}
SmallVector<LogicalTensorDesc> infer_output_attrs(
SmallVector<LogicalTensorDesc> infer_output_attrs_fallible(
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
SmallVector<LogicalTensorDesc> out;
for (size_t i = 0; i < 2; ++ i) {
out.push_back({TensorLayout(), inputs[0]->comp_node()});
}
return out;
const SmallVector<LogicalTensorDesc>& inputs) {
auto cn = inputs[0].comp_node;
return {
{TensorLayout(inputs[0].layout.dtype), cn},
{TensorLayout(dtype::Int32()), cn}
};
}
OP_TRAIT_REG(CondTake, CondTake, opr::CondTake)
.apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.infer_output_attrs(infer_output_attrs)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.fallback();
} // namespace
......
......@@ -28,10 +28,12 @@ namespace {
CompNode::UnorderedSet collect_comp_nodes(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
CompNode::UnorderedSet comp_nodes;
for (auto&& input : inputs) {
comp_nodes.insert(input->comp_node());
SmallVector<LogicalTensorDesc> descs;
for (auto&& i : inputs) {
comp_nodes.insert(i->comp_node());
descs.push_back({i->layout(), i->comp_node(), {}});
}
for (auto&& output_attr : def.infer_output_attrs(def, inputs)) {
for (auto&& output_attr : def.infer_output_attrs_fallible(def, descs)) {
comp_nodes.insert(output_attr.comp_node);
}
return comp_nodes;
......
......@@ -31,7 +31,6 @@ SmallVector<Tensor*> to_raw_ptr_array(
}
return ret;
}
} // anonymous namespace
void exec(const OpDef& def,
const SmallVector<TensorPtr>& inputs_,
......@@ -61,11 +60,25 @@ void exec(const OpDef& def,
}
}
SmallVector<LogicalTensorDesc> infer_output_attrs(const OpDef& def,
SmallVector<LogicalTensorDesc>
infer_output_attrs(const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
auto&& graph = ProxyGraph::get_default_graph();
return graph->infer_output_attrs(def, to_raw_ptr_array(inputs));
}
} // anonymous namespace
SmallVector<TensorPtr>
apply_on_physical_tensor(const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
auto desc = infer_output_attrs(def, inputs);
SmallVector<TensorPtr> outputs;
for (auto&& i : desc) {
outputs.push_back(Tensor::make(i.layout, i.comp_node));
}
exec(def, inputs, outputs);
return outputs;
}
SmallVector<LogicalTensorDesc>
infer_output_attrs_fallible(const OpDef& def,
......
......@@ -17,11 +17,8 @@ namespace mgb {
namespace imperative {
namespace proxy_graph_detail {
void exec(const OpDef& def,
const SmallVector<TensorPtr>& inputs_,
const SmallVector<TensorPtr>& outputs_);
SmallVector<LogicalTensorDesc> infer_output_attrs(const OpDef& def,
SmallVector<TensorPtr>
apply_on_physical_tensor(const OpDef& def,
const SmallVector<TensorPtr>& inputs);
SmallVector<LogicalTensorDesc>
......
......@@ -40,11 +40,6 @@ public:
const OpDef& def,
const SmallVector<TensorPtr>& inputs);
static void exec(
const OpDef& def,
const SmallVector<TensorPtr>& inputs,
const SmallVector<TensorPtr>& outputs);
static cg::OperatorNodeBase* apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs);
......@@ -53,10 +48,6 @@ public:
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs);
static SmallVector<LogicalTensorDesc> infer_output_attrs(
const OpDef& def,
const SmallVector<TensorPtr>& inputs);
static BackwardGraphResult make_backward_graph(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册