未验证 提交 aaa4fe49 编写于 作者: L Leo Chen 提交者: GitHub

use function instead of lambda, test=develop (#22348)

* use function instead of lambda, test=develop

* follow comments, test=develop
上级 24f9037e
...@@ -23,45 +23,46 @@ ...@@ -23,45 +23,46 @@
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
// clang-format off
const char* OUT_INITIALIZER_TEMPLATE = const char* OUT_INITIALIZER_TEMPLATE =
R"({"%s", {std::shared_ptr<imperative::VarBase>(new imperative::VarBase(tracer->GenerateUniqueName()))}})"; R"({"%s", {std::shared_ptr<imperative::VarBase>(new imperative::VarBase(tracer->GenerateUniqueName()))}})";
const char* OP_FUNCTION_TEMPLATE = const char* OP_FUNCTION_TEMPLATE =
R"([](const imperative::NameVarBaseMap& ins, const framework::AttributeMap& attrs, R"(
imperative::NameVarBaseMap outs, const std::map<std::string, size_t>& out_nums) inline imperative::NameVarBaseMap %s(const imperative::NameVarBaseMap& ins, const framework::AttributeMap& attrs,
{ imperative::NameVarBaseMap outs, const std::map<std::string, size_t>& out_nums)
auto tracer = imperative::GetCurrentTracer(); {
if (outs.size() == 0) { auto tracer = imperative::GetCurrentTracer();
if (out_nums.size() == 0) { if (outs.size() == 0) {
imperative::NameVarBaseMap outs_ = %s; if (out_nums.size() == 0) {
outs = std::move(outs_); imperative::NameVarBaseMap outs_ = %s;
} else { outs = std::move(outs_);
for (auto &pair : out_nums) { } else {
for (size_t i = 0; i < pair.second; i ++) { for (auto &pair : out_nums) {
auto var_base_name = tracer->GenerateUniqueName(); for (size_t i = 0; i < pair.second; i ++) {
auto out = new imperative::VarBase(var_base_name); auto var_base_name = tracer->GenerateUniqueName();
outs[pair.first].emplace_back(std::shared_ptr<imperative::VarBase>(out)); outs[pair.first].emplace_back(new imperative::VarBase(var_base_name));
}
} }
} }
} }
}
{
py::gil_scoped_release release; tracer->TraceOp("%s", std::move(ins), std::move(outs), std::move(attrs));
tracer->TraceOp("%s", std::move(ins), std::move(outs), std::move(attrs)); return outs;
return outs; })";
}
}, py::arg("ins"), py::arg("attrs")=framework::AttributeMap(),
py::arg("outs")=imperative::NameVarBaseMap(),
py::arg("out_nums")=std::map<std::string, size_t>())";
const char* PYBIND_ITEM_TEMPLATE = R"( %s.def("%s", %s);)"; const char* PYBIND_ITEM_TEMPLATE =
R"(
%s.def("%s", &%s, py::arg("ins"), py::arg("attrs")=framework::AttributeMap(), py::arg("outs")=imperative::NameVarBaseMap(),
py::arg("out_nums")=std::map<std::string, size_t>(), py::call_guard<py::gil_scoped_release>());)";
static std::vector<std::string> GenerateOpFunctions( // clang-format on
const std::string& module_name) {
static std::tuple<std::vector<std::string>, std::vector<std::string>>
GenerateOpFunctions(const std::string& module_name) {
auto& op_info_map = paddle::framework::OpInfoMap::Instance().map(); auto& op_info_map = paddle::framework::OpInfoMap::Instance().map();
std::vector<std::string> op_function_list; std::vector<std::string> op_function_list, bind_function_list;
for (auto& pair : op_info_map) { for (auto& pair : op_info_map) {
auto& op_info = pair.second; auto& op_info = pair.second;
auto op_proto = op_info.proto_; auto op_proto = op_info.proto_;
...@@ -85,18 +86,21 @@ static std::vector<std::string> GenerateOpFunctions( ...@@ -85,18 +86,21 @@ static std::vector<std::string> GenerateOpFunctions(
} }
outs_initializer += "}"; outs_initializer += "}";
std::string func_name = "imperative_" + op_type;
// generate op funtcion body // generate op funtcion body
auto op_function_str = paddle::string::Sprintf(OP_FUNCTION_TEMPLATE, auto op_function_str = paddle::string::Sprintf(
outs_initializer, op_type); OP_FUNCTION_TEMPLATE, func_name, outs_initializer, op_type);
// generate pybind item // generate pybind item
auto pybind_op_function = paddle::string::Sprintf( auto bind_function_str = paddle::string::Sprintf(
PYBIND_ITEM_TEMPLATE, module_name.c_str(), op_type, op_function_str); PYBIND_ITEM_TEMPLATE, module_name, op_type, func_name);
pybind_op_function += "\n";
op_function_list.emplace_back(std::move(pybind_op_function)); op_function_list.emplace_back(std::move(op_function_str));
bind_function_list.emplace_back(std::move(bind_function_str));
} }
return op_function_list; return std::make_tuple(op_function_list, bind_function_list);
} }
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
...@@ -115,19 +119,21 @@ int main(int argc, char* argv[]) { ...@@ -115,19 +119,21 @@ int main(int argc, char* argv[]) {
out << "#include " + header + "\n"; out << "#include " + header + "\n";
} }
// all op functions
auto op_funcs = GenerateOpFunctions("m");
out << "namespace py = pybind11;" out << "namespace py = pybind11;"
<< "\n"; << "\n";
out << "namespace paddle {\n" out << "namespace paddle {\n"
<< "namespace pybind {\n" << "namespace pybind {\n";
<< "\n" out << paddle::string::join_strings(std::get<0>(op_funcs), '\n');
<< "inline void BindOpFunctions(pybind11::module *module) {\n" out << "\n\n";
<< " auto m = module->def_submodule(\"ops\");\n\n";
// all op functions
auto op_funcs = GenerateOpFunctions("m");
out << paddle::string::join_strings(op_funcs, '\n'); out << "inline void BindOpFunctions(pybind11::module *module) {\n"
<< " auto m = module->def_submodule(\"ops\");\n\n";
out << paddle::string::join_strings(std::get<1>(op_funcs), '\n');
out << "\n";
out << "}\n\n" out << "}\n\n"
<< "} // namespace pybind\n" << "} // namespace pybind\n"
<< "} // namespace paddle\n"; << "} // namespace paddle\n";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册