未验证 提交 08e81475 编写于 作者: W wanghuancoder 提交者: GitHub

use PYTHON_C_API in dygraph (#32524)

* use PYTHON_C_API in dygraph, test=develop
上级 022198c5
...@@ -51,6 +51,8 @@ limitations under the License. */ ...@@ -51,6 +51,8 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
PyTypeObject *g_varbase_pytype = nullptr;
namespace py = ::pybind11; namespace py = ::pybind11;
class Layer : public imperative::Layer { class Layer : public imperative::Layer {
...@@ -470,9 +472,9 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index, ...@@ -470,9 +472,9 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
} }
template <typename P> template <typename P>
static void VarBaseCopy(std::shared_ptr<imperative::VarBase> &src, static void VarBaseCopy(std::shared_ptr<imperative::VarBase> &src, // NOLINT
imperative::VarBase &dst, const P &dst_device, imperative::VarBase &dst, // NOLINT
const bool blocking) { const P &dst_device, const bool blocking) {
if (dst.SharedVar()->IsEmpty()) { if (dst.SharedVar()->IsEmpty()) {
VLOG(3) << "deep copy Variable from " << src->Name() << " to " VLOG(3) << "deep copy Variable from " << src->Name() << " to "
<< dst.Name(); << dst.Name();
...@@ -667,9 +669,10 @@ void BindImperative(py::module *m_ptr) { ...@@ -667,9 +669,10 @@ void BindImperative(py::module *m_ptr) {
imperative::SetCurrentTracer(tracer); imperative::SetCurrentTracer(tracer);
}); });
py::class_<imperative::VarBase, std::shared_ptr<imperative::VarBase>>( py::class_<imperative::VarBase, std::shared_ptr<imperative::VarBase>> varbase(
m, "VarBase", R"DOC()DOC") m, "VarBase", R"DOC()DOC");
.def_static("_alive_vars", &imperative::VarBase::AliveVarNames) g_varbase_pytype = (PyTypeObject *)varbase.ptr(); // NOLINT
varbase.def_static("_alive_vars", &imperative::VarBase::AliveVarNames)
.def("__init__", .def("__init__",
[](imperative::VarBase &self) { [](imperative::VarBase &self) {
std::string name = std::string name =
...@@ -1468,28 +1471,22 @@ void BindImperative(py::module *m_ptr) { ...@@ -1468,28 +1471,22 @@ void BindImperative(py::module *m_ptr) {
&imperative::VarBase::SetOverridedStopGradient) &imperative::VarBase::SetOverridedStopGradient)
.def_property("persistable", &imperative::VarBase::Persistable, .def_property("persistable", &imperative::VarBase::Persistable,
&imperative::VarBase::SetPersistable) &imperative::VarBase::SetPersistable)
.def_property_readonly("shape", .def_property_readonly(
[](imperative::VarBase &self) { "shape",
if (self.Var().IsType<framework::LoDTensor>()) { [](imperative::VarBase &self) {
return framework::vectorize<int>( if (self.Var().IsType<framework::LoDTensor>()) {
self.Var() return framework::vectorize<int>(
.Get<framework::LoDTensor>() self.Var().Get<framework::LoDTensor>().dims());
.dims()); } else if (self.Var().IsType<framework::SelectedRows>()) {
} else if (self.Var() return framework::vectorize<int>(
.IsType< self.Var().Get<framework::SelectedRows>().value().dims());
framework::SelectedRows>()) { } else {
return framework::vectorize<int>( VLOG(2) << "It is meaningless to get shape of "
self.Var() "variable type "
.Get<framework::SelectedRows>() << GetTypeName(self);
.value() return std::vector<int>();
.dims()); }
} else { })
VLOG(2) << "It is meaningless to get shape of "
"variable type "
<< GetTypeName(self);
return std::vector<int>();
}
})
.def_property_readonly("is_leaf", &imperative::VarBase::IsLeaf, .def_property_readonly("is_leaf", &imperative::VarBase::IsLeaf,
R"DOC( R"DOC(
Whether a Tensor is leaf Tensor. Whether a Tensor is leaf Tensor.
......
此差异已折叠。
...@@ -212,16 +212,17 @@ const char* OUT_VAR_TYPE = R"(std::shared_ptr<imperative::VarBase>)"; ...@@ -212,16 +212,17 @@ const char* OUT_VAR_TYPE = R"(std::shared_ptr<imperative::VarBase>)";
const char* OUT_VAR_LIST_TYPE = R"(std::vector<std::shared_ptr<imperative::VarBase>>)"; const char* OUT_VAR_LIST_TYPE = R"(std::vector<std::shared_ptr<imperative::VarBase>>)";
const char* CAST_VAR_TEMPLATE = R"( const char* CAST_VAR_TEMPLATE = R"(
auto %s = CastPyHandleToVarBase("%s", "%s", %d, %s, %s);)"; auto %s = GetVarBaseFromArgs("%s", "%s", args, %d, %s);)";
const char* CAST_VAR_LIST_TEMPLATE = R"( const char* CAST_VAR_LIST_TEMPLATE = R"(
auto %s = CastPyHandleToVarBaseList("%s", "%s", %d, %s, %s);)"; auto %s = GetVarBaseListFromArgs("%s", "%s", args, %d, %s);)";
const char* CAST_SIZE_T_TEMPLATE = R"(
auto %s = GetUnsignedLongFromArgs("%s", "%s", args, %d, %s);)";
const char* ARG_TEMPLATE = R"(const %s& %s)"; const char* ARG_TEMPLATE = R"(const %s& %s)";
const char* RETURN_TUPLE_TYPE = R"(std::tuple<%s>)"; const char* RETURN_TUPLE_TYPE = R"(std::tuple<%s>)";
const char* RETURN_TYPE = R"(%s)";
const char* RETURN_TUPLE_TEMPLATE = R"(std::make_tuple(%s))"; const char* RETURN_TUPLE_TEMPLATE = R"(std::make_tuple(%s))";
const char* RETURN_LIST_TEMPLATE = R"(outs["%s"])"; const char* RETURN_LIST_TEMPLATE = R"(outs["%s"])";
const char* RETURN_TEMPLATE = R"(outs["%s"][0])"; const char* RETURN_TEMPLATE = R"(outs["%s"][0])";
...@@ -251,23 +252,34 @@ const char* INPLACE_MAPPING_TEMPLATE = R"({"%s", "%s"})"; ...@@ -251,23 +252,34 @@ const char* INPLACE_MAPPING_TEMPLATE = R"({"%s", "%s"})";
const char* OP_FUNCTION_TEMPLATE = const char* OP_FUNCTION_TEMPLATE =
R"( R"(
%s %s(%s) static PyObject * %s(PyObject *self, PyObject *args, PyObject *kwargs)
{ {
%s PyThreadState *tstate = nullptr;
framework::AttributeMap attrs; try
ConstructAttrMapFromPyArgs("%s", %d, &attrs, args);
{ {
py::gil_scoped_release release; %s
framework::AttributeMap attrs;
ConstructAttrMapFromPyArgs("%s", args, %d, PyTuple_GET_SIZE(args) , attrs);
tstate = PyEval_SaveThread();
%s %s
imperative::NameVarBaseMap outs = %s; imperative::NameVarBaseMap outs = %s;
imperative::NameVarBaseMap ins = %s; imperative::NameVarBaseMap ins = %s;
%s %s
imperative::GetCurrentTracer()->TraceOp("%s", ins, outs, attrs, {%s}); imperative::GetCurrentTracer()->TraceOp("%s", ins, outs, attrs, {%s});
PyEval_RestoreThread(tstate);
tstate = nullptr;
return %s; return %s;
} }
catch(...) {
if (tstate) {
PyEval_RestoreThread(tstate);
}
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
})"; })";
const char* PYBIND_ITEM_TEMPLATE = R"( %s.def("%s", &%s);)"; const char* PYBIND_ITEM_TEMPLATE = R"( {"%s", (PyCFunction)(void(*)(void))%s, METH_VARARGS | METH_KEYWORDS, "C++ interface function for %s in dygraph."},)";
// clang-format on // clang-format on
static inline bool FindInsMap(const std::string& op_type, static inline bool FindInsMap(const std::string& op_type,
...@@ -326,9 +338,8 @@ std::string GenerateOpFunctionsBody( ...@@ -326,9 +338,8 @@ std::string GenerateOpFunctionsBody(
const auto in_cast_type = const auto in_cast_type =
input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE; input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = input.dispensable() ? "true" : "false"; auto dispensable = input.dispensable() ? "true" : "false";
ins_cast_str += ins_cast_str += paddle::string::Sprintf(in_cast_type, in_name, op_type,
paddle::string::Sprintf(in_cast_type, in_name, op_type, in_name, in_name, arg_idx++, dispensable);
arg_idx++, TempName(in_name), dispensable);
if (input.dispensable()) { if (input.dispensable()) {
const auto in_template = input.duplicable() const auto in_template = input.duplicable()
...@@ -356,7 +367,6 @@ std::string GenerateOpFunctionsBody( ...@@ -356,7 +367,6 @@ std::string GenerateOpFunctionsBody(
// Generate outs initializer // Generate outs initializer
std::string outs_initializer = "{"; std::string outs_initializer = "{";
std::string outs_initializer_with_null = ""; std::string outs_initializer_with_null = "";
std::string return_type = "";
std::string inplace_mapping_str = ""; std::string inplace_mapping_str = "";
std::string return_str = ""; std::string return_str = "";
...@@ -395,6 +405,12 @@ std::string GenerateOpFunctionsBody( ...@@ -395,6 +405,12 @@ std::string GenerateOpFunctionsBody(
paddle::string::Sprintf(out_template, out_name, out_name); paddle::string::Sprintf(out_template, out_name, out_name);
outs_initializer += ","; outs_initializer += ",";
} }
const auto in_cast_type =
output.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = output.dispensable() ? "true" : "false";
ins_cast_str += paddle::string::Sprintf(in_cast_type, out_name, op_type,
out_name, arg_idx++, dispensable);
} else if (use_inplace_strategy && inplace_map.count(out_name)) { } else if (use_inplace_strategy && inplace_map.count(out_name)) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
inplace_map[out_name], "", inplace_map[out_name], "",
...@@ -440,6 +456,11 @@ std::string GenerateOpFunctionsBody( ...@@ -440,6 +456,11 @@ std::string GenerateOpFunctionsBody(
input_args_num++; input_args_num++;
outs_initializer += paddle::string::Sprintf( outs_initializer += paddle::string::Sprintf(
OUT_DUPLICABLE_INITIALIZER_TEMPLATE, out_name, out_num_str); OUT_DUPLICABLE_INITIALIZER_TEMPLATE, out_name, out_num_str);
auto dispensable = output.dispensable() ? "true" : "false";
ins_cast_str +=
paddle::string::Sprintf(CAST_SIZE_T_TEMPLATE, out_num_str, op_type,
out_num_str, arg_idx++, dispensable);
} else { } else {
outs_initializer += outs_initializer +=
paddle::string::Sprintf(OUT_INITIALIZER_TEMPLATE, out_name); paddle::string::Sprintf(OUT_INITIALIZER_TEMPLATE, out_name);
...@@ -447,15 +468,12 @@ std::string GenerateOpFunctionsBody( ...@@ -447,15 +468,12 @@ std::string GenerateOpFunctionsBody(
outs_initializer += ","; outs_initializer += ",";
} }
return_type += out_type;
return_type += ",";
return_str += paddle::string::Sprintf(return_template, out_name); return_str += paddle::string::Sprintf(return_template, out_name);
return_str += ","; return_str += ",";
outs_num += 1; outs_num += 1;
} }
if (outs_initializer.back() == ',') { if (outs_initializer.back() == ',') {
outs_initializer.pop_back(); outs_initializer.pop_back();
return_type.pop_back();
return_str.pop_back(); return_str.pop_back();
} }
outs_initializer += "}"; outs_initializer += "}";
...@@ -470,11 +488,13 @@ std::string GenerateOpFunctionsBody( ...@@ -470,11 +488,13 @@ std::string GenerateOpFunctionsBody(
viwe_input_name, viwe_output_name); viwe_input_name, viwe_output_name);
} }
if (outs_num == 0) { if (outs_num == 0) {
return_type = "void"; return_str = "Py_None";
} } else if (outs_num == 1) {
if (outs_num > 1) { return_str = "MakeReturnPyObject(" + return_str + ")";
return_str = paddle::string::Sprintf(RETURN_TUPLE_TEMPLATE, return_str); } else {
return_type = paddle::string::Sprintf(RETURN_TUPLE_TYPE, return_type); return_str = "MakeReturnPyObject(" +
paddle::string::Sprintf(RETURN_TUPLE_TEMPLATE, return_str) +
")";
} }
std::string function_args = ""; std::string function_args = "";
if (input_args == "") { if (input_args == "") {
...@@ -485,17 +505,17 @@ std::string GenerateOpFunctionsBody( ...@@ -485,17 +505,17 @@ std::string GenerateOpFunctionsBody(
// generate op funtcion body // generate op funtcion body
auto op_function_str = paddle::string::Sprintf( auto op_function_str = paddle::string::Sprintf(
OP_FUNCTION_TEMPLATE, return_type, func_name, function_args, ins_cast_str, OP_FUNCTION_TEMPLATE, func_name, ins_cast_str, op_type, input_args_num,
op_type, input_args_num, inplace_strategy_str, outs_initializer, inplace_strategy_str, outs_initializer, ins_initializer,
ins_initializer, ins_initializer_with_null + outs_initializer_with_null + ins_initializer_with_null + outs_initializer_with_null +
view_strategy_str, view_strategy_str,
op_type, inplace_mapping_str, return_str); op_type, inplace_mapping_str, return_str);
return op_function_str; return op_function_str;
} }
static std::tuple<std::vector<std::string>, std::vector<std::string>> static std::tuple<std::vector<std::string>, std::vector<std::string>>
GenerateOpFunctions(const std::string& module_name) { GenerateOpFunctions() {
auto& op_info_map = paddle::framework::OpInfoMap::Instance().map(); auto& op_info_map = paddle::framework::OpInfoMap::Instance().map();
std::vector<std::string> op_function_list, bind_function_list; std::vector<std::string> op_function_list, bind_function_list;
...@@ -536,7 +556,7 @@ GenerateOpFunctions(const std::string& module_name) { ...@@ -536,7 +556,7 @@ GenerateOpFunctions(const std::string& module_name) {
// generate pybind item // generate pybind item
auto bind_function_str = paddle::string::Sprintf( auto bind_function_str = paddle::string::Sprintf(
PYBIND_ITEM_TEMPLATE, module_name, op_type, func_name); PYBIND_ITEM_TEMPLATE, op_type, func_name, op_type);
op_function_list.emplace_back(std::move(op_function_str)); op_function_list.emplace_back(std::move(op_function_str));
bind_function_list.emplace_back(std::move(bind_function_str)); bind_function_list.emplace_back(std::move(bind_function_str));
...@@ -551,8 +571,8 @@ GenerateOpFunctions(const std::string& module_name) { ...@@ -551,8 +571,8 @@ GenerateOpFunctions(const std::string& module_name) {
// generate pybind item // generate pybind item
auto inplace_bind_function_str = auto inplace_bind_function_str =
paddle::string::Sprintf(PYBIND_ITEM_TEMPLATE, module_name, paddle::string::Sprintf(PYBIND_ITEM_TEMPLATE, inplace_op_type,
inplace_op_type, inplace_func_name); inplace_func_name, inplace_op_type);
op_function_list.emplace_back(std::move(inplace_op_function_str)); op_function_list.emplace_back(std::move(inplace_op_function_str));
bind_function_list.emplace_back(std::move(inplace_bind_function_str)); bind_function_list.emplace_back(std::move(inplace_bind_function_str));
...@@ -572,7 +592,9 @@ int main(int argc, char* argv[]) { ...@@ -572,7 +592,9 @@ int main(int argc, char* argv[]) {
ascend_ptr->InitGEForUT(); ascend_ptr->InitGEForUT();
#endif #endif
std::vector<std::string> headers{"\"paddle/fluid/imperative/tracer.h\""}; std::vector<std::string> headers{"\"paddle/fluid/imperative/tracer.h\"",
"\"pybind11/detail/common.h\"",
"<Python.h>"};
std::ofstream out(argv[1], std::ios::out); std::ofstream out(argv[1], std::ios::out);
...@@ -582,22 +604,29 @@ int main(int argc, char* argv[]) { ...@@ -582,22 +604,29 @@ int main(int argc, char* argv[]) {
out << "#include " + header + "\n"; out << "#include " + header + "\n";
} }
auto op_funcs = GenerateOpFunctions("m"); out << "\n\n";
auto op_funcs = GenerateOpFunctions();
out << "namespace py = pybind11;"
<< "\n";
out << "namespace paddle {\n" out << "namespace paddle {\n"
<< "namespace pybind {\n\n"; << "namespace pybind {\n\n";
out << "std::atomic<int> VarBaseUniqueNameID{0};\n"; out << "std::atomic<int> VarBaseUniqueNameID{0};\n";
out << paddle::string::join_strings(std::get<0>(op_funcs), '\n'); out << paddle::string::join_strings(std::get<0>(op_funcs), '\n');
out << "\n\n"; out << "\n\n";
out << "inline void BindOpFunctions(pybind11::module *module) {\n" out << "static PyMethodDef ExtestMethods[] = {\n"
<< " auto m = module->def_submodule(\"ops\");\n\n"; << paddle::string::join_strings(std::get<1>(op_funcs), '\n')
<< "\n {nullptr,nullptr,0,nullptr}"
<< "};\n\n";
out << paddle::string::join_strings(std::get<1>(op_funcs), '\n'); out << "inline void BindOpFunctions(pybind11::module *module) {\n"
out << "\n"; << " auto m = module->def_submodule(\"ops\");\n"
out << "}\n\n" << " if (PyModule_AddFunctions(m.ptr(), ExtestMethods) < 0) {\n"
<< " PADDLE_THROW(platform::errors::Fatal (\"Add functions to "
"core.ops failed!\"));\n"
<< " }\n\n"
<< " InitOpsAttrTypeMap();"
<< "}\n\n"
<< "} // namespace pybind\n" << "} // namespace pybind\n"
<< "} // namespace paddle\n"; << "} // namespace paddle\n";
......
...@@ -29,6 +29,9 @@ limitations under the License. */ ...@@ -29,6 +29,9 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
PyTypeObject *g_vartype_pytype = nullptr;
PyTypeObject *g_blockdesc_pytype = nullptr;
namespace pd = paddle::framework; namespace pd = paddle::framework;
template <typename T> template <typename T>
...@@ -82,8 +85,9 @@ void BindProgramDesc(pybind11::module *m) { ...@@ -82,8 +85,9 @@ void BindProgramDesc(pybind11::module *m) {
} }
void BindBlockDesc(pybind11::module *m) { void BindBlockDesc(pybind11::module *m) {
pybind11::class_<pd::BlockDesc>(*m, "BlockDesc", "") pybind11::class_<pd::BlockDesc> blockdesc(*m, "BlockDesc", "");
.def_property_readonly("id", &pd::BlockDesc::ID) g_blockdesc_pytype = (PyTypeObject *)blockdesc.ptr(); // NOLINT
blockdesc.def_property_readonly("id", &pd::BlockDesc::ID)
.def_property_readonly("parent", &pd::BlockDesc::Parent) .def_property_readonly("parent", &pd::BlockDesc::Parent)
.def("get_forward_block_idx", &pd::BlockDesc::ForwardBlockID) .def("get_forward_block_idx", &pd::BlockDesc::ForwardBlockID)
.def("_set_forward_block_idx", &pd::BlockDesc::SetForwardBlockID) .def("_set_forward_block_idx", &pd::BlockDesc::SetForwardBlockID)
...@@ -174,8 +178,9 @@ void BindVarDsec(pybind11::module *m) { ...@@ -174,8 +178,9 @@ void BindVarDsec(pybind11::module *m) {
.def("need_check_feed", &pd::VarDesc::NeedCheckFeed) .def("need_check_feed", &pd::VarDesc::NeedCheckFeed)
.def("set_need_check_feed", &pd::VarDesc::SetNeedCheckFeed); .def("set_need_check_feed", &pd::VarDesc::SetNeedCheckFeed);
pybind11::enum_<pd::proto::VarType::Type>(var_desc, "VarType", "") pybind11::enum_<pd::proto::VarType::Type> vartype(var_desc, "VarType", "");
.value("BOOL", pd::proto::VarType::BOOL) g_vartype_pytype = (PyTypeObject *)vartype.ptr(); // NOLINT
vartype.value("BOOL", pd::proto::VarType::BOOL)
.value("UINT8", pd::proto::VarType::UINT8) .value("UINT8", pd::proto::VarType::UINT8)
.value("INT8", pd::proto::VarType::INT8) .value("INT8", pd::proto::VarType::INT8)
.value("INT16", pd::proto::VarType::INT16) .value("INT16", pd::proto::VarType::INT16)
......
...@@ -357,7 +357,7 @@ def convert_shape_to_list(shape): ...@@ -357,7 +357,7 @@ def convert_shape_to_list(shape):
map(lambda x: x.numpy()[0] if isinstance(x, Variable) else x, map(lambda x: x.numpy()[0] if isinstance(x, Variable) else x,
shape)) shape))
else: else:
shape = list(shape.numpy().astype(int)) shape = shape.numpy().astype(int).tolist()
return shape return shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册