未验证 提交 32d79bb9 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Refactored Python-C Attributes Parsing Functions (#39328)

上级 7b70b792
...@@ -100,17 +100,15 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj) { ...@@ -100,17 +100,15 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj) {
bool PyObject_CheckString(PyObject* obj) { return PyUnicode_Check(obj); } bool PyObject_CheckString(PyObject* obj) { return PyUnicode_Check(obj); }
void CastPyArg2AttrBoolean(PyObject* obj, bool CastPyArg2Boolean(PyObject* obj, const std::string& op_type,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
ssize_t arg_pos) { ssize_t arg_pos) {
if (obj == Py_None) { if (obj == Py_None) {
attrs[key] = false; // To be compatible with QA integration testing. Some return false; // To be compatible with QA integration testing. Some
// test case pass in None. // test case pass in None.
} else if (obj == Py_True) { } else if (obj == Py_True) {
attrs[key] = true; return true;
} else if (obj == Py_False) { } else if (obj == Py_False) {
attrs[key] = false; return false;
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be " "%s(): argument (position %d) must be "
...@@ -118,14 +116,20 @@ void CastPyArg2AttrBoolean(PyObject* obj, ...@@ -118,14 +116,20 @@ void CastPyArg2AttrBoolean(PyObject* obj,
op_type, arg_pos + 1, op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
} }
return false;
} }
void CastPyArg2AttrInt(PyObject* obj, void CastPyArg2AttrBoolean(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, const std::string& key, const std::string& op_type,
ssize_t arg_pos) { ssize_t arg_pos) {
attrs[key] = CastPyArg2Boolean(obj, op_type, arg_pos);
}
int CastPyArg2Int(PyObject* obj, const std::string& op_type, ssize_t arg_pos) {
if (PyObject_CheckLongOrToLong(&obj)) { if (PyObject_CheckLongOrToLong(&obj)) {
attrs[key] = (int)PyLong_AsLong(obj); // NOLINT return (int)PyLong_AsLong(obj); // NOLINT
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be " "%s(): argument (position %d) must be "
...@@ -133,14 +137,21 @@ void CastPyArg2AttrInt(PyObject* obj, ...@@ -133,14 +137,21 @@ void CastPyArg2AttrInt(PyObject* obj,
op_type, arg_pos + 1, op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
} }
return 0;
} }
void CastPyArg2AttrLong(PyObject* obj, void CastPyArg2AttrInt(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, const std::string& key, const std::string& op_type,
ssize_t arg_pos) { ssize_t arg_pos) {
attrs[key] = CastPyArg2Int(obj, op_type, arg_pos);
}
int64_t CastPyArg2Long(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
if (PyObject_CheckLongOrToLong(&obj)) { if (PyObject_CheckLongOrToLong(&obj)) {
attrs[key] = (int64_t)PyLong_AsLong(obj); // NOLINT return (int64_t)PyLong_AsLong(obj); // NOLINT
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be " "%s(): argument (position %d) must be "
...@@ -148,14 +159,21 @@ void CastPyArg2AttrLong(PyObject* obj, ...@@ -148,14 +159,21 @@ void CastPyArg2AttrLong(PyObject* obj,
op_type, arg_pos + 1, op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
} }
return 0;
} }
void CastPyArg2AttrFloat(PyObject* obj, void CastPyArg2AttrLong(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, const std::string& key, const std::string& op_type,
ssize_t arg_pos) { ssize_t arg_pos) {
attrs[key] = CastPyArg2Long(obj, op_type, arg_pos);
}
float CastPyArg2Float(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
if (PyObject_CheckFloatOrToFloat(&obj)) { if (PyObject_CheckFloatOrToFloat(&obj)) {
attrs[key] = (float)PyFloat_AsDouble(obj); // NOLINT return (float)PyFloat_AsDouble(obj); // NOLINT
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be " "%s(): argument (position %d) must be "
...@@ -163,17 +181,24 @@ void CastPyArg2AttrFloat(PyObject* obj, ...@@ -163,17 +181,24 @@ void CastPyArg2AttrFloat(PyObject* obj,
op_type, arg_pos + 1, op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
} }
return 0.0;
} }
void CastPyArg2AttrString(PyObject* obj, void CastPyArg2AttrFloat(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, const std::string& key, const std::string& op_type,
ssize_t arg_pos) { ssize_t arg_pos) {
attrs[key] = CastPyArg2Float(obj, op_type, arg_pos);
}
std::string CastPyArg2String(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
if (PyObject_CheckString(obj)) { if (PyObject_CheckString(obj)) {
Py_ssize_t size; Py_ssize_t size;
const char* data; const char* data;
data = PyUnicode_AsUTF8AndSize(obj, &size); data = PyUnicode_AsUTF8AndSize(obj, &size);
attrs[key] = std::string(data, (size_t)size); // NOLINT return std::string(data, (size_t)size); // NOLINT
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be " "%s(): argument (position %d) must be "
...@@ -181,16 +206,23 @@ void CastPyArg2AttrString(PyObject* obj, ...@@ -181,16 +206,23 @@ void CastPyArg2AttrString(PyObject* obj,
op_type, arg_pos + 1, op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
} }
return "";
} }
void CastPyArg2AttrBooleans(PyObject* obj, void CastPyArg2AttrString(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, const std::string& key, const std::string& op_type,
ssize_t arg_pos) { ssize_t arg_pos) {
attrs[key] = CastPyArg2String(obj, op_type, arg_pos);
}
std::vector<bool> CastPyArg2Booleans(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
std::vector<bool> value;
if (PyList_Check(obj)) { if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj); Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<bool> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i); item = PyList_GetItem(obj, i);
if (PyObject_CheckBool(&item)) { if (PyObject_CheckBool(&item)) {
...@@ -204,11 +236,9 @@ void CastPyArg2AttrBooleans(PyObject* obj, ...@@ -204,11 +236,9 @@ void CastPyArg2AttrBooleans(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else if (PyTuple_Check(obj)) { } else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj); Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<bool> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i); item = PyTuple_GetItem(obj, i);
if (PyObject_CheckBool(&item)) { if (PyObject_CheckBool(&item)) {
...@@ -222,7 +252,6 @@ void CastPyArg2AttrBooleans(PyObject* obj, ...@@ -222,7 +252,6 @@ void CastPyArg2AttrBooleans(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be " "%s(): argument (position %d) must be "
...@@ -230,16 +259,23 @@ void CastPyArg2AttrBooleans(PyObject* obj, ...@@ -230,16 +259,23 @@ void CastPyArg2AttrBooleans(PyObject* obj,
op_type, arg_pos + 1, op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
} }
return value;
} }
void CastPyArg2AttrInts(PyObject* obj, void CastPyArg2AttrBooleans(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, const std::string& key, const std::string& op_type,
ssize_t arg_pos) { ssize_t arg_pos) {
attrs[key] = CastPyArg2Booleans(obj, op_type, arg_pos);
}
std::vector<int> CastPyArg2Ints(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
std::vector<int> value;
if (PyList_Check(obj)) { if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj); Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<int> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i); item = PyList_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) { if (PyObject_CheckLongOrToLong(&item)) {
...@@ -253,11 +289,9 @@ void CastPyArg2AttrInts(PyObject* obj, ...@@ -253,11 +289,9 @@ void CastPyArg2AttrInts(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else if (PyTuple_Check(obj)) { } else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj); Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<int> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i); item = PyTuple_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) { if (PyObject_CheckLongOrToLong(&item)) {
...@@ -271,11 +305,9 @@ void CastPyArg2AttrInts(PyObject* obj, ...@@ -271,11 +305,9 @@ void CastPyArg2AttrInts(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else if (PySequence_Check(obj)) { } else if (PySequence_Check(obj)) {
Py_ssize_t len = PySequence_Size(obj); Py_ssize_t len = PySequence_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<int> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i); item = PySequence_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) { if (PyObject_CheckLongOrToLong(&item)) {
...@@ -289,7 +321,6 @@ void CastPyArg2AttrInts(PyObject* obj, ...@@ -289,7 +321,6 @@ void CastPyArg2AttrInts(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be " "%s(): argument (position %d) must be "
...@@ -297,16 +328,23 @@ void CastPyArg2AttrInts(PyObject* obj, ...@@ -297,16 +328,23 @@ void CastPyArg2AttrInts(PyObject* obj,
op_type, arg_pos + 1, op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
} }
return value;
} }
void CastPyArg2AttrLongs(PyObject* obj, void CastPyArg2AttrInts(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, const std::string& key, const std::string& op_type,
ssize_t arg_pos) { ssize_t arg_pos) {
attrs[key] = CastPyArg2Ints(obj, op_type, arg_pos);
}
std::vector<int64_t> CastPyArg2Longs(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
std::vector<int64_t> value;
if (PyList_Check(obj)) { if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj); Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<int64_t> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i); item = PyList_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) { if (PyObject_CheckLongOrToLong(&item)) {
...@@ -320,11 +358,9 @@ void CastPyArg2AttrLongs(PyObject* obj, ...@@ -320,11 +358,9 @@ void CastPyArg2AttrLongs(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else if (PyTuple_Check(obj)) { } else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj); Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<int64_t> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i); item = PyTuple_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) { if (PyObject_CheckLongOrToLong(&item)) {
...@@ -338,11 +374,9 @@ void CastPyArg2AttrLongs(PyObject* obj, ...@@ -338,11 +374,9 @@ void CastPyArg2AttrLongs(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else if (PySequence_Check(obj)) { } else if (PySequence_Check(obj)) {
Py_ssize_t len = PySequence_Size(obj); Py_ssize_t len = PySequence_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<int64_t> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i); item = PySequence_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) { if (PyObject_CheckLongOrToLong(&item)) {
...@@ -356,7 +390,6 @@ void CastPyArg2AttrLongs(PyObject* obj, ...@@ -356,7 +390,6 @@ void CastPyArg2AttrLongs(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be " "%s(): argument (position %d) must be "
...@@ -364,16 +397,23 @@ void CastPyArg2AttrLongs(PyObject* obj, ...@@ -364,16 +397,23 @@ void CastPyArg2AttrLongs(PyObject* obj,
op_type, arg_pos + 1, op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
} }
return value;
} }
void CastPyArg2AttrFloats(PyObject* obj, void CastPyArg2AttrLongs(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, const std::string& key, const std::string& op_type,
ssize_t arg_pos) { ssize_t arg_pos) {
attrs[key] = CastPyArg2Longs(obj, op_type, arg_pos);
}
std::vector<float> CastPyArg2Floats(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
std::vector<float> value;
if (PyList_Check(obj)) { if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj); Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<float> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i); item = PyList_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) { if (PyObject_CheckFloatOrToFloat(&item)) {
...@@ -387,11 +427,9 @@ void CastPyArg2AttrFloats(PyObject* obj, ...@@ -387,11 +427,9 @@ void CastPyArg2AttrFloats(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else if (PyTuple_Check(obj)) { } else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj); Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<float> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i); item = PyTuple_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) { if (PyObject_CheckFloatOrToFloat(&item)) {
...@@ -405,11 +443,9 @@ void CastPyArg2AttrFloats(PyObject* obj, ...@@ -405,11 +443,9 @@ void CastPyArg2AttrFloats(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else if (PySequence_Check(obj)) { } else if (PySequence_Check(obj)) {
Py_ssize_t len = PySequence_Size(obj); Py_ssize_t len = PySequence_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<float> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i); item = PySequence_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) { if (PyObject_CheckFloatOrToFloat(&item)) {
...@@ -423,7 +459,6 @@ void CastPyArg2AttrFloats(PyObject* obj, ...@@ -423,7 +459,6 @@ void CastPyArg2AttrFloats(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be " "%s(): argument (position %d) must be "
...@@ -431,16 +466,24 @@ void CastPyArg2AttrFloats(PyObject* obj, ...@@ -431,16 +466,24 @@ void CastPyArg2AttrFloats(PyObject* obj,
op_type, arg_pos + 1, op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
} }
return value;
} }
void CastPyArg2AttrFloat64s(PyObject* obj, void CastPyArg2AttrFloats(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, const std::string& key, const std::string& op_type,
ssize_t arg_pos) { ssize_t arg_pos) {
attrs[key] = CastPyArg2Floats(obj, op_type, arg_pos);
}
std::vector<double> CastPyArg2Float64s(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos) {
std::vector<double> value;
if (PyList_Check(obj)) { if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj); Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<double> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i); item = PyList_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) { if (PyObject_CheckFloatOrToFloat(&item)) {
...@@ -454,11 +497,9 @@ void CastPyArg2AttrFloat64s(PyObject* obj, ...@@ -454,11 +497,9 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else if (PyTuple_Check(obj)) { } else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj); Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<double> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i); item = PyTuple_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) { if (PyObject_CheckFloatOrToFloat(&item)) {
...@@ -472,11 +513,9 @@ void CastPyArg2AttrFloat64s(PyObject* obj, ...@@ -472,11 +513,9 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else if (PySequence_Check(obj)) { } else if (PySequence_Check(obj)) {
Py_ssize_t len = PySequence_Size(obj); Py_ssize_t len = PySequence_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<double> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i); item = PySequence_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) { if (PyObject_CheckFloatOrToFloat(&item)) {
...@@ -490,7 +529,6 @@ void CastPyArg2AttrFloat64s(PyObject* obj, ...@@ -490,7 +529,6 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be " "%s(): argument (position %d) must be "
...@@ -498,16 +536,24 @@ void CastPyArg2AttrFloat64s(PyObject* obj, ...@@ -498,16 +536,24 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
op_type, arg_pos + 1, op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
} }
return value;
} }
void CastPyArg2AttrStrings(PyObject* obj, void CastPyArg2AttrFloat64s(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, const std::string& key, const std::string& op_type,
ssize_t arg_pos) { ssize_t arg_pos) {
attrs[key] = CastPyArg2Float64s(obj, op_type, arg_pos);
}
std::vector<std::string> CastPyArg2Strings(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos) {
std::vector<std::string> value;
if (PyList_Check(obj)) { if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj); Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<std::string> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i); item = PyList_GetItem(obj, i);
if (PyObject_CheckString(item)) { if (PyObject_CheckString(item)) {
...@@ -524,11 +570,9 @@ void CastPyArg2AttrStrings(PyObject* obj, ...@@ -524,11 +570,9 @@ void CastPyArg2AttrStrings(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else if (PyTuple_Check(obj)) { } else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj); Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<std::string> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i); item = PyTuple_GetItem(obj, i);
if (PyObject_CheckString(item)) { if (PyObject_CheckString(item)) {
...@@ -545,7 +589,6 @@ void CastPyArg2AttrStrings(PyObject* obj, ...@@ -545,7 +589,6 @@ void CastPyArg2AttrStrings(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be " "%s(): argument (position %d) must be "
...@@ -553,6 +596,15 @@ void CastPyArg2AttrStrings(PyObject* obj, ...@@ -553,6 +596,15 @@ void CastPyArg2AttrStrings(PyObject* obj,
op_type, arg_pos + 1, op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
} }
return value;
}
void CastPyArg2AttrStrings(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Strings(obj, op_type, arg_pos);
} }
void CastPyArg2AttrBlock(PyObject* obj, void CastPyArg2AttrBlock(PyObject* obj,
......
...@@ -43,6 +43,30 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj); ...@@ -43,6 +43,30 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj);
bool PyObject_CheckString(PyObject* obj); bool PyObject_CheckString(PyObject* obj);
bool CastPyArg2Boolean(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
int CastPyArg2Int(PyObject* obj, const std::string& op_type, ssize_t arg_pos);
int64_t CastPyArg2Long(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
float CastPyArg2Float(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
std::string CastPyArg2String(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
std::vector<bool> CastPyArg2Booleans(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
std::vector<int> CastPyArg2Ints(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
std::vector<int64_t> CastPyArg2Longs(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
std::vector<float> CastPyArg2Floats(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
std::vector<double> CastPyArg2Float64s(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos);
std::vector<std::string> CastPyArg2Strings(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos);
void CastPyArg2AttrBoolean(PyObject* obj, void CastPyArg2AttrBoolean(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, const std::string& key, const std::string& op_type,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册