// Copyright (c) 2021 CINN Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/cinn/ir/ir.h" #include #include #include #include #include #include #include "paddle/cinn/ir/ir_base.h" #include "paddle/cinn/ir/ir_operators.h" #include "paddle/cinn/ir/ir_printer.h" #include "paddle/cinn/ir/ir_visitor.h" #include "paddle/cinn/ir/lowered_func.h" #include "paddle/cinn/ir/operation.h" #include "paddle/cinn/ir/registry.h" #include "paddle/cinn/ir/tensor.h" #include "paddle/cinn/lang/packed_func.h" #include "paddle/cinn/poly/stage.h" #include "paddle/cinn/pybind/bind.h" #include "paddle/cinn/pybind/bind_utils.h" namespace py = pybind11; namespace cinn::pybind { using ir::IrNode; using ir::IrNodeRef; using ir::IrNodeTy; // lowered_func.h using ir::Argument; using ir::Expr; using ir::LoweredFunc; using ir::Var; namespace { void BindLoweredFunc(py::module *); void BindNode(py::module *); void BindIrVisitor(py::module *); void BindIrIr(py::module *); void BindOperation(py::module *); void BindPackedFunc(py::module *); void BindRegistry(py::module *); void BindLoweredFunc(py::module *m) { py::class_ argument(*m, "Argument"); py::enum_ io(argument, "IO"); io.value("kInput", Argument::IO::kInput).value("kOutput", Argument::IO::kOutput); argument.def(py::init(), py::arg("buffer"), py::arg("io") = Argument::IO::kInput) .def(py::init(), py::arg("var"), py::arg("io") = Argument::IO::kInput) .def("set_buffer", &Argument::set_buffer) .def("set_var", &Argument::set_var) .def("is_input", &Argument::is_input) .def("is_output", &Argument::is_output) .def("is_var", &Argument::is_var) .def("is_buffer", &Argument::is_buffer) .def("defined", &Argument::defined) .def("buffer_arg", &Argument::buffer_arg) .def("type", &Argument::type) .def("name", &Argument::name) .def("human_readable", &Argument::human_readable); py::class_ lowered_func(*m, "LoweredFunc"); lowered_func.def(py::init<>()) .def(py::init()) .def("name", [](const ir::LoweredFunc &self) -> std::string { return self->name; }) .def("__str__", [](const ir::LoweredFunc &self) -> std::string { return utils::GetStreamCnt(Expr(self)); }) .def("__repr__", [](const ir::LoweredFunc &self) -> std::string { return llvm::formatv("", self.get(), self->name.c_str()); }); } void BindNode(py::module *m) { // enum class IrNodeTy py::enum_ ir_node_ty(*m, "IrNodeTy"); ir_node_ty.value("kUnk", ir::IrNodeTy::kUnk); #define DECLARE_IR_NODE_TY(__ty) ir_node_ty.value(#__ty, ir::IrNodeTy::__ty); NODETY_FORALL(DECLARE_IR_NODE_TY) #undef DECLARE_IR_NODE_TY // class IrNode py::class_ ir_node(*m, "IrNode", py::module_local()); ir_node.def(py::init<>()) .def(py::init()) .def_readwrite("operands", &ir::IrNode::operands) .def("node_type", &ir::IrNode::node_type) .def("type", &ir::IrNode::type) .def("set_type", &ir::IrNode::set_type) .def("expr_fields_mutable", py::overload_cast<>(&ir::IrNode::expr_fields)) .def("expr_fields_const", py::overload_cast<>(&ir::IrNode::expr_fields, py::const_)) .def("type_info", &ir::IrNode::type_info); // class Shared DefineShared(m, "IrNode"); // class IrNodeRef : public Shared py::class_> ir_node_ref(*m, "IrNodeRef"); ir_node_ref.def(py::init<>()) .def(py::init()) .def(py::init()) .def("node_type", &ir::IrNodeRef::node_type); // struct IntImm : ExprNode DefineExprNode(m, "IntImm"); py::class_> int_imm(*m, "IntImm"); int_imm.def_readwrite("value", &ir::IntImm::value) .def(py::init()) .def("__str__", [](const ir::IntImm &self) { return std::to_string(self.value); }) .def("__repr__", [](ir::IntImm &self) -> std::string { return llvm::formatv("", self.self(), self.value); }); // struct UIntImm : ExprNode DefineExprNode(m, "UIntImm"); py::class_> uint_imm(*m, "UIntImm"); uint_imm.def_readwrite("value", &ir::UIntImm::value).def(py::init()); // struct FloatImm : ExprNode DefineExprNode(m, "FloatImm"); py::class_> float_imm(*m, "FloatImm"); float_imm.def_readwrite("value", &ir::FloatImm::value).def(py::init()); // struct StringImm : ExprNode DefineExprNode(m, "StringImm"); py::class_> string_imm(*m, "StringImm"); string_imm.def_readwrite("value", &ir::StringImm::value).def(py::init()); auto expr = py::class_(*m, "Expr"); expr.def(py::init()); expr.def(py::init()); expr.def(py::init()); expr.def(py::init()); expr.def(py::init()); expr.def(py::init()); expr.def(py::init()); expr.def(py::init()); expr.def(py::init()); expr.def(py::init()); expr.def("as_int32", &ir::Expr::as_int32) .def("as_int64", &ir::Expr::as_int64) .def("as_float", &ir::Expr::as_float) .def("as_double", &ir::Expr::as_double) .def("int", [](ir::Expr &self) { return self.As()->value; }) .def("float", [](ir::Expr &self) { return self.As()->value; }) .def("__str__", [](const Expr &self) { return utils::GetStreamCnt(self); }) .def("__repr__", [](const Expr &self) -> std::string { std::string content = self.get() ? utils::GetStreamCnt(self) : ""; return llvm::formatv("", content); }); expr.def("as_var_mutable", py::overload_cast<>(&ir::Expr::as_var), py::return_value_policy::reference) .def("as_var_const", py::overload_cast<>(&ir::Expr::as_var, py::const_), py::return_value_policy::reference) .def("as_var_ref", &ir::Expr::as_var_ref); expr.def("as_buffer_mutable", py::overload_cast<>(&ir::Expr::as_buffer), py::return_value_policy::reference) .def("as_buffer_const", py::overload_cast<>(&ir::Expr::as_buffer, py::const_), py::return_value_policy::reference) .def("as_buffer_ref", &ir::Expr::as_buffer_ref); expr.def("is_constant", &ir::Expr::is_constant) .def("get_constant", &ir::Expr::get_constant) .def("is_var", &ir::Expr::is_var) .def("type", &ir::Expr::type); // operators #define BIND_POD_BINARY_OP(otype__) \ .def(py::self + otype__) \ .def(py::self - otype__) \ .def(py::self *otype__) \ .def(py::self / otype__) \ .def(py::self % otype__) \ .def(py::self < otype__) \ .def(py::self <= otype__) \ .def(py::self > otype__) \ .def(py::self >= otype__) \ .def(otype__ + py::self) \ .def(otype__ - py::self) \ .def(otype__ *py::self) \ .def(otype__ / py::self) \ .def(otype__ % py::self) \ .def(otype__ < py::self) \ .def(otype__ <= py::self) \ .def(otype__ > py::self) \ .def(otype__ >= py::self) expr // BIND_POD_BINARY_OP(py::self) // BIND_POD_BINARY_OP(int()) // BIND_POD_BINARY_OP(float()); expr.def("__add__", [](const Expr &self, const Var &other) -> Expr { return self + other; }) .def("__sub__", [](const Expr &self, const Var &other) -> Expr { return self - other; }) .def("__mul__", [](const Expr &self, const Var &other) -> Expr { return self * other; }) .def("__div__", [](const Expr &self, const Var &other) -> Expr { return self / other; }); } void BindIrVisitor(py::module *m) { py::class_ ir_visitor(*m, "IRVisitor"); ir_visitor.def(py::init<>()).def("visit", py::overload_cast(&ir::IRVisitor::Visit)); #define DEFINE_VISIT_FN(__ty) ir_visitor.def("visit", py::overload_cast(&ir::IRVisitor::Visit)); NODETY_FORALL(DEFINE_VISIT_FN) #undef DEFINE_VISIT_FN } void BindIrIr(py::module *m) { using ir::Expr; using ir::IrNode; using ir::IrNodeRef; using ir::Var; using py::arg; // struct Cast : ExprNode DefineExprNode(m, "Cast"); py::class_> cast(*m, "Cast"); cast.def(py::init<>()) .def("v_mutable", py::overload_cast<>(&ir::Cast::v), py::return_value_policy::reference) .def("v_const", py::overload_cast<>(&ir::Cast::v, py::const_), py::return_value_policy::reference); // struct Let : ExprNode DefineExprNode(m, "Let"); py::class_> let(*m, "Let"); let.def(py::init<>()) .def_readwrite("symbol", &ir::Let::symbol) .def_readwrite("body", &ir::Let::body) .def_static("make", &ir::Let::Make) .def("type", &ir::Let::type) .def("expr_fields_mutable", py::overload_cast<>(&ir::Let::expr_fields)) .def("expr_fields_const", py::overload_cast<>(&ir::Let::expr_fields, py::const_)); // struct Reduce : ExprNode DefineExprNode(m, "Reduce"); py::class_> reduce(*m, "Reduce"); py::enum_ reduce_type(reduce, "ReduceType"); reduce_type // .value("kSum", ir::Reduce::ReduceType::kSum) .value("kSub", ir::Reduce::ReduceType::kSub) .value("kMul", ir::Reduce::ReduceType::kMul) .value("kDiv", ir::Reduce::ReduceType::kDiv) .value("kMax", ir::Reduce::ReduceType::kMax) .value("kMin", ir::Reduce::ReduceType::kMin) .value("kAll", ir::Reduce::ReduceType::kAll) .value("kAny", ir::Reduce::ReduceType::kAny); reduce.def_readwrite("init", &ir::Reduce::init) .def_readwrite("body", &ir::Reduce::body) .def_readwrite("reduce_type", &ir::Reduce::reduce_type) .def_static("make", &ir::Reduce::Make) .def("type", &ir::Reduce::type) .def("expr_fields_mutable", py::overload_cast<>(&ir::Reduce::expr_fields)) .def("expr_fields_const", py::overload_cast<>(&ir::Reduce::expr_fields, py::const_)); // enum class CallType py::enum_ call_type(*m, "CallType"); call_type.value("Extern", ir::CallType::Extern) .value("CINN", ir::CallType::CINN) .value("Intrinsic", ir::CallType::Intrinsic) .value("ISL", ir::CallType::ISL); // struct Call : ExprNode DefineExprNode(m, "Call"); py::class_> call(*m, "Call"); call.def(py::init()) .def_readwrite("name", &ir::Call::name) .def_readwrite("read_args", &ir::Call::read_args) .def_readwrite("write_args", &ir::Call::write_args) .def_readwrite("call_type", &ir::Call::call_type) .def_readwrite("func", &ir::Call::func) .def_readwrite("value_index", &ir::Call::value_index) .def_static("make", &ir::Call::Make) .def("total_args_count", &ir::Call::total_args_count) .def("is_extern_call", &ir::Call::is_extern_call) .def("is_cinn_call", &ir::Call::is_cinn_call) .def("is_intrinsic_call", &ir::Call::is_intrinsic_call) .def("is_isl_call", &ir::Call::is_isl_call) .def("expr_fields_mutable", py::overload_cast<>(&ir::Call::expr_fields)) .def("expr_fields_const", py::overload_cast<>(&ir::Call::expr_fields, py::const_)); // struct _Var_ : ExprNode<_Var_> DefineExprNode(m, "_Var_"); py::class_> _var_(*m, "_Var_"); _var_.def_readwrite("name", &ir::_Var_::name) .def_readwrite("is_reduce_axis", &ir::_Var_::is_reduce_axis) .def_readwrite("lower_bound", &ir::_Var_::lower_bound) .def_readwrite("upper_bound", &ir::_Var_::upper_bound) .def_readwrite("tag", &ir::_Var_::tag) .def(py::init<>()) .def(py::init()) .def_static("make", py::overload_cast(&ir::_Var_::Make)) .def_static("make", py::overload_cast(&ir::_Var_::Make)) .def("copy", &ir::_Var_::Copy); // struct Select DefineExprNode(m, "Select"); py::class_> select(*m, "Select"); select.def_readwrite("condition", &ir::Select::condition) .def_readwrite("true_value", &ir::Select::true_value) .def_readwrite("false_value", &ir::Select::false_value) .def(py::init()) .def_static("make", &ir::Select::Make) .def("type", &ir::Select::type) .def("expr_fields_mutable", py::overload_cast<>(&ir::Select::expr_fields)) .def("expr_fields_const", py::overload_cast<>(&ir::Select::expr_fields, py::const_)); // struct LoadStoreAddrMnger py::class_ load_store_addr_manager(*m, "LoadStoreAddrMnger"); load_store_addr_manager.def_readwrite("tensor", &ir::LoadStoreAddrMnger::tensor) .def("is_addr_tensor", &ir::LoadStoreAddrMnger::is_addr_tensor) .def("is_addr_scalar", &ir::LoadStoreAddrMnger::is_addr_scalar); // struct Load : ExprNode, LoadStoreAddrMnger DefineExprNode(m, "Load"); py::class_, ir::LoadStoreAddrMnger> load(*m, "Load"); load.def_readwrite("indices", &ir::Load::indices) .def("index", &ir::Load::index) .def_static("make", &ir::Load::Make) .def("expr_fields_mutable", py::overload_cast<>(&ir::Load::expr_fields)) .def("expr_fields_const", py::overload_cast<>(&ir::Load::expr_fields, py::const_)) .def("name", &ir::Load::name) .def("type", &ir::Load::type); // struct Store : ExprNode, LoadStoreAddrMnger DefineExprNode(m, "Store"); py::class_, ir::LoadStoreAddrMnger> store(*m, "Store"); store.def_readwrite("value", &ir::Store::value) .def_readwrite("indices", &ir::Store::indices) .def_static("make", &ir::Store::Make) .def("expr_fields_mutable", py::overload_cast<>(&ir::Store::expr_fields)) .def("expr_fields_const", py::overload_cast<>(&ir::Store::expr_fields, py::const_)) .def("type", &ir::Store::type) .def("index", &ir::Store::index); #define DEFINE_BINARY_NODE(__node) \ DefineBinaryOpNode(m, #__node); \ py::class_> py_##__node(*m, #__node); \ py_##__node.def(py::init()).def_static("make", &ir::__node::Make).def("type", &ir::__node::type) DEFINE_BINARY_NODE(Add); DEFINE_BINARY_NODE(Sub); DEFINE_BINARY_NODE(Mul); DEFINE_BINARY_NODE(Div); DEFINE_BINARY_NODE(Mod); DEFINE_BINARY_NODE(Min); DEFINE_BINARY_NODE(Max); DEFINE_BINARY_NODE(EQ); DEFINE_BINARY_NODE(NE); DEFINE_BINARY_NODE(LT); DEFINE_BINARY_NODE(LE); DEFINE_BINARY_NODE(GT); DEFINE_BINARY_NODE(GE); DEFINE_BINARY_NODE(And); DEFINE_BINARY_NODE(Or); #undef DEFINE_BINARY_NODE // FracOp DefineBinaryOpNode(m, "FracOp"); py::class_> frac_op(*m, "FracOp"); frac_op.def(py::init<>()).def_static("make", &ir::FracOp::Make).def("type", &ir::FracOp::type); #define DEFINE_UNARY_NODE(__node) \ DefineUnaryOpNode(m, #__node); \ py::class_> py_##__node(*m, #__node); \ py_##__node.def(py::init()).def_static("make", &ir::__node::Make) DEFINE_UNARY_NODE(Minus); DEFINE_UNARY_NODE(Not); #undef DEFINE_UNARY_NODE py::class_ var(*m, "Var"); var.def(py::init<>()) .def(py::init()) .def(py::init(), arg("name_hint"), arg("t") = common::type_of()) .def(py::init()) .def(py::init()) .def(py::init()) .def("get_mutable", py::overload_cast<>(&Var::get), py::return_value_policy::reference) .def("get_const", py::overload_cast<>(&Var::get, py::const_), py::return_value_policy::reference) .def("to_expr_mutable", py::overload_cast<>(&Var::operator ir::Expr)) .def("to_expr_const", py::overload_cast<>(&Var::operator ir::Expr, py::const_)) .def("__repr__", [](Var &self) -> std::string { return llvm::formatv("", self->name); }) .def("expr", [](Var &self) -> Expr { return Expr(self->self()); }) BIND_POD_BINARY_OP(int()) // BIND_POD_BINARY_OP(int32_t()) // BIND_POD_BINARY_OP(float()) #define BINARY_OP(type__) \ .def("__add__", [](Var &self, type__ v) -> Expr { return self + v; }) \ .def("__sub__", [](Var &self, type__ v) -> Expr { return self - v; }) \ .def("__truediv__", [](Var &self, type__ v) -> Expr { return self / v; }) \ .def("__mul__", [](Var &self, type__ v) -> Expr { return self * v; }) \ .def("__mod__", [](Var &self, type__ v) -> Expr { return self % v; }) BINARY_OP(int32_t) // BINARY_OP(int64_t) // BINARY_OP(float) // BINARY_OP(double); #undef BINARY_OP DefineExprNode(m, "Product"); py::class_> product(*m, "Product"); product.def_static("make", &ir::Product::Make) .def("type", &ir::Product::type) .def("operand_mutable", py::overload_cast(&ir::Product::operand), py::return_value_policy::reference) .def("operand_const", py::overload_cast(&ir::Product::operand, py::const_), py::return_value_policy::reference); DefineExprNode(m, "Sum"); py::class_> sum(*m, "Sum"); sum.def_static("make", &ir::Sum::Make) .def("operand_mutable", py::overload_cast(&ir::Sum::operand), py::return_value_policy::reference) .def("operand_const", py::overload_cast(&ir::Sum::operand, py::const_), py::return_value_policy::reference) .def("type", &ir::Sum::type); DefineExprNode(m, "Block"); py::class_> block(*m, "Block"); block.def_readwrite("stmts", &ir::Block::stmts) .def(py::init<>()) .def_static("make", &ir::Block::Make) .def("expr_fields_mutable", py::overload_cast<>(&ir::Block::expr_fields)) .def("expr_fields_const", py::overload_cast<>(&ir::Block::expr_fields, py::const_)); DefineExprNode(m, "_Module_"); py::class_> _module_(*m, "_Module_"); _module_.def_readwrite("name", &ir::_Module_::name) .def_readwrite("target", &ir::_Module_::target) .def_readwrite("buffers", &ir::_Module_::buffers) .def_readwrite("functions", &ir::_Module_::functions) .def_readwrite("submodules", &ir::_Module_::submodules); } void BindOperation(py::module *m) { py::class_ placeholder_op(*m, "PlaceholderOp"); placeholder_op.def_readwrite("shape", &ir::PlaceholderOp::shape) .def_readwrite("dtype", &ir::PlaceholderOp::dtype) .def_static("make", &ir::PlaceholderOp::Make) .def("func_type", &ir::PlaceholderOp::func_type); py::class_ call_op(*m, "CallOp"); call_op.def("target", &ir::CallOp::target) .def_readwrite("call_expr", &ir::CallOp::call_expr) .def("read_args_mutable", py::overload_cast<>(&ir::CallOp::read_args)) .def("read_args_const", py::overload_cast<>(&ir::CallOp::read_args, py::const_)) .def("write_args_mutable", py::overload_cast<>(&ir::CallOp::write_args)) .def("write_args_const", py::overload_cast<>(&ir::CallOp::write_args, py::const_)) .def("args", &ir::CallOp::args) .def_readwrite("func", &ir::CallOp::func) .def_readwrite("value_slot", &ir::CallOp::value_slot) .def_readwrite("is_tuple_get", &ir::CallOp::is_tuple_get) .def_readwrite("num_value_slots", &ir::CallOp::num_value_slots) .def(py::init<>()) .def_static("make", &ir::CallOp::Make) .def("func_type", &ir::CallOp::func_type); py::class_ compute_op(*m, "ComputeOp"); compute_op.def_readwrite("reduce_axis", &ir::ComputeOp::reduce_axis) .def_readwrite("shape", &ir::ComputeOp::shape) .def_readwrite("body", &ir::ComputeOp::body) .def_readwrite("producer_fn", &ir::ComputeOp::producer_fn) .def(py::init<>()) .def_static("make", &ir::ComputeOp::Make) .def("func_type", &ir::ComputeOp::func_type); } void BindIrTensor(py::module *m) { py::class_ tensor(*m, "Tensor"); tensor.def(py::init<>()) .def(py::init()) .def("ndims", &ir::Tensor::ndims) .def("__call__", [](ir::Tensor &self, Expr a) { return self(a); }) .def("__call__", [](ir::Tensor &self, Expr a, Expr b) { return self(a, b); }) .def("__call__", [](ir::Tensor &self, Expr a, Expr b, Expr c) { return self(a, b, c); }) .def("__call__", [](ir::Tensor &self, Expr a, Expr b, Expr c, Expr d) { return self(a, b, c, d); }); DefineExprNode(m, "_Tensor_"); py::class_> _tensor_(*m, "_Tensor_"); _tensor_.def_readwrite("shape", &ir::_Tensor_::shape) .def_readwrite("reduce_axis", &ir::_Tensor_::reduce_axis) .def_readwrite("operation", &ir::_Tensor_::operation) .def_readwrite("name", &ir::_Tensor_::name) .def_readwrite("buffer", &ir::_Tensor_::buffer) .def("domain_with_reduce_axis", &ir::_Tensor_::domain_without_reduce_axis) .def("domain_without_reduce_axis", &ir::_Tensor_::domain_without_reduce_axis) .def_static("make", &ir::_Tensor_::Make) .def("is_tuple", &ir::_Tensor_::is_tuple) .def("is_tuple_get", &ir::_Tensor_::is_tuple_get) .def("tuple_get", &ir::_Tensor_::TupleGet) .def("get_depend_tensor_names", &ir::_Tensor_::GetDependTensorNames) .def("is_depend_on_statement", &ir::_Tensor_::IsDependOnStatement) .def("depending_tensor_names", &ir::_Tensor_::DependingTensorNames) .def("same_shape_with", &ir::_Tensor_::HasSameShapeWith) .def("is_compute_node", &ir::_Tensor_::is_compute_node) .def("is_placeholder_node", &ir::_Tensor_::is_placeholder_node) .def("is_call_node", &ir::_Tensor_::is_call_node) .def("is_extern_call_node", &ir::_Tensor_::is_extern_call_node) .def("is_preceding_view_node", &ir::_Tensor_::is_preceding_view_node) .def("is_buffer_shared_node", &ir::_Tensor_::is_buffer_shared_node) .def("operation_type", &ir::_Tensor_::operation_type) .def("get_compute_op", &ir::_Tensor_::get_compute_op) .def("get_placeholder_op", &ir::_Tensor_::get_placeholder_op) .def("body", &ir::_Tensor_::body) .def("tensor_store_expanded_body", &ir::_Tensor_::tensor_store_expanded_body) .def("inline_expanded", &ir::_Tensor_::inline_expanded) .def("contains_reduce_axis", &ir::_Tensor_::contains_reduce_axis) .def("expr_fields_mutable", py::overload_cast<>(&ir::_Tensor_::expr_fields)) .def("expr_fields_const", py::overload_cast<>(&ir::_Tensor_::expr_fields, py::const_)) .def("axis", &ir::_Tensor_::axis) .def("axis_with_reduce", &ir::_Tensor_::axis_with_reduce) .def("buffer_depended_tensor_names", &ir::_Tensor_::buffer_depended_tensor_names) .def(py::init<>()) .def("has_expression", &ir::_Tensor_::has_expression) .def("reshape", &ir::_Tensor_::Reshape) .def("reshape_copied", &ir::_Tensor_::ReshapeCopied) .def("with_buffer", py::overload_cast(&ir::_Tensor_::WithBuffer), py::arg("type") = Type::type_t::Void) .def("with_buffer", py::overload_cast(&ir::_Tensor_::WithBuffer), py::arg("memory_type"), py::arg("buffer_name") = "", py::arg("type") = Type::type_t::Void) .def("bind", py::overload_cast(&ir::_Tensor_::Bind)) .def("bind", py::overload_cast(&ir::_Tensor_::Bind)) .def("__str__", [](const ir::Tensor &self) { return "name + ">"; }); py::class_ operation(*m, "Operation"); operation.def(py::init<>()).def(py::init()).def_readwrite("name", &ir::Operation::name); } auto PackedFuncCall(lang::PackedFunc &self, py::args args) { // NOLINT lang::Args cinn_args; using common::CINNValue; for (auto handle : args) { if (py::isinstance(handle)) { cinn_args.Append(CINNValue(py::cast(handle))); } else if (py::isinstance(handle)) { cinn_args.Append(CINNValue(py::cast(handle))); } else if (py::isinstance(handle)) { cinn_args.Append(CINNValue(py::cast(handle))); } else if (py::isinstance(handle)) { cinn_args.Append(CINNValue(py::cast(handle))); } else { LOG(FATAL) << "unsupported type: " << std::string(py::str(handle.get_type())); } } lang::RetValue ret_value; self.body()(cinn_args, &ret_value); return ConvertToVar(ret_value); } void BindPackedFunc(py::module *m) { py::class_ args(*m, "Args"); args.def(py::init<>()) .def(py::init()) .def("append", &lang::Args::Append) .def("size", &lang::Args::size) .def("__len__", &lang::Args::size) .def( "__getitem__", [](lang::Args &self, int i) { return self[i]; }, py::return_value_policy::reference) .def("__setitem__", [](lang::Args &self, int i, common::CINNValue &v) { self[i] = v; }); py::class_ packed_func(*m, "PackedFunc"); packed_func.def(py::init<>()) .def(py::init()) .def(py::init()) .def("body", &lang::PackedFunc::body) .def("__call__", &PackedFuncCall); } void BindRegistry(py::module *m) { py::class_ registry(*m, "Registry"); registry .def_static("register", &ir::Registry::Register, py::arg("name"), py::arg("override") = false, py::return_value_policy::reference) .def_static("register", &ir::Registry::Register, py::return_value_policy::reference) .def_static("remove", &ir::Registry::Remove) .def_static("get", &ir::Registry::Get, py::return_value_policy::reference) .def_static("list_names", &ir::Registry::ListNames) .def("set_body", py::overload_cast(&ir::Registry::SetBody), py::return_value_policy::reference); #ifdef CINN_WITH_TEST ir::Registry::Register("test_add_int64").SetBody([](lang::Args args, lang::RetValue *rv) { int64_t x = args[0]; int64_t y = args[1]; *rv = x + y; }); ir::Registry::Register("test_add_expr").SetBody([](lang::Args args, lang::RetValue *rv) { ir::Expr x = args[0]; ir::Expr y = args[1]; *rv = x + y; }); ir::Registry::Register("test_mul_float").SetBody([](lang::Args args, lang::RetValue *rv) { float x = args[0]; float y = args[1]; *rv = x * y; }); #endif } } // namespace void BindIr(py::module *m) { BindOperation(m); BindLoweredFunc(m); BindNode(m); BindIrVisitor(m); BindIrIr(m); BindIrTensor(m); BindPackedFunc(m); BindRegistry(m); } } // namespace cinn::pybind