提交 3c61e0e0 编写于 作者: M Megvii Engine Team

feat(ops): add JITFusion op

GitOrigin-RevId: 7dc35d4e80f1ac9334ebb49b0202b96b004e45b1
上级 aa587446
......@@ -657,6 +657,85 @@ OP_TRAIT_REG(CompiledOp, CompiledOp)
} // namespace compiled_op
} // namespace
namespace {
namespace jit_fusion {
static thread_local bool tm_enabled = true;
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto& op = def.cast_final_safe<JITFusionOp>();
op.op->set_scope(op.scope());
auto outputs = OpDef::apply_on_var_node(*op.op, inputs);
if (!tm_enabled) {
// skip for dump (JITExecutor can not be dumped)
return outputs;
}
for (auto& output : outputs) {
jit::InternalGraphGenerator igg{output->owner_opr()};
std::vector<cg::OperatorNodeBase*> reverse_order;
cg::DepOprIter iter{
[&](cg::OperatorNodeBase* opr) { reverse_order.push_back(opr); }};
for (auto&& input : inputs) {
iter.set_visited(input->owner_opr());
}
iter.add(output->owner_opr());
std::reverse(reverse_order.begin(), reverse_order.end());
for (auto&& opr : reverse_order) {
igg.add_opr(opr);
}
auto ig = igg.generate();
output = jit::JITExecutor::make(ig, igg.orig_inps()).node();
}
return outputs;
}
auto infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& input_descs) {
return OpDef::infer_output_attrs_fallible(
*def.cast_final_safe<JITFusionOp>().op, input_descs);
}
auto props(const OpDef& def) {
return OpDef::props(*def.cast_final_safe<JITFusionOp>().op);
}
auto hash(const OpDef& def) {
return def.cast_final_safe<JITFusionOp>().op->hash();
}
auto is_samt_st(const OpDef& def, const OpDef& another) {
if (!another.same_type<JITFusionOp>()) {
return false;
}
auto& lhs = def.cast_final_safe<JITFusionOp>();
auto& rhs = another.cast_final_safe<JITFusionOp>();
return lhs.op->is_same(*rhs.op);
}
EncodedSubgraph make_backward_graph(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad) {
return {};
}
OP_TRAIT_REG(JITFusionOp, JITFusionOp)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.props(props)
.hash(hash)
.is_same_st(is_samt_st)
.make_backward_graph(make_backward_graph)
.fallback();
} // namespace jit_fusion
} // namespace
bool JITFusionOp::set_enabled(bool enabled) {
std::swap(enabled, jit_fusion::tm_enabled);
return enabled;
}
MGB_DYN_TYPE_OBJ_FINAL_IMPL(UniqueKey);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(SubgraphOp);
......@@ -665,4 +744,6 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardOpKey);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CompiledOp);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(JITFusionOp);
} // namespace mgb::imperative
......@@ -111,4 +111,12 @@ struct CompiledOp final : OpDefImplBase<CompiledOp> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
};
struct JITFusionOp final : OpDefImplBase<JITFusionOp> {
std::shared_ptr<OpDef> op;
JITFusionOp() = default;
JITFusionOp(std::shared_ptr<OpDef> op) : op{op} {}
static bool set_enabled(bool enabled);
MGB_DYN_TYPE_OBJ_FINAL_DECL;
};
} // namespace mgb::imperative
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册