未验证 提交 aaa4fe49 编写于 作者: L Leo Chen 提交者: GitHub

use function instead of lambda, test=develop (#22348)

* use function instead of lambda, test=develop

* follow comments, test=develop
上级 24f9037e
......@@ -23,13 +23,15 @@
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/string/string_helper.h"
// clang-format off
const char* OUT_INITIALIZER_TEMPLATE =
R"({"%s", {std::shared_ptr<imperative::VarBase>(new imperative::VarBase(tracer->GenerateUniqueName()))}})";
const char* OP_FUNCTION_TEMPLATE =
R"([](const imperative::NameVarBaseMap& ins, const framework::AttributeMap& attrs,
R"(
inline imperative::NameVarBaseMap %s(const imperative::NameVarBaseMap& ins, const framework::AttributeMap& attrs,
imperative::NameVarBaseMap outs, const std::map<std::string, size_t>& out_nums)
{
{
auto tracer = imperative::GetCurrentTracer();
if (outs.size() == 0) {
if (out_nums.size() == 0) {
......@@ -39,29 +41,28 @@ const char* OP_FUNCTION_TEMPLATE =
for (auto &pair : out_nums) {
for (size_t i = 0; i < pair.second; i ++) {
auto var_base_name = tracer->GenerateUniqueName();
auto out = new imperative::VarBase(var_base_name);
outs[pair.first].emplace_back(std::shared_ptr<imperative::VarBase>(out));
outs[pair.first].emplace_back(new imperative::VarBase(var_base_name));
}
}
}
}
{
py::gil_scoped_release release;
tracer->TraceOp("%s", std::move(ins), std::move(outs), std::move(attrs));
return outs;
}
}, py::arg("ins"), py::arg("attrs")=framework::AttributeMap(),
py::arg("outs")=imperative::NameVarBaseMap(),
py::arg("out_nums")=std::map<std::string, size_t>())";
})";
const char* PYBIND_ITEM_TEMPLATE =
R"(
%s.def("%s", &%s, py::arg("ins"), py::arg("attrs")=framework::AttributeMap(), py::arg("outs")=imperative::NameVarBaseMap(),
py::arg("out_nums")=std::map<std::string, size_t>(), py::call_guard<py::gil_scoped_release>());)";
const char* PYBIND_ITEM_TEMPLATE = R"( %s.def("%s", %s);)";
// clang-format on
static std::vector<std::string> GenerateOpFunctions(
const std::string& module_name) {
static std::tuple<std::vector<std::string>, std::vector<std::string>>
GenerateOpFunctions(const std::string& module_name) {
auto& op_info_map = paddle::framework::OpInfoMap::Instance().map();
std::vector<std::string> op_function_list;
std::vector<std::string> op_function_list, bind_function_list;
for (auto& pair : op_info_map) {
auto& op_info = pair.second;
auto op_proto = op_info.proto_;
......@@ -85,18 +86,21 @@ static std::vector<std::string> GenerateOpFunctions(
}
outs_initializer += "}";
std::string func_name = "imperative_" + op_type;
// generate op funtcion body
auto op_function_str = paddle::string::Sprintf(OP_FUNCTION_TEMPLATE,
outs_initializer, op_type);
auto op_function_str = paddle::string::Sprintf(
OP_FUNCTION_TEMPLATE, func_name, outs_initializer, op_type);
// generate pybind item
auto pybind_op_function = paddle::string::Sprintf(
PYBIND_ITEM_TEMPLATE, module_name.c_str(), op_type, op_function_str);
pybind_op_function += "\n";
op_function_list.emplace_back(std::move(pybind_op_function));
auto bind_function_str = paddle::string::Sprintf(
PYBIND_ITEM_TEMPLATE, module_name, op_type, func_name);
op_function_list.emplace_back(std::move(op_function_str));
bind_function_list.emplace_back(std::move(bind_function_str));
}
return op_function_list;
return std::make_tuple(op_function_list, bind_function_list);
}
int main(int argc, char* argv[]) {
......@@ -115,19 +119,21 @@ int main(int argc, char* argv[]) {
out << "#include " + header + "\n";
}
// all op functions
auto op_funcs = GenerateOpFunctions("m");
out << "namespace py = pybind11;"
<< "\n";
out << "namespace paddle {\n"
<< "namespace pybind {\n"
<< "\n"
<< "inline void BindOpFunctions(pybind11::module *module) {\n"
<< " auto m = module->def_submodule(\"ops\");\n\n";
// all op functions
auto op_funcs = GenerateOpFunctions("m");
<< "namespace pybind {\n";
out << paddle::string::join_strings(std::get<0>(op_funcs), '\n');
out << "\n\n";
out << paddle::string::join_strings(op_funcs, '\n');
out << "inline void BindOpFunctions(pybind11::module *module) {\n"
<< " auto m = module->def_submodule(\"ops\");\n\n";
out << paddle::string::join_strings(std::get<1>(op_funcs), '\n');
out << "\n";
out << "}\n\n"
<< "} // namespace pybind\n"
<< "} // namespace paddle\n";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册