ops.cpp 1.5 KB
Newer Older
M
Megvii Engine Team 已提交
1 2 3 4 5 6 7 8 9 10 11
/**
 * \file imperative/python/src/ops.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

12 13 14 15 16
#include "./ops.h"

#include "megbrain/imperative.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
17
#include "megbrain/imperative/ops/autogen.h"
18 19 20

namespace py = pybind11;

21 22 23 24 25 26 27 28 29 30
namespace {
auto normalize_enum(const std::string& in) {
    std::string ret;
    for (auto&& c : in) {
        ret += toupper(c);
    }
    return ret;
}
} // anonymous namespace

31
void init_ops(py::module m) {
32
    using namespace mgb::imperative;
33

34 35 36 37
    py::class_<BackwardGraph, std::shared_ptr<BackwardGraph>, OpDef>(m, "BackwardGraph")
        .def("interpret", [](BackwardGraph& self, py::object pyf, py::object pyc,
                             const mgb::SmallVector<py::object>& inputs) {
                auto f = [pyf](OpDef& op, const mgb::SmallVector<py::object>& inputs) {
38
                    return py::cast<mgb::SmallVector<py::object>>(pyf(op.shared_from_this(), inputs));
39 40 41 42 43 44 45
                };
                auto c = [pyc](const TensorPtr& tensor) {
                    return pyc(tensor->dev_tensor());
                };
                return self.graph().interpret<py::object>(f, c, inputs);
            });

46
    #include "opdef.py.inl"
47
}