diff --git a/imperative/src/impl/ops/utility.cpp b/imperative/src/impl/ops/utility.cpp index 9ac3fbf7da00d6c4e978f25c870a6298729e0fb4..07a434a3ffcc2af61df917d86a02c3b6be402d2b 100644 --- a/imperative/src/impl/ops/utility.cpp +++ b/imperative/src/impl/ops/utility.cpp @@ -657,6 +657,85 @@ OP_TRAIT_REG(CompiledOp, CompiledOp) } // namespace compiled_op } // namespace +namespace { +namespace jit_fusion { + +static thread_local bool tm_enabled = true; + +auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto& op = def.cast_final_safe(); + op.op->set_scope(op.scope()); + auto outputs = OpDef::apply_on_var_node(*op.op, inputs); + if (!tm_enabled) { + // skip for dump (JITExecutor can not be dumped) + return outputs; + } + for (auto& output : outputs) { + jit::InternalGraphGenerator igg{output->owner_opr()}; + std::vector reverse_order; + cg::DepOprIter iter{ + [&](cg::OperatorNodeBase* opr) { reverse_order.push_back(opr); }}; + for (auto&& input : inputs) { + iter.set_visited(input->owner_opr()); + } + iter.add(output->owner_opr()); + std::reverse(reverse_order.begin(), reverse_order.end()); + for (auto&& opr : reverse_order) { + igg.add_opr(opr); + } + auto ig = igg.generate(); + output = jit::JITExecutor::make(ig, igg.orig_inps()).node(); + } + return outputs; +} + +auto infer_output_attrs_fallible( + const OpDef& def, const SmallVector& input_descs) { + return OpDef::infer_output_attrs_fallible( + *def.cast_final_safe().op, input_descs); +} + +auto props(const OpDef& def) { + return OpDef::props(*def.cast_final_safe().op); +} + +auto hash(const OpDef& def) { + return def.cast_final_safe().op->hash(); +} + +auto is_samt_st(const OpDef& def, const OpDef& another) { + if (!another.same_type()) { + return false; + } + auto& lhs = def.cast_final_safe(); + auto& rhs = another.cast_final_safe(); + return lhs.op->is_same(*rhs.op); +} + +EncodedSubgraph make_backward_graph( + const OpDef& def, const SmallVector& inputs, + const SmallVector& input_requires_grad, + const SmallVector& output_has_grad) { + return {}; +} + +OP_TRAIT_REG(JITFusionOp, JITFusionOp) + .apply_on_var_node(apply_on_var_node) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .props(props) + .hash(hash) + .is_same_st(is_samt_st) + .make_backward_graph(make_backward_graph) + .fallback(); + +} // namespace jit_fusion +} // namespace + +bool JITFusionOp::set_enabled(bool enabled) { + std::swap(enabled, jit_fusion::tm_enabled); + return enabled; +} + MGB_DYN_TYPE_OBJ_FINAL_IMPL(UniqueKey); MGB_DYN_TYPE_OBJ_FINAL_IMPL(SubgraphOp); @@ -665,4 +744,6 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardOpKey); MGB_DYN_TYPE_OBJ_FINAL_IMPL(CompiledOp); +MGB_DYN_TYPE_OBJ_FINAL_IMPL(JITFusionOp); + } // namespace mgb::imperative diff --git a/imperative/src/include/megbrain/imperative/ops/utility.h b/imperative/src/include/megbrain/imperative/ops/utility.h index be984ffb17d854c5c56985c0c049438ec5859b68..79dd970fd93f4e6e901e8493e1401205a006d003 100644 --- a/imperative/src/include/megbrain/imperative/ops/utility.h +++ b/imperative/src/include/megbrain/imperative/ops/utility.h @@ -111,4 +111,12 @@ struct CompiledOp final : OpDefImplBase { MGB_DYN_TYPE_OBJ_FINAL_DECL; }; +struct JITFusionOp final : OpDefImplBase { + std::shared_ptr op; + JITFusionOp() = default; + JITFusionOp(std::shared_ptr op) : op{op} {} + static bool set_enabled(bool enabled); + MGB_DYN_TYPE_OBJ_FINAL_DECL; +}; + } // namespace mgb::imperative