transformation.h 2.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
/**
 * \file imperative/python/src/transformation.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once

14 15 16 17 18 19
#include <optional>
#include <string>

#include "pybind11/pybind11.h"

#include "megbrain/imperative/dispatch.h"
20
#include "megbrain/imperative/transformation.h"
21
#include "megbrain/imperative/utils/helper.h"
22 23
#include "megbrain/imperative/value.h"
#include "megbrain/utils/small_vector.h"
24 25 26

namespace mgb::imperative::python {
struct TransformationManager {
27
public:
28 29
    enum Segment {
        ModuleTrace,
30
        DTypePromote,
31
        DimExpansion,
32
        Grad,
33 34 35 36 37
        Scalar,
        Trace,
        Eval,
    };

38
    std::array<std::vector<std::shared_ptr<Transformation>>, 7> segments;
39

40 41 42 43 44 45 46 47 48 49 50 51
private:
    template <Segment segment>
    void unregister(std::shared_ptr<Transformation> transformation) noexcept {
        mgb_assert(segment < segments.size());
        auto iter = std::find(
                segments[segment].begin(), segments[segment].end(), transformation);
        mgb_assert(iter != segments[segment].end());
        transformation->unregister();
        segments[segment].erase(iter);
    }

public:
52
    template <Segment segment>
53 54
    [[nodiscard]] std::unique_ptr<CleanupGuard<>> register_at(
            std::shared_ptr<Transformation> transformation) {
55 56 57 58 59 60 61 62 63 64 65 66 67 68
        mgb_assert(segment < segments.size());
        std::shared_ptr<Transformation> next;
        for (size_t i = segment; i < segments.size(); ++i) {
            if (!segments[i].empty()) {
                next = segments[i].back();
                break;
            }
        }
        if (!next) {
            transformation->register_at(Transformation::bottom());
        } else {
            transformation->register_at(next->pos());
        }
        segments[segment].push_back(transformation);
69 70
        return std::make_unique<CleanupGuard<>>(
                [this, transformation]() { unregister<segment>(transformation); });
71 72 73 74 75 76 77
    }

    static TransformationManager& get_instance() {
        static TransformationManager sl_instance;
        return sl_instance;
    }
};
78

79
class PyValue final : public PrimitiveValue<PyValue, pybind11::object> {
80
public:
81
    using PrimitiveValue::PrimitiveValue;
82 83 84 85 86 87

    std::string to_string() const {
        return pybind11::str((const pybind11::object&)*this).cast<std::string>();
    }
};

88
}  // namespace mgb::imperative::python