未验证 提交 08e81475 编写于 作者: W wanghuancoder 提交者: GitHub

use PYTHON_C_API in dygraph (#32524)

* use PYTHON_C_API in dygraph, test=develop
上级 022198c5
......@@ -51,6 +51,8 @@ limitations under the License. */
namespace paddle {
namespace pybind {
PyTypeObject *g_varbase_pytype = nullptr;
namespace py = ::pybind11;
class Layer : public imperative::Layer {
......@@ -470,9 +472,9 @@ static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
}
template <typename P>
static void VarBaseCopy(std::shared_ptr<imperative::VarBase> &src,
imperative::VarBase &dst, const P &dst_device,
const bool blocking) {
static void VarBaseCopy(std::shared_ptr<imperative::VarBase> &src, // NOLINT
imperative::VarBase &dst, // NOLINT
const P &dst_device, const bool blocking) {
if (dst.SharedVar()->IsEmpty()) {
VLOG(3) << "deep copy Variable from " << src->Name() << " to "
<< dst.Name();
......@@ -667,9 +669,10 @@ void BindImperative(py::module *m_ptr) {
imperative::SetCurrentTracer(tracer);
});
py::class_<imperative::VarBase, std::shared_ptr<imperative::VarBase>>(
m, "VarBase", R"DOC()DOC")
.def_static("_alive_vars", &imperative::VarBase::AliveVarNames)
py::class_<imperative::VarBase, std::shared_ptr<imperative::VarBase>> varbase(
m, "VarBase", R"DOC()DOC");
g_varbase_pytype = (PyTypeObject *)varbase.ptr(); // NOLINT
varbase.def_static("_alive_vars", &imperative::VarBase::AliveVarNames)
.def("__init__",
[](imperative::VarBase &self) {
std::string name =
......@@ -1468,21 +1471,15 @@ void BindImperative(py::module *m_ptr) {
&imperative::VarBase::SetOverridedStopGradient)
.def_property("persistable", &imperative::VarBase::Persistable,
&imperative::VarBase::SetPersistable)
.def_property_readonly("shape",
.def_property_readonly(
"shape",
[](imperative::VarBase &self) {
if (self.Var().IsType<framework::LoDTensor>()) {
return framework::vectorize<int>(
self.Var()
.Get<framework::LoDTensor>()
.dims());
} else if (self.Var()
.IsType<
framework::SelectedRows>()) {
self.Var().Get<framework::LoDTensor>().dims());
} else if (self.Var().IsType<framework::SelectedRows>()) {
return framework::vectorize<int>(
self.Var()
.Get<framework::SelectedRows>()
.value()
.dims());
self.Var().Get<framework::SelectedRows>().value().dims());
} else {
VLOG(2) << "It is meaningless to get shape of "
"variable type "
......
......@@ -25,6 +25,7 @@
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/type_defs.h"
......@@ -34,6 +35,28 @@ namespace py = pybind11;
namespace paddle {
namespace pybind {
class OpAttrTypeMap {
public:
static OpAttrTypeMap& Instance() {
static OpAttrTypeMap g_op_attr_type_map;
return g_op_attr_type_map;
}
std::unordered_map<
std::string,
std::unordered_map<std::string, paddle::framework::proto::AttrType>>&
Map() {
return ops_attrtype_map_;
}
private:
OpAttrTypeMap() = default;
std::unordered_map<
std::string,
std::unordered_map<std::string, paddle::framework::proto::AttrType>>
ops_attrtype_map_;
};
static inline std::shared_ptr<imperative::VarBase> CastPyHandleToVarBase(
const std::string& op_type, const std::string& arg_name, int arg_idx,
const py::handle& handle, bool dispensable = false) {
......@@ -173,6 +196,839 @@ static inline void HandleViewBetweenInputAndOutput(
<< "), share allocation and inplace version.";
}
}
extern PyTypeObject* g_varbase_pytype;
extern PyTypeObject* g_vartype_pytype;
extern PyTypeObject* g_blockdesc_pytype;
inline bool PyObject_CheckBool(PyObject** obj) { return PyBool_Check(*obj); }
inline bool PyObject_CheckLongOrToLong(PyObject** obj) {
if ((PyLong_Check(*obj) && !PyBool_Check(*obj)) ||
PyObject_IsInstance(*obj, (PyObject*)g_vartype_pytype) || // NOLINT
PyObject_IsInstance(*obj, (PyObject*)g_varbase_pytype)) { // NOLINT
return true;
}
auto to = PyNumber_Long(*obj);
if (to) {
*obj = to;
return true;
}
return false;
}
inline bool PyObject_CheckFloatOrToFloat(PyObject** obj) {
// sometimes users provide PyLong or numpy.int64 but attr is float
if (PyFloat_Check(*obj) || PyLong_Check(*obj) ||
PyObject_IsInstance(*obj, (PyObject*)g_varbase_pytype)) { // NOLINT
return true;
}
auto to = PyNumber_Float(*obj);
if (to) {
*obj = to;
return true;
}
return false;
}
inline bool PyObject_CheckString(PyObject* obj) { return PyUnicode_Check(obj); }
static inline void CastPyArg2AttrBoolean(
PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, ssize_t arg_pos) {
if (obj == Py_None) {
attrs[key] = false; // To be compatible with QA integration testing. Some
// test case pass in None.
} else if (obj == Py_True) {
attrs[key] = true;
} else if (obj == Py_False) {
attrs[key] = false;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"bool, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
}
static inline void CastPyArg2AttrInt(
PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, ssize_t arg_pos) {
if (PyObject_CheckLongOrToLong(&obj)) {
attrs[key] = (int)PyLong_AsLong(obj); // NOLINT
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"int, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
}
static inline void CastPyArg2AttrLong(
PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, ssize_t arg_pos) {
if (PyObject_CheckLongOrToLong(&obj)) {
attrs[key] = (int64_t)PyLong_AsLong(obj); // NOLINT
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"long, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
}
static inline void CastPyArg2AttrFloat(
PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, ssize_t arg_pos) {
if (PyObject_CheckFloatOrToFloat(&obj)) {
attrs[key] = (float)PyFloat_AsDouble(obj); // NOLINT
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"float, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
}
static inline void CastPyArg2AttrString(
PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, ssize_t arg_pos) {
if (PyObject_CheckString(obj)) {
Py_ssize_t size;
const char* data;
data = PyUnicode_AsUTF8AndSize(obj, &size);
attrs[key] = std::string(data, (size_t)size);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"str, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
}
static inline void CastPyArg2AttrBooleans(
PyObject* obj, paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, ssize_t arg_pos) {
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr;
std::vector<bool> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
if (PyObject_CheckBool(&item)) {
value.emplace_back(PyLong_AsLong(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of bool, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
std::vector<bool> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
if (PyObject_CheckBool(&item)) {
value.emplace_back(PyLong_AsLong(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of bool, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list or tuple, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
}
static inline void CastPyArg2AttrInts(
PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, ssize_t arg_pos) {
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr;
std::vector<int> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) {
value.emplace_back(PyLong_AsLong(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of int, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
std::vector<int> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) {
value.emplace_back(PyLong_AsLong(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of int, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else if (PySequence_Check(obj)) {
Py_ssize_t len = PySequence_Size(obj);
PyObject* item = nullptr;
std::vector<int> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) {
value.emplace_back(PyLong_AsLong(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of int, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list or tuple, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
}
static inline void CastPyArg2AttrLongs(
PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, ssize_t arg_pos) {
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr;
std::vector<int64_t> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) {
value.emplace_back(PyLong_AsLong(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of int, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
std::vector<int64_t> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) {
value.emplace_back(PyLong_AsLong(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of int, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else if (PySequence_Check(obj)) {
Py_ssize_t len = PySequence_Size(obj);
PyObject* item = nullptr;
std::vector<int64_t> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) {
value.emplace_back(PyLong_AsLong(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of int, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list or tuple, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
}
static inline void CastPyArg2AttrFloats(
PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, ssize_t arg_pos) {
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr;
std::vector<float> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) {
value.emplace_back(PyFloat_AsDouble(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of float, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
std::vector<float> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) {
value.emplace_back(PyFloat_AsDouble(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of float, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else if (PySequence_Check(obj)) {
Py_ssize_t len = PySequence_Size(obj);
PyObject* item = nullptr;
std::vector<float> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) {
value.emplace_back(PyFloat_AsDouble(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of float, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list or tuple, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
}
static inline void CastPyArg2AttrFloat64s(
PyObject* obj, paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, ssize_t arg_pos) {
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr;
std::vector<double> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) {
value.emplace_back(PyFloat_AsDouble(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of float, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
std::vector<double> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) {
value.emplace_back(PyFloat_AsDouble(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of float, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else if (PySequence_Check(obj)) {
Py_ssize_t len = PySequence_Size(obj);
PyObject* item = nullptr;
std::vector<double> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) {
value.emplace_back(PyFloat_AsDouble(item));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of float, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list or tuple, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
}
static inline void CastPyArg2AttrStrings(
PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, ssize_t arg_pos) {
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr;
std::vector<std::string> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
if (PyObject_CheckString(item)) {
Py_ssize_t size;
const char* data;
data = PyUnicode_AsUTF8AndSize(item, &size);
value.emplace_back(std::string(data, (size_t)size)); // NOLINT
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of str, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
std::vector<std::string> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
if (PyObject_CheckString(item)) {
Py_ssize_t size;
const char* data;
data = PyUnicode_AsUTF8AndSize(item, &size);
value.emplace_back(std::string(data, (size_t)size));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of str, but got %s at pos %d",
op_type, arg_pos + 1,
((PyTypeObject*)item->ob_type)->tp_name, // NOLINT
i));
}
}
attrs[key] = value;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list or tuple, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
}
static inline void CastPyArg2AttrBlock(
PyObject* obj, paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, ssize_t arg_pos) {
::pybind11::detail::instance* inst =
(::pybind11::detail::instance*)obj; // NOLINT
if (!PyObject_IsInstance((PyObject*)inst, // NOLINT
(PyObject*)g_blockdesc_pytype)) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"BlockDesc, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
void** vh = inst->simple_layout ? inst->simple_value_holder
: &inst->nonsimple.values_and_holders[0];
attrs[key] = reinterpret_cast<paddle::framework::BlockDesc*&>(vh[0]);
}
static inline void ConstructAttrMapFromPyArgs(
const std::string& op_type, PyObject* args, ssize_t attr_start,
ssize_t attr_end, paddle::framework::AttributeMap& attrs) { // NOLINT
PADDLE_ENFORCE_EQ(
(attr_end - attr_start) % 2, 0,
platform::errors::InvalidArgument(
"The number of arguments for attributes should be even."));
auto attr_type_map = &(OpAttrTypeMap::Instance().Map()[op_type]);
PyObject* obj = nullptr;
for (ssize_t arg_pos = attr_start; arg_pos < attr_end; arg_pos += 2) {
Py_ssize_t key_len;
const char* key_ptr;
obj = PyTuple_GET_ITEM(args, arg_pos);
if (PyObject_CheckString(obj)) {
key_ptr = PyUnicode_AsUTF8AndSize(obj, &key_len);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be str, but got "
"%s",
op_type, arg_pos, ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
std::string key(key_ptr, (size_t)key_len);
auto iter = attr_type_map->find(key);
if (iter == attr_type_map->end()) {
continue;
}
obj = PyTuple_GET_ITEM(args, arg_pos + 1);
switch (iter->second) {
case paddle::framework::proto::AttrType::INT:
CastPyArg2AttrInt(obj, attrs, key, op_type, arg_pos);
break;
case paddle::framework::proto::AttrType::FLOAT:
CastPyArg2AttrFloat(obj, attrs, key, op_type, arg_pos);
break;
case paddle::framework::proto::AttrType::STRING:
CastPyArg2AttrString(obj, attrs, key, op_type, arg_pos);
break;
case paddle::framework::proto::AttrType::INTS:
CastPyArg2AttrInts(obj, attrs, key, op_type, arg_pos);
break;
case paddle::framework::proto::AttrType::FLOATS:
CastPyArg2AttrFloats(obj, attrs, key, op_type, arg_pos);
break;
case paddle::framework::proto::AttrType::STRINGS:
CastPyArg2AttrStrings(obj, attrs, key, op_type, arg_pos);
break;
case paddle::framework::proto::AttrType::BOOLEAN:
CastPyArg2AttrBoolean(obj, attrs, key, op_type, arg_pos);
break;
case paddle::framework::proto::AttrType::BOOLEANS:
CastPyArg2AttrBooleans(obj, attrs, key, op_type, arg_pos);
break;
case paddle::framework::proto::AttrType::LONG:
CastPyArg2AttrLong(obj, attrs, key, op_type, arg_pos);
break;
case paddle::framework::proto::AttrType::LONGS:
CastPyArg2AttrLongs(obj, attrs, key, op_type, arg_pos);
break;
case paddle::framework::proto::AttrType::FLOAT64S:
CastPyArg2AttrFloat64s(obj, attrs, key, op_type, arg_pos);
break;
case paddle::framework::proto::AttrType::BLOCK:
CastPyArg2AttrBlock(obj, attrs, key, op_type, arg_pos);
break;
default:
break;
}
}
}
static inline std::shared_ptr<imperative::VarBase> GetVarBaseFromArgs(
const std::string& op_type, const std::string& arg_name, PyObject* args,
ssize_t arg_idx, bool dispensable = false) {
::pybind11::detail::instance* inst =
(::pybind11::detail::instance*)PyTuple_GET_ITEM(args, arg_idx);
if (PyTuple_Check((PyObject*)inst)) { // NOLINT
inst = (::pybind11::detail::instance*)PyTuple_GET_ITEM(inst, 0);
}
if (inst == nullptr || (PyObject*)inst == Py_None) { // NOLINT
if (!dispensable) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be Tensor, but got None",
op_type, arg_name, arg_idx));
}
return nullptr;
}
if (!PyObject_IsInstance((PyObject*)inst, // NOLINT
(PyObject*)g_varbase_pytype)) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be Tensor, but got "
"%s",
op_type, arg_name, arg_idx,
((PyTypeObject*)((PyObject*)inst)->ob_type)->tp_name)); // NOLINT
}
void** vh = inst->simple_layout ? inst->simple_value_holder
: &inst->nonsimple.values_and_holders[0];
return reinterpret_cast<std::shared_ptr<paddle::imperative::VarBase>&>(vh[1]);
}
static inline std::vector<std::shared_ptr<imperative::VarBase>>
GetVarBaseListFromArgs(const std::string& op_type, const std::string& arg_name,
PyObject* args, ssize_t arg_idx,
bool dispensable = false) {
PyObject* list = PyTuple_GET_ITEM(args, arg_idx);
if (list == nullptr) {
if (!dispensable) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensor, but got "
"None",
op_type, arg_name, arg_idx)); // NOLINT
}
return {};
}
std::vector<std::shared_ptr<imperative::VarBase>> result;
if (PyList_Check(list)) {
Py_ssize_t len = PyList_Size(list);
if (len == 0) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensors, but got "
"empty list",
op_type, arg_name, arg_idx));
}
::pybind11::detail::instance* item = nullptr;
for (Py_ssize_t i = 0; i < len; i++) {
item = (::pybind11::detail::instance*)PyList_GetItem(list, i);
if (!PyObject_IsInstance((PyObject*)item, // NOLINT
(PyObject*)g_varbase_pytype)) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensors, but "
"got list of "
"%s",
op_type, arg_name, arg_idx,
((PyTypeObject*)((PyObject*)item)->ob_type)->tp_name)); // NOLINT
}
void** vh = item->simple_layout ? item->simple_value_holder
: &item->nonsimple.values_and_holders[0];
result.emplace_back(
reinterpret_cast<std::shared_ptr<paddle::imperative::VarBase>&>(
vh[1]));
}
} else if (PyTuple_Check(list)) {
Py_ssize_t len = PyTuple_Size(list);
if (len == 0) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensors, but got "
"empty list",
op_type, arg_name, arg_idx));
}
::pybind11::detail::instance* item = nullptr;
for (Py_ssize_t i = 0; i < len; i++) {
item = (::pybind11::detail::instance*)PyTuple_GetItem(list, i); // NOLINT
if (!PyObject_IsInstance((PyObject*)item, // NOLINT
(PyObject*)g_varbase_pytype)) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensors, but "
"got list of "
"%s",
op_type, arg_name, arg_idx,
((PyTypeObject*)((PyObject*)item)->ob_type)->tp_name)); // NOLINT
}
void** vh = item->simple_layout ? item->simple_value_holder
: &item->nonsimple.values_and_holders[0];
result.emplace_back(
reinterpret_cast<std::shared_ptr<paddle::imperative::VarBase>&>(
vh[1]));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be list of Tensors, but got "
"%s",
op_type, arg_name, arg_idx,
((PyTypeObject*)list->ob_type)->tp_name)); // NOLINT
}
return result;
}
static inline unsigned long GetUnsignedLongFromArgs( // NOLINT
const std::string& op_type, const std::string& arg_name, PyObject* args,
ssize_t arg_idx, bool dispensable = false) {
PyObject* item = PyTuple_GET_ITEM(args, arg_idx);
if (item == nullptr) {
if (!dispensable) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be long, but got None",
op_type, arg_name, arg_idx));
}
return 0;
}
if (PyObject_CheckLongOrToLong(&item)) {
return PyLong_AsUnsignedLong(item);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be "
"long, but got %s",
op_type, arg_name, arg_idx,
((PyTypeObject*)item->ob_type)->tp_name)); // NOLINT
}
}
static inline PyObject* MakeReturnPyObject(
const std::shared_ptr<paddle::imperative::VarBase>& out) {
return ::pybind11::detail::type_caster_base<imperative::VarBase>::cast_holder(
::pybind11::detail::holder_helper<
std::shared_ptr<imperative::VarBase>>::get(out),
&out)
.ptr();
}
static inline PyObject* MakeReturnPyObject(
const std::vector<std::shared_ptr<imperative::VarBase>>& out) {
PyObject* result = PyList_New((Py_ssize_t)out.size());
for (size_t i = 0; i < out.size(); i++) {
PyList_SET_ITEM(
result, (Py_ssize_t)i,
::pybind11::detail::type_caster_base<imperative::VarBase>::cast_holder(
::pybind11::detail::holder_helper<
std::shared_ptr<imperative::VarBase>>::get(out[i]),
&out[i])
.ptr()); // NOLINT
}
return result;
}
template <typename Tuple, size_t N>
struct TupleVarBasesResult {
static void Run(const Tuple& out, PyObject* result) {
TupleVarBasesResult<Tuple, N - 1>::Run(out, result);
PyTuple_SET_ITEM(result, N - 1, MakeReturnPyObject(std::get<N - 1>(out)));
}
};
template <typename Tuple>
struct TupleVarBasesResult<Tuple, 1> {
static void Run(const Tuple& out, PyObject* result) {
PyTuple_SET_ITEM(result, 0, MakeReturnPyObject(std::get<0>(out)));
}
};
template <typename... Args>
static inline PyObject* MakeReturnPyObject(const std::tuple<Args...>& out) {
auto len = sizeof...(Args);
PyObject* result = PyTuple_New(len);
TupleVarBasesResult<decltype(out), sizeof...(Args)>::Run(out, result);
return result;
}
void InitOpsAttrTypeMap() {
auto op_info_map = paddle::framework::OpInfoMap::Instance().map();
for (auto iter = op_info_map.begin(); iter != op_info_map.end(); ++iter) {
auto op_proto = iter->second.proto_;
if (op_proto == nullptr) {
continue;
}
auto attrs_proto = op_proto->attrs();
for (auto& attr : attrs_proto) {
OpAttrTypeMap::Instance().Map()[iter->first][attr.name()] = attr.type();
}
}
}
PyObject* EOFExceptionException =
PyErr_NewException("paddle.EOFException", PyExc_Exception, NULL);
PyObject* EnforceNotMetException =
PyErr_NewException("paddle.EnforceNotMet", PyExc_Exception, NULL);
void ThrowExceptionToPython(std::exception_ptr p) {
try {
if (p) std::rethrow_exception(p);
} catch (const platform::EOFException& e) {
PyErr_SetString(EOFExceptionException, e.what());
} catch (const platform::EnforceNotMet& e) {
switch (e.code()) {
case paddle::platform::error::INVALID_ARGUMENT:
PyErr_SetString(PyExc_ValueError, e.what());
break;
case paddle::platform::error::NOT_FOUND:
case paddle::platform::error::ALREADY_EXISTS:
case paddle::platform::error::PRECONDITION_NOT_MET:
case paddle::platform::error::PERMISSION_DENIED:
case paddle::platform::error::EXECUTION_TIMEOUT:
case paddle::platform::error::UNAVAILABLE:
PyErr_SetString(PyExc_RuntimeError, e.what());
break;
case paddle::platform::error::OUT_OF_RANGE:
PyErr_SetString(PyExc_IndexError, e.what());
break;
case paddle::platform::error::RESOURCE_EXHAUSTED:
PyErr_SetString(PyExc_MemoryError, e.what());
break;
case paddle::platform::error::UNIMPLEMENTED:
PyErr_SetString(PyExc_NotImplementedError, e.what());
break;
case paddle::platform::error::FATAL:
PyErr_SetString(PyExc_SystemError, e.what());
break;
case paddle::platform::error::EXTERNAL:
PyErr_SetString(PyExc_OSError, e.what());
break;
default:
PyErr_SetString(EnforceNotMetException, e.what());
break;
}
}
}
} // namespace pybind
} // namespace paddle
......
......@@ -212,16 +212,17 @@ const char* OUT_VAR_TYPE = R"(std::shared_ptr<imperative::VarBase>)";
const char* OUT_VAR_LIST_TYPE = R"(std::vector<std::shared_ptr<imperative::VarBase>>)";
const char* CAST_VAR_TEMPLATE = R"(
auto %s = CastPyHandleToVarBase("%s", "%s", %d, %s, %s);)";
auto %s = GetVarBaseFromArgs("%s", "%s", args, %d, %s);)";
const char* CAST_VAR_LIST_TEMPLATE = R"(
auto %s = CastPyHandleToVarBaseList("%s", "%s", %d, %s, %s);)";
auto %s = GetVarBaseListFromArgs("%s", "%s", args, %d, %s);)";
const char* CAST_SIZE_T_TEMPLATE = R"(
auto %s = GetUnsignedLongFromArgs("%s", "%s", args, %d, %s);)";
const char* ARG_TEMPLATE = R"(const %s& %s)";
const char* RETURN_TUPLE_TYPE = R"(std::tuple<%s>)";
const char* RETURN_TYPE = R"(%s)";
const char* RETURN_TUPLE_TEMPLATE = R"(std::make_tuple(%s))";
const char* RETURN_LIST_TEMPLATE = R"(outs["%s"])";
const char* RETURN_TEMPLATE = R"(outs["%s"][0])";
......@@ -251,23 +252,34 @@ const char* INPLACE_MAPPING_TEMPLATE = R"({"%s", "%s"})";
const char* OP_FUNCTION_TEMPLATE =
R"(
%s %s(%s)
static PyObject * %s(PyObject *self, PyObject *args, PyObject *kwargs)
{
PyThreadState *tstate = nullptr;
try
{
%s
framework::AttributeMap attrs;
ConstructAttrMapFromPyArgs("%s", %d, &attrs, args);
{
py::gil_scoped_release release;
ConstructAttrMapFromPyArgs("%s", args, %d, PyTuple_GET_SIZE(args) , attrs);
tstate = PyEval_SaveThread();
%s
imperative::NameVarBaseMap outs = %s;
imperative::NameVarBaseMap ins = %s;
%s
imperative::GetCurrentTracer()->TraceOp("%s", ins, outs, attrs, {%s});
PyEval_RestoreThread(tstate);
tstate = nullptr;
return %s;
}
catch(...) {
if (tstate) {
PyEval_RestoreThread(tstate);
}
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
})";
const char* PYBIND_ITEM_TEMPLATE = R"( %s.def("%s", &%s);)";
const char* PYBIND_ITEM_TEMPLATE = R"( {"%s", (PyCFunction)(void(*)(void))%s, METH_VARARGS | METH_KEYWORDS, "C++ interface function for %s in dygraph."},)";
// clang-format on
static inline bool FindInsMap(const std::string& op_type,
......@@ -326,9 +338,8 @@ std::string GenerateOpFunctionsBody(
const auto in_cast_type =
input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = input.dispensable() ? "true" : "false";
ins_cast_str +=
paddle::string::Sprintf(in_cast_type, in_name, op_type, in_name,
arg_idx++, TempName(in_name), dispensable);
ins_cast_str += paddle::string::Sprintf(in_cast_type, in_name, op_type,
in_name, arg_idx++, dispensable);
if (input.dispensable()) {
const auto in_template = input.duplicable()
......@@ -356,7 +367,6 @@ std::string GenerateOpFunctionsBody(
// Generate outs initializer
std::string outs_initializer = "{";
std::string outs_initializer_with_null = "";
std::string return_type = "";
std::string inplace_mapping_str = "";
std::string return_str = "";
......@@ -395,6 +405,12 @@ std::string GenerateOpFunctionsBody(
paddle::string::Sprintf(out_template, out_name, out_name);
outs_initializer += ",";
}
const auto in_cast_type =
output.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = output.dispensable() ? "true" : "false";
ins_cast_str += paddle::string::Sprintf(in_cast_type, out_name, op_type,
out_name, arg_idx++, dispensable);
} else if (use_inplace_strategy && inplace_map.count(out_name)) {
PADDLE_ENFORCE_NE(
inplace_map[out_name], "",
......@@ -440,6 +456,11 @@ std::string GenerateOpFunctionsBody(
input_args_num++;
outs_initializer += paddle::string::Sprintf(
OUT_DUPLICABLE_INITIALIZER_TEMPLATE, out_name, out_num_str);
auto dispensable = output.dispensable() ? "true" : "false";
ins_cast_str +=
paddle::string::Sprintf(CAST_SIZE_T_TEMPLATE, out_num_str, op_type,
out_num_str, arg_idx++, dispensable);
} else {
outs_initializer +=
paddle::string::Sprintf(OUT_INITIALIZER_TEMPLATE, out_name);
......@@ -447,15 +468,12 @@ std::string GenerateOpFunctionsBody(
outs_initializer += ",";
}
return_type += out_type;
return_type += ",";
return_str += paddle::string::Sprintf(return_template, out_name);
return_str += ",";
outs_num += 1;
}
if (outs_initializer.back() == ',') {
outs_initializer.pop_back();
return_type.pop_back();
return_str.pop_back();
}
outs_initializer += "}";
......@@ -470,11 +488,13 @@ std::string GenerateOpFunctionsBody(
viwe_input_name, viwe_output_name);
}
if (outs_num == 0) {
return_type = "void";
}
if (outs_num > 1) {
return_str = paddle::string::Sprintf(RETURN_TUPLE_TEMPLATE, return_str);
return_type = paddle::string::Sprintf(RETURN_TUPLE_TYPE, return_type);
return_str = "Py_None";
} else if (outs_num == 1) {
return_str = "MakeReturnPyObject(" + return_str + ")";
} else {
return_str = "MakeReturnPyObject(" +
paddle::string::Sprintf(RETURN_TUPLE_TEMPLATE, return_str) +
")";
}
std::string function_args = "";
if (input_args == "") {
......@@ -485,9 +505,9 @@ std::string GenerateOpFunctionsBody(
// generate op funtcion body
auto op_function_str = paddle::string::Sprintf(
OP_FUNCTION_TEMPLATE, return_type, func_name, function_args, ins_cast_str,
op_type, input_args_num, inplace_strategy_str, outs_initializer,
ins_initializer, ins_initializer_with_null + outs_initializer_with_null +
OP_FUNCTION_TEMPLATE, func_name, ins_cast_str, op_type, input_args_num,
inplace_strategy_str, outs_initializer, ins_initializer,
ins_initializer_with_null + outs_initializer_with_null +
view_strategy_str,
op_type, inplace_mapping_str, return_str);
......@@ -495,7 +515,7 @@ std::string GenerateOpFunctionsBody(
}
static std::tuple<std::vector<std::string>, std::vector<std::string>>
GenerateOpFunctions(const std::string& module_name) {
GenerateOpFunctions() {
auto& op_info_map = paddle::framework::OpInfoMap::Instance().map();
std::vector<std::string> op_function_list, bind_function_list;
......@@ -536,7 +556,7 @@ GenerateOpFunctions(const std::string& module_name) {
// generate pybind item
auto bind_function_str = paddle::string::Sprintf(
PYBIND_ITEM_TEMPLATE, module_name, op_type, func_name);
PYBIND_ITEM_TEMPLATE, op_type, func_name, op_type);
op_function_list.emplace_back(std::move(op_function_str));
bind_function_list.emplace_back(std::move(bind_function_str));
......@@ -551,8 +571,8 @@ GenerateOpFunctions(const std::string& module_name) {
// generate pybind item
auto inplace_bind_function_str =
paddle::string::Sprintf(PYBIND_ITEM_TEMPLATE, module_name,
inplace_op_type, inplace_func_name);
paddle::string::Sprintf(PYBIND_ITEM_TEMPLATE, inplace_op_type,
inplace_func_name, inplace_op_type);
op_function_list.emplace_back(std::move(inplace_op_function_str));
bind_function_list.emplace_back(std::move(inplace_bind_function_str));
......@@ -572,7 +592,9 @@ int main(int argc, char* argv[]) {
ascend_ptr->InitGEForUT();
#endif
std::vector<std::string> headers{"\"paddle/fluid/imperative/tracer.h\""};
std::vector<std::string> headers{"\"paddle/fluid/imperative/tracer.h\"",
"\"pybind11/detail/common.h\"",
"<Python.h>"};
std::ofstream out(argv[1], std::ios::out);
......@@ -582,22 +604,29 @@ int main(int argc, char* argv[]) {
out << "#include " + header + "\n";
}
auto op_funcs = GenerateOpFunctions("m");
out << "\n\n";
auto op_funcs = GenerateOpFunctions();
out << "namespace py = pybind11;"
<< "\n";
out << "namespace paddle {\n"
<< "namespace pybind {\n\n";
out << "std::atomic<int> VarBaseUniqueNameID{0};\n";
out << paddle::string::join_strings(std::get<0>(op_funcs), '\n');
out << "\n\n";
out << "inline void BindOpFunctions(pybind11::module *module) {\n"
<< " auto m = module->def_submodule(\"ops\");\n\n";
out << "static PyMethodDef ExtestMethods[] = {\n"
<< paddle::string::join_strings(std::get<1>(op_funcs), '\n')
<< "\n {nullptr,nullptr,0,nullptr}"
<< "};\n\n";
out << paddle::string::join_strings(std::get<1>(op_funcs), '\n');
out << "\n";
out << "}\n\n"
out << "inline void BindOpFunctions(pybind11::module *module) {\n"
<< " auto m = module->def_submodule(\"ops\");\n"
<< " if (PyModule_AddFunctions(m.ptr(), ExtestMethods) < 0) {\n"
<< " PADDLE_THROW(platform::errors::Fatal (\"Add functions to "
"core.ops failed!\"));\n"
<< " }\n\n"
<< " InitOpsAttrTypeMap();"
<< "}\n\n"
<< "} // namespace pybind\n"
<< "} // namespace paddle\n";
......
......@@ -29,6 +29,9 @@ limitations under the License. */
namespace paddle {
namespace pybind {
PyTypeObject *g_vartype_pytype = nullptr;
PyTypeObject *g_blockdesc_pytype = nullptr;
namespace pd = paddle::framework;
template <typename T>
......@@ -82,8 +85,9 @@ void BindProgramDesc(pybind11::module *m) {
}
void BindBlockDesc(pybind11::module *m) {
pybind11::class_<pd::BlockDesc>(*m, "BlockDesc", "")
.def_property_readonly("id", &pd::BlockDesc::ID)
pybind11::class_<pd::BlockDesc> blockdesc(*m, "BlockDesc", "");
g_blockdesc_pytype = (PyTypeObject *)blockdesc.ptr(); // NOLINT
blockdesc.def_property_readonly("id", &pd::BlockDesc::ID)
.def_property_readonly("parent", &pd::BlockDesc::Parent)
.def("get_forward_block_idx", &pd::BlockDesc::ForwardBlockID)
.def("_set_forward_block_idx", &pd::BlockDesc::SetForwardBlockID)
......@@ -174,8 +178,9 @@ void BindVarDsec(pybind11::module *m) {
.def("need_check_feed", &pd::VarDesc::NeedCheckFeed)
.def("set_need_check_feed", &pd::VarDesc::SetNeedCheckFeed);
pybind11::enum_<pd::proto::VarType::Type>(var_desc, "VarType", "")
.value("BOOL", pd::proto::VarType::BOOL)
pybind11::enum_<pd::proto::VarType::Type> vartype(var_desc, "VarType", "");
g_vartype_pytype = (PyTypeObject *)vartype.ptr(); // NOLINT
vartype.value("BOOL", pd::proto::VarType::BOOL)
.value("UINT8", pd::proto::VarType::UINT8)
.value("INT8", pd::proto::VarType::INT8)
.value("INT16", pd::proto::VarType::INT16)
......
......@@ -357,7 +357,7 @@ def convert_shape_to_list(shape):
map(lambda x: x.numpy()[0] if isinstance(x, Variable) else x,
shape))
else:
shape = list(shape.numpy().astype(int))
shape = shape.numpy().astype(int).tolist()
return shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册