ir.cc 11.0 KB
Newer Older
F
flame 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pybind/ir.h"
16

17
#include <Python.h>
W
WangZhen 已提交
18
#include <algorithm>
19
#include <memory>
F
flame 已提交
20 21
#include <string>
#include <unordered_map>
W
WangZhen 已提交
22
#include <unordered_set>
23
#include <utility>
24

25
#include "paddle/fluid/ir/dialect/pd_dialect.h"
26 27
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h"
28
#include "paddle/fluid/ir_adaptor/translator/translate.h"
29
#include "paddle/ir/core/block.h"
30
#include "paddle/ir/core/builtin_attribute.h"
31
#include "paddle/ir/core/program.h"
32 33 34
#include "paddle/ir/core/type.h"
#include "paddle/ir/core/value.h"
#include "paddle/phi/core/enforce.h"
F
flame 已提交
35 36 37
#include "pybind11/stl.h"

namespace py = pybind11;
38 39
using ir::Block;
using ir::Operation;
40 41
using ir::OpOperand;
using ir::OpResult;
42
using ir::Program;
43 44
using ir::Type;
using ir::Value;
45
using paddle::dialect::APIBuilder;
46
using paddle::dialect::DenseTensorType;
F
flame 已提交
47 48 49 50 51
using pybind11::return_value_policy;

namespace paddle {
namespace pybind {

52 53 54 55
PyTypeObject *g_ir_opresult_pytype = nullptr;

void BindOpsAPI(pybind11::module *module);

56 57
void BindProgram(py::module *m) {
  py::class_<Program> program(*m, "Program");
58 59 60 61 62 63 64 65 66 67 68
  program
      .def(
          "__init__",
          [](Program &self) { new (&self) Program(ir::IrContext::Instance()); })
      .def("__str__",
           [](Program &self) {
             std::ostringstream print_stream;
             self.Print(print_stream);
             return print_stream.str();
           })
      .def("parameters_num", &Program::parameters_num)
69 70 71 72 73
      .def("block",
           py::overload_cast<>(&Program::block),
           return_value_policy::reference)
      .def("block",
           py::overload_cast<>(&Program::block, py::const_),
74
           return_value_policy::reference);
F
flame 已提交
75
}
76

77 78 79
void BindBlock(py::module *m) {
  py::class_<Block> block(*m, "Block");
  block.def("front", &Block::front, return_value_policy::reference)
80 81
      .def("get_parent_program",
           [](Block &self) { return self.GetParentOp()->GetParentProgram(); })
82 83 84 85 86 87 88 89 90 91 92
      .def("get_ops",
           [](Block &self) -> py::list {
             py::list op_list;
             for (auto iter = self.begin(); iter != self.end(); iter++) {
               op_list.append(*iter);
             }
             return op_list;
           })
      .def("remove_op", [](Block &self, Operation *op) {
        auto op_iter = std::find(self.begin(), self.end(), op);
        self.erase(op_iter);
93
      });
94 95
}

96 97
void BindOperation(py::module *m) {
  py::class_<Operation> op(*m, "Operation");
98
  op.def("name", &Operation::name)
99
      .def("get_parent_block",
100 101
           py::overload_cast<>(&Operation::GetParent),
           return_value_policy::reference)
102
      .def("get_parent_block",
103 104
           py::overload_cast<>(&Operation::GetParent, py::const_),
           return_value_policy::reference)
105
      .def("num_operands", &Operation::num_operands)
106
      .def("num_results", &Operation::num_results)
107
      .def("operand", &Operation::operand)
108
      .def("result", &Operation::result)
109
      .def("operand_source", &Operation::operand_source)
110 111 112 113
      .def("operands",
           [](Operation &self) -> py::list {
             py::list op_list;
             for (uint32_t i = 0; i < self.num_operands(); i++) {
114
               op_list.append(self.operand(i));
115 116 117 118 119 120 121 122 123 124 125
             }
             return op_list;
           })
      .def("results",
           [](Operation &self) -> py::list {
             py::list op_list;
             for (uint32_t i = 0; i < self.num_results(); i++) {
               op_list.append(self.result(i));
             }
             return op_list;
           })
126 127 128 129 130 131 132 133
      .def("operands_source",
           [](Operation &self) -> py::list {
             py::list op_list;
             for (uint32_t i = 0; i < self.num_operands(); i++) {
               op_list.append(self.operand_source(i));
             }
             return op_list;
           })
134 135 136 137 138 139
      .def("get_input_names",
           [](Operation &self) -> py::list {
             py::list op_list;
             paddle::dialect::OpYamlInfoInterface yaml_interface =
                 self.dyn_cast<paddle::dialect::OpYamlInfoInterface>();
             auto inputs_info = std::get<0>(yaml_interface.GetOpInfo());
140
             for (auto &input_info : inputs_info) {
141 142 143 144 145 146 147 148 149 150
               op_list.append(input_info.name);
             }
             return op_list;
           })
      .def("get_attr_names",
           [](Operation &self) -> py::list {
             py::list op_list;
             paddle::dialect::OpYamlInfoInterface yaml_interface =
                 self.dyn_cast<paddle::dialect::OpYamlInfoInterface>();
             auto attrs_info = std::get<1>(yaml_interface.GetOpInfo());
151
             for (auto &attr_info : attrs_info) {
152 153 154 155 156 157 158 159 160 161
               op_list.append(attr_info.name);
             }
             return op_list;
           })
      .def("get_output_names",
           [](Operation &self) -> py::list {
             py::list op_list;
             paddle::dialect::OpYamlInfoInterface yaml_interface =
                 self.dyn_cast<paddle::dialect::OpYamlInfoInterface>();
             auto outputs_info = std::get<2>(yaml_interface.GetOpInfo());
162
             for (auto &output_info : outputs_info) {
163 164 165 166 167 168 169 170 171 172 173 174
               op_list.append(output_info.name);
             }
             return op_list;
           })
      .def("replace_all_uses_with",
           [](Operation &self, const std::vector<OpResult> &op_results) {
             self.ReplaceAllUsesWith(op_results);
           });
}

void BindValue(py::module *m) {
  py::class_<Value> value(*m, "Value");
175 176 177 178
  value
      .def("get_defining_op",
           &Value::GetDefiningOp,
           return_value_policy::reference)
X
xiaoguoguo626807 已提交
179 180 181 182 183 184 185
      .def("__eq__", &Value::operator==)
      .def("__eq__",
           [](Value &self, OpResult &other) {
             return self.impl() == other.value_impl();
           })
      .def("__hash__",
           [](const Value &self) { return std::hash<ir::Value>{}(self); });
186 187 188 189
}

void BindOpOperand(py::module *m) {
  py::class_<OpOperand> op_operand(*m, "OpOperand");
190 191 192 193 194 195
  op_operand
      .def("source",
           [](OpOperand &self) { return self.source().dyn_cast<OpResult>(); })
      .def("set_source", [](OpOperand &self, const OpResult &result) {
        self.set_source(result);
      });
196 197 198 199
}

void BindOpResult(py::module *m) {
  py::class_<OpResult> op_result(*m, "OpResult");
200
  g_ir_opresult_pytype = reinterpret_cast<PyTypeObject *>(op_result.ptr());
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
  op_result
      .def("get_defining_op",
           &OpResult::GetDefiningOp,
           return_value_policy::reference)
      .def("use_empty", &OpResult::use_empty)
      .def("type", &OpResult::type)
      .def("set_stop_gradient",
           [](OpResult &self, bool stop_gradient) {
             auto *defining_op = self.owner();
             std::vector<ir::Attribute> stop_gradients;
             if (defining_op->HasAttribute(kAttrStopGradients)) {
               stop_gradients = defining_op->attribute(kAttrStopGradients)
                                    .dyn_cast<ir::ArrayAttribute>()
                                    .AsVector();
             } else {
               stop_gradients = std::vector<ir::Attribute>(
                   defining_op->num_results(),
                   ir::BoolAttribute::get(ir::IrContext::Instance(), false));
             }
             stop_gradients[self.GetResultIndex()] = ir::BoolAttribute::get(
                 ir::IrContext::Instance(), stop_gradient);
             defining_op->set_attribute(
                 kAttrStopGradients,
                 ir::ArrayAttribute::get(ir::IrContext::Instance(),
                                         stop_gradients));
           })
X
xiaoguoguo626807 已提交
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
      .def("get_stop_gradient",
           [](OpResult &self) {
             auto *defining_op = self.owner();
             if (defining_op->HasAttribute(kAttrStopGradients)) {
               auto stop_gradients = defining_op->attribute(kAttrStopGradients)
                                         .dyn_cast<ir::ArrayAttribute>()
                                         .AsVector();
               return stop_gradients[self.GetResultIndex()]
                   .dyn_cast<ir::BoolAttribute>()
                   .data();
             } else {
               return false;
             }
           })
      .def("__eq__", &OpResult::operator==)
      .def("__eq__",
           [](OpResult &self, Value &other) {
             return self.value_impl() == other.impl();
           })
      .def("__hash__", [](OpResult &self) {
        return std::hash<ir::Value>{}(self.dyn_cast<ir::Value>());
248 249 250 251 252 253
      });
}

void BindType(py::module *m) {
  py::class_<Type> ir_type(*m, "Type");
  ir_type.def("__eq__", [](Type &self, Type &other) { return self == other; })
254 255 256 257 258
      .def("__str__", [](Type &self) {
        std::ostringstream print_stream;
        print_stream << self;
        return print_stream.str();
      });
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
}

void BindUtils(pybind11::module *m) {
  m->def("get_op_result_shape", [](const OpResult &op_result) {
    if (op_result.type().isa<DenseTensorType>()) {
      return phi::vectorize(
          op_result.type().dyn_cast<DenseTensorType>().dims());
    } else {
      PADDLE_THROW(phi::errors::InvalidArgument(
          "get_op_result_shape currently only support op_result that is a "
          "DenseTensorType"));
    }
  });
  m->def("get_op_result_dtype", [](const OpResult &op_result) {
    if (op_result.type().isa<DenseTensorType>()) {
      return op_result.type().dyn_cast<DenseTensorType>().dtype();
    } else {
      PADDLE_THROW(phi::errors::InvalidArgument(
          "get_op_result_dtype currently only support op_result that is a "
          "DenseTensorType"));
    }
  });
281 282 283 284 285 286 287 288 289
  m->def("set_global_program",
         [](Program *program) { APIBuilder::Instance().SetProgram(program); });
  m->def("set_insertion_point",
         [](Operation *op) { APIBuilder::Instance().SetInsertionPoint(op); });
  m->def("reset_insertion_point_to_start",
         []() { APIBuilder::Instance().ResetInsertionPointToStart(); });
  m->def("reset_insertion_point_to_end",
         []() { APIBuilder::Instance().ResetInsertionPointToEnd(); });
  m->def("translate_to_new_ir", &paddle::TranslateLegacyProgramToProgram);
290 291
}

292 293 294 295 296 297 298 299 300 301 302 303
void BindNewIR(pybind11::module *module) {
  auto ir_module = module->def_submodule("ir");
  BindProgram(&ir_module);
  BindBlock(&ir_module);
  BindOperation(&ir_module);
  BindValue(&ir_module);
  BindOpOperand(&ir_module);
  BindOpResult(&ir_module);
  BindType(&ir_module);
  BindUtils(&ir_module);
  auto ops_modules = ir_module.def_submodule("ops");
  BindOpsAPI(&ops_modules);
304 305
}

F
flame 已提交
306 307
}  // namespace pybind
}  // namespace paddle