// Copyright (c) 2021 CINN Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "paddle/cinn/common/cinn_value.h" #include "paddle/cinn/common/shared.h" #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir_base.h" #include "paddle/cinn/ir/ir_visitor.h" #include "paddle/cinn/ir/tensor.h" #include "paddle/cinn/pybind/bind.h" #include "paddle/cinn/runtime/cinn_runtime.h" namespace py = pybind11; namespace cinn::pybind { using common::CINNValue; using common::Shared; using common::Type; using ir::Expr; using ir::ExprNode; using ExprOp = absl::variant; using BinaryOp = absl::variant<>; using UnaryOp = absl::variant<>; // hold CINNValue using ValueVar = absl::variant; inline ValueVar ConvertToVar(const CINNValue &value) { auto type_code = value.type_code(); ValueVar var; if (type_code == ::cinn_type_code()) { var = static_cast(value); } else if (type_code == ::cinn_type_code()) { var = static_cast(value); } else if (type_code == ::cinn_type_code()) { var = static_cast(value); } else if (type_code == CINNValue::TypeCode()) { var = value.operator ir::Var(); } else if (type_code == CINNValue::TypeCode()) { var = ir::Expr(value.operator ir::Expr()); } else { var = nullptr; } return var; } template auto DefineShared(py::module *m, absl::string_view obj_name) { std::string name = "Shared" + std::string(obj_name); py::class_> shared(*m, name.c_str()); shared.def(py::init<>()).def(py::init()).def(py::init &>()); return shared; } template void DefineExprNode(py::module *m, absl::string_view node_name) { using ExprNodeT = ExprNode; std::string prefix{"ExprNode"}; std::string name = prefix + std::string(node_name); py::class_ expr_node(*m, name.c_str(), py::module_local()); expr_node.def(py::init<>()) .def(py::init()) .def(py::init()) .def("operands_mutable", py::overload_cast<>(&ExprNodeT::operands)) .def("operands_const", py::overload_cast<>(&ExprNodeT::operands, py::const_)) .def("operand_mutable", py::overload_cast(&ExprNodeT::operand), py::return_value_policy::reference) .def("operand_const", py::overload_cast(&ExprNodeT::operand, py::const_), py::return_value_policy::reference) .def("copy", &ExprNodeT::Copy) .def("node_type", &ExprNodeT::node_type); } template void DefineBinaryOpNode(py::module *m, absl::string_view node_name) { DefineExprNode(m, node_name); std::string prefix{"BinaryOpNode"}; std::string name = prefix + std::string(node_name); using BinaryOpNodeT = ir::BinaryOpNode; py::class_> binary_op_node(*m, name.c_str()); binary_op_node.def(py::init<>()) .def(py::init()) .def("a_mutable", py::overload_cast<>(&BinaryOpNodeT::a), py::return_value_policy::reference) .def("a_const", py::overload_cast<>(&BinaryOpNodeT::a, py::const_), py::return_value_policy::reference) .def("b_mutable", py::overload_cast<>(&BinaryOpNodeT::b), py::return_value_policy::reference) .def("b_const", py::overload_cast<>(&BinaryOpNodeT::b, py::const_), py::return_value_policy::reference) .def("type", &BinaryOpNodeT::type) .def("expr_fields_mutable", py::overload_cast<>(&BinaryOpNodeT::expr_fields)) .def("expr_fields_const", py::overload_cast<>(&BinaryOpNodeT::expr_fields, py::const_)); } template void DefineUnaryOpNode(py::module *m, absl::string_view node_name) { using UnaryOpNodeT = ir::UnaryOpNode; DefineExprNode(m, node_name); std::string name = "UnaryOpNode" + std::string(node_name); py::class_> unary_op_node(*m, name.c_str()); unary_op_node.def(py::init<>()) .def(py::init()) .def("type", &UnaryOpNodeT::type) .def("v_mutable", py::overload_cast<>(&UnaryOpNodeT::v), py::return_value_policy::reference) .def("v_const", py::overload_cast<>(&UnaryOpNodeT::v, py::const_), py::return_value_policy::reference) .def("expr_fields_mutable", py::overload_cast<>(&UnaryOpNodeT::expr_fields)) .def("expr_fields_const", py::overload_cast<>(&UnaryOpNodeT::expr_fields, py::const_)) .def("operands_mutable", py::overload_cast<>(&UnaryOpNodeT::operands), py::return_value_policy::reference) .def("operands_const", py::overload_cast<>(&UnaryOpNodeT::operands, py::const_), py::return_value_policy::reference); } class ObjectWrapper : public Object { public: using Object::Object; const char *type_info() const override { PYBIND11_OVERLOAD_PURE(const char *, Object, type_info); } }; class IrNodeWrapper : ir::IrNode { using ir::IrNode::IrNode; }; class _Operation_Wrapper : ir::_Operation_ { public: const char *func_type() const override { PYBIND11_OVERLOAD_PURE(const char *, ir::_Operation_, func_type); } }; } // namespace cinn::pybind