未验证 提交 32d79bb9 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Refactored Python-C Attributes Parsing Functions (#39328)

上级 7b70b792
......@@ -100,17 +100,15 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj) {
bool PyObject_CheckString(PyObject* obj) { return PyUnicode_Check(obj); }
void CastPyArg2AttrBoolean(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
bool CastPyArg2Boolean(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
if (obj == Py_None) {
attrs[key] = false; // To be compatible with QA integration testing. Some
return false; // To be compatible with QA integration testing. Some
// test case pass in None.
} else if (obj == Py_True) {
attrs[key] = true;
return true;
} else if (obj == Py_False) {
attrs[key] = false;
return false;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
......@@ -118,14 +116,20 @@ void CastPyArg2AttrBoolean(PyObject* obj,
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
return false;
}
void CastPyArg2AttrInt(PyObject* obj,
void CastPyArg2AttrBoolean(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Boolean(obj, op_type, arg_pos);
}
int CastPyArg2Int(PyObject* obj, const std::string& op_type, ssize_t arg_pos) {
if (PyObject_CheckLongOrToLong(&obj)) {
attrs[key] = (int)PyLong_AsLong(obj); // NOLINT
return (int)PyLong_AsLong(obj); // NOLINT
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
......@@ -133,14 +137,21 @@ void CastPyArg2AttrInt(PyObject* obj,
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
return 0;
}
void CastPyArg2AttrLong(PyObject* obj,
void CastPyArg2AttrInt(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Int(obj, op_type, arg_pos);
}
int64_t CastPyArg2Long(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
if (PyObject_CheckLongOrToLong(&obj)) {
attrs[key] = (int64_t)PyLong_AsLong(obj); // NOLINT
return (int64_t)PyLong_AsLong(obj); // NOLINT
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
......@@ -148,14 +159,21 @@ void CastPyArg2AttrLong(PyObject* obj,
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
return 0;
}
void CastPyArg2AttrFloat(PyObject* obj,
void CastPyArg2AttrLong(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Long(obj, op_type, arg_pos);
}
float CastPyArg2Float(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
if (PyObject_CheckFloatOrToFloat(&obj)) {
attrs[key] = (float)PyFloat_AsDouble(obj); // NOLINT
return (float)PyFloat_AsDouble(obj); // NOLINT
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
......@@ -163,17 +181,24 @@ void CastPyArg2AttrFloat(PyObject* obj,
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
return 0.0;
}
void CastPyArg2AttrString(PyObject* obj,
void CastPyArg2AttrFloat(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Float(obj, op_type, arg_pos);
}
std::string CastPyArg2String(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
if (PyObject_CheckString(obj)) {
Py_ssize_t size;
const char* data;
data = PyUnicode_AsUTF8AndSize(obj, &size);
attrs[key] = std::string(data, (size_t)size); // NOLINT
return std::string(data, (size_t)size); // NOLINT
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
......@@ -181,16 +206,23 @@ void CastPyArg2AttrString(PyObject* obj,
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
return "";
}
void CastPyArg2AttrBooleans(PyObject* obj,
void CastPyArg2AttrString(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2String(obj, op_type, arg_pos);
}
std::vector<bool> CastPyArg2Booleans(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
std::vector<bool> value;
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr;
std::vector<bool> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
if (PyObject_CheckBool(&item)) {
......@@ -204,11 +236,9 @@ void CastPyArg2AttrBooleans(PyObject* obj,
i));
}
}
attrs[key] = value;
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
std::vector<bool> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
if (PyObject_CheckBool(&item)) {
......@@ -222,7 +252,6 @@ void CastPyArg2AttrBooleans(PyObject* obj,
i));
}
}
attrs[key] = value;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
......@@ -230,16 +259,23 @@ void CastPyArg2AttrBooleans(PyObject* obj,
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
return value;
}
void CastPyArg2AttrInts(PyObject* obj,
void CastPyArg2AttrBooleans(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Booleans(obj, op_type, arg_pos);
}
std::vector<int> CastPyArg2Ints(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
std::vector<int> value;
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr;
std::vector<int> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) {
......@@ -253,11 +289,9 @@ void CastPyArg2AttrInts(PyObject* obj,
i));
}
}
attrs[key] = value;
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
std::vector<int> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) {
......@@ -271,11 +305,9 @@ void CastPyArg2AttrInts(PyObject* obj,
i));
}
}
attrs[key] = value;
} else if (PySequence_Check(obj)) {
Py_ssize_t len = PySequence_Size(obj);
PyObject* item = nullptr;
std::vector<int> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) {
......@@ -289,7 +321,6 @@ void CastPyArg2AttrInts(PyObject* obj,
i));
}
}
attrs[key] = value;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
......@@ -297,16 +328,23 @@ void CastPyArg2AttrInts(PyObject* obj,
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
return value;
}
void CastPyArg2AttrLongs(PyObject* obj,
void CastPyArg2AttrInts(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Ints(obj, op_type, arg_pos);
}
std::vector<int64_t> CastPyArg2Longs(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
std::vector<int64_t> value;
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr;
std::vector<int64_t> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) {
......@@ -320,11 +358,9 @@ void CastPyArg2AttrLongs(PyObject* obj,
i));
}
}
attrs[key] = value;
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
std::vector<int64_t> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) {
......@@ -338,11 +374,9 @@ void CastPyArg2AttrLongs(PyObject* obj,
i));
}
}
attrs[key] = value;
} else if (PySequence_Check(obj)) {
Py_ssize_t len = PySequence_Size(obj);
PyObject* item = nullptr;
std::vector<int64_t> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) {
......@@ -356,7 +390,6 @@ void CastPyArg2AttrLongs(PyObject* obj,
i));
}
}
attrs[key] = value;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
......@@ -364,16 +397,23 @@ void CastPyArg2AttrLongs(PyObject* obj,
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
return value;
}
void CastPyArg2AttrFloats(PyObject* obj,
void CastPyArg2AttrLongs(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Longs(obj, op_type, arg_pos);
}
std::vector<float> CastPyArg2Floats(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
std::vector<float> value;
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr;
std::vector<float> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) {
......@@ -387,11 +427,9 @@ void CastPyArg2AttrFloats(PyObject* obj,
i));
}
}
attrs[key] = value;
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
std::vector<float> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) {
......@@ -405,11 +443,9 @@ void CastPyArg2AttrFloats(PyObject* obj,
i));
}
}
attrs[key] = value;
} else if (PySequence_Check(obj)) {
Py_ssize_t len = PySequence_Size(obj);
PyObject* item = nullptr;
std::vector<float> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) {
......@@ -423,7 +459,6 @@ void CastPyArg2AttrFloats(PyObject* obj,
i));
}
}
attrs[key] = value;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
......@@ -431,16 +466,24 @@ void CastPyArg2AttrFloats(PyObject* obj,
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
return value;
}
void CastPyArg2AttrFloat64s(PyObject* obj,
void CastPyArg2AttrFloats(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Floats(obj, op_type, arg_pos);
}
std::vector<double> CastPyArg2Float64s(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos) {
std::vector<double> value;
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr;
std::vector<double> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) {
......@@ -454,11 +497,9 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
i));
}
}
attrs[key] = value;
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
std::vector<double> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) {
......@@ -472,11 +513,9 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
i));
}
}
attrs[key] = value;
} else if (PySequence_Check(obj)) {
Py_ssize_t len = PySequence_Size(obj);
PyObject* item = nullptr;
std::vector<double> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) {
......@@ -490,7 +529,6 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
i));
}
}
attrs[key] = value;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
......@@ -498,16 +536,24 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
return value;
}
void CastPyArg2AttrStrings(PyObject* obj,
void CastPyArg2AttrFloat64s(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Float64s(obj, op_type, arg_pos);
}
std::vector<std::string> CastPyArg2Strings(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos) {
std::vector<std::string> value;
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr;
std::vector<std::string> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
if (PyObject_CheckString(item)) {
......@@ -524,11 +570,9 @@ void CastPyArg2AttrStrings(PyObject* obj,
i));
}
}
attrs[key] = value;
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
std::vector<std::string> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
if (PyObject_CheckString(item)) {
......@@ -545,7 +589,6 @@ void CastPyArg2AttrStrings(PyObject* obj,
i));
}
}
attrs[key] = value;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
......@@ -553,6 +596,15 @@ void CastPyArg2AttrStrings(PyObject* obj,
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
return value;
}
void CastPyArg2AttrStrings(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Strings(obj, op_type, arg_pos);
}
void CastPyArg2AttrBlock(PyObject* obj,
......
......@@ -43,6 +43,30 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj);
bool PyObject_CheckString(PyObject* obj);
bool CastPyArg2Boolean(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
int CastPyArg2Int(PyObject* obj, const std::string& op_type, ssize_t arg_pos);
int64_t CastPyArg2Long(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
float CastPyArg2Float(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
std::string CastPyArg2String(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
std::vector<bool> CastPyArg2Booleans(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
std::vector<int> CastPyArg2Ints(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
std::vector<int64_t> CastPyArg2Longs(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
std::vector<float> CastPyArg2Floats(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
std::vector<double> CastPyArg2Float64s(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos);
std::vector<std::string> CastPyArg2Strings(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos);
void CastPyArg2AttrBoolean(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册