#pragma once #include <optional> #include <string> #include "pybind11/pybind11.h" #include "megbrain/imperative/dispatch.h" #include "megbrain/imperative/transformation.h" #include "megbrain/imperative/utils/helper.h" #include "megbrain/imperative/value.h" #include "megbrain/utils/small_vector.h" namespace mgb::imperative::python { struct TransformationManager { public: enum Segment { ModuleTrace, GroupComm, DTypePromote, DimExpansion, Format, Grad, Scalar, Symbol, Trace, Eval, }; std::array<std::vector<std::shared_ptr<Transformation>>, 10> segments; private: template <Segment segment> void unregister(std::shared_ptr<Transformation> transformation) noexcept { mgb_assert(segment < segments.size()); auto iter = std::find( segments[segment].begin(), segments[segment].end(), transformation); mgb_assert(iter != segments[segment].end()); transformation->unregister(); segments[segment].erase(iter); } public: template <Segment segment> [[nodiscard]] std::unique_ptr<CleanupGuard<>> register_at( std::shared_ptr<Transformation> transformation) { mgb_assert(segment < segments.size()); std::shared_ptr<Transformation> next; for (size_t i = segment; i < segments.size(); ++i) { if (!segments[i].empty()) { next = segments[i].back(); break; } } if (!next) { transformation->register_at(Transformation::bottom()); } else { transformation->register_at(next->pos()); } segments[segment].push_back(transformation); return std::make_unique<CleanupGuard<>>( [this, transformation]() { unregister<segment>(transformation); }); } static TransformationManager& get_instance() { static TransformationManager sl_instance; return sl_instance; } }; class PyValue final : public PrimitiveValue<PyValue, pybind11::object> { public: using PrimitiveValue::PrimitiveValue; std::string to_string() const { return pybind11::str((const pybind11::object&)*this).cast<std::string>(); } }; } // namespace mgb::imperative::python