未验证 提交 08e81475 编写于 作者: W wanghuancoder 提交者: GitHub

use PYTHON_C_API in dygraph (#32524)

* use PYTHON_C_API in dygraph, test=develop
上级 022198c5
......@@ -51,6 +51,8 @@ limitations under the License. */
namespace paddle {
namespace pybind {
PyTypeObject *g_varbase_pytype = nullptr;
namespace py = ::pybind11;
class Layer : public imperative::Layer {
......@@ -470,9 +472,9 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
}
template <typename P>
static void VarBaseCopy(std::shared_ptr<imperative::VarBase> &src,
imperative::VarBase &dst, const P &dst_device,
const bool blocking) {
static void VarBaseCopy(std::shared_ptr<imperative::VarBase> &src, // NOLINT
imperative::VarBase &dst, // NOLINT
const P &dst_device, const bool blocking) {
if (dst.SharedVar()->IsEmpty()) {
VLOG(3) << "deep copy Variable from " << src->Name() << " to "
<< dst.Name();
......@@ -667,9 +669,10 @@ void BindImperative(py::module *m_ptr) {
imperative::SetCurrentTracer(tracer);
});
py::class_<imperative::VarBase, std::shared_ptr<imperative::VarBase>>(
m, "VarBase", R"DOC()DOC")
.def_static("_alive_vars", &imperative::VarBase::AliveVarNames)
py::class_<imperative::VarBase, std::shared_ptr<imperative::VarBase>> varbase(
m, "VarBase", R"DOC()DOC");
g_varbase_pytype = (PyTypeObject *)varbase.ptr(); // NOLINT
varbase.def_static("_alive_vars", &imperative::VarBase::AliveVarNames)
.def("__init__",
[](imperative::VarBase &self) {
std::string name =
......@@ -1468,21 +1471,15 @@ void BindImperative(py::module *m_ptr) {
&imperative::VarBase::SetOverridedStopGradient)
.def_property("persistable", &imperative::VarBase::Persistable,
&imperative::VarBase::SetPersistable)
.def_property_readonly("shape",
.def_property_readonly(
"shape",
[](imperative::VarBase &self) {
if (self.Var().IsType<framework::LoDTensor>()) {
return framework::vectorize<int>(
self.Var()
.Get<framework::LoDTensor>()
.dims());
} else if (self.Var()
.IsType<
framework::SelectedRows>()) {
self.Var().Get<framework::LoDTensor>().dims());
} else if (self.Var().IsType<framework::SelectedRows>()) {
return framework::vectorize<int>(
self.Var()
.Get<framework::SelectedRows>()
.value()
.dims());
self.Var().Get<framework::SelectedRows>().value().dims());
} else {
VLOG(2) << "It is meaningless to get shape of "
"variable type "
......
此差异已折叠。
......@@ -212,16 +212,17 @@ const char* OUT_VAR_TYPE = R"(std::shared_ptr<imperative::VarBase>)";
const char* OUT_VAR_LIST_TYPE = R"(std::vector<std::shared_ptr<imperative::VarBase>>)";
const char* CAST_VAR_TEMPLATE = R"(
auto %s = CastPyHandleToVarBase("%s", "%s", %d, %s, %s);)";
auto %s = GetVarBaseFromArgs("%s", "%s", args, %d, %s);)";
const char* CAST_VAR_LIST_TEMPLATE = R"(
auto %s = CastPyHandleToVarBaseList("%s", "%s", %d, %s, %s);)";
auto %s = GetVarBaseListFromArgs("%s", "%s", args, %d, %s);)";
const char* CAST_SIZE_T_TEMPLATE = R"(
auto %s = GetUnsignedLongFromArgs("%s", "%s", args, %d, %s);)";
const char* ARG_TEMPLATE = R"(const %s& %s)";
const char* RETURN_TUPLE_TYPE = R"(std::tuple<%s>)";
const char* RETURN_TYPE = R"(%s)";
const char* RETURN_TUPLE_TEMPLATE = R"(std::make_tuple(%s))";
const char* RETURN_LIST_TEMPLATE = R"(outs["%s"])";
const char* RETURN_TEMPLATE = R"(outs["%s"][0])";
......@@ -251,23 +252,34 @@ const char* INPLACE_MAPPING_TEMPLATE = R"({"%s", "%s"})";
const char* OP_FUNCTION_TEMPLATE =
R"(
%s %s(%s)
static PyObject * %s(PyObject *self, PyObject *args, PyObject *kwargs)
{
PyThreadState *tstate = nullptr;
try
{
%s
framework::AttributeMap attrs;
ConstructAttrMapFromPyArgs("%s", %d, &attrs, args);
{
py::gil_scoped_release release;
ConstructAttrMapFromPyArgs("%s", args, %d, PyTuple_GET_SIZE(args) , attrs);
tstate = PyEval_SaveThread();
%s
imperative::NameVarBaseMap outs = %s;
imperative::NameVarBaseMap ins = %s;
%s
imperative::GetCurrentTracer()->TraceOp("%s", ins, outs, attrs, {%s});
PyEval_RestoreThread(tstate);
tstate = nullptr;
return %s;
}
catch(...) {
if (tstate) {
PyEval_RestoreThread(tstate);
}
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
})";
const char* PYBIND_ITEM_TEMPLATE = R"( %s.def("%s", &%s);)";
const char* PYBIND_ITEM_TEMPLATE = R"( {"%s", (PyCFunction)(void(*)(void))%s, METH_VARARGS | METH_KEYWORDS, "C++ interface function for %s in dygraph."},)";
// clang-format on
static inline bool FindInsMap(const std::string& op_type,
......@@ -326,9 +338,8 @@ std::string GenerateOpFunctionsBody(
const auto in_cast_type =
input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = input.dispensable() ? "true" : "false";
ins_cast_str +=
paddle::string::Sprintf(in_cast_type, in_name, op_type, in_name,
arg_idx++, TempName(in_name), dispensable);
ins_cast_str += paddle::string::Sprintf(in_cast_type, in_name, op_type,
in_name, arg_idx++, dispensable);
if (input.dispensable()) {
const auto in_template = input.duplicable()
......@@ -356,7 +367,6 @@ std::string GenerateOpFunctionsBody(
// Generate outs initializer
std::string outs_initializer = "{";
std::string outs_initializer_with_null = "";
std::string return_type = "";
std::string inplace_mapping_str = "";
std::string return_str = "";
......@@ -395,6 +405,12 @@ std::string GenerateOpFunctionsBody(
paddle::string::Sprintf(out_template, out_name, out_name);
outs_initializer += ",";
}
const auto in_cast_type =
output.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = output.dispensable() ? "true" : "false";
ins_cast_str += paddle::string::Sprintf(in_cast_type, out_name, op_type,
out_name, arg_idx++, dispensable);
} else if (use_inplace_strategy && inplace_map.count(out_name)) {
PADDLE_ENFORCE_NE(
inplace_map[out_name], "",
......@@ -440,6 +456,11 @@ std::string GenerateOpFunctionsBody(
input_args_num++;
outs_initializer += paddle::string::Sprintf(
OUT_DUPLICABLE_INITIALIZER_TEMPLATE, out_name, out_num_str);
auto dispensable = output.dispensable() ? "true" : "false";
ins_cast_str +=
paddle::string::Sprintf(CAST_SIZE_T_TEMPLATE, out_num_str, op_type,
out_num_str, arg_idx++, dispensable);
} else {
outs_initializer +=
paddle::string::Sprintf(OUT_INITIALIZER_TEMPLATE, out_name);
......@@ -447,15 +468,12 @@ std::string GenerateOpFunctionsBody(
outs_initializer += ",";
}
return_type += out_type;
return_type += ",";
return_str += paddle::string::Sprintf(return_template, out_name);
return_str += ",";
outs_num += 1;
}
if (outs_initializer.back() == ',') {
outs_initializer.pop_back();
return_type.pop_back();
return_str.pop_back();
}
outs_initializer += "}";
......@@ -470,11 +488,13 @@ std::string GenerateOpFunctionsBody(
viwe_input_name, viwe_output_name);
}
if (outs_num == 0) {
return_type = "void";
}
if (outs_num > 1) {
return_str = paddle::string::Sprintf(RETURN_TUPLE_TEMPLATE, return_str);
return_type = paddle::string::Sprintf(RETURN_TUPLE_TYPE, return_type);
return_str = "Py_None";
} else if (outs_num == 1) {
return_str = "MakeReturnPyObject(" + return_str + ")";
} else {
return_str = "MakeReturnPyObject(" +
paddle::string::Sprintf(RETURN_TUPLE_TEMPLATE, return_str) +
")";
}
std::string function_args = "";
if (input_args == "") {
......@@ -485,9 +505,9 @@ std::string GenerateOpFunctionsBody(
// generate op funtcion body
auto op_function_str = paddle::string::Sprintf(
OP_FUNCTION_TEMPLATE, return_type, func_name, function_args, ins_cast_str,
op_type, input_args_num, inplace_strategy_str, outs_initializer,
ins_initializer, ins_initializer_with_null + outs_initializer_with_null +
OP_FUNCTION_TEMPLATE, func_name, ins_cast_str, op_type, input_args_num,
inplace_strategy_str, outs_initializer, ins_initializer,
ins_initializer_with_null + outs_initializer_with_null +
view_strategy_str,
op_type, inplace_mapping_str, return_str);
......@@ -495,7 +515,7 @@ std::string GenerateOpFunctionsBody(
}
static std::tuple<std::vector<std::string>, std::vector<std::string>>
GenerateOpFunctions(const std::string& module_name) {
GenerateOpFunctions() {
auto& op_info_map = paddle::framework::OpInfoMap::Instance().map();
std::vector<std::string> op_function_list, bind_function_list;
......@@ -536,7 +556,7 @@ GenerateOpFunctions(const std::string& module_name) {
// generate pybind item
auto bind_function_str = paddle::string::Sprintf(
PYBIND_ITEM_TEMPLATE, module_name, op_type, func_name);
PYBIND_ITEM_TEMPLATE, op_type, func_name, op_type);
op_function_list.emplace_back(std::move(op_function_str));
bind_function_list.emplace_back(std::move(bind_function_str));
......@@ -551,8 +571,8 @@ GenerateOpFunctions(const std::string& module_name) {
// generate pybind item
auto inplace_bind_function_str =
paddle::string::Sprintf(PYBIND_ITEM_TEMPLATE, module_name,
inplace_op_type, inplace_func_name);
paddle::string::Sprintf(PYBIND_ITEM_TEMPLATE, inplace_op_type,
inplace_func_name, inplace_op_type);
op_function_list.emplace_back(std::move(inplace_op_function_str));
bind_function_list.emplace_back(std::move(inplace_bind_function_str));
......@@ -572,7 +592,9 @@ int main(int argc, char* argv[]) {
ascend_ptr->InitGEForUT();
#endif
std::vector<std::string> headers{"\"paddle/fluid/imperative/tracer.h\""};
std::vector<std::string> headers{"\"paddle/fluid/imperative/tracer.h\"",
"\"pybind11/detail/common.h\"",
"<Python.h>"};
std::ofstream out(argv[1], std::ios::out);
......@@ -582,22 +604,29 @@ int main(int argc, char* argv[]) {
out << "#include " + header + "\n";
}
auto op_funcs = GenerateOpFunctions("m");
out << "\n\n";
auto op_funcs = GenerateOpFunctions();
out << "namespace py = pybind11;"
<< "\n";
out << "namespace paddle {\n"
<< "namespace pybind {\n\n";
out << "std::atomic<int> VarBaseUniqueNameID{0};\n";
out << paddle::string::join_strings(std::get<0>(op_funcs), '\n');
out << "\n\n";
out << "inline void BindOpFunctions(pybind11::module *module) {\n"
<< " auto m = module->def_submodule(\"ops\");\n\n";
out << "static PyMethodDef ExtestMethods[] = {\n"
<< paddle::string::join_strings(std::get<1>(op_funcs), '\n')
<< "\n {nullptr,nullptr,0,nullptr}"
<< "};\n\n";
out << paddle::string::join_strings(std::get<1>(op_funcs), '\n');
out << "\n";
out << "}\n\n"
out << "inline void BindOpFunctions(pybind11::module *module) {\n"
<< " auto m = module->def_submodule(\"ops\");\n"
<< " if (PyModule_AddFunctions(m.ptr(), ExtestMethods) < 0) {\n"
<< " PADDLE_THROW(platform::errors::Fatal (\"Add functions to "
"core.ops failed!\"));\n"
<< " }\n\n"
<< " InitOpsAttrTypeMap();"
<< "}\n\n"
<< "} // namespace pybind\n"
<< "} // namespace paddle\n";
......
......@@ -29,6 +29,9 @@ limitations under the License. */
namespace paddle {
namespace pybind {
PyTypeObject *g_vartype_pytype = nullptr;
PyTypeObject *g_blockdesc_pytype = nullptr;
namespace pd = paddle::framework;
template <typename T>
......@@ -82,8 +85,9 @@ void BindProgramDesc(pybind11::module *m) {
}
void BindBlockDesc(pybind11::module *m) {
pybind11::class_<pd::BlockDesc>(*m, "BlockDesc", "")
.def_property_readonly("id", &pd::BlockDesc::ID)
pybind11::class_<pd::BlockDesc> blockdesc(*m, "BlockDesc", "");
g_blockdesc_pytype = (PyTypeObject *)blockdesc.ptr(); // NOLINT
blockdesc.def_property_readonly("id", &pd::BlockDesc::ID)
.def_property_readonly("parent", &pd::BlockDesc::Parent)
.def("get_forward_block_idx", &pd::BlockDesc::ForwardBlockID)
.def("_set_forward_block_idx", &pd::BlockDesc::SetForwardBlockID)
......@@ -174,8 +178,9 @@ void BindVarDsec(pybind11::module *m) {
.def("need_check_feed", &pd::VarDesc::NeedCheckFeed)
.def("set_need_check_feed", &pd::VarDesc::SetNeedCheckFeed);
pybind11::enum_<pd::proto::VarType::Type>(var_desc, "VarType", "")
.value("BOOL", pd::proto::VarType::BOOL)
pybind11::enum_<pd::proto::VarType::Type> vartype(var_desc, "VarType", "");
g_vartype_pytype = (PyTypeObject *)vartype.ptr(); // NOLINT
vartype.value("BOOL", pd::proto::VarType::BOOL)
.value("UINT8", pd::proto::VarType::UINT8)
.value("INT8", pd::proto::VarType::INT8)
.value("INT16", pd::proto::VarType::INT16)
......
......@@ -357,7 +357,7 @@ def convert_shape_to_list(shape):
map(lambda x: x.numpy()[0] if isinstance(x, Variable) else x,
shape))
else:
shape = list(shape.numpy().astype(int))
shape = shape.numpy().astype(int).tolist()
return shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册