From aaa4fe491a0f588f8bfb523210275cc1928e3dff Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Sun, 19 Jan 2020 11:28:56 +0800 Subject: [PATCH] use function instead of lambda, test=develop (#22348) * use function instead of lambda, test=develop * follow comments, test=develop --- paddle/fluid/pybind/op_function_generator.cc | 92 +++++++++++--------- 1 file changed, 49 insertions(+), 43 deletions(-) diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index a5f823f97f..2c112079cc 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -23,45 +23,46 @@ #include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/string/string_helper.h" +// clang-format off const char* OUT_INITIALIZER_TEMPLATE = R"({"%s", {std::shared_ptr(new imperative::VarBase(tracer->GenerateUniqueName()))}})"; const char* OP_FUNCTION_TEMPLATE = - R"([](const imperative::NameVarBaseMap& ins, const framework::AttributeMap& attrs, - imperative::NameVarBaseMap outs, const std::map& out_nums) - { - auto tracer = imperative::GetCurrentTracer(); - if (outs.size() == 0) { - if (out_nums.size() == 0) { - imperative::NameVarBaseMap outs_ = %s; - outs = std::move(outs_); - } else { - for (auto &pair : out_nums) { - for (size_t i = 0; i < pair.second; i ++) { - auto var_base_name = tracer->GenerateUniqueName(); - auto out = new imperative::VarBase(var_base_name); - outs[pair.first].emplace_back(std::shared_ptr(out)); - } +R"( +inline imperative::NameVarBaseMap %s(const imperative::NameVarBaseMap& ins, const framework::AttributeMap& attrs, + imperative::NameVarBaseMap outs, const std::map& out_nums) +{ + auto tracer = imperative::GetCurrentTracer(); + if (outs.size() == 0) { + if (out_nums.size() == 0) { + imperative::NameVarBaseMap outs_ = %s; + outs = std::move(outs_); + } else { + for (auto &pair : out_nums) { + for (size_t i = 0; i < pair.second; i ++) { + auto var_base_name = tracer->GenerateUniqueName(); + outs[pair.first].emplace_back(new imperative::VarBase(var_base_name)); } } } - - { - py::gil_scoped_release release; - tracer->TraceOp("%s", std::move(ins), std::move(outs), std::move(attrs)); - return outs; - } - }, py::arg("ins"), py::arg("attrs")=framework::AttributeMap(), - py::arg("outs")=imperative::NameVarBaseMap(), - py::arg("out_nums")=std::map())"; + } + + tracer->TraceOp("%s", std::move(ins), std::move(outs), std::move(attrs)); + return outs; +})"; -const char* PYBIND_ITEM_TEMPLATE = R"( %s.def("%s", %s);)"; +const char* PYBIND_ITEM_TEMPLATE = +R"( + %s.def("%s", &%s, py::arg("ins"), py::arg("attrs")=framework::AttributeMap(), py::arg("outs")=imperative::NameVarBaseMap(), + py::arg("out_nums")=std::map(), py::call_guard());)"; -static std::vector GenerateOpFunctions( - const std::string& module_name) { +// clang-format on + +static std::tuple, std::vector> +GenerateOpFunctions(const std::string& module_name) { auto& op_info_map = paddle::framework::OpInfoMap::Instance().map(); - std::vector op_function_list; + std::vector op_function_list, bind_function_list; for (auto& pair : op_info_map) { auto& op_info = pair.second; auto op_proto = op_info.proto_; @@ -85,18 +86,21 @@ static std::vector GenerateOpFunctions( } outs_initializer += "}"; + std::string func_name = "imperative_" + op_type; + // generate op funtcion body - auto op_function_str = paddle::string::Sprintf(OP_FUNCTION_TEMPLATE, - outs_initializer, op_type); + auto op_function_str = paddle::string::Sprintf( + OP_FUNCTION_TEMPLATE, func_name, outs_initializer, op_type); // generate pybind item - auto pybind_op_function = paddle::string::Sprintf( - PYBIND_ITEM_TEMPLATE, module_name.c_str(), op_type, op_function_str); - pybind_op_function += "\n"; - op_function_list.emplace_back(std::move(pybind_op_function)); + auto bind_function_str = paddle::string::Sprintf( + PYBIND_ITEM_TEMPLATE, module_name, op_type, func_name); + + op_function_list.emplace_back(std::move(op_function_str)); + bind_function_list.emplace_back(std::move(bind_function_str)); } - return op_function_list; + return std::make_tuple(op_function_list, bind_function_list); } int main(int argc, char* argv[]) { @@ -115,19 +119,21 @@ int main(int argc, char* argv[]) { out << "#include " + header + "\n"; } + // all op functions + auto op_funcs = GenerateOpFunctions("m"); + out << "namespace py = pybind11;" << "\n"; out << "namespace paddle {\n" - << "namespace pybind {\n" - << "\n" - << "inline void BindOpFunctions(pybind11::module *module) {\n" - << " auto m = module->def_submodule(\"ops\");\n\n"; - - // all op functions - auto op_funcs = GenerateOpFunctions("m"); + << "namespace pybind {\n"; + out << paddle::string::join_strings(std::get<0>(op_funcs), '\n'); + out << "\n\n"; - out << paddle::string::join_strings(op_funcs, '\n'); + out << "inline void BindOpFunctions(pybind11::module *module) {\n" + << " auto m = module->def_submodule(\"ops\");\n\n"; + out << paddle::string::join_strings(std::get<1>(op_funcs), '\n'); + out << "\n"; out << "}\n\n" << "} // namespace pybind\n" << "} // namespace paddle\n"; -- GitLab