diff --git a/imperative/python/src/transformation.h b/imperative/python/src/transformation.h new file mode 100644 index 0000000000000000000000000000000000000000..df10d2c52c2a6b99be348e365c81be13f1687df0 --- /dev/null +++ b/imperative/python/src/transformation.h @@ -0,0 +1,61 @@ +/** + * \file imperative/python/src/transformation.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megbrain/imperative/transformation.h" + +namespace mgb::imperative::python { +struct TransformationManager { + enum Segment { + ModuleTrace, + Grad, + Scalar, + Trace, + Eval, + }; + + std::array>, 5> segments; + + template + void register_at(std::shared_ptr transformation) { + mgb_assert(segment < segments.size()); + std::shared_ptr next; + for (size_t i = segment; i < segments.size(); ++i) { + if (!segments[i].empty()) { + next = segments[i].back(); + break; + } + } + if (!next) { + transformation->register_at(Transformation::bottom()); + } else { + transformation->register_at(next->pos()); + } + segments[segment].push_back(transformation); + } + + template + void unregister(std::shared_ptr transformation) noexcept { + mgb_assert(segment < segments.size()); + auto iter = std::find( + segments[segment].begin(), segments[segment].end(), transformation); + mgb_assert(iter != segments[segment].end()); + transformation->unregister(); + segments[segment].erase(iter); + } + + static TransformationManager& get_instance() { + static TransformationManager sl_instance; + return sl_instance; + } +}; +} // namespace mgb::imperative::python