未验证 提交 08e81475 编写于 作者: W wanghuancoder 提交者: GitHub

use PYTHON_C_API in dygraph (#32524)

* use PYTHON_C_API in dygraph, test=develop
上级 022198c5
...@@ -51,6 +51,8 @@ limitations under the License. */ ...@@ -51,6 +51,8 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
PyTypeObject *g_varbase_pytype = nullptr;
namespace py = ::pybind11; namespace py = ::pybind11;
class Layer : public imperative::Layer { class Layer : public imperative::Layer {
...@@ -470,9 +472,9 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index, ...@@ -470,9 +472,9 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
} }
template <typename P> template <typename P>
static void VarBaseCopy(std::shared_ptr<imperative::VarBase> &src, static void VarBaseCopy(std::shared_ptr<imperative::VarBase> &src, // NOLINT
imperative::VarBase &dst, const P &dst_device, imperative::VarBase &dst, // NOLINT
const bool blocking) { const P &dst_device, const bool blocking) {
if (dst.SharedVar()->IsEmpty()) { if (dst.SharedVar()->IsEmpty()) {
VLOG(3) << "deep copy Variable from " << src->Name() << " to " VLOG(3) << "deep copy Variable from " << src->Name() << " to "
<< dst.Name(); << dst.Name();
...@@ -667,9 +669,10 @@ void BindImperative(py::module *m_ptr) { ...@@ -667,9 +669,10 @@ void BindImperative(py::module *m_ptr) {
imperative::SetCurrentTracer(tracer); imperative::SetCurrentTracer(tracer);
}); });
py::class_<imperative::VarBase, std::shared_ptr<imperative::VarBase>>( py::class_<imperative::VarBase, std::shared_ptr<imperative::VarBase>> varbase(
m, "VarBase", R"DOC()DOC") m, "VarBase", R"DOC()DOC");
.def_static("_alive_vars", &imperative::VarBase::AliveVarNames) g_varbase_pytype = (PyTypeObject *)varbase.ptr(); // NOLINT
varbase.def_static("_alive_vars", &imperative::VarBase::AliveVarNames)
.def("__init__", .def("__init__",
[](imperative::VarBase &self) { [](imperative::VarBase &self) {
std::string name = std::string name =
...@@ -1468,28 +1471,22 @@ void BindImperative(py::module *m_ptr) { ...@@ -1468,28 +1471,22 @@ void BindImperative(py::module *m_ptr) {
&imperative::VarBase::SetOverridedStopGradient) &imperative::VarBase::SetOverridedStopGradient)
.def_property("persistable", &imperative::VarBase::Persistable, .def_property("persistable", &imperative::VarBase::Persistable,
&imperative::VarBase::SetPersistable) &imperative::VarBase::SetPersistable)
.def_property_readonly("shape", .def_property_readonly(
[](imperative::VarBase &self) { "shape",
if (self.Var().IsType<framework::LoDTensor>()) { [](imperative::VarBase &self) {
return framework::vectorize<int>( if (self.Var().IsType<framework::LoDTensor>()) {
self.Var() return framework::vectorize<int>(
.Get<framework::LoDTensor>() self.Var().Get<framework::LoDTensor>().dims());
.dims()); } else if (self.Var().IsType<framework::SelectedRows>()) {
} else if (self.Var() return framework::vectorize<int>(
.IsType< self.Var().Get<framework::SelectedRows>().value().dims());
framework::SelectedRows>()) { } else {
return framework::vectorize<int>( VLOG(2) << "It is meaningless to get shape of "
self.Var() "variable type "
.Get<framework::SelectedRows>() << GetTypeName(self);
.value() return std::vector<int>();
.dims()); }
} else { })
VLOG(2) << "It is meaningless to get shape of "
"variable type "
<< GetTypeName(self);
return std::vector<int>();
}
})
.def_property_readonly("is_leaf", &imperative::VarBase::IsLeaf, .def_property_readonly("is_leaf", &imperative::VarBase::IsLeaf,
R"DOC( R"DOC(
Whether a Tensor is leaf Tensor. Whether a Tensor is leaf Tensor.
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "paddle/fluid/framework/attribute.h" #include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/type_defs.h"
...@@ -34,6 +35,28 @@ namespace py = pybind11; ...@@ -34,6 +35,28 @@ namespace py = pybind11;
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
class OpAttrTypeMap {
public:
static OpAttrTypeMap& Instance() {
static OpAttrTypeMap g_op_attr_type_map;
return g_op_attr_type_map;
}
std::unordered_map<
std::string,
std::unordered_map<std::string, paddle::framework::proto::AttrType>>&
Map() {
return ops_attrtype_map_;
}
private:
OpAttrTypeMap() = default;
std::unordered_map<
std::string,
std::unordered_map<std::string, paddle::framework::proto::AttrType>>
ops_attrtype_map_;
};
static inline std::shared_ptr<imperative::VarBase> CastPyHandleToVarBase( static inline std::shared_ptr<imperative::VarBase> CastPyHandleToVarBase(
const std::string& op_type, const std::string& arg_name, int arg_idx, const std::string& op_type, const std::string& arg_name, int arg_idx,
const py::handle& handle, bool dispensable = false) { const py::handle& handle, bool dispensable = false) {
...@@ -173,6 +196,839 @@ static inline void HandleViewBetweenInputAndOutput( ...@@ -173,6 +196,839 @@ static inline void HandleViewBetweenInputAndOutput(
<< "), share allocation and inplace version."; << "), share allocation and inplace version.";
} }
} }
extern PyTypeObject* g_varbase_pytype;
extern PyTypeObject* g_vartype_pytype;
extern PyTypeObject* g_blockdesc_pytype;
inline bool PyObject_CheckBool(PyObject** obj) { return PyBool_Check(*obj); }
inline bool PyObject_CheckLongOrToLong(PyObject** obj) {
if ((PyLong_Check(*obj) && !PyBool_Check(*obj)) ||
PyObject_IsInstance(*obj, (PyObject*)g_vartype_pytype) || // NOLINT
PyObject_IsInstance(*obj, (PyObject*)g_varbase_pytype)) { // NOLINT
return true;
}
auto to = PyNumber_Long(*obj);
if (to) {
*obj = to;
return true;
}
return false;
}
inline bool PyObject_CheckFloatOrToFloat(PyObject** obj) {
// sometimes users provide PyLong or numpy.int64 but attr is float
if (PyFloat_Check(*obj) || PyLong_Check(*obj) ||
PyObject_IsInstance(*obj, (PyObject*)g_varbase_pytype)) { // NOLINT
return true;
}
auto to = PyNumber_Float(*obj);
if (to) {
*obj = to;
return true;
}
return false;
}
inline bool PyObject_CheckString(PyObject* obj) { return PyUnicode_Check(obj); }
static inline void CastPyArg2AttrBoolean(
PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, ssize_t arg_pos) {
if (obj == Py_None) {
attrs[key] = false; // To be compatible with QA integration testing. Some
// test case pass in None.
} else if (obj == Py_True) {
attrs[key] = true;
} else if (obj == Py_False) {
attrs[key] = false;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"bool, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
}
static inline void CastPyArg2AttrInt(
PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, ssize_t arg_pos) {
if (PyObject_CheckLongOrToLong(&obj)) {
attrs[key] = (int)PyLong_AsLong(obj); // NOLINT
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"int, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
}
static inline void CastPyArg2AttrLong(
PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, ssize_t arg_pos) {
if (PyObject_CheckLongOrToLong(&obj)) {
attrs[key] = (int64_t)PyLong_AsLong(obj); // NOLINT
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"long, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
}
static inline void CastPyArg2AttrFloat(
PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, ssize_t arg_pos) {
if (PyObject_CheckFloatOrToFloat(&obj)) {
attrs[key] = (float)PyFloat_AsDouble(obj); // NOLINT
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"float, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
}
static inline void CastPyArg2AttrString(
PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, ssize_t arg_pos) {
if (PyObject_CheckString(obj)) {
Py_ssize_t size;
const char* data;
data = PyUnicode_AsUTF8AndSize(obj, &size);
attrs[key] = std::string(data, (size_t)size);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"str, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
}
static inline void CastPyArg2AttrBooleans(
PyObject* obj, paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, ssize_t arg_pos) {
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr;
std::vector<bool> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
if (PyObject_CheckBool(&item)) {
value.emplace_back(PyLong_AsLong(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of bool, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
std::vector<bool> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
if (PyObject_CheckBool(&item)) {
value.emplace_back(PyLong_AsLong(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of bool, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list or tuple, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
}
static inline void CastPyArg2AttrInts(
PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, ssize_t arg_pos) {
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr;
std::vector<int> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) {
value.emplace_back(PyLong_AsLong(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of int, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
std::vector<int> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) {
value.emplace_back(PyLong_AsLong(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of int, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else if (PySequence_Check(obj)) {
Py_ssize_t len = PySequence_Size(obj);
PyObject* item = nullptr;
std::vector<int> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) {
value.emplace_back(PyLong_AsLong(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of int, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list or tuple, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
}
static inline void CastPyArg2AttrLongs(
PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, ssize_t arg_pos) {
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr;
std::vector<int64_t> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) {
value.emplace_back(PyLong_AsLong(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of int, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
std::vector<int64_t> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) {
value.emplace_back(PyLong_AsLong(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of int, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else if (PySequence_Check(obj)) {
Py_ssize_t len = PySequence_Size(obj);
PyObject* item = nullptr;
std::vector<int64_t> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) {
value.emplace_back(PyLong_AsLong(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of int, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list or tuple, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
}
static inline void CastPyArg2AttrFloats(
PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, ssize_t arg_pos) {
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr;
std::vector<float> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) {
value.emplace_back(PyFloat_AsDouble(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of float, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
std::vector<float> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) {
value.emplace_back(PyFloat_AsDouble(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of float, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else if (PySequence_Check(obj)) {
Py_ssize_t len = PySequence_Size(obj);
PyObject* item = nullptr;
std::vector<float> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) {
value.emplace_back(PyFloat_AsDouble(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of float, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list or tuple, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
}
static inline void CastPyArg2AttrFloat64s(
PyObject* obj, paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, ssize_t arg_pos) {
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr;
std::vector<double> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) {
value.emplace_back(PyFloat_AsDouble(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of float, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
std::vector<double> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) {
value.emplace_back(PyFloat_AsDouble(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of float, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else if (PySequence_Check(obj)) {
Py_ssize_t len = PySequence_Size(obj);
PyObject* item = nullptr;
std::vector<double> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) {
value.emplace_back(PyFloat_AsDouble(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of float, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list or tuple, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
}
static inline void CastPyArg2AttrStrings(
PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, ssize_t arg_pos) {
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr;
std::vector<std::string> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
if (PyObject_CheckString(item)) {
Py_ssize_t size;
const char* data;
data = PyUnicode_AsUTF8AndSize(item, &size);
value.emplace_back(std::string(data, (size_t)size)); // NOLINT
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of str, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
std::vector<std::string> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
if (PyObject_CheckString(item)) {
Py_ssize_t size;
const char* data;
data = PyUnicode_AsUTF8AndSize(item, &size);
value.emplace_back(std::string(data, (size_t)size));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of str, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list or tuple, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
}
static inline void CastPyArg2AttrBlock(
PyObject* obj, paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, ssize_t arg_pos) {
::pybind11::detail::instance* inst =
(::pybind11::detail::instance*)obj; // NOLINT
if (!PyObject_IsInstance((PyObject*)inst, // NOLINT
(PyObject*)g_blockdesc_pytype)) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"BlockDesc, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
void** vh = inst->simple_layout ? inst->simple_value_holder
: &inst->nonsimple.values_and_holders[0];
attrs[key] = reinterpret_cast<paddle::framework::BlockDesc*&>(vh[0]);
}
static inline void ConstructAttrMapFromPyArgs(
const std::string& op_type, PyObject* args, ssize_t attr_start,
ssize_t attr_end, paddle::framework::AttributeMap& attrs) { // NOLINT
PADDLE_ENFORCE_EQ(
(attr_end - attr_start) % 2, 0,
platform::errors::InvalidArgument(
"The number of arguments for attributes should be even."));
auto attr_type_map = &(OpAttrTypeMap::Instance().Map()[op_type]);
PyObject* obj = nullptr;
for (ssize_t arg_pos = attr_start; arg_pos < attr_end; arg_pos += 2) {
Py_ssize_t key_len;
const char* key_ptr;
obj = PyTuple_GET_ITEM(args, arg_pos);
if (PyObject_CheckString(obj)) {
key_ptr = PyUnicode_AsUTF8AndSize(obj, &key_len);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be str, but got "
"%s",
op_type, arg_pos, ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
std::string key(key_ptr, (size_t)key_len);
auto iter = attr_type_map->find(key);
if (iter == attr_type_map->end()) {
continue;
}
obj = PyTuple_GET_ITEM(args, arg_pos + 1);
switch (iter->second) {
case paddle::framework::proto::AttrType::INT:
CastPyArg2AttrInt(obj, attrs, key, op_type, arg_pos);
break;
case paddle::framework::proto::AttrType::FLOAT:
CastPyArg2AttrFloat(obj, attrs, key, op_type, arg_pos);
break;
case paddle::framework::proto::AttrType::STRING:
CastPyArg2AttrString(obj, attrs, key, op_type, arg_pos);
break;
case paddle::framework::proto::AttrType::INTS:
CastPyArg2AttrInts(obj, attrs, key, op_type, arg_pos);
break;
case paddle::framework::proto::AttrType::FLOATS:
CastPyArg2AttrFloats(obj, attrs, key, op_type, arg_pos);
break;
case paddle::framework::proto::AttrType::STRINGS:
CastPyArg2AttrStrings(obj, attrs, key, op_type, arg_pos);
break;
case paddle::framework::proto::AttrType::BOOLEAN:
CastPyArg2AttrBoolean(obj, attrs, key, op_type, arg_pos);
break;
case paddle::framework::proto::AttrType::BOOLEANS:
CastPyArg2AttrBooleans(obj, attrs, key, op_type, arg_pos);
break;
case paddle::framework::proto::AttrType::LONG:
CastPyArg2AttrLong(obj, attrs, key, op_type, arg_pos);
break;
case paddle::framework::proto::AttrType::LONGS:
CastPyArg2AttrLongs(obj, attrs, key, op_type, arg_pos);
break;
case paddle::framework::proto::AttrType::FLOAT64S:
CastPyArg2AttrFloat64s(obj, attrs, key, op_type, arg_pos);
break;
case paddle::framework::proto::AttrType::BLOCK:
CastPyArg2AttrBlock(obj, attrs, key, op_type, arg_pos);
break;
default:
break;
}
}
}
static inline std::shared_ptr<imperative::VarBase> GetVarBaseFromArgs(
const std::string& op_type, const std::string& arg_name, PyObject* args,
ssize_t arg_idx, bool dispensable = false) {
::pybind11::detail::instance* inst =
(::pybind11::detail::instance*)PyTuple_GET_ITEM(args, arg_idx);
if (PyTuple_Check((PyObject*)inst)) { // NOLINT
inst = (::pybind11::detail::instance*)PyTuple_GET_ITEM(inst, 0);
}
if (inst == nullptr || (PyObject*)inst == Py_None) { // NOLINT
if (!dispensable) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be Tensor, but got None",
op_type, arg_name, arg_idx));
}
return nullptr;
}
if (!PyObject_IsInstance((PyObject*)inst, // NOLINT
(PyObject*)g_varbase_pytype)) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be Tensor, but got "
"%s",
op_type, arg_name, arg_idx,
((PyTypeObject*)((PyObject*)inst)->ob_type)->tp_name)); // NOLINT
}
void** vh = inst->simple_layout ? inst->simple_value_holder
: &inst->nonsimple.values_and_holders[0];
return reinterpret_cast<std::shared_ptr<paddle::imperative::VarBase>&>(vh[1]);
}
static inline std::vector<std::shared_ptr<imperative::VarBase>>
GetVarBaseListFromArgs(const std::string& op_type, const std::string& arg_name,
PyObject* args, ssize_t arg_idx,
bool dispensable = false) {
PyObject* list = PyTuple_GET_ITEM(args, arg_idx);
if (list == nullptr) {
if (!dispensable) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensor, but got "
"None",
op_type, arg_name, arg_idx)); // NOLINT
}
return {};
}
std::vector<std::shared_ptr<imperative::VarBase>> result;
if (PyList_Check(list)) {
Py_ssize_t len = PyList_Size(list);
if (len == 0) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensors, but got "
"empty list",
op_type, arg_name, arg_idx));
}
::pybind11::detail::instance* item = nullptr;
for (Py_ssize_t i = 0; i < len; i++) {
item = (::pybind11::detail::instance*)PyList_GetItem(list, i);
if (!PyObject_IsInstance((PyObject*)item, // NOLINT
(PyObject*)g_varbase_pytype)) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensors, but "
"got list of "
"%s",
op_type, arg_name, arg_idx,
((PyTypeObject*)((PyObject*)item)->ob_type)->tp_name)); // NOLINT
}
void** vh = item->simple_layout ? item->simple_value_holder
: &item->nonsimple.values_and_holders[0];
result.emplace_back(
reinterpret_cast<std::shared_ptr<paddle::imperative::VarBase>&>(
vh[1]));
}
} else if (PyTuple_Check(list)) {
Py_ssize_t len = PyTuple_Size(list);
if (len == 0) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensors, but got "
"empty list",
op_type, arg_name, arg_idx));
}
::pybind11::detail::instance* item = nullptr;
for (Py_ssize_t i = 0; i < len; i++) {
item = (::pybind11::detail::instance*)PyTuple_GetItem(list, i); // NOLINT
if (!PyObject_IsInstance((PyObject*)item, // NOLINT
(PyObject*)g_varbase_pytype)) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensors, but "
"got list of "
"%s",
op_type, arg_name, arg_idx,
((PyTypeObject*)((PyObject*)item)->ob_type)->tp_name)); // NOLINT
}
void** vh = item->simple_layout ? item->simple_value_holder
: &item->nonsimple.values_and_holders[0];
result.emplace_back(
reinterpret_cast<std::shared_ptr<paddle::imperative::VarBase>&>(
vh[1]));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensors, but got "
"%s",
op_type, arg_name, arg_idx,
((PyTypeObject*)list->ob_type)->tp_name)); // NOLINT
}
return result;
}
static inline unsigned long GetUnsignedLongFromArgs( // NOLINT
const std::string& op_type, const std::string& arg_name, PyObject* args,
ssize_t arg_idx, bool dispensable = false) {
PyObject* item = PyTuple_GET_ITEM(args, arg_idx);
if (item == nullptr) {
if (!dispensable) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be long, but got None",
op_type, arg_name, arg_idx));
}
return 0;
}
if (PyObject_CheckLongOrToLong(&item)) {
return PyLong_AsUnsignedLong(item);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be "
"long, but got %s",
op_type, arg_name, arg_idx,
((PyTypeObject*)item->ob_type)->tp_name)); // NOLINT
}
}
static inline PyObject* MakeReturnPyObject(
const std::shared_ptr<paddle::imperative::VarBase>& out) {
return ::pybind11::detail::type_caster_base<imperative::VarBase>::cast_holder(
::pybind11::detail::holder_helper<
std::shared_ptr<imperative::VarBase>>::get(out),
&out)
.ptr();
}
static inline PyObject* MakeReturnPyObject(
const std::vector<std::shared_ptr<imperative::VarBase>>& out) {
PyObject* result = PyList_New((Py_ssize_t)out.size());
for (size_t i = 0; i < out.size(); i++) {
PyList_SET_ITEM(
result, (Py_ssize_t)i,
::pybind11::detail::type_caster_base<imperative::VarBase>::cast_holder(
::pybind11::detail::holder_helper<
std::shared_ptr<imperative::VarBase>>::get(out[i]),
&out[i])
.ptr()); // NOLINT
}
return result;
}
template <typename Tuple, size_t N>
struct TupleVarBasesResult {
static void Run(const Tuple& out, PyObject* result) {
TupleVarBasesResult<Tuple, N - 1>::Run(out, result);
PyTuple_SET_ITEM(result, N - 1, MakeReturnPyObject(std::get<N - 1>(out)));
}
};
template <typename Tuple>
struct TupleVarBasesResult<Tuple, 1> {
static void Run(const Tuple& out, PyObject* result) {
PyTuple_SET_ITEM(result, 0, MakeReturnPyObject(std::get<0>(out)));
}
};
template <typename... Args>
static inline PyObject* MakeReturnPyObject(const std::tuple<Args...>& out) {
auto len = sizeof...(Args);
PyObject* result = PyTuple_New(len);
TupleVarBasesResult<decltype(out), sizeof...(Args)>::Run(out, result);
return result;
}
void InitOpsAttrTypeMap() {
auto op_info_map = paddle::framework::OpInfoMap::Instance().map();
for (auto iter = op_info_map.begin(); iter != op_info_map.end(); ++iter) {
auto op_proto = iter->second.proto_;
if (op_proto == nullptr) {
continue;
}
auto attrs_proto = op_proto->attrs();
for (auto& attr : attrs_proto) {
OpAttrTypeMap::Instance().Map()[iter->first][attr.name()] = attr.type();
}
}
}
PyObject* EOFExceptionException =
PyErr_NewException("paddle.EOFException", PyExc_Exception, NULL);
PyObject* EnforceNotMetException =
PyErr_NewException("paddle.EnforceNotMet", PyExc_Exception, NULL);
void ThrowExceptionToPython(std::exception_ptr p) {
try {
if (p) std::rethrow_exception(p);
} catch (const platform::EOFException& e) {
PyErr_SetString(EOFExceptionException, e.what());
} catch (const platform::EnforceNotMet& e) {
switch (e.code()) {
case paddle::platform::error::INVALID_ARGUMENT:
PyErr_SetString(PyExc_ValueError, e.what());
break;
case paddle::platform::error::NOT_FOUND:
case paddle::platform::error::ALREADY_EXISTS:
case paddle::platform::error::PRECONDITION_NOT_MET:
case paddle::platform::error::PERMISSION_DENIED:
case paddle::platform::error::EXECUTION_TIMEOUT:
case paddle::platform::error::UNAVAILABLE:
PyErr_SetString(PyExc_RuntimeError, e.what());
break;
case paddle::platform::error::OUT_OF_RANGE:
PyErr_SetString(PyExc_IndexError, e.what());
break;
case paddle::platform::error::RESOURCE_EXHAUSTED:
PyErr_SetString(PyExc_MemoryError, e.what());
break;
case paddle::platform::error::UNIMPLEMENTED:
PyErr_SetString(PyExc_NotImplementedError, e.what());
break;
case paddle::platform::error::FATAL:
PyErr_SetString(PyExc_SystemError, e.what());
break;
case paddle::platform::error::EXTERNAL:
PyErr_SetString(PyExc_OSError, e.what());
break;
default:
PyErr_SetString(EnforceNotMetException, e.what());
break;
}
}
}
} // namespace pybind } // namespace pybind
} // namespace paddle } // namespace paddle
......
...@@ -212,16 +212,17 @@ const char* OUT_VAR_TYPE = R"(std::shared_ptr<imperative::VarBase>)"; ...@@ -212,16 +212,17 @@ const char* OUT_VAR_TYPE = R"(std::shared_ptr<imperative::VarBase>)";
const char* OUT_VAR_LIST_TYPE = R"(std::vector<std::shared_ptr<imperative::VarBase>>)"; const char* OUT_VAR_LIST_TYPE = R"(std::vector<std::shared_ptr<imperative::VarBase>>)";
const char* CAST_VAR_TEMPLATE = R"( const char* CAST_VAR_TEMPLATE = R"(
auto %s = CastPyHandleToVarBase("%s", "%s", %d, %s, %s);)"; auto %s = GetVarBaseFromArgs("%s", "%s", args, %d, %s);)";
const char* CAST_VAR_LIST_TEMPLATE = R"( const char* CAST_VAR_LIST_TEMPLATE = R"(
auto %s = CastPyHandleToVarBaseList("%s", "%s", %d, %s, %s);)"; auto %s = GetVarBaseListFromArgs("%s", "%s", args, %d, %s);)";
const char* CAST_SIZE_T_TEMPLATE = R"(
auto %s = GetUnsignedLongFromArgs("%s", "%s", args, %d, %s);)";
const char* ARG_TEMPLATE = R"(const %s& %s)"; const char* ARG_TEMPLATE = R"(const %s& %s)";
const char* RETURN_TUPLE_TYPE = R"(std::tuple<%s>)"; const char* RETURN_TUPLE_TYPE = R"(std::tuple<%s>)";
const char* RETURN_TYPE = R"(%s)";
const char* RETURN_TUPLE_TEMPLATE = R"(std::make_tuple(%s))"; const char* RETURN_TUPLE_TEMPLATE = R"(std::make_tuple(%s))";
const char* RETURN_LIST_TEMPLATE = R"(outs["%s"])"; const char* RETURN_LIST_TEMPLATE = R"(outs["%s"])";
const char* RETURN_TEMPLATE = R"(outs["%s"][0])"; const char* RETURN_TEMPLATE = R"(outs["%s"][0])";
...@@ -251,23 +252,34 @@ const char* INPLACE_MAPPING_TEMPLATE = R"({"%s", "%s"})"; ...@@ -251,23 +252,34 @@ const char* INPLACE_MAPPING_TEMPLATE = R"({"%s", "%s"})";
const char* OP_FUNCTION_TEMPLATE = const char* OP_FUNCTION_TEMPLATE =
R"( R"(
%s %s(%s) static PyObject * %s(PyObject *self, PyObject *args, PyObject *kwargs)
{ {
%s PyThreadState *tstate = nullptr;
framework::AttributeMap attrs; try
ConstructAttrMapFromPyArgs("%s", %d, &attrs, args);
{ {
py::gil_scoped_release release; %s
framework::AttributeMap attrs;
ConstructAttrMapFromPyArgs("%s", args, %d, PyTuple_GET_SIZE(args) , attrs);
tstate = PyEval_SaveThread();
%s %s
imperative::NameVarBaseMap outs = %s; imperative::NameVarBaseMap outs = %s;
imperative::NameVarBaseMap ins = %s; imperative::NameVarBaseMap ins = %s;
%s %s
imperative::GetCurrentTracer()->TraceOp("%s", ins, outs, attrs, {%s}); imperative::GetCurrentTracer()->TraceOp("%s", ins, outs, attrs, {%s});
PyEval_RestoreThread(tstate);
tstate = nullptr;
return %s; return %s;
} }
catch(...) {
if (tstate) {
PyEval_RestoreThread(tstate);
}
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
})"; })";
const char* PYBIND_ITEM_TEMPLATE = R"( %s.def("%s", &%s);)"; const char* PYBIND_ITEM_TEMPLATE = R"( {"%s", (PyCFunction)(void(*)(void))%s, METH_VARARGS | METH_KEYWORDS, "C++ interface function for %s in dygraph."},)";
// clang-format on // clang-format on
static inline bool FindInsMap(const std::string& op_type, static inline bool FindInsMap(const std::string& op_type,
...@@ -326,9 +338,8 @@ std::string GenerateOpFunctionsBody( ...@@ -326,9 +338,8 @@ std::string GenerateOpFunctionsBody(
const auto in_cast_type = const auto in_cast_type =
input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE; input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = input.dispensable() ? "true" : "false"; auto dispensable = input.dispensable() ? "true" : "false";
ins_cast_str += ins_cast_str += paddle::string::Sprintf(in_cast_type, in_name, op_type,
paddle::string::Sprintf(in_cast_type, in_name, op_type, in_name, in_name, arg_idx++, dispensable);
arg_idx++, TempName(in_name), dispensable);
if (input.dispensable()) { if (input.dispensable()) {
const auto in_template = input.duplicable() const auto in_template = input.duplicable()
...@@ -356,7 +367,6 @@ std::string GenerateOpFunctionsBody( ...@@ -356,7 +367,6 @@ std::string GenerateOpFunctionsBody(
// Generate outs initializer // Generate outs initializer
std::string outs_initializer = "{"; std::string outs_initializer = "{";
std::string outs_initializer_with_null = ""; std::string outs_initializer_with_null = "";
std::string return_type = "";
std::string inplace_mapping_str = ""; std::string inplace_mapping_str = "";
std::string return_str = ""; std::string return_str = "";
...@@ -395,6 +405,12 @@ std::string GenerateOpFunctionsBody( ...@@ -395,6 +405,12 @@ std::string GenerateOpFunctionsBody(
paddle::string::Sprintf(out_template, out_name, out_name); paddle::string::Sprintf(out_template, out_name, out_name);
outs_initializer += ","; outs_initializer += ",";
} }
const auto in_cast_type =
output.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = output.dispensable() ? "true" : "false";
ins_cast_str += paddle::string::Sprintf(in_cast_type, out_name, op_type,
out_name, arg_idx++, dispensable);
} else if (use_inplace_strategy && inplace_map.count(out_name)) { } else if (use_inplace_strategy && inplace_map.count(out_name)) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
inplace_map[out_name], "", inplace_map[out_name], "",
...@@ -440,6 +456,11 @@ std::string GenerateOpFunctionsBody( ...@@ -440,6 +456,11 @@ std::string GenerateOpFunctionsBody(
input_args_num++; input_args_num++;
outs_initializer += paddle::string::Sprintf( outs_initializer += paddle::string::Sprintf(
OUT_DUPLICABLE_INITIALIZER_TEMPLATE, out_name, out_num_str); OUT_DUPLICABLE_INITIALIZER_TEMPLATE, out_name, out_num_str);
auto dispensable = output.dispensable() ? "true" : "false";
ins_cast_str +=
paddle::string::Sprintf(CAST_SIZE_T_TEMPLATE, out_num_str, op_type,
out_num_str, arg_idx++, dispensable);
} else { } else {
outs_initializer += outs_initializer +=
paddle::string::Sprintf(OUT_INITIALIZER_TEMPLATE, out_name); paddle::string::Sprintf(OUT_INITIALIZER_TEMPLATE, out_name);
...@@ -447,15 +468,12 @@ std::string GenerateOpFunctionsBody( ...@@ -447,15 +468,12 @@ std::string GenerateOpFunctionsBody(
outs_initializer += ","; outs_initializer += ",";
} }
return_type += out_type;
return_type += ",";
return_str += paddle::string::Sprintf(return_template, out_name); return_str += paddle::string::Sprintf(return_template, out_name);
return_str += ","; return_str += ",";
outs_num += 1; outs_num += 1;
} }
if (outs_initializer.back() == ',') { if (outs_initializer.back() == ',') {
outs_initializer.pop_back(); outs_initializer.pop_back();
return_type.pop_back();
return_str.pop_back(); return_str.pop_back();
} }
outs_initializer += "}"; outs_initializer += "}";
...@@ -470,11 +488,13 @@ std::string GenerateOpFunctionsBody( ...@@ -470,11 +488,13 @@ std::string GenerateOpFunctionsBody(
viwe_input_name, viwe_output_name); viwe_input_name, viwe_output_name);
} }
if (outs_num == 0) { if (outs_num == 0) {
return_type = "void"; return_str = "Py_None";
} } else if (outs_num == 1) {
if (outs_num > 1) { return_str = "MakeReturnPyObject(" + return_str + ")";
return_str = paddle::string::Sprintf(RETURN_TUPLE_TEMPLATE, return_str); } else {
return_type = paddle::string::Sprintf(RETURN_TUPLE_TYPE, return_type); return_str = "MakeReturnPyObject(" +
paddle::string::Sprintf(RETURN_TUPLE_TEMPLATE, return_str) +
")";
} }
std::string function_args = ""; std::string function_args = "";
if (input_args == "") { if (input_args == "") {
...@@ -485,17 +505,17 @@ std::string GenerateOpFunctionsBody( ...@@ -485,17 +505,17 @@ std::string GenerateOpFunctionsBody(
// generate op funtcion body // generate op funtcion body
auto op_function_str = paddle::string::Sprintf( auto op_function_str = paddle::string::Sprintf(
OP_FUNCTION_TEMPLATE, return_type, func_name, function_args, ins_cast_str, OP_FUNCTION_TEMPLATE, func_name, ins_cast_str, op_type, input_args_num,
op_type, input_args_num, inplace_strategy_str, outs_initializer, inplace_strategy_str, outs_initializer, ins_initializer,
ins_initializer, ins_initializer_with_null + outs_initializer_with_null + ins_initializer_with_null + outs_initializer_with_null +
view_strategy_str, view_strategy_str,
op_type, inplace_mapping_str, return_str); op_type, inplace_mapping_str, return_str);
return op_function_str; return op_function_str;
} }
static std::tuple<std::vector<std::string>, std::vector<std::string>> static std::tuple<std::vector<std::string>, std::vector<std::string>>
GenerateOpFunctions(const std::string& module_name) { GenerateOpFunctions() {
auto& op_info_map = paddle::framework::OpInfoMap::Instance().map(); auto& op_info_map = paddle::framework::OpInfoMap::Instance().map();
std::vector<std::string> op_function_list, bind_function_list; std::vector<std::string> op_function_list, bind_function_list;
...@@ -536,7 +556,7 @@ GenerateOpFunctions(const std::string& module_name) { ...@@ -536,7 +556,7 @@ GenerateOpFunctions(const std::string& module_name) {
// generate pybind item // generate pybind item
auto bind_function_str = paddle::string::Sprintf( auto bind_function_str = paddle::string::Sprintf(
PYBIND_ITEM_TEMPLATE, module_name, op_type, func_name); PYBIND_ITEM_TEMPLATE, op_type, func_name, op_type);
op_function_list.emplace_back(std::move(op_function_str)); op_function_list.emplace_back(std::move(op_function_str));
bind_function_list.emplace_back(std::move(bind_function_str)); bind_function_list.emplace_back(std::move(bind_function_str));
...@@ -551,8 +571,8 @@ GenerateOpFunctions(const std::string& module_name) { ...@@ -551,8 +571,8 @@ GenerateOpFunctions(const std::string& module_name) {
// generate pybind item // generate pybind item
auto inplace_bind_function_str = auto inplace_bind_function_str =
paddle::string::Sprintf(PYBIND_ITEM_TEMPLATE, module_name, paddle::string::Sprintf(PYBIND_ITEM_TEMPLATE, inplace_op_type,
inplace_op_type, inplace_func_name); inplace_func_name, inplace_op_type);
op_function_list.emplace_back(std::move(inplace_op_function_str)); op_function_list.emplace_back(std::move(inplace_op_function_str));
bind_function_list.emplace_back(std::move(inplace_bind_function_str)); bind_function_list.emplace_back(std::move(inplace_bind_function_str));
...@@ -572,7 +592,9 @@ int main(int argc, char* argv[]) { ...@@ -572,7 +592,9 @@ int main(int argc, char* argv[]) {
ascend_ptr->InitGEForUT(); ascend_ptr->InitGEForUT();
#endif #endif
std::vector<std::string> headers{"\"paddle/fluid/imperative/tracer.h\""}; std::vector<std::string> headers{"\"paddle/fluid/imperative/tracer.h\"",
"\"pybind11/detail/common.h\"",
"<Python.h>"};
std::ofstream out(argv[1], std::ios::out); std::ofstream out(argv[1], std::ios::out);
...@@ -582,22 +604,29 @@ int main(int argc, char* argv[]) { ...@@ -582,22 +604,29 @@ int main(int argc, char* argv[]) {
out << "#include " + header + "\n"; out << "#include " + header + "\n";
} }
auto op_funcs = GenerateOpFunctions("m"); out << "\n\n";
auto op_funcs = GenerateOpFunctions();
out << "namespace py = pybind11;"
<< "\n";
out << "namespace paddle {\n" out << "namespace paddle {\n"
<< "namespace pybind {\n\n"; << "namespace pybind {\n\n";
out << "std::atomic<int> VarBaseUniqueNameID{0};\n"; out << "std::atomic<int> VarBaseUniqueNameID{0};\n";
out << paddle::string::join_strings(std::get<0>(op_funcs), '\n'); out << paddle::string::join_strings(std::get<0>(op_funcs), '\n');
out << "\n\n"; out << "\n\n";
out << "inline void BindOpFunctions(pybind11::module *module) {\n" out << "static PyMethodDef ExtestMethods[] = {\n"
<< " auto m = module->def_submodule(\"ops\");\n\n"; << paddle::string::join_strings(std::get<1>(op_funcs), '\n')
<< "\n {nullptr,nullptr,0,nullptr}"
<< "};\n\n";
out << paddle::string::join_strings(std::get<1>(op_funcs), '\n'); out << "inline void BindOpFunctions(pybind11::module *module) {\n"
out << "\n"; << " auto m = module->def_submodule(\"ops\");\n"
out << "}\n\n" << " if (PyModule_AddFunctions(m.ptr(), ExtestMethods) < 0) {\n"
<< " PADDLE_THROW(platform::errors::Fatal (\"Add functions to "
"core.ops failed!\"));\n"
<< " }\n\n"
<< " InitOpsAttrTypeMap();"
<< "}\n\n"
<< "} // namespace pybind\n" << "} // namespace pybind\n"
<< "} // namespace paddle\n"; << "} // namespace paddle\n";
......
...@@ -29,6 +29,9 @@ limitations under the License. */ ...@@ -29,6 +29,9 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
PyTypeObject *g_vartype_pytype = nullptr;
PyTypeObject *g_blockdesc_pytype = nullptr;
namespace pd = paddle::framework; namespace pd = paddle::framework;
template <typename T> template <typename T>
...@@ -82,8 +85,9 @@ void BindProgramDesc(pybind11::module *m) { ...@@ -82,8 +85,9 @@ void BindProgramDesc(pybind11::module *m) {
} }
void BindBlockDesc(pybind11::module *m) { void BindBlockDesc(pybind11::module *m) {
pybind11::class_<pd::BlockDesc>(*m, "BlockDesc", "") pybind11::class_<pd::BlockDesc> blockdesc(*m, "BlockDesc", "");
.def_property_readonly("id", &pd::BlockDesc::ID) g_blockdesc_pytype = (PyTypeObject *)blockdesc.ptr(); // NOLINT
blockdesc.def_property_readonly("id", &pd::BlockDesc::ID)
.def_property_readonly("parent", &pd::BlockDesc::Parent) .def_property_readonly("parent", &pd::BlockDesc::Parent)
.def("get_forward_block_idx", &pd::BlockDesc::ForwardBlockID) .def("get_forward_block_idx", &pd::BlockDesc::ForwardBlockID)
.def("_set_forward_block_idx", &pd::BlockDesc::SetForwardBlockID) .def("_set_forward_block_idx", &pd::BlockDesc::SetForwardBlockID)
...@@ -174,8 +178,9 @@ void BindVarDsec(pybind11::module *m) { ...@@ -174,8 +178,9 @@ void BindVarDsec(pybind11::module *m) {
.def("need_check_feed", &pd::VarDesc::NeedCheckFeed) .def("need_check_feed", &pd::VarDesc::NeedCheckFeed)
.def("set_need_check_feed", &pd::VarDesc::SetNeedCheckFeed); .def("set_need_check_feed", &pd::VarDesc::SetNeedCheckFeed);
pybind11::enum_<pd::proto::VarType::Type>(var_desc, "VarType", "") pybind11::enum_<pd::proto::VarType::Type> vartype(var_desc, "VarType", "");
.value("BOOL", pd::proto::VarType::BOOL) g_vartype_pytype = (PyTypeObject *)vartype.ptr(); // NOLINT
vartype.value("BOOL", pd::proto::VarType::BOOL)
.value("UINT8", pd::proto::VarType::UINT8) .value("UINT8", pd::proto::VarType::UINT8)
.value("INT8", pd::proto::VarType::INT8) .value("INT8", pd::proto::VarType::INT8)
.value("INT16", pd::proto::VarType::INT16) .value("INT16", pd::proto::VarType::INT16)
......
...@@ -357,7 +357,7 @@ def convert_shape_to_list(shape): ...@@ -357,7 +357,7 @@ def convert_shape_to_list(shape):
map(lambda x: x.numpy()[0] if isinstance(x, Variable) else x, map(lambda x: x.numpy()[0] if isinstance(x, Variable) else x,
shape)) shape))
else: else:
shape = list(shape.numpy().astype(int)) shape = shape.numpy().astype(int).tolist()
return shape return shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册