提交 20e8541b 编写于 作者: M Megvii Engine Team

refactor(imperative): bind fallback impl on first op method call

GitOrigin-RevId: 82ae1e32052f274dea67ced95dc6ab694883425b
上级 18274e02
...@@ -38,6 +38,38 @@ StaticData& static_data() { ...@@ -38,6 +38,38 @@ StaticData& static_data() {
return data; return data;
} }
void OpMethFallback::impl(ApplyOnPhysicalTensor& func,
op_meth_tag::ApplyOnPhysicalTensor) {
func.Base::operator=(proxy_graph_detail::apply_on_physical_tensor);
}
void OpMethFallback::impl(Execute& func, op_meth_tag::Execute) {
func.Base::operator=(proxy_graph_detail::execute);
}
void OpMethFallback::impl(InferOutputMemDesc& func,
op_meth_tag::InferOutputMemDesc) {
func.Base::operator=(proxy_graph_detail::infer_output_mem_desc);
}
void OpMethFallback::impl(InferOutputAttrsFallible& func,
op_meth_tag::InferOutputAttrsFallible) {
func.Base::operator=(proxy_graph_detail::infer_output_attrs_fallible);
}
void OpMethFallback::impl(GradMaker& func, op_meth_tag::GradMaker) {
func.Base::operator=(proxy_graph_detail::make_backward_graph);
}
void OpMethFallback::impl(DecideDispatchMode& func,
op_meth_tag::DecideDispatchMode) {
static auto decide_dispatch_mode =
[](const OpDef&, const SmallVector<LogicalTensorDesc>&) {
return DispatchMode::KERNEL;
};
func.Base::operator=(decide_dispatch_mode);
}
void OpMethFallback::impl(MakeNameFunc& func, op_meth_tag::MakeNameFunc) {
static auto make_name = [](const OpDef& def) -> std::string {
return def.trait()->name;
};
func.Base::operator=(make_name);
}
} // detail } // detail
OpTrait::OpTrait(const char* name_): name(name_) {} OpTrait::OpTrait(const char* name_): name(name_) {}
...@@ -66,44 +98,17 @@ void OpTrait::for_each_trait(thin_function<void(OpTrait&)> visitor){ ...@@ -66,44 +98,17 @@ void OpTrait::for_each_trait(thin_function<void(OpTrait&)> visitor){
} }
} }
DispatchMode fallback_decide_dispatch_mode(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
return KERNEL;
}
OpTraitRegistry& OpTraitRegistry::fallback() { OpTraitRegistry& OpTraitRegistry::fallback() {
if (trait->apply_on_var_node) { if (trait->apply_on_var_node) {
// fallback to proxy graph impl // fallback to proxy graph impl
if (!trait->apply_on_physical_tensor) { trait->apply_on_physical_tensor.allow_fallback = true;
trait->apply_on_physical_tensor = trait->execute.allow_fallback = true;
proxy_graph_detail::apply_on_physical_tensor; trait->infer_output_mem_desc.allow_fallback = true;
} trait->infer_output_attrs_fallible.allow_fallback = true;
if (!trait->execute) { trait->make_backward_graph.allow_fallback = true;
trait->execute = proxy_graph_detail::execute;
}
if (!trait->infer_output_mem_desc) {
trait->infer_output_mem_desc =
proxy_graph_detail::infer_output_mem_desc;
}
if (!trait->infer_output_attrs_fallible) {
trait->infer_output_attrs_fallible =
proxy_graph_detail::infer_output_attrs_fallible;
}
if (!trait->make_backward_graph) {
trait->make_backward_graph =
proxy_graph_detail::make_backward_graph;
}
}
if (!trait->decide_dispatch_mode) {
trait->decide_dispatch_mode = fallback_decide_dispatch_mode;
}
if (!trait->make_name) {
static auto make_name = [](const OpDef& def) -> std::string {
return def.trait()->name;
};
trait->make_name = make_name;
} }
trait->decide_dispatch_mode.allow_fallback = true;
trait->make_name.allow_fallback = true;
return *this; return *this;
} }
......
...@@ -15,21 +15,10 @@ ...@@ -15,21 +15,10 @@
namespace mgb { namespace mgb {
namespace imperative { namespace imperative {
namespace detail { namespace detail {
template <typename Signature> template <typename Tag, typename Signature>
struct OpMeth; struct OpMeth;
template <typename RType, typename... Args>
struct OpMeth<RType(Args...)> : public thin_function<RType(Args...)> {
using Base = thin_function<RType(Args...)>;
using Base::Base;
RType operator()(Args... args) const {
if (!this->Base::operator bool()) {
mgb_throw(MegBrainError, "Not Implemented");
}
return this->Base::operator()(std::forward<Args>(args)...);
}
};
template<typename T> template<typename T>
struct ToVarNodeArray: std::false_type {}; struct ToVarNodeArray: std::false_type {};
template<> template<>
...@@ -58,28 +47,95 @@ struct ToVarNodeArray<cg::OperatorNodeBase*>: std::true_type { ...@@ -58,28 +47,95 @@ struct ToVarNodeArray<cg::OperatorNodeBase*>: std::true_type {
}; };
} // namespace detail } // namespace detail
using OpDefMaker = detail::OpMeth< // clang-format off
decltype(OpDef::make_from_op_node)>; #define OpMethType(TYPE, SIG) \
using DecideDispatchMode = detail::OpMeth< namespace detail::op_meth_tag { \
decltype(OpDef::decide_dispatch_mode)>; struct TYPE { \
using ApplyOnPhysicalTensor = detail::OpMeth< constexpr static char name[] = #TYPE; \
decltype(OpDef::apply_on_physical_tensor)>; }; \
using InferOutputMemDesc = detail::OpMeth< } \
decltype(OpDef::infer_output_mem_desc)>; using TYPE = detail::OpMeth<detail::op_meth_tag::TYPE, SIG>
using Execute = detail::OpMeth<
decltype(OpDef::execute)>; OpMethType(OpDefMaker,
using ApplyOnDeviceTensorND = detail::OpMeth< decltype(OpDef::make_from_op_node));
decltype(OpDef::apply_on_device_tensornd)>;
using ApplyOnVarNode = detail::OpMeth< OpMethType(DecideDispatchMode,
decltype(OpDef::apply_on_var_node)>; decltype(OpDef::decide_dispatch_mode));
using InferOutputAttrsFallible = detail::OpMeth<
decltype(OpDef::infer_output_attrs_fallible)>; OpMethType(ApplyOnPhysicalTensor,
using GradMaker = detail::OpMeth< decltype(OpDef::apply_on_physical_tensor));
decltype(OpDef::make_backward_graph)>;
using Props = detail::OpMeth<decltype(OpDef::props)>; OpMethType(InferOutputMemDesc,
using HashFunc = detail::OpMeth<size_t(const OpDef&)>; decltype(OpDef::infer_output_mem_desc));
using IsSame = detail::OpMeth<bool(const OpDef&, const OpDef&)>;
using MakeNameFunc = detail::OpMeth<std::string(const OpDef&)>; OpMethType(Execute,
decltype(OpDef::execute));
OpMethType(ApplyOnDeviceTensorND,
decltype(OpDef::apply_on_device_tensornd));
OpMethType(ApplyOnVarNode,
decltype(OpDef::apply_on_var_node));
OpMethType(InferOutputAttrsFallible,
decltype(OpDef::infer_output_attrs_fallible));
OpMethType(GradMaker,
decltype(OpDef::make_backward_graph));
OpMethType(Props,
decltype(OpDef::props));
OpMethType(HashFunc,
size_t(const OpDef&));
OpMethType(IsSame,
bool(const OpDef&, const OpDef&));
OpMethType(MakeNameFunc,
std::string(const OpDef&));
// clang-format on
namespace detail {
struct OpMethNotImpl {
template <typename Tag, typename RType, typename... Args>
static void impl(thin_function<RType(Args...)>& func, Tag) {
func = [](Args... args) -> RType {
mgb_throw(MegBrainError, "%s was not implemented yet", Tag::name);
};
}
};
struct OpMethFallback : public OpMethNotImpl {
using OpMethNotImpl::impl;
static void impl(ApplyOnPhysicalTensor& func,
op_meth_tag::ApplyOnPhysicalTensor);
static void impl(Execute& func, op_meth_tag::Execute);
static void impl(InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc);
static void impl(InferOutputAttrsFallible& func,
op_meth_tag::InferOutputAttrsFallible);
static void impl(GradMaker& func, op_meth_tag::GradMaker);
static void impl(DecideDispatchMode& func, op_meth_tag::DecideDispatchMode);
static void impl(MakeNameFunc& func, op_meth_tag::MakeNameFunc);
};
template <typename Tag, typename RType, typename... Args>
struct OpMeth<Tag, RType(Args...)> : public thin_function<RType(Args...)> {
using Base = thin_function<RType(Args...)>;
using Base::operator bool;
OpMeth() : Base{}, allow_fallback(false){};
explicit OpMeth(const Base& base) { this->Base::operator=(base); }
RType operator()(Args... args) const {
if (!this->Base::operator bool()) {
if (allow_fallback) {
OpMethFallback::impl(*const_cast<OpMeth*>(this), Tag{});
} else {
OpMethNotImpl::impl(*const_cast<OpMeth*>(this), Tag{});
}
}
return this->Base::operator()(std::forward<Args>(args)...);
}
bool allow_fallback = false;
};
} // namespace detail
struct OpTrait { struct OpTrait {
const char* name; const char* name;
...@@ -102,28 +158,31 @@ struct OpTrait { ...@@ -102,28 +158,31 @@ struct OpTrait {
static void for_each_trait(thin_function<void(OpTrait&)> visitor); static void for_each_trait(thin_function<void(OpTrait&)> visitor);
}; };
#define FOR_EACH_OP_METH(cb) \ // clang-format off
cb(make_from_op_node) \ #define FOR_EACH_OP_METH(cb) \
cb(decide_dispatch_mode) \ cb(make_from_op_node) \
cb(apply_on_physical_tensor) \ cb(decide_dispatch_mode) \
cb(infer_output_mem_desc) \ cb(apply_on_physical_tensor) \
cb(execute) \ cb(infer_output_mem_desc) \
cb(apply_on_device_tensornd) \ cb(execute) \
cb(apply_on_var_node) \ cb(apply_on_device_tensornd) \
cb(apply_on_var_node) \
cb(infer_output_attrs_fallible) \ cb(infer_output_attrs_fallible) \
cb(make_backward_graph) \ cb(make_backward_graph) \
cb(props) \ cb(props) \
cb(hash) \ cb(hash) \
cb(is_same_st) \ cb(is_same_st) \
cb(make_name) cb(make_name)
// clang-format on
struct OpTraitRegistry { struct OpTraitRegistry {
OpTrait* trait; OpTrait* trait;
#define DECL(meth) \ #define DECL(meth) \
OpTraitRegistry& meth(decltype(OpTrait::meth) f) { \ OpTraitRegistry& meth(decltype(OpTrait::meth)::Base f) { \
mgb_assert(!trait->meth, "op %s has duplicate method %s", trait->name, #meth); \ mgb_assert(!trait->meth, "op %s has duplicate method %s", trait->name, \
trait->meth = f; \ #meth); \
return *this; \ trait->meth.Base::operator=(f); \
return *this; \
} }
FOR_EACH_OP_METH(DECL) FOR_EACH_OP_METH(DECL)
#undef DECL #undef DECL
...@@ -162,7 +221,7 @@ struct OpTraitRegistry { ...@@ -162,7 +221,7 @@ struct OpTraitRegistry {
} }
}; };
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
#define OP_TRAIT_REG(name, ...) \ #define OP_TRAIT_REG(name, ...) \
......
...@@ -80,26 +80,30 @@ void TensorSanityCheck::enable() { ...@@ -80,26 +80,30 @@ void TensorSanityCheck::enable() {
OpTrait::for_each_trait([this](OpTrait& trait) { OpTrait::for_each_trait([this](OpTrait& trait) {
auto backup = std::make_unique<ApplyOnPhysicalTensor>( auto backup = std::make_unique<ApplyOnPhysicalTensor>(
std::move(trait.apply_on_physical_tensor)); std::move(trait.apply_on_physical_tensor));
trait.apply_on_physical_tensor = [this, backup = backup.get()] ( trait.apply_on_physical_tensor = ApplyOnPhysicalTensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs) { [this, backup = backup.get()](
for (auto&& i: inputs) { const OpDef& def,
if (!m_checker->check(i)) { const SmallVector<TensorPtr>& inputs) {
mgb_throw(TensorChecksumCalc::Error, for (auto&& i : inputs) {
"tensor modified before exec %s", print_op(def).c_str()); if (!m_checker->check(i)) {
} mgb_throw(TensorChecksumCalc::Error,
} "tensor modified before exec %s",
auto output = (*backup)(def, inputs); print_op(def).c_str());
for (auto&& i: output) { }
mgb_assert(m_checker->check(i)); }
} auto output = (*backup)(def, inputs);
for (auto&& i: inputs) { for (auto&& i : output) {
if (!m_checker->check(i)) { mgb_assert(m_checker->check(i));
mgb_throw(TensorChecksumCalc::Error, }
"tensor modified after exec %s", print_op(def).c_str()); for (auto&& i : inputs) {
} if (!m_checker->check(i)) {
} mgb_throw(TensorChecksumCalc::Error,
return output; "tensor modified after exec %s",
}; print_op(def).c_str());
}
}
return output;
});
m_checker->hook_list.push_back({&trait, std::move(backup)}); m_checker->hook_list.push_back({&trait, std::move(backup)});
}); });
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册