提交 20e8541b 编写于 作者: M Megvii Engine Team

refactor(imperative): bind fallback impl on first op method call

GitOrigin-RevId: 82ae1e32052f274dea67ced95dc6ab694883425b
上级 18274e02
......@@ -38,6 +38,38 @@ StaticData& static_data() {
return data;
}
void OpMethFallback::impl(ApplyOnPhysicalTensor& func,
op_meth_tag::ApplyOnPhysicalTensor) {
func.Base::operator=(proxy_graph_detail::apply_on_physical_tensor);
}
void OpMethFallback::impl(Execute& func, op_meth_tag::Execute) {
func.Base::operator=(proxy_graph_detail::execute);
}
void OpMethFallback::impl(InferOutputMemDesc& func,
op_meth_tag::InferOutputMemDesc) {
func.Base::operator=(proxy_graph_detail::infer_output_mem_desc);
}
void OpMethFallback::impl(InferOutputAttrsFallible& func,
op_meth_tag::InferOutputAttrsFallible) {
func.Base::operator=(proxy_graph_detail::infer_output_attrs_fallible);
}
void OpMethFallback::impl(GradMaker& func, op_meth_tag::GradMaker) {
func.Base::operator=(proxy_graph_detail::make_backward_graph);
}
void OpMethFallback::impl(DecideDispatchMode& func,
op_meth_tag::DecideDispatchMode) {
static auto decide_dispatch_mode =
[](const OpDef&, const SmallVector<LogicalTensorDesc>&) {
return DispatchMode::KERNEL;
};
func.Base::operator=(decide_dispatch_mode);
}
void OpMethFallback::impl(MakeNameFunc& func, op_meth_tag::MakeNameFunc) {
static auto make_name = [](const OpDef& def) -> std::string {
return def.trait()->name;
};
func.Base::operator=(make_name);
}
} // detail
OpTrait::OpTrait(const char* name_): name(name_) {}
......@@ -66,44 +98,17 @@ void OpTrait::for_each_trait(thin_function<void(OpTrait&)> visitor){
}
}
DispatchMode fallback_decide_dispatch_mode(
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs) {
return KERNEL;
}
OpTraitRegistry& OpTraitRegistry::fallback() {
if (trait->apply_on_var_node) {
// fallback to proxy graph impl
if (!trait->apply_on_physical_tensor) {
trait->apply_on_physical_tensor =
proxy_graph_detail::apply_on_physical_tensor;
}
if (!trait->execute) {
trait->execute = proxy_graph_detail::execute;
}
if (!trait->infer_output_mem_desc) {
trait->infer_output_mem_desc =
proxy_graph_detail::infer_output_mem_desc;
}
if (!trait->infer_output_attrs_fallible) {
trait->infer_output_attrs_fallible =
proxy_graph_detail::infer_output_attrs_fallible;
}
if (!trait->make_backward_graph) {
trait->make_backward_graph =
proxy_graph_detail::make_backward_graph;
}
}
if (!trait->decide_dispatch_mode) {
trait->decide_dispatch_mode = fallback_decide_dispatch_mode;
}
if (!trait->make_name) {
static auto make_name = [](const OpDef& def) -> std::string {
return def.trait()->name;
};
trait->make_name = make_name;
trait->apply_on_physical_tensor.allow_fallback = true;
trait->execute.allow_fallback = true;
trait->infer_output_mem_desc.allow_fallback = true;
trait->infer_output_attrs_fallible.allow_fallback = true;
trait->make_backward_graph.allow_fallback = true;
}
trait->decide_dispatch_mode.allow_fallback = true;
trait->make_name.allow_fallback = true;
return *this;
}
......
......@@ -15,21 +15,10 @@
namespace mgb {
namespace imperative {
namespace detail {
template <typename Signature>
template <typename Tag, typename Signature>
struct OpMeth;
template <typename RType, typename... Args>
struct OpMeth<RType(Args...)> : public thin_function<RType(Args...)> {
using Base = thin_function<RType(Args...)>;
using Base::Base;
RType operator()(Args... args) const {
if (!this->Base::operator bool()) {
mgb_throw(MegBrainError, "Not Implemented");
}
return this->Base::operator()(std::forward<Args>(args)...);
}
};
template<typename T>
struct ToVarNodeArray: std::false_type {};
template<>
......@@ -58,28 +47,95 @@ struct ToVarNodeArray<cg::OperatorNodeBase*>: std::true_type {
};
} // namespace detail
using OpDefMaker = detail::OpMeth<
decltype(OpDef::make_from_op_node)>;
using DecideDispatchMode = detail::OpMeth<
decltype(OpDef::decide_dispatch_mode)>;
using ApplyOnPhysicalTensor = detail::OpMeth<
decltype(OpDef::apply_on_physical_tensor)>;
using InferOutputMemDesc = detail::OpMeth<
decltype(OpDef::infer_output_mem_desc)>;
using Execute = detail::OpMeth<
decltype(OpDef::execute)>;
using ApplyOnDeviceTensorND = detail::OpMeth<
decltype(OpDef::apply_on_device_tensornd)>;
using ApplyOnVarNode = detail::OpMeth<
decltype(OpDef::apply_on_var_node)>;
using InferOutputAttrsFallible = detail::OpMeth<
decltype(OpDef::infer_output_attrs_fallible)>;
using GradMaker = detail::OpMeth<
decltype(OpDef::make_backward_graph)>;
using Props = detail::OpMeth<decltype(OpDef::props)>;
using HashFunc = detail::OpMeth<size_t(const OpDef&)>;
using IsSame = detail::OpMeth<bool(const OpDef&, const OpDef&)>;
using MakeNameFunc = detail::OpMeth<std::string(const OpDef&)>;
// clang-format off
#define OpMethType(TYPE, SIG) \
namespace detail::op_meth_tag { \
struct TYPE { \
constexpr static char name[] = #TYPE; \
}; \
} \
using TYPE = detail::OpMeth<detail::op_meth_tag::TYPE, SIG>
OpMethType(OpDefMaker,
decltype(OpDef::make_from_op_node));
OpMethType(DecideDispatchMode,
decltype(OpDef::decide_dispatch_mode));
OpMethType(ApplyOnPhysicalTensor,
decltype(OpDef::apply_on_physical_tensor));
OpMethType(InferOutputMemDesc,
decltype(OpDef::infer_output_mem_desc));
OpMethType(Execute,
decltype(OpDef::execute));
OpMethType(ApplyOnDeviceTensorND,
decltype(OpDef::apply_on_device_tensornd));
OpMethType(ApplyOnVarNode,
decltype(OpDef::apply_on_var_node));
OpMethType(InferOutputAttrsFallible,
decltype(OpDef::infer_output_attrs_fallible));
OpMethType(GradMaker,
decltype(OpDef::make_backward_graph));
OpMethType(Props,
decltype(OpDef::props));
OpMethType(HashFunc,
size_t(const OpDef&));
OpMethType(IsSame,
bool(const OpDef&, const OpDef&));
OpMethType(MakeNameFunc,
std::string(const OpDef&));
// clang-format on
namespace detail {
struct OpMethNotImpl {
template <typename Tag, typename RType, typename... Args>
static void impl(thin_function<RType(Args...)>& func, Tag) {
func = [](Args... args) -> RType {
mgb_throw(MegBrainError, "%s was not implemented yet", Tag::name);
};
}
};
struct OpMethFallback : public OpMethNotImpl {
using OpMethNotImpl::impl;
static void impl(ApplyOnPhysicalTensor& func,
op_meth_tag::ApplyOnPhysicalTensor);
static void impl(Execute& func, op_meth_tag::Execute);
static void impl(InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc);
static void impl(InferOutputAttrsFallible& func,
op_meth_tag::InferOutputAttrsFallible);
static void impl(GradMaker& func, op_meth_tag::GradMaker);
static void impl(DecideDispatchMode& func, op_meth_tag::DecideDispatchMode);
static void impl(MakeNameFunc& func, op_meth_tag::MakeNameFunc);
};
template <typename Tag, typename RType, typename... Args>
struct OpMeth<Tag, RType(Args...)> : public thin_function<RType(Args...)> {
using Base = thin_function<RType(Args...)>;
using Base::operator bool;
OpMeth() : Base{}, allow_fallback(false){};
explicit OpMeth(const Base& base) { this->Base::operator=(base); }
RType operator()(Args... args) const {
if (!this->Base::operator bool()) {
if (allow_fallback) {
OpMethFallback::impl(*const_cast<OpMeth*>(this), Tag{});
} else {
OpMethNotImpl::impl(*const_cast<OpMeth*>(this), Tag{});
}
}
return this->Base::operator()(std::forward<Args>(args)...);
}
bool allow_fallback = false;
};
} // namespace detail
struct OpTrait {
const char* name;
......@@ -102,28 +158,31 @@ struct OpTrait {
static void for_each_trait(thin_function<void(OpTrait&)> visitor);
};
#define FOR_EACH_OP_METH(cb) \
cb(make_from_op_node) \
cb(decide_dispatch_mode) \
cb(apply_on_physical_tensor) \
cb(infer_output_mem_desc) \
cb(execute) \
cb(apply_on_device_tensornd) \
cb(apply_on_var_node) \
// clang-format off
#define FOR_EACH_OP_METH(cb) \
cb(make_from_op_node) \
cb(decide_dispatch_mode) \
cb(apply_on_physical_tensor) \
cb(infer_output_mem_desc) \
cb(execute) \
cb(apply_on_device_tensornd) \
cb(apply_on_var_node) \
cb(infer_output_attrs_fallible) \
cb(make_backward_graph) \
cb(props) \
cb(hash) \
cb(is_same_st) \
cb(make_backward_graph) \
cb(props) \
cb(hash) \
cb(is_same_st) \
cb(make_name)
// clang-format on
struct OpTraitRegistry {
OpTrait* trait;
#define DECL(meth) \
OpTraitRegistry& meth(decltype(OpTrait::meth) f) { \
mgb_assert(!trait->meth, "op %s has duplicate method %s", trait->name, #meth); \
trait->meth = f; \
return *this; \
#define DECL(meth) \
OpTraitRegistry& meth(decltype(OpTrait::meth)::Base f) { \
mgb_assert(!trait->meth, "op %s has duplicate method %s", trait->name, \
#meth); \
trait->meth.Base::operator=(f); \
return *this; \
}
FOR_EACH_OP_METH(DECL)
#undef DECL
......@@ -162,7 +221,7 @@ struct OpTraitRegistry {
}
};
} // namespace imperative
} // namespace imperative
} // namespace mgb
#define OP_TRAIT_REG(name, ...) \
......
......@@ -80,26 +80,30 @@ void TensorSanityCheck::enable() {
OpTrait::for_each_trait([this](OpTrait& trait) {
auto backup = std::make_unique<ApplyOnPhysicalTensor>(
std::move(trait.apply_on_physical_tensor));
trait.apply_on_physical_tensor = [this, backup = backup.get()] (
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
for (auto&& i: inputs) {
if (!m_checker->check(i)) {
mgb_throw(TensorChecksumCalc::Error,
"tensor modified before exec %s", print_op(def).c_str());
}
}
auto output = (*backup)(def, inputs);
for (auto&& i: output) {
mgb_assert(m_checker->check(i));
}
for (auto&& i: inputs) {
if (!m_checker->check(i)) {
mgb_throw(TensorChecksumCalc::Error,
"tensor modified after exec %s", print_op(def).c_str());
}
}
return output;
};
trait.apply_on_physical_tensor = ApplyOnPhysicalTensor(
[this, backup = backup.get()](
const OpDef& def,
const SmallVector<TensorPtr>& inputs) {
for (auto&& i : inputs) {
if (!m_checker->check(i)) {
mgb_throw(TensorChecksumCalc::Error,
"tensor modified before exec %s",
print_op(def).c_str());
}
}
auto output = (*backup)(def, inputs);
for (auto&& i : output) {
mgb_assert(m_checker->check(i));
}
for (auto&& i : inputs) {
if (!m_checker->check(i)) {
mgb_throw(TensorChecksumCalc::Error,
"tensor modified after exec %s",
print_op(def).c_str());
}
}
return output;
});
m_checker->hook_list.push_back({&trait, std::move(backup)});
});
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册