ir.cc 2.0 KB
Newer Older
F
flame 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pybind/ir.h"
16

W
WangZhen 已提交
17
#include <algorithm>
18
#include <memory>
F
flame 已提交
19 20
#include <string>
#include <unordered_map>
W
WangZhen 已提交
21
#include <unordered_set>
22
#include <utility>
23

24 25
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/program.h"
F
flame 已提交
26 27 28
#include "pybind11/stl.h"

namespace py = pybind11;
29 30 31
using ir::Block;
using ir::Operation;
using ir::Program;
F
flame 已提交
32 33 34 35 36
using pybind11::return_value_policy;

namespace paddle {
namespace pybind {

37 38 39 40 41 42 43 44 45
void BindProgram(py::module *m) {
  py::class_<Program> program(*m, "Program");
  program.def("parameters_num", &Program::parameters_num)
      .def("block", &Program::block, return_value_policy::reference)
      .def("print", [](Program &self) {
        std::ostringstream print_stream;
        self.Print(print_stream);
        LOG(INFO) << print_stream.str();
      });
F
flame 已提交
46
}
47

48 49 50 51 52 53 54 55 56 57
void BindBlock(py::module *m) {
  py::class_<Block> block(*m, "Block");
  block.def("front", &Block::front, return_value_policy::reference)
      .def("get_op_list", [](Block &self) -> py::list {
        py::list op_list;
        for (auto iter = self.begin(); iter != self.end(); iter++) {
          op_list.append(*iter);
        }
        return op_list;
      });
58 59
}

60 61 62
void BindOperation(py::module *m) {
  py::class_<Operation> op(*m, "Operation");
  op.def("name", &Operation::name);
63 64
}

65 66 67 68
void BindNewIR(pybind11::module *m) {
  BindProgram(m);
  BindBlock(m);
  BindOperation(m);
69 70
}

F
flame 已提交
71 72
}  // namespace pybind
}  // namespace paddle