From d3689c3f3cb150f7e52f035e0d9607b72da1a9e2 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 14 Jan 2022 13:51:47 +0800 Subject: [PATCH] feat(imperative/python): add transformation manager GitOrigin-RevId: a3c1732ffd4fafdde27c69fcdc3488ae19ec3dd0 --- imperative/python/src/transformation.h | 61 ++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 imperative/python/src/transformation.h diff --git a/imperative/python/src/transformation.h b/imperative/python/src/transformation.h new file mode 100644 index 000000000..df10d2c52 --- /dev/null +++ b/imperative/python/src/transformation.h @@ -0,0 +1,61 @@ +/** + * \file imperative/python/src/transformation.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megbrain/imperative/transformation.h" + +namespace mgb::imperative::python { +struct TransformationManager { + enum Segment { + ModuleTrace, + Grad, + Scalar, + Trace, + Eval, + }; + + std::array>, 5> segments; + + template + void register_at(std::shared_ptr transformation) { + mgb_assert(segment < segments.size()); + std::shared_ptr next; + for (size_t i = segment; i < segments.size(); ++i) { + if (!segments[i].empty()) { + next = segments[i].back(); + break; + } + } + if (!next) { + transformation->register_at(Transformation::bottom()); + } else { + transformation->register_at(next->pos()); + } + segments[segment].push_back(transformation); + } + + template + void unregister(std::shared_ptr transformation) noexcept { + mgb_assert(segment < segments.size()); + auto iter = std::find( + segments[segment].begin(), segments[segment].end(), transformation); + mgb_assert(iter != segments[segment].end()); + transformation->unregister(); + segments[segment].erase(iter); + } + + static TransformationManager& get_instance() { + static TransformationManager sl_instance; + return sl_instance; + } +}; +} // namespace mgb::imperative::python -- GitLab