// Copyright (c) 2021 CINN Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "paddle/cinn/backends/codegen_c.h" #include "paddle/cinn/common/target.h" #include "paddle/cinn/ir/module.h" #include "paddle/cinn/ir/tensor.h" #include "paddle/cinn/lang/buffer.h" #include "paddle/cinn/lang/builtin.h" #include "paddle/cinn/lang/compute.h" #include "paddle/cinn/lang/lower.h" #include "paddle/cinn/lang/placeholder.h" #include "paddle/cinn/pybind/bind.h" #include "paddle/cinn/pybind/bind_utils.h" namespace py = pybind11; namespace cinn::pybind { using common::Type; using lang::Placeholder; using py::arg; using utils::GetStreamCnt; using utils::StringFormat; namespace { void BindBuffer(py::module *); void BindLower(py::module *); void BindLowerVec(py::module *); void BindPlaceholder(py::module *); void BindCompute(py::module *); void BindModule(py::module *); void BindBuiltin(py::module *); void BindBuffer(py::module *m) { py::class_ buffer(*m, "Buffer"); buffer.def(py::init(), py::arg("type"), py::arg("name") = "") .def(py::init()) .def("buffer", &lang::Buffer::buffer); } void BindLower(py::module *m) { using py::arg; m->def("lower", &lang::Lower, arg("name"), arg("stages"), arg("tensor_args"), arg("scalar_args") = std::vector(), arg("temp_tensors") = std::vector(), arg("b") = nullptr, arg("target") = common::DefaultHostTarget(), arg("supprt_ir_schedule") = false); } void BindLowerVec(py::module *m) { using py::arg; m->def("lower_vec", &lang::LowerVec, arg("name"), arg("stages"), arg("tensor_args"), arg("scalar_args") = std::vector(), arg("temp_tensors") = std::vector(), arg("b") = nullptr, arg("target") = common::DefaultHostTarget(), arg("supprt_ir_schedule") = false); } void BindCompute(py::module *m) { #define MAKE_COMPUTE_FN(__fn) \ py::overload_cast &, __fn, const std::string &, const std::vector &>( \ &lang::Compute) #define DEFINE_COMPUTE(__fn) \ m->def("compute", \ MAKE_COMPUTE_FN(__fn), \ arg("domin"), \ arg("fn"), \ arg("name") = "", \ arg("shape") = std::vector()) // DEFINE_COMPUTE(std::function); // DEFINE_COMPUTE(std::function); DEFINE_COMPUTE(std::function &)>); // DEFINE_COMPUTE(std::function); // DEFINE_COMPUTE(std::function); // DEFINE_COMPUTE(std::function); // DEFINE_COMPUTE(std::function); DEFINE_COMPUTE(lang::compute_handler_t); #undef DEFINE_COMPUTE #undef MAKE_COMPUTE_FN py::class_ return_type(*m, "ReturnType"); return_type.def_readwrite("type", &lang::ReturnType::type) .def_readwrite("dims", &lang::ReturnType::dims) .def_readwrite("name", &lang::ReturnType::name); m->def("call_lowered", py::overload_cast &, const std::vector &>( &lang::CallLowered)); m->def("call_extern", py::overload_cast &, const std::map> &>( &lang::CallExtern)); } void BindModule(py::module *m) { py::class_ module(*m, "Module"); module.def("target", &ir::Module::target) .def("buffers", &ir::Module::buffers) .def("functions", &ir::Module::functions) .def("submodules", &ir::Module::submodules) .def("compile", &ir::Module::Compile) .def("get_c_code", [](const ir::Module &self) -> std::string { backends::CodeGenC codegen(common::DefaultHostTarget()); codegen.SetInlineBuiltinCodes(false); return codegen.Compile(self, backends::CodeGenC::OutputKind::CImpl); }); py::class_ builder(module, "Builder"); builder.def(py::init()) .def("add_function", &ir::Module::Builder::AddFunction) .def("add_buffer", &ir::Module::Builder::AddBuffer) .def("build", &ir::Module::Builder::Build); } class PlaceholderWrapper { public: #define DEFINE_PLACEHOLDER(__dtype, __type) \ if (dtype == #__dtype) placeholder_ = std::make_unique>(name, shape) #define INIT_PLACEHOLDER \ DEFINE_PLACEHOLDER(int32, int32_t); \ DEFINE_PLACEHOLDER(int64, int64_t); \ DEFINE_PLACEHOLDER(float32, float); \ DEFINE_PLACEHOLDER(float64, double) PlaceholderWrapper(absl::string_view dtype, const std::string &name, const std::vector &shape) { INIT_PLACEHOLDER; } PlaceholderWrapper(absl::string_view dtype, const std::string &name, const std::vector &shape) { INIT_PLACEHOLDER; } #undef INIT_PLACEHOLDER #undef DEFINE_PLACEHOLDER ir::Type type() const { return absl::visit([](auto &v) { return v->type(); }, placeholder_); } ir::Tensor tensor() const { return absl::visit([](auto &v) { return v->tensor(); }, placeholder_); } ir::Expr operator()(ir::Expr a) const { return absl::visit([&](auto &v) { return (*v)(a); }, placeholder_); } ir::Expr operator()(ir::Expr a, ir::Expr b) const { return absl::visit([&](auto &v) { return (*v)(a, b); }, placeholder_); } ir::Expr operator()(ir::Expr a, ir::Expr b, ir::Expr c) const { return absl::visit([&](auto &v) { return (*v)(a, b, c); }, placeholder_); } ir::Expr operator()(const std::vector &indices) const { return absl::visit([&](auto &v) { return (*v)(indices); }, placeholder_); } operator ir::Tensor() { return absl::visit([&](auto &v) { return ir::Tensor(*v); }, placeholder_); } operator ir::Expr() { return absl::visit([&](auto &v) { return ir::Expr(*v); }, placeholder_); } private: template using PlaceholderVariant = absl::variant>...>; PlaceholderVariant placeholder_; }; void BindPlaceholder(py::module *m) { py::class_ placeholder(*m, "Placeholder"); placeholder.def(py::init &>()) .def(py::init &>()) .def("type", &PlaceholderWrapper::type) .def("tensor", &PlaceholderWrapper::tensor) .def("__call__", [](PlaceholderWrapper &self, ir::Expr a) { return self(std::move(a)); }) .def("__call__", [](PlaceholderWrapper &self, ir::Expr a, ir::Expr b) { return self(std::move(a), std::move(b)); }) .def("__call__", [](PlaceholderWrapper &self, ir::Expr a, ir::Expr b, ir::Expr c) { return self(std::move(a), std::move(b), std::move(c)); }) .def("__call__", [](PlaceholderWrapper &self, const std::vector &indices) { return self(indices); }) .def("to_expr", [](PlaceholderWrapper &self) { return ir::Expr(self); }) .def("to_tensor", [](PlaceholderWrapper &self) { return ir::Tensor(self); }); m->def("create_placeholder", static_cast &, Type, const std::string &)>(&lang::CreatePlaceHolder)); m->def("create_placeholder", static_cast &, Type, const std::string &)>(&lang::CreatePlaceHolder)); } void BindBuiltin(py::module *m) { m->def("reduce_sum", &lang::ReduceSum, py::arg("e"), py::arg("reduce_axis"), py::arg("init") = Expr()); m->def("reduce_mul", &lang::ReduceMul); m->def("reduce_max", &lang::ReduceMax); m->def("reduce_min", &lang::ReduceMin); m->def("reduce_all", &lang::ReduceAll); m->def("reduce_any", &lang::ReduceAny); } } // namespace void BindLang(py::module *m) { BindBuffer(m); BindLower(m); BindLowerVec(m); BindPlaceholder(m); BindCompute(m); BindModule(m); BindBuiltin(m); } } // namespace cinn::pybind