#pragma once #include #include #include "pybind11/pybind11.h" #include "megbrain/imperative/dispatch.h" #include "megbrain/imperative/transformation.h" #include "megbrain/imperative/utils/helper.h" #include "megbrain/imperative/value.h" #include "megbrain/utils/small_vector.h" namespace mgb::imperative::python { struct TransformationManager { public: enum Segment { ModuleTrace, GroupComm, DTypePromote, DimExpansion, Complex, Format, Grad, ExternalConvert, Scalar, Symbol, Trace, Eval, SEGMENT_COUNT, }; std::array>, SEGMENT_COUNT> segments; private: template void unregister(std::shared_ptr transformation) noexcept { mgb_assert(segment < segments.size()); auto iter = std::find( segments[segment].begin(), segments[segment].end(), transformation); mgb_assert(iter != segments[segment].end()); transformation->unregister(); segments[segment].erase(iter); } public: template [[nodiscard]] std::unique_ptr> register_at( std::shared_ptr transformation) { mgb_assert(segment < segments.size()); std::shared_ptr next; for (size_t i = segment; i < segments.size(); ++i) { if (!segments[i].empty()) { next = segments[i].back(); break; } } if (!next) { transformation->register_at(Transformation::bottom()); } else { transformation->register_at(next->pos()); } segments[segment].push_back(transformation); return std::make_unique>( [this, transformation]() { unregister(transformation); }); } static TransformationManager& get_instance() { static TransformationManager sl_instance; return sl_instance; } }; class PyValue final : public PrimitiveValue { public: using PrimitiveValue::PrimitiveValue; std::string to_string() const { return pybind11::str((const pybind11::object&)*this).cast(); } }; } // namespace mgb::imperative::python